什么是线段树
线段树是一种分治思想的二叉树结构,用于在区间上进行信息维护与统计,与按照二进制进行区间划分的树状数组相比,线段树是一种更为通用的数据结构:
- 线段树的每一个节点都代表一个区间。
- 线段树有唯一的根节点,代表的区间是整个统计的范围。
- 线段树的每一个叶子节点都代表一个长度为 \(1\) 的元区间 \([x,x]\)。
- 定义对于每个内部节点 \([l,r]\) ,定义 \(mid = \lfloor (l + r)/2\rfloor\),它的左子节点为 \([l,mid]\),右子节点为 \([mid+1,r]\)。
除去线段树的最后一层,整棵线段树一定是一棵完全二叉树,深度为 \(O(\log N)\)。因此,可以按照与二叉堆类似的“父子二倍”节点编号方法:
- 根节点编号为 \(1\)。
- 编号为 \(x\) 的节点的左子节点编号为 \(x\times2\),右子节点编号为 \(x\times2+1\)。
在理想情况下,\(N\) 个叶节点的满二叉树有 \(2\times N-1\) 个节点。应为在上述储存方式下,最后还有一层产生了空余,所以保存线段树的数组长度不应少于 \(4\times N\)。
线段树的建树
给定一个长度为 \(N\) 的序列 \(A\),我们可以在区间 \([1,N]\) 上建立一棵线段树,每个叶节点 \([i,i]\) 储存 \(A[i]\) 的值。以区间最大值为例,代码如下:
int a[100005];
struct node{
int l,r,date;
#define l(x) t[x].l;
#define r(x) t[x].r;
#define date(x) t[x].date;
}t[100005*4];
void build(int p,int l,int r){
l(p)=l;
r(p)=r;
if(l==r){
date(p)=a[l];
return;
}
int mid=(l+r)/2;
build(p*2,l,mid);
build(p*2+!,mid+1,r);
date(p)=max(date(p*2),date(p*2+1));
}
线段树的单点修改
将 \(A[x]\) 的值修改为 \(v\)。在线段树中,根节点是执行各种操作的入口。我们需要从根节点开始,递归找到代表区间 \([x,x]\) 的叶节点,然后从下往上更新 \([x,x]\) 以及它所有的祖先节点上保留的信息。时间复杂度为 \(O(\log n)\)。
void change(int p,int x,int v){
if(l(p)==r(p)){
date(p)=v;
return;
}
int mid=(l(p)+r(p))/2;
if(x<=mid) change(p*2,x,v);
else change(p*2+1,x,v);
date(p)=max(date(p*2),date(p*2+1));
}
线段树的区间查询
查询序列 \(A\) 在区间 \([l,r]\) 上的最大值。从根节点开始,递归执行下列过程:
- 若 \([l,r]\) 完全覆盖了当前节点代表的区间,立即回溯,并且该节点的 \(date\) 值为候选项。
- 若左子节点与 \([l,r]\) 有重叠部分,递归访问左子节点。
- 若右子节点与 \([l,r]\) 有重叠部分,递归访问右子节点。
int ask(int p,int l,int r){
if(l<=l(p) && r>=r(p)) return date(p);
int mid=(l(p)+r(p))/2,v=-(1<<30);//负无穷大
if(l<=mid) v=max(v,ask(p*2,l,r));
if(r>mid) v=max(v,ask(p*2+1,l,r));
return v;
}
线段树的区间修改
在区间修改操作中,如果某个节点被修改区间 \([l,r]\) 完全覆盖,那么以该节点为根的整棵子树都将发生变化,如果逐一进行更新,那么将使一次区间修改的时间复杂度增加至 \(O(n)\),效率不高。
假设我们逐一修改了被查询区间完全覆盖的节点 \(P\) 所代表的区间 \([l],r]\),但是在之后的查询操作中没有用到该区间的子区间作为答案,那么我们对以 \(P\) 为根的子树的修改都是没有意义的。
这启发我们在后续修改操作中,同样也可以在修改区间完全覆盖当前节点代表的区间时立即返回,但是在回溯之前在 \(P\) 上打上标记,表示“该节点曾经被修改,但它的子节点没有被更新”。
在后续的指令中,需要从节点 \(P\) 向下递归,我们再检查 \(P\) 是否被标记。如果有标记,则根据信息更新它的两个子节点并在子节点上打标记,然后清除 \(P\) 上的标记。
区间查询,区间修改的时间复杂度均为 \(O(\log n)\)。
模板
以【模板】线段树1为例
#include<bits/stdc++.h>
using namespace std;
struct node{
int l,r;
long long sum,add;
#define l(x) tree[x].l
#define r(x) tree[x].r
#define sum(x) tree[x].sum
#define add(x) tree[x].add
}tree[100005*4];
int n,m,t1,t2,t3,t4,a[100005];
void build(int p,int l,int r){
l(p)=l;
r(p)=r;
if(l==r){
sum(p)=a[l];
return;
}
int mid=(l+r)/2;
build(p*2,l,mid);
build(p*2+1,mid+1,r);
sum(p)=sum(p*2)+sum(p*2+1);
}
void spread(int p){
if(add(p)){
sum(p*2)+=add(p)*(r(p*2)-l(p*2)+1);
sum(p*2+1)+=add(p)*(r(p*2+1)-l(p*2+1)+1);
add(p*2)+=add(p);
add(p*2+1)+=add(p);
add(p)=0;
}
}
void change(int p,int l,int r,int d){
if(l<=l(p) && r>=r(p)){
sum(p)+=(long long)d*(r(p)-l(p)+1);
add(p)+=d;
return;
}
spread(p);
int mid=(l(p)+r(p))/2;
if(l<=mid) change(p*2,l,r,d);
if(r>mid) change(p*2+1,l,r,d);
sum(p)=sum(p*2)+sum(p*2+1);
}
long long ask(int p,int l,int r){
if(l<=l(p) && r>=r(p)) return sum(p);
spread(p);
int mid=(l(p)+r(p))/2;
long long v=0;
if(l<=mid) v+=ask(p*2,l,r);
if(r>mid) v+=ask(p*2+1,l,r);
return v;
}
int main(){
cin>>n>>m;
for(int i=1;i<=n;i++) cin>>a[i];
build(1,1,n);
while(m--){
cin>>t1>>t2>>t3;
if(t1==1){
cin>>t4;
change(1,t2,t3,t4);
}
else{
cout<<ask(1,t2,t3)<<endl;
}
}
return 0;
}
一些例题
The Child and Sequence
Sol
设模数为 \(mod\),显然当修改区间 \([l,r]\) 中的最大值小于 \(mod\) 时不应该继续尝试修改。这启示我们在遇到这类在区间上做运算的问题时,应挖掘该运算的特殊性质,减少不必要的递归和修改。
Code
#include<bits/stdc++.h>
using namespace std;
int n,m,a[100005];
struct node{
int l,r;
long long sum,mx;
#define l(x) t[x].l
#define r(x) t[x].r
#define sum(x) t[x].sum
#define mx(x) t[x].mx
}t[100005*4];
void build(int p,int l,int r){
l(p)=l;
r(p)=r;
if(l==r){
sum(p)=a[l];
mx(p)=a[l];
return;
}
int mid=(l+r)/2;
build(p*2,l,mid);
build(p*2+1,mid+1,r);
sum(p)=sum(p*2)+sum(p*2+1);
mx(p)=max(mx(p*2),mx(p*2+1));
}
void change(int p,int x,long long v){
if(l(p)==r(p)){
sum(p)=v;
mx(p)=v;
return;
}
int mid=(l(p)+r(p))/2;
if(x<=mid) change(p*2,x,v);
else change(p*2+1,x,v);
mx(p)=max(mx(p*2),mx(p*2+1));
sum(p)=sum(p*2+1)+sum(p*2);
}
void change_mod(int p,int l,int r,long long d){
if(l<=l(p) && r>=r(p) && mx(p)<d) return;
if(l<=l(p) && r>=r(p) && l(p)==r(p)){
mx(p)%=d;
sum(p)%=d;
return;
}
int mid=(l(p)+r(p))/2;
if(l<=mid) change_mod(p*2,l,r,d);
if(r>mid) change_mod(p*2+1,l,r,d);
sum(p)=sum(p*2)+sum(p*2+1);
mx(p)=max(mx(p*2),mx(p*2+1));
}
long long ask(int p,int l,int r){
if(l<=l(p) && r>=r(p)) return sum(p);
int mid=(l(p)+r(p))/2;
long long v=0;
if(l<=mid) v+=ask(p*2,l,r);
if(r>mid) v+=ask(p*2+1,l,r);
return v;
}
int main(){
cin>>n>>m;
for(int i=1;i<=n;i++){
cin>>a[i];
}
build(1,1,n);
while(m--){
int t1;
cin>>t1;
if(t1==1){
int t2,t3;
cin>>t2>>t3;
cout<<ask(1,t2,t3)<<endl;
}
else if(t1==2){
int t2,t3,t4;
cin>>t2>>t3>>t4;
change_mod(1,t2,t3,t4);
}
else if(t1==3){
int t2,t3;
cin>>t2>>t3;
change(1,t2,t3);
}
}
return 0;
}