线段树

发布时间 2023-08-28 21:09:10作者: AlanJoker

下饭写错合集:

  • if(l>R||r<L) return 0; 写成 if(l>=R||r<=L) return 0;
  • t1[p] 没有初始化为 \(1\)
  • 忘记建树 QAQ。

  • 线段树是解决 RMQ 问题中的一种常用的数据结构,树状数组能实现的功能是线段树能实现功能的子集。

  • 线段树可以在 \(O(\log n)\) 内实现。

  • 单点/区间修改(加,乘,赋值,开根号……)。

  • 单点/区间查询(求和,求积,求max,求min,求gcd)。

  • \(d_x\) 维护的区间是 \([l,r]\),左右儿子维护的区间是 \([l,mid]\)\([mid+1,r]\),其中 $mid=\left \lfloor \frac {l+r}{2} \right \rfloor $。

建树:

#define mid ((l+r)>>1)
#define ls(x) (x<<1)
#define rs(x) (x<<1|1)
void build(int p,int l,int r){
	if(l==r){ // 点 p 代表区间 [l,l] 
		s[p]=a[l]; // 初始化 
		return;
	}
	build(ls(p),l,mid), build(rs(p),mid+1,r);
	s[p]=s[ls(p)]+s[rs(p)]; // push_up(p) 
}
  • push_up(p) 代表信息上传至 \(p\),即将 \(s_{ls(p)},s_{rs(p)}\) 的信息合并到 \(s_p\) 中。
  • 对于数组空间,每个点都会访问到自己的左右儿子,且线段树的空间具有扩充性,\(2^n+1\) 个结点的线段树实际上会扩充到 \(2^{n+1}\) 个。
  • 所以线段树一般开 4 倍空间,即 int a[N],s[N * 4];
  • 调用 \(build(1,1,n)\) 就代表根结点 \(1\) 维护区间 \([1,n]\) 向下递归建树,每一个叶子结点代表 \(a\) 中一点。

区间查询:

  • 设要找的区间是 \([L,R]\),我们递归区间 \([l,r]\),向 \(L \leq l\leq r \leq R\) 的目标下递归。
  • 仍然每次递归 \([l,mid]\)\([mid+1,r]\),同时减掉 \([L,R]\) 外的区间。
  • 一旦满足 \([l,r]\) 包含于 \([L,R]\),直接 return。
int ask(int p,int l,int r,int L,int R){ // 区间查询 
	if(r<L||R<l) return 0; // [l,r]在[L,R]外
	if(L<=l&r<=R) return s[p]; // [l,r]在[L,R]内
   	push_down(p,l,r); // 下移懒标记
	return ask(ls(p),l,mid,L,R)+ask(rs(p),mid+1,r,L,R);//递归合并答案 
}
ask(1,1,n,l,r);

单点修改:

  • \([l,r]\) 二分查找要修改的位置 \(x\),得到点 \(p\)
  • 由于 \(p\) 一定是叶子结点,直接修改信息即可。
void add(int p,int l,int r,int x,int y){
	if(l==r) s[p]+=y;
	x<=mid ? add(ls(p),l,mid,x,y) : add(rs(p),mid+1,r,x,y);
	s[p]=s[ls(p)]+s[rs(p)];// push_up(p)
}

x<=mid ? …… 这个不能忘记

区间修改:

  • 懒惰标记:就是通过延迟对节点信息的更改,从而减少可能不必要的操作次数。每次执行修改时,我们通过打标记的方法表明该节点对应的区间在某一次操作中被更改,但不更新该节点的子节点的信息。实质性的修改则在下一次访问带有标记的节点时才进行。(摘自 OI Wiki)
  • \(t[4 * N]\) 是懒标记数组。
int t[4*N];
void maketag(int p,int l,int r,int x){ // 打懒标记 
	s[p]+=(r-l+1)*x, t[p]+=x;
}
void push_down(int p,int l,int r){
	if(!t[p]) return; // 没有懒标记就返回
	maketag(ls(p),l,mid,t[p]); maketag(rs(p),mid+1,r,t[p]);
	t[p]=0; // 清除父亲的懒标记 
}
void update(int p,int l,int r,int L,int R,int x){
	if(r<L||R<l) return;
	if(L<=l&&r<=R){
		maketag(p,l,r,x); return;
	}
	push_down(p,l,r);
	update(ls(p),l,mid,L,R,x); update(rs(p),mid+1,r,L,R,x);
	push_up(p);
}
  • 结合 addpush_down 就可以实现区间修改的代码。

