树上启发式合并学习笔记

发布时间 2023-04-06 13:06:18作者: xiezheyuan

前言

树上启发式合并(DSU on tree),是一种启发式算法,多用于解决子树询问问题。

和莫队很像,只要支持在 \(O(T(n))\) 加入 / 删除一个点对答案的贡献,就可以在 \(O(n\log n \cdot T(n))\) 内求出所有节点的子树的答案。

流程

例题

经典例题——树上数颜色:

给出一个以 \(1\) 为根的 \(n\) 个节点有根树。每一个节点 \(i\) 有一个颜色 \(c_i\)。有 \(q\) 个询问,每个询问给出一个 \(a\),你需要求出以 \(a\) 为根的子树中的颜色种类数。

\(1 \leq n,c_i \leq 10^5,1\leq q\leq 10^6,1 \leq a \leq n\)。空间限制 \(20\operatorname{MB}\)

这道题看起来似乎需要复杂的数据结构(如树套树或线段树合并,不过空间复杂度可能会爆炸)。但是也可以使用树上启发式合并。

从暴力开始

首先有一个朴素的暴力。对于每一个询问去暴力找。时间复杂度 \(O(nq)\),显然无法承受。

然后考虑询问最多只有 \(n\) 种,因此可以预先求出来,时间复杂度 \(O(n^2)\)。显然无法承受。

然后引入这道题关键性质——支持 \(O(1)\) 加入 / 删除一个点对答案的贡献。

于是就有一个时间复杂度还是 \(O(n^2)\) 的算法,考虑对于每一个节点加入这个子树的贡献,最后删除贡献。然后就可以方便的递归求解了!

代码:

void add(int i){// 加入节点 i 的贡献
    if(!(bkt[c[i]]++)) tans++;
}

void del(int i){// 删除节点 i 的贡献
    if(!(--bkt[c[i]])) tans--;
}

void addtree(int u,int fa){// 加入以 i 为根的子树的贡献
    add(u);
    for(int i=head[u];i;i=g[i].nxt){
        int v=g[i].to;
        if(v!=fa) addtree(v,u);
    }
}

void deltree(int u,int fa){// 删除以 i 为根的子树的贡献
    del(u);
    for(int i=head[u];i;i=g[i].nxt){
        int v=g[i].to;
        if(v!=fa) deltree(v,u);
    }
}

void solve(int u,int fa){// 求解答案
    for(int i=head[u];i;i=g[i].nxt){
        int v=g[i].to;
        if(v!=fa) solve(v,u);// 递归
    }
    addtree(u,fa);// 加入以 u 为根的贡献
    ans[u]=tans;// 保存答案
    deltree(u,fa);// 删除以 u 为根的贡献,避免由于忘记清空而出现答案紊乱
}

一点小优化

首先我们考虑一下哪些是可以不用清空的。

对对对,就是每一个节点的最后一个子节点所组成的子树。因为这个点反正会被一起清空,不用单独清空了。因此我们可以单独考虑最后一个。

然后考虑最后一个节点怎么选。显然如果子树越大,不用清空的节点越多。因此我们可以把子树最大的子节点(重子节点)当成最后一个节点。

看起来时间复杂度还是 \(O(n^2)\)。真的是这样的吗?

复杂度证明

如果不关心时间复杂度证明,可以跳过这一节。

我们将父节点连向重子节点的边称为重边,将连向非重子节点(即轻子节点)称为轻边。

引理 1:根节点到树上任意节点的轻边数不超过 \(\log n\) 条。

Proof:令根节点到节点 \(i\) 的轻边数为 \(a\)\(i\) 的子树大小为 \(b\)

每过一条轻边,子树大小都将至少减半(如果不减半必定连接的不是轻子节点)。于是 \(b<\dfrac{n}{2^a}\)

又因为 \(b\geq1\),所以 \(n\leq 2^{a}\),所以 \(a\leq \log_{2}n\)。证毕。

推论 1:一个节点更新答案的次数,等于根节点到其的轻边个数 + 1

Proof:每经过一条轻边,都会清空一次答案,再加上自己本身。

于是得出上面的算法复杂度为 \(O(n(\log n +1))=O(n\log n)\)。时间复杂度得到了本质上的提升。

代码实现

void dfs(int u,int fa){// 找每个节点的重子节点
    siz[u]=1;int tmp=0;
    for(int i=head[u];i;i=g[i].nxt){
        int v=g[i].to;
        if(v==fa) continue;
        dfs(v,u);
        siz[u]+=siz[v];
        if(siz[v]>tmp) son[u]=v,tmp=siz[v];
    }
}

