点分治

发布时间 2023-10-10 15:00:28作者: _Libra

应用

处理大规模的树上路径信息问题。

前提

  • 将整棵树的计算分而治之——递归其子树进行计算。

  • 对于以 \(rt\) 为根的子树,树上路径可分为两类:

    1. 经过根节点;
    2. 不经过根节点;

    对于第一种情况,又可以分为两类:

    1. 根节点为路径端点;
    2. 根节点不为端点;
      对于第二种情况可以转化成两条以根节点为端点的路径拼接而成。
  • 如果树呈链状,时间复杂度会变为 \(O(n)\)。为避免此情况,我们引入重心。写一个 get_rt 函数。

找重心的代码:

点击查看代码

void get_rt(int now,int fa){
  //maxp 即 max_part:以 rt 为根的树内,子树最大的 size
	maxp[now]=0;sz[now]=1;
	for(int i=0;i<G[now].size();i++){
		P nx=G[now][i];
		int to=nx.to;
		if(to==fa||vis[to]) continue;
		get_rt(to,now);
		sz[now]+=sz[to];
		maxp[now]=max(maxp[now],sz[to]);
	}
	maxp[now]=max(maxp[now],sum-sz[now]);//sum-sz[now]:另一半的 size
	if(maxp[now]<maxp[rt]) rt=now;

}

基本思路

  • 找该树内的重心 \(rt\)
  • 求出子树内各点到 \(rt\) 的距离;
  • 计算子树贡献;
    对于不同的题,这一步一般不一样,这一步也是最关键的一步!
  • 递归子树,重复上述步骤。

其他内容不再赘述,可以自行上网搜索。

例题

P3806 【模板】点分治 1

模板题

关于 calc 函数:
因为本题是判断距离为 \(k\) 的点对是否存在,只需判断在当前这棵树的其他子树内是否存在 \(dis[k-dis[x]]\) 即可(\(x\) 为该子树内的任意一点)

点击查看代码

#include<bits/stdc++.h>
using namespace std;

int n,m,ask[110],rt,maxp[10010],sz[10010],q[100100],sum,rem[100100],dis[100100];
bool ok[110],jud[100001000],vis[10010];

struct P{
	int to,val;
};

vector<P> G[10010];

void get_rt(int now,int fa){
	maxp[now]=0;sz[now]=1;
	for(int i=0;i<G[now].size();i++){
		P nx=G[now][i];
		int to=nx.to;
		if(to==fa||vis[to]) continue;
		get_rt(to,now);
		sz[now]+=sz[to];
		maxp[now]=max(maxp[now],sz[to]);
	}
	maxp[now]=max(maxp[now],sum-sz[now]);
	if(maxp[now]<maxp[rt]) rt=now;
}

void get_dis(int now,int fa){
	rem[++rem[0]]=dis[now];
	for(int i=0;i<G[now].size();i++){
		P nx=G[now][i];
		int to=nx.to;
		if(to==fa||vis[to]) continue;
		dis[to]=dis[now]+nx.val;
		get_dis(to,now);
	}
}

void calc(int now,int w){
	int p=0;
	dis[now]=w;//get_dis(now,now);
	for(int i=0;i<G[now].size();i++){
		P nx=G[now][i];
		int to=nx.to;
		if(vis[to]) continue;
		rem[0]=0;dis[to]=nx.val;
		get_dis(to,now);
		for(int j=1;j<=rem[0];j++){
			for(int k=1;k<=m;k++){
				if(ask[k]>=rem[j]){
					ok[k]|=jud[ask[k]-rem[j]];
				}
			}
		}
		for(int j=1;j<=rem[0];j++){
			jud[rem[j]]=1;q[++p]=rem[j];
		}
	}
	for(int i=1;i<=p;i++) jud[q[i]]=0;
}

void solve(int now){
	jud[0]=vis[now]=1;
//	rem[0]=0;
	calc(now,0);
	for(int i=0;i<G[now].size();i++){
		P nx=G[now][i];
		int to=nx.to;
		if(vis[to]) continue;
		sum=sz[to];maxp[rt=0]=1e8;
		get_rt(to,now);solve(rt);
	}
}

int main(){
	cin>>n>>m;
	for(int i=1;i<n;i++){
		int u,v,w;cin>>u>>v>>w;
		G[u].push_back({v,w});
		G[v].push_back({u,w});
	}
	for(int i=1;i<=m;i++) cin>>ask[i];
	sum=maxp[rt]=n;
	get_rt(1,0);
	solve(rt);
	for(int i=1;i<=m;i++){
		if(ok[i]) cout<<"AYE"<<endl;
		else cout<<"NAY"<<endl;
	}
	
	return 0;
}

P2634 [国家集训队] 聪聪可可

\(dis\mod 3\);

关于 calc 函数:
对于距离为 \(3\) 的倍数的点对,一定是有 两条 \(dis_0\) 或一条 \(dis_1\)\(dis_2\) 拼凑而成。
所以每个子树的贡献为 \(cnt[0]\times cnt[0]+cnt[1]\times cnt[2]\times 2\)
在计算贡献时,会存在不经过 \(rt\) 的点,所以要减去子树的贡献。(容斥)