P3372 【模板】线段树 1:

区间加询问区间和。

#include<bits/stdc++.h>
#define ll long long
#define mid ((l+r)>>1)
#define ls(x) (x<<1)
#define rs(x) (x<<1|1)
const int inf=0x3f3f3f3f,maxn=5e5+7;
using namespace std;
int n,m;
ll a[maxn],s[maxn<<2],t[maxn<<2];
void push_up(int p){
	s[p]=s[ls(p)]+s[rs(p)];
}
void build(int p,int l,int r){
	if(l==r){
		s[p]=a[l]; return;
	}
	build(ls(p),l,mid); build(rs(p),mid+1,r);
	push_up(p);
}
void maketag(int p,int l,int r,int k){
	s[p]+=1ll*(r-l+1)*k; t[p]+=k;
}
void push_down(int p,int l,int r){
	if(!t[p]) return;
	maketag(ls(p),l,mid,t[p]); maketag(rs(p),mid+1,r,t[p]);
	t[p]=0;
}
ll ask(int p,int l,int r,int L,int R){
	if(l>R||r<L) return 0;
	if(l>=L&&r<=R) return s[p];
	push_down(p,l,r);
	return ask(ls(p),l,mid,L,R)+ask(rs(p),mid+1,r,L,R);
}
void update(int p,int l,int r,int L,int R,int x){
	if(r<L||R<l) return;
	if(L<=l&&r<=R){
		maketag(p,l,r,x); return;
	}
	push_down(p,l,r);
	update(ls(p),l,mid,L,R,x); update(rs(p),mid+1,r,L,R,x);
	push_up(p);
}
int main(){
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;i++){
		scanf("%lld",&a[i]);
	}
	build(1,1,n);
	while(m--){
		int op,x,y,k;
		scanf("%d%d%d",&op,&x,&y);
		if(op==2){
			printf("%lld\n",ask(1,1,n,x,y));
		}else{
			scanf("%d",&k);
			update(1,1,n,x,y,k);
		}
	}
	return 0;
}
  • Q:如果不是区间加,而是区间赋值怎么打标记?

直接将 \(maketag\) 操作中的 += 变成 = 即可。如果有赋值为 \(0\) 的情况,要将 \(t\) 数组初始化为 \(-1\),便于判断 if(!t[p]) return。;

P3373 【模板】线段树 2

区间乘、区间加,区间求和并取模。

  • 多的乘法:新增 \(t1_p\) 代表点 \(p\) 对应区间的乘法懒标记,初始化为 \(1\)。对于一次乘法,t1[p]*=x
  • 同时也要新增一个 \(maketag2\) 代表乘法的打标记。
  • 一个点若 \(t_p\)\(t1_p\) 都有标记,规定下传顺序为先乘后加,所以就有了乘法标记对加法的贡献,t[p]*=x
void make_tag2(int p,int x){
	s[p]=1ll*s[p]*x%M;
	t[p]=1ll*t[p]*x%M; // 乘法对加法的贡献
	t1[p]=1ll*t1[p]*x%M;
}

乘法的懒标记记得初始化为 \(1\)