void add(int i){// 加入节点 i 的贡献
    if(!(bkt[c[i]]++)) tans++;
}

void del(int i){// 删除节点 i 的贡献
    if(!(--bkt[c[i]])) tans--;
}

void addtree(int u,int fa){// 加入以 i 为根的子树的贡献
    add(u);
    for(int i=head[u];i;i=g[i].nxt){
        int v=g[i].to;
        if(v!=fa) addtree(v,u);
    }
}

void deltree(int u,int fa){// 删除以 i 为根的子树的贡献
    del(u);
    for(int i=head[u];i;i=g[i].nxt){
        int v=g[i].to;
        if(v!=fa) deltree(v,u);
    }
}

void solve(int u,int fa,bool flag){// 主程序 flag 表示要不要清空
    for(int i=head[u];i;i=g[i].nxt){
        int v=g[i].to;
        if(v!=fa && son[u]!=v) solve(v,u,1); // 如果不是重子节点,直接进去,记得需要清空。
    }
    if(son[u]) solve(son[u],u,0); // 如果有重子节点,就进去,记得不要清空。
    for(int i=head[u];i;i=g[i].nxt){
        int v=g[i].to;
        if(v!=fa && son[u]!=v) addtree(v,u); // 添加轻子节点的贡献,重子节点的贡献在上面搞过了,且没有被清空。
    }
    add(u); // 添加这个节点本身的贡献
    ans[u]=tans; // 保存答案
    if(flag) deltree(u,fa); // 如果要清空你就清空
}

例题

CF600E Lomsat gelral

有一棵 \(n\) 个结点的以 \(1\) 号结点为根的有根树

每个结点都有一个颜色,颜色是以编号表示的, \(i\) 号结点的颜色编号为 \(c_i\)

如果一种颜色在以 \(x\) 为根的子树内出现次数最多,称其在以 \(x\) 为根的子树中占主导地位。显然,同一子树中可能有多种颜色占主导地位。

你的任务是对于每一个 \(i\in[1,n]\),求出以 \(i\) 为根的子树中,占主导地位的颜色的编号和。

\(1 \leq c_i \leq n \leq 10^5\)

模板题。只要改一改如何加入和删除即可。注意每一次清空子树时需要连带临时变量一起清空,否则会出现重复贡献影响答案。

#include <bits/stdc++.h>
#define int long long
using namespace std;

const int N = 1e5+5;

struct edge{
    int nxt,to;
} g[N<<1];

int head[N],ec,ans[N],tans,c[N],bkt[N],n,siz[N],son[N];

void add(int u,int v){
    g[++ec].nxt=head[u];
    g[ec].to=v;
    head[u]=ec;
}

void dfs(int u,int fa){
    siz[u]=1;int tmp=0;
    for(int i=head[u];i;i=g[i].nxt){
        int v=g[i].to;
        if(v==fa) continue;
        dfs(v,u);
        siz[u]+=siz[v];
        if(siz[v]>tmp) son[u]=v,tmp=siz[v];
    }
}

int flg;

void add(int i){
    if((++bkt[c[i]])>flg){
    	flg=bkt[c[i]];tans=c[i];
	}
	else if(bkt[c[i]]==flg) tans+=c[i];
}

void del(int i){
    --bkt[c[i]];
}

void addtree(int u,int fa){
    add(u);
    for(int i=head[u];i;i=g[i].nxt){
        int v=g[i].to;
        if(v!=fa) addtree(v,u);
    }
}

void deltree(int u,int fa){
    del(u);
    for(int i=head[u];i;i=g[i].nxt){
        int v=g[i].to;
        if(v!=fa) deltree(v,u);
    }
}

void solve(int u,int fa,bool flag){
    for(int i=head[u];i;i=g[i].nxt){
        int v=g[i].to;
        if(v!=fa && son[u]!=v) solve(v,u,1);
    }
    if(son[u]) solve(son[u],u,0);
    for(int i=head[u];i;i=g[i].nxt){
        int v=g[i].to;
        if(v!=fa && son[u]!=v) addtree(v,u);
    }
    add(u);
    ans[u]=tans;
    if(flag){
    	deltree(u,fa);
    	tans=flg=0;
	}
}

signed main(){
    cin>>n;
    for(int i=1;i<=n;i++) cin>>c[i];
    for(int i=1,u,v;i<n;i++){
        cin>>u>>v;
        add(u,v);add(v,u);
    }
    dfs(1,0);
    solve(1,0,0);
    for(int i=1;i<=n;i++) cout<<ans[i]<<' ';
	return 0; 
}