应用
处理大规模的树上路径信息问题。
前提
-
将整棵树的计算分而治之——递归其子树进行计算。
-
对于以 \(rt\) 为根的子树,树上路径可分为两类:
- 经过根节点;
- 不经过根节点;
对于第一种情况,又可以分为两类:
- 根节点为路径端点;
- 根节点不为端点;
对于第二种情况可以转化成两条以根节点为端点的路径拼接而成。
-
如果树呈链状,时间复杂度会变为 \(O(n)\)。为避免此情况,我们引入重心。写一个 get_rt 函数。
找重心的代码:
点击查看代码
void get_rt(int now,int fa){
//maxp 即 max_part:以 rt 为根的树内,子树最大的 size
maxp[now]=0;sz[now]=1;
for(int i=0;i<G[now].size();i++){
P nx=G[now][i];
int to=nx.to;
if(to==fa||vis[to]) continue;
get_rt(to,now);
sz[now]+=sz[to];
maxp[now]=max(maxp[now],sz[to]);
}
maxp[now]=max(maxp[now],sum-sz[now]);//sum-sz[now]:另一半的 size
if(maxp[now]<maxp[rt]) rt=now;
}
基本思路
- 找该树内的重心 \(rt\);
- 求出子树内各点到 \(rt\) 的距离;
- 计算子树贡献;
对于不同的题,这一步一般不一样,这一步也是最关键的一步! - 递归子树,重复上述步骤。
其他内容不再赘述,可以自行上网搜索。
例题
P3806 【模板】点分治 1
模板题
关于 calc 函数:
因为本题是判断距离为 \(k\) 的点对是否存在,只需判断在当前这棵树的其他子树内是否存在 \(dis[k-dis[x]]\) 即可(\(x\) 为该子树内的任意一点)
点击查看代码
#include<bits/stdc++.h>
using namespace std;
int n,m,ask[110],rt,maxp[10010],sz[10010],q[100100],sum,rem[100100],dis[100100];
bool ok[110],jud[100001000],vis[10010];
struct P{
int to,val;
};
vector<P> G[10010];
void get_rt(int now,int fa){
maxp[now]=0;sz[now]=1;
for(int i=0;i<G[now].size();i++){
P nx=G[now][i];
int to=nx.to;
if(to==fa||vis[to]) continue;
get_rt(to,now);
sz[now]+=sz[to];
maxp[now]=max(maxp[now],sz[to]);
}
maxp[now]=max(maxp[now],sum-sz[now]);
if(maxp[now]<maxp[rt]) rt=now;
}
void get_dis(int now,int fa){
rem[++rem[0]]=dis[now];
for(int i=0;i<G[now].size();i++){
P nx=G[now][i];
int to=nx.to;
if(to==fa||vis[to]) continue;
dis[to]=dis[now]+nx.val;
get_dis(to,now);
}
}
void calc(int now,int w){
int p=0;
dis[now]=w;//get_dis(now,now);
for(int i=0;i<G[now].size();i++){
P nx=G[now][i];
int to=nx.to;
if(vis[to]) continue;
rem[0]=0;dis[to]=nx.val;
get_dis(to,now);
for(int j=1;j<=rem[0];j++){
for(int k=1;k<=m;k++){
if(ask[k]>=rem[j]){
ok[k]|=jud[ask[k]-rem[j]];
}
}
}
for(int j=1;j<=rem[0];j++){
jud[rem[j]]=1;q[++p]=rem[j];
}
}
for(int i=1;i<=p;i++) jud[q[i]]=0;
}
void solve(int now){
jud[0]=vis[now]=1;
// rem[0]=0;
calc(now,0);
for(int i=0;i<G[now].size();i++){
P nx=G[now][i];
int to=nx.to;
if(vis[to]) continue;
sum=sz[to];maxp[rt=0]=1e8;
get_rt(to,now);solve(rt);
}
}
int main(){
cin>>n>>m;
for(int i=1;i<n;i++){
int u,v,w;cin>>u>>v>>w;
G[u].push_back({v,w});
G[v].push_back({u,w});
}
for(int i=1;i<=m;i++) cin>>ask[i];
sum=maxp[rt]=n;
get_rt(1,0);
solve(rt);
for(int i=1;i<=m;i++){
if(ok[i]) cout<<"AYE"<<endl;
else cout<<"NAY"<<endl;
}
return 0;
}
P2634 [国家集训队] 聪聪可可
将 \(dis\mod 3\);
关于 calc 函数:
对于距离为 \(3\) 的倍数的点对,一定是有 两条 \(dis_0\) 或一条 \(dis_1\) 和 \(dis_2\) 拼凑而成。
所以每个子树的贡献为 \(cnt[0]\times cnt[0]+cnt[1]\times cnt[2]\times 2\)。
在计算贡献时,会存在不经过 \(rt\) 的点,所以要减去子树的贡献。(容斥)
点击查看代码
#include<bits/stdc++.h>
using namespace std;
#define ll long long
ll ans;
int n,dis[20010],rt,sum,maxp[20010],sz[20010],cnt[3];
bool vis[20010];
struct P{
int to,val;
};
vector<P> G[20010];
void get_rt(int now,int fa){
sz[now]=1;maxp[now]=0;
for(int i=0;i<G[now].size();i++){
P nx=G[now][i];
int to=nx.to;
if(to==fa||vis[to]) continue;
get_rt(to,now);
sz[now]+=sz[to];
maxp[now]=max(maxp[now],sz[to]);
}
maxp[now]=max(maxp[now],sum-sz[now]);
if(maxp[now]<maxp[rt]) rt=now;
}
void get_dis(int now,int fa){
cnt[dis[now]]++;
for(int i=0;i<G[now].size();i++){
P nx=G[now][i];
int to=nx.to;
if(to==fa||vis[to]) continue;
dis[to]=(dis[now]+nx.val)%3;
get_dis(to,now);
}
}
ll calc(int now,int val){
dis[now]=val%3;cnt[0]=cnt[1]=cnt[2]=0;
get_dis(now,0);
return cnt[0]*cnt[0]+cnt[1]*cnt[2]*2;
}
void solve(int now){
vis[now]=1;ans+=calc(now,0);
for(int i=0;i<G[now].size();i++){
P nx=G[now][i];
int to=nx.to;
if(vis[to]) continue;
ans-=calc(to,nx.val);
sum=sz[to];maxp[rt=0]=1e8;
get_rt(to,now);solve(rt);
}
}
int main(){
cin>>n;
for(int i=1;i<n;i++){
int x,y,w;cin>>x>>y>>w;
G[x].push_back({y,w%3});
G[y].push_back({x,w%3});
}
maxp[rt]=sum=n;
get_rt(1,0);
solve(rt);
ll g=__gcd(ans,1ll*n*n);
cout<<(ans)/g<<"/"<<(1ll*n*n)/g<<endl;
return 0;
}
//dis baoliu 0/1/2
P4178 Tree
关于 calc 函数:
将 \(dis\) 存下来,排序,双指针判断距离和是否小于等于 \(k\),每次贡献为 \(r-l\)。
需要减去未经过 \(rt\) 的贡献。(容斥)
点击查看代码
#include<bits/stdc++.h>
using namespace std;
const int maxn=4e4+10;
int n,k,rt,sum,sz[maxn],maxp[maxn],dis[maxn],rem[maxn],ans;
bool vis[maxn];
int cnt=0;
struct P{
int to,w;
};
vector<P> G[maxn];
map<int,int> mp;
void get_rt(int now,int fa){
// if(cnt==10) return;
maxp[now]=0;sz[now]=1;
for(int i=0;i<G[now].size();i++){
P nx=G[now][i];
int to=nx.to;
if(to==fa||vis[to]) continue;
get_rt(to,now);
sz[now]+=sz[to];
maxp[now]=max(maxp[now],sz[to]);
}
maxp[now]=max(maxp[now],sum-sz[now]);
if(maxp[now]<maxp[rt]) rt=now;
}
void get_dis(int now,int fa){
// if(cnt==10) return;
rem[++rem[0]]=dis[now];
for(int i=0;i<G[now].size();i++){
P nx=G[now][i];
int to=nx.to,w=nx.w;
if(to==fa||vis[to]) continue;
dis[to]=dis[now]+w;
get_dis(to,now);
}
}
int calc(int now,int val){
int res=0;
rem[0]=0;
dis[now]=val;get_dis(now,0);
sort(rem+1,rem+1+rem[0]);
int l=1,r=rem[0];
while(l<=r){
if(rem[l]+rem[r]<=k) res+=r-l,l++;
else r--;
}
return res;
}
void solve(int now){
vis[now]=1;ans+=calc(now,0);
for(int i=0;i<G[now].size();i++){
P nx=G[now][i];
int to=nx.to;
if(vis[to]) continue;
// if(cnt==10) return;
// cout<<now<<" "<<to<<endl;
// cnt++;
ans-=calc(to,nx.w);
sum=sz[to];maxp[rt=0]=1e9+10;
get_rt(to,now);solve(rt);
}
}
int main(){
cin>>n;
for(int i=1;i<n;i++){
int u,v,w;cin>>u>>v>>w;
G[u].push_back({v,w});
G[v].push_back({u,w});
}
cin>>k;
sum=maxp[0]=n;
get_rt(1,0);solve(rt);
cout<<ans<<endl;
return 0;
}