#include<bits/stdc++.h>
#define mid ((l+r)>>1)
#define ls(x) (x<<1)
#define rs(x) (x<<1|1)
#define ll long long
const int maxn=1e5+10,M=571373;
using namespace std;
int N,Q,A[maxn],s[maxn<<2],t[maxn<<2],t1[maxn<<2];
void push_up(int p){
	s[p]=(s[ls(p)]+s[rs(p)])%M;
}
void build(int p,int l,int r){
	t1[p]=1;
	if(l==r) return s[p]=A[l],void();
	build(ls(p),l,mid); build(rs(p),mid+1,r);
	push_up(p); 
}
void maketag(int p,int l,int r,int x){
	s[p]=(s[p]+(r-l+1)*x)%M; t[p]=(t[p]+x)%M;
}
void maketag2(int p,int x){
	s[p]=1ll*s[p]*x%M; t[p]=1ll*t[p]*x%M; t1[p]=1ll*t1[p]*x%M;
}
void push_down(int p,int l,int r){
	if(!t[p]&&t1[p]==1) return;
	maketag2(ls(p),t1[p]); maketag2(rs(p),t1[p]);
	maketag(ls(p),l,mid,t[p]); maketag(rs(p),mid+1,r,t[p]); 
	t[p]=0; t1[p]=1; 
}
void add(int p,int l,int r,int L,int R,int x){
	if(l>R||r<L) return;
	if(l>=L&&r<=R) return maketag(p,l,r,x);
	push_down(p,l,r);
	add(ls(p),l,mid,L,R,x); add(rs(p),mid+1,r,L,R,x);
	push_up(p);
}
void mul(int p,int l,int r,int L,int R,int x){
	if(l>R||r<L) return;
	if(l>=L&&r<=R) return maketag2(p,x);
	push_down(p,l,r);
	mul(ls(p),l,mid,L,R,x); mul(rs(p),mid+1,r,L,R,x);
	push_up(p);
} 
ll ask(int p,int l,int r,int L,int R){
	if(l>R||r<L) return 0;
	if(l>=L&&r<=R) return s[p];
	push_down(p,l,r);
	return (ask(ls(p),l,mid,L,R)+ask(rs(p),mid+1,r,L,R))%M;
}
int main(){
	scanf("%d%d",&N,&Q); int x; cin>>x;
	for(int i=1;i<=N;i++){
		scanf("%d",&A[i]);
	}
	build(1,1,N);
	while(Q--){
		int op,x,y,k;
		scanf("%d%d%d",&op,&x,&y);
		if(op^3){
			scanf("%d",&k);
			op==1 ? mul(1,1,N,x,y,k) : add(1,1,N,x,y,k);
		}else{
			printf("%lld\n",ask(1,1,N,x,y));
		}	
	}
	return 0;
}

P1471 方差:

区间加,求区间平均值和区间方差。\(O((n+m)\log n)\)

\[s^2=\sum_{i=1}^{n} a_i^2- (\frac {\sum\limits_{i=1}^{n}a}{n})^2 \]

设每个数的和 \(sum1\),每个数的平方的和为 \(sum2\)。当每个数加 \(k\) 时,

\[sum2=sum2+2\times k \times sum1+n\times k^2 \]

#include<bits/stdc++.h>
#define ll long long
#define mid ((l+r)>>1)
#define ls(x) (x<<1)
#define rs(x) (x<<1|1)
const int inf=0x3f3f3f3f,maxn=1e5+7;
using namespace std;
int n,m;
double t[maxn<<2],sa[maxn<<2],sb[maxn<<2]; //sa维护每一个数 sb维护每一个数的平方 
void push_up(int p){
    sa[p]=sa[ls(p)]+sa[rs(p)];
	sb[p]=sb[ls(p)]+sb[rs(p)];
}
void build(int p,int l,int r){
    if(l==r){
    	cin>>sa[p]; sb[p]=sa[p]*sa[p]; return;
    }
    build(ls(p),l,mid); build(rs(p),mid+1,r);
    push_up(p);
}
void maketag(int p,int l,int r,double k){
    sb[p]+=2*k*sa[p]+(r-l+1)*k*k; // 详见: https://www.luogu.com.cn/blog/DPair2005/solution-p1471 
	sa[p]+=(r-l+1)*k; t[p]+=k;
}
void push_down(int p,int l,int r){
    if(!t[p]) return;
    maketag(ls(p),l,mid,t[p]); maketag(rs(p),mid+1,r,t[p]);
    t[p]=0;
}

double ask_a(int p,int l,int r,int L,int R){
    if(l>R||r<L) return 0;
    if(l>=L&&r<=R) return sa[p];
    push_down(p,l,r);
    return ask_a(ls(p),l,mid,L,R)+ask_a(rs(p),mid+1,r,L,R);
}