点击查看代码
#include<bits/stdc++.h>
using namespace std;
#define ll long long

ll ans;
int n,dis[20010],rt,sum,maxp[20010],sz[20010],cnt[3];
bool vis[20010];

struct P{
	int to,val;
};

vector<P> G[20010];

void get_rt(int now,int fa){
	sz[now]=1;maxp[now]=0;
	for(int i=0;i<G[now].size();i++){
		P nx=G[now][i];
		int to=nx.to;
		if(to==fa||vis[to]) continue;
		get_rt(to,now);
		sz[now]+=sz[to];
		maxp[now]=max(maxp[now],sz[to]);
	}
	maxp[now]=max(maxp[now],sum-sz[now]);
	if(maxp[now]<maxp[rt]) rt=now;
}

void get_dis(int now,int fa){
	cnt[dis[now]]++;
	for(int i=0;i<G[now].size();i++){
		P nx=G[now][i];
		int to=nx.to;
		if(to==fa||vis[to]) continue;
		dis[to]=(dis[now]+nx.val)%3;
		get_dis(to,now);
	}
}

ll calc(int now,int val){
	dis[now]=val%3;cnt[0]=cnt[1]=cnt[2]=0;
	get_dis(now,0);
	return cnt[0]*cnt[0]+cnt[1]*cnt[2]*2;
}

void solve(int now){
	vis[now]=1;ans+=calc(now,0);
	for(int i=0;i<G[now].size();i++){
		P nx=G[now][i];
		int to=nx.to;
		if(vis[to]) continue;
		ans-=calc(to,nx.val);
		sum=sz[to];maxp[rt=0]=1e8;
		get_rt(to,now);solve(rt);
	}
}

int main(){
	cin>>n;
	for(int i=1;i<n;i++){
		int x,y,w;cin>>x>>y>>w;
		G[x].push_back({y,w%3});
		G[y].push_back({x,w%3});
	}
	maxp[rt]=sum=n;
	get_rt(1,0);
	solve(rt);
	ll g=__gcd(ans,1ll*n*n);
	cout<<(ans)/g<<"/"<<(1ll*n*n)/g<<endl;
	
	return 0;
}
//dis baoliu 0/1/2

P4178 Tree

关于 calc 函数:
\(dis\) 存下来,排序,双指针判断距离和是否小于等于 \(k\),每次贡献为 \(r-l\)
需要减去未经过 \(rt\) 的贡献。(容斥)

点击查看代码
#include<bits/stdc++.h>
using namespace std;


const int maxn=4e4+10;
int n,k,rt,sum,sz[maxn],maxp[maxn],dis[maxn],rem[maxn],ans;
bool vis[maxn];
int cnt=0;

struct P{
	int to,w;
};

vector<P> G[maxn];
map<int,int> mp;

void get_rt(int now,int fa){
//	if(cnt==10) return;
	maxp[now]=0;sz[now]=1;
	for(int i=0;i<G[now].size();i++){
		P nx=G[now][i];
		int to=nx.to;
		if(to==fa||vis[to]) continue;
		get_rt(to,now);
		sz[now]+=sz[to];
		maxp[now]=max(maxp[now],sz[to]);
	}
	maxp[now]=max(maxp[now],sum-sz[now]);
	if(maxp[now]<maxp[rt]) rt=now;
}

void get_dis(int now,int fa){
//	if(cnt==10) return;
	rem[++rem[0]]=dis[now];
	for(int i=0;i<G[now].size();i++){
		P nx=G[now][i];
		int to=nx.to,w=nx.w;
		if(to==fa||vis[to]) continue;
		dis[to]=dis[now]+w;
		get_dis(to,now);
	}
}

int calc(int now,int val){
	int res=0;
	rem[0]=0;
	dis[now]=val;get_dis(now,0);
	sort(rem+1,rem+1+rem[0]);
	int l=1,r=rem[0];
	while(l<=r){
		if(rem[l]+rem[r]<=k) res+=r-l,l++;
		else r--;
	}
	return res;
}

void solve(int now){
	vis[now]=1;ans+=calc(now,0);
	for(int i=0;i<G[now].size();i++){
		P nx=G[now][i];
		int to=nx.to;
		if(vis[to]) continue;
	//	if(cnt==10) return;
	//	cout<<now<<" "<<to<<endl;
	//	cnt++;
		ans-=calc(to,nx.w);
		sum=sz[to];maxp[rt=0]=1e9+10;
		get_rt(to,now);solve(rt);
	}
}

int main(){
	cin>>n;
	for(int i=1;i<n;i++){
		int u,v,w;cin>>u>>v>>w;
		G[u].push_back({v,w});
		G[v].push_back({u,w});
	}
	cin>>k;
	sum=maxp[0]=n;
	get_rt(1,0);solve(rt);
	cout<<ans<<endl;
	
	return 0;
}