线段树学习笔记

发布时间 2023-07-20 21:22:48作者: ZnHF

什么是线段树

线段树是一种分治思想的二叉树结构,用于在区间上进行信息维护与统计,与按照二进制进行区间划分的树状数组相比,线段树是一种更为通用的数据结构:

  1. 线段树的每一个节点都代表一个区间。
  2. 线段树有唯一的根节点,代表的区间是整个统计的范围。
  3. 线段树的每一个叶子节点都代表一个长度为 \(1\) 的元区间 \([x,x]\)
  4. 定义对于每个内部节点 \([l,r]\) ,定义 \(mid = \lfloor (l + r)/2\rfloor\),它的左子节点为 \([l,mid]\),右子节点为 \([mid+1,r]\)

除去线段树的最后一层,整棵线段树一定是一棵完全二叉树,深度为 \(O(\log N)\)。因此,可以按照与二叉堆类似的“父子二倍”节点编号方法:

  1. 根节点编号为 \(1\)
  2. 编号为 \(x\) 的节点的左子节点编号为 \(x\times2\),右子节点编号为 \(x\times2+1\)
    在理想情况下,\(N\) 个叶节点的满二叉树有 \(2\times N-1\) 个节点。应为在上述储存方式下,最后还有一层产生了空余,所以保存线段树的数组长度不应少于 \(4\times N\)

线段树的建树

给定一个长度为 \(N\) 的序列 \(A\),我们可以在区间 \([1,N]\) 上建立一棵线段树,每个叶节点 \([i,i]\) 储存 \(A[i]\) 的值。以区间最大值为例,代码如下:

int a[100005];
struct node{
    int l,r,date;
    #define l(x) t[x].l;
    #define r(x) t[x].r;
    #define date(x) t[x].date;
}t[100005*4];
void build(int p,int l,int r){
    l(p)=l;
    r(p)=r;
    if(l==r){
        date(p)=a[l];
        return;
    }
    int mid=(l+r)/2;
    build(p*2,l,mid);
    build(p*2+!,mid+1,r);
    date(p)=max(date(p*2),date(p*2+1));
}

线段树的单点修改

\(A[x]\) 的值修改为 \(v\)。在线段树中,根节点是执行各种操作的入口。我们需要从根节点开始,递归找到代表区间 \([x,x]\) 的叶节点,然后从下往上更新 \([x,x]\) 以及它所有的祖先节点上保留的信息。时间复杂度为 \(O(\log n)\)

void change(int p,int x,int v){
    if(l(p)==r(p)){
        date(p)=v;
        return;
    }
    int mid=(l(p)+r(p))/2;
    if(x<=mid) change(p*2,x,v);
    else change(p*2+1,x,v);
    date(p)=max(date(p*2),date(p*2+1));
}

线段树的区间查询

查询序列 \(A\) 在区间 \([l,r]\) 上的最大值。从根节点开始,递归执行下列过程:

  1. \([l,r]\) 完全覆盖了当前节点代表的区间,立即回溯,并且该节点的 \(date\) 值为候选项。
  2. 若左子节点与 \([l,r]\) 有重叠部分,递归访问左子节点。
  3. 若右子节点与 \([l,r]\) 有重叠部分,递归访问右子节点。
int ask(int p,int l,int r){
    if(l<=l(p) && r>=r(p)) return date(p);
    int mid=(l(p)+r(p))/2,v=-(1<<30);//负无穷大
    if(l<=mid) v=max(v,ask(p*2,l,r));
    if(r>mid) v=max(v,ask(p*2+1,l,r));
    return v;
}

线段树的区间修改

在区间修改操作中,如果某个节点被修改区间 \([l,r]\) 完全覆盖,那么以该节点为根的整棵子树都将发生变化,如果逐一进行更新,那么将使一次区间修改的时间复杂度增加至 \(O(n)\),效率不高。
假设我们逐一修改了被查询区间完全覆盖的节点 \(P\) 所代表的区间 \([l],r]\),但是在之后的查询操作中没有用到该区间的子区间作为答案,那么我们对以 \(P\) 为根的子树的修改都是没有意义的。
这启发我们在后续修改操作中,同样也可以在修改区间完全覆盖当前节点代表的区间时立即返回,但是在回溯之前在 \(P\) 上打上标记,表示“该节点曾经被修改,但它的子节点没有被更新”。
在后续的指令中,需要从节点 \(P\) 向下递归,我们再检查 \(P\) 是否被标记。如果有标记,则根据信息更新它的两个子节点并在子节点上打标记,然后清除 \(P\) 上的标记。
区间查询,区间修改的时间复杂度均为 \(O(\log n)\)

模板

【模板】线段树1为例