double ask_b(int p,int l,int r,int L,int R){
	if(l>R||r<L) return 0;
    if(l>=L&&r<=R) return sb[p];
    push_down(p,l,r);
    return ask_b(ls(p),l,mid,L,R)+ask_b(rs(p),mid+1,r,L,R);
}

void update(int p,int l,int r,int L,int R,double x){
    if(r<L||R<l) return;
    if(L<=l&&r<=R){
        maketag(p,l,r,x); return;
    }
    push_down(p,l,r);
    update(ls(p),l,mid,L,R,x); update(rs(p),mid+1,r,L,R,x);
    push_up(p);
}
int main(){
    scanf("%d%d",&n,&m);
    build(1,1,n);
    while(m--){
        int op,x,y;
        scanf("%d%d%d",&op,&x,&y);
        if(op==2){ // 平均值 
            printf("%.4lf\n",ask_a(1,1,n,x,y)/(y-x+1));
        }
        else if(op==1){ // 区间加 
        	double z; cin>>z;
        	update(1,1,n,x,y,z);
		}
		else if(op==3){ // 方差  
			double sum1=ask_b(1,1,n,x,y)/(y-x+1),sum2=ask_a(1,1,n,x,y)/(y-x+1); // 方差等于每个数平方的和除以n减去平均数的平方 
			double ans=sum1-sum2*sum2;
			printf("%.4lf\n",ans);
		}
    }
    return 0;
}

P4513 小白逛公园

单点修改,求区间最大子段和。

#include<bits/stdc++.h>
#define ll long long
#define mid ((l+r)>>1)
#define ls(x) (x<<1)
#define rs(x) (x<<1|1)
const int inf=0x3f3f3f3f,maxn=5e5+7;
using namespace std;
int n,m;
struct Node{
	int maxv,maxl,maxr,sumv; //最大子段和 必须包含左端点的最大子段和 必须包含右端点的最大子段和  区间和 
}s[maxn<<2]; 
void push_up(Node &p,const Node &ls,const Node &rs){
	if(ls.maxr<0&&rs.maxl<0) p.maxv=max(ls.maxr,rs.maxl);
	else{
		p.maxv=0;
		if(ls.maxr>0) p.maxv+=ls.maxr;
		if(rs.maxl>0) p.maxv+=rs.maxl;
	}
	p.maxv=max(p.maxv,ls.maxv);
	p.maxv=max(p.maxv,rs.maxv);
	p.maxl=max(ls.maxl,ls.sumv+rs.maxl);
	p.maxr=max(rs.maxr,rs.sumv+ls.maxr);
	p.sumv=ls.sumv+rs.sumv;
}
void build(int p,int l,int r){
	if(l==r){
		scanf("%d",&s[p].maxv);
		s[p].maxl=s[p].maxr=s[p].sumv=s[p].maxv;
		return;
	}
	build(ls(p),l,mid); build(rs(p),mid+1,r);
	push_up(s[p],s[ls(p)],s[rs(p)]);
}
void add(int p,int l,int r,int x,int y){
	if(l==r) {
		s[p].maxl=s[p].maxr=s[p].sumv=s[p].maxv=y;
		return;
	}
	x<=mid ? add(ls(p),l,mid,x,y) : add(rs(p),mid+1,r,x,y);
	push_up(s[p],s[ls(p)],s[rs(p)]);
}
Node ask(int ql,int qr,int rt,int l,int r){
	if(ql<=l&&r<=qr) return s[rt];
	int m=(l+r)>>1;
	if(ql<=m&&m<qr){
	  	Node res;
	  	Node ls=ask(ql,qr,ls(rt),l,m),rs=ask(ql,qr,rs(rt),m+1,r);
	  	push_up(res,ls,rs);
	  	return res;
	}
	else if(ql<=m) return ask(ql,qr,ls(rt),l,m);
	else return ask(ql,qr,rs(rt),m+1,r);
}
int main(){
    scanf("%d%d",&n,&m);
    build(1,1,n);
    while(m--){
    	int op,x,y;
    	scanf("%d%d%d",&op,&x,&y);
    	if(op==1){
    		if(x>y) swap(x,y);
			printf("%d\n",ask(x,y,1,1,n).maxv); 
		}else{
			add(1,1,n,x,y);
		}
	}
    return 0;
}

`