#include<bits/stdc++.h>
using namespace std;
struct node{
	int l,r;
	long long sum,add;
	#define l(x) tree[x].l
	#define r(x) tree[x].r
	#define sum(x) tree[x].sum
	#define add(x) tree[x].add
}tree[100005*4];
int n,m,t1,t2,t3,t4,a[100005];
void build(int p,int l,int r){
	l(p)=l;
	r(p)=r;
	if(l==r){
		sum(p)=a[l];
		return;
	}
	int mid=(l+r)/2;
	build(p*2,l,mid);
	build(p*2+1,mid+1,r);
	sum(p)=sum(p*2)+sum(p*2+1);
}
void spread(int p){
	if(add(p)){
		sum(p*2)+=add(p)*(r(p*2)-l(p*2)+1);
		sum(p*2+1)+=add(p)*(r(p*2+1)-l(p*2+1)+1);
		add(p*2)+=add(p);
		add(p*2+1)+=add(p);
		add(p)=0;
	}
}
void change(int p,int l,int r,int d){
	if(l<=l(p) && r>=r(p)){
		sum(p)+=(long long)d*(r(p)-l(p)+1);
		add(p)+=d;
		return;
	}
	spread(p);
	int mid=(l(p)+r(p))/2;
	if(l<=mid) change(p*2,l,r,d);
	if(r>mid) change(p*2+1,l,r,d);
	sum(p)=sum(p*2)+sum(p*2+1);
}
long long ask(int p,int l,int r){
	if(l<=l(p) && r>=r(p)) return sum(p);
	spread(p);
	int mid=(l(p)+r(p))/2;
	long long v=0;
	if(l<=mid) v+=ask(p*2,l,r);
	if(r>mid) v+=ask(p*2+1,l,r);
	return v;
}
int main(){
	cin>>n>>m;
	for(int i=1;i<=n;i++) cin>>a[i];
	build(1,1,n);
	while(m--){
		cin>>t1>>t2>>t3;
		if(t1==1){
			cin>>t4;
			change(1,t2,t3,t4);
		}
		else{
			cout<<ask(1,t2,t3)<<endl;
		}
	}
	return 0;
}

一些例题

The Child and Sequence

Sol

设模数为 \(mod\),显然当修改区间 \([l,r]\) 中的最大值小于 \(mod\) 时不应该继续尝试修改。这启示我们在遇到这类在区间上做运算的问题时,应挖掘该运算的特殊性质,减少不必要的递归和修改。

Code

#include<bits/stdc++.h>
using namespace std;
int n,m,a[100005];
struct node{
	int l,r;
	long long sum,mx;
	#define l(x) t[x].l
	#define r(x) t[x].r
	#define sum(x) t[x].sum
	#define mx(x) t[x].mx
}t[100005*4];
void build(int p,int l,int r){
	l(p)=l;
	r(p)=r;
	if(l==r){
		sum(p)=a[l];
		mx(p)=a[l];
		return;
	}
	int mid=(l+r)/2;
	build(p*2,l,mid);
	build(p*2+1,mid+1,r);
	sum(p)=sum(p*2)+sum(p*2+1);
	mx(p)=max(mx(p*2),mx(p*2+1));
}
void change(int p,int x,long long v){
	if(l(p)==r(p)){
		sum(p)=v;
		mx(p)=v;
		return;
	}
	int mid=(l(p)+r(p))/2;
	if(x<=mid) change(p*2,x,v);
	else change(p*2+1,x,v);
	mx(p)=max(mx(p*2),mx(p*2+1));
	sum(p)=sum(p*2+1)+sum(p*2);
}
void change_mod(int p,int l,int r,long long d){
	if(l<=l(p) && r>=r(p) && mx(p)<d) return;
	if(l<=l(p) && r>=r(p) && l(p)==r(p)){
		mx(p)%=d;
		sum(p)%=d;
		return;
	}
	int mid=(l(p)+r(p))/2;
	if(l<=mid) change_mod(p*2,l,r,d);
	if(r>mid) change_mod(p*2+1,l,r,d);
	sum(p)=sum(p*2)+sum(p*2+1);
	mx(p)=max(mx(p*2),mx(p*2+1));
}
long long ask(int p,int l,int r){
	if(l<=l(p) && r>=r(p)) return sum(p);
	int mid=(l(p)+r(p))/2;
	long long v=0;
	if(l<=mid) v+=ask(p*2,l,r);
	if(r>mid) v+=ask(p*2+1,l,r);
	return v;
}
int main(){
	cin>>n>>m;
	for(int i=1;i<=n;i++){
		cin>>a[i];
	}
	build(1,1,n);
	while(m--){
		int t1;
		cin>>t1;
		if(t1==1){
			int t2,t3;
			cin>>t2>>t3;
			cout<<ask(1,t2,t3)<<endl;
		}
		else if(t1==2){
			int t2,t3,t4;
			cin>>t2>>t3>>t4;
			change_mod(1,t2,t3,t4);
		}
		else if(t1==3){
			int t2,t3;
			cin>>t2>>t3;
			change(1,t2,t3);
		}
	}
	return 0;
}