P2664 树上游戏

发布时间 2023-08-21 14:43:55作者: Hiostream

Problem

lrb 有一棵树,树的每个节点有个颜色。给一个长度为 \(n\) 的颜色序列,定义 \(s(i,j)\)\(i\)\(j\) 的颜色数量。以及

\[sum_i=\sum_{j=1}^n s(i, j) \]

现在他想让你求出所有的 \(sum_i\)\(1\leq n\leq 10^5\)

传送门~

这不是黑题?


Solution

FlierKing orz

Treeloveswater orz

把点分治的外壳扔了之后,要解决的是如何 \(O(n)\) 处理出所有 lca 为 \(x\) 的链对答案作出了贡献。很遗憾没能从发明人角度想出来正解是如何得出的 (这 nm 怎么想)。果然还是我太弱了吗 QAQ

所以我能做的就是再说一遍题解……

一个必须想出来的东西就是,如果我们正在遍历到 \(y\),如果 \(color_y\) 是第一次出现在 \(x-y-\cdots\) 这条链上,那么对于链底位于 \(y\) 子树内的所有链来说,这条链上的这个 color 的贡献就算是 \(y\) 的。这样当我们在计算 \(ans_z\)\(z\)\(x\) 的另一颗不是 \(y\) 所在的子树内的任意一点)的时候,如果 \(x-z\) 这条链上没出现 \(color_y\)\(y\) 就可以给 \(ans_z\) 贡献 \(size_y\) 个答案,即如果链的另一端位于 \(y\) 的子树内,\(color_y\) 就会作出 \(1\) 的贡献,共有 \(size_y\) 个满足条件的另一端。我们就通过一遍 dfs 求出贡献数组 \(col_{color_y}=\sum size_y\)

计算答案的时候,根据题解得知可以直接遍历并计算每个点的答案。当前遍历到 \(y\)。要统计的链都是形如 \(y-x-z\) 的链,可以分为 \(y-x,x-z\) 两段。我们以 \(y-x\) 段为主,也就是说,如果同一个颜色同时出现在了两段内,贡献就算是左边的。

对于 \(y-x\) 这段,单次作出的贡献就是这条链上的颜色数量 \(num\),这条链是固定不变的,并且作了 \(size_x-size_y\) 倍贡献(换言之,\(z\)\(size_x-size_y\) 种选择,因为不能和 \(y\) 在同一子树所以要减)。所以第一段做的贡献为 \(num\times(size_x-size_y)\)

对于 \(x-z\) 这段,假如不会有部分颜色的贡献被上一段抢了的话,贡献和就直接是 \(\sum col_i\) 就可以了?不对,还应该去除一下 \(y\) 所在子树内的贡献,因为 \(z\) 不能和 \(y\) 在同一子树内。现在来考虑怎么去除被抢的那部分颜色的贡献,也就是说 \(col_{\texttt{被抢的颜色}}\) 不能被计算,发现可以在遍历 \(y\) 的同时,如果 \(y\) 是一个新颜色,就从 \(sum=\sum col_i\) 中减去 \(col_{color_y}\) 就解决了。对于这部分,贡献就是 \(sum\)

细节:观察到这两条链中的 \(x\) 是重复的,防止贡献重复计算,可以让 \(color_x\) 被前面的链抢走,但是在进行从 \(col\) 中去除 \(y\) 子树的贡献时,在 \(x\) 这里去除 \(size_y\) 即可。别忘了继续分治之前把贡献数组清空。

关于 \(O(n)\) 解法:不比点分治简单,正常人可以不学。

Code

码量惊人。

#include <bits/stdc++.h>
#define ll long long
using namespace std;

const int N = 100010;
int n, sum, rt, p[N], siz[N], f[N];
ll val, num, lst, c[N], col[N], ans[N];
bool vis[N];
inline void Max(int &x, int y) {if (x < y) x = y; }
struct edge{
	int to, nxt;
}e[N << 1];
int head[N], cnt = 0;
inline void add(int u, int v) {
	e[++cnt] = (edge) {v, head[u]}, head[u] = cnt;
}

void get_rt(int x, int fa) {
	siz[x] = 1, f[x] = 0;
	for (int i=head[x]; i; i=e[i].nxt) {
		int y = e[i].to;
		if (y == fa || vis[y]) continue;
		get_rt(y, x);
		siz[x] += siz[y];
		Max(f[x], siz[y]);
	}
	Max(f[x], sum - siz[x]);
	if (f[rt] > f[x]) rt = x;
}

void dfs1(int x, int fa) {
	siz[x] = 1, ++c[p[x]];
	for (int i=head[x]; i; i=e[i].nxt) {
		int y = e[i].to;
		if (y == fa || vis[y]) continue;
		dfs1(y, x);
		siz[x] += siz[y];
	}
	if (c[p[x]] == 1)
		val += siz[x], col[p[x]] += siz[x];
	--c[p[x]];
}

void change(int x, int fa, int mul) {
	++c[p[x]];
	for (int i=head[x]; i; i=e[i].nxt) {
		int y = e[i].to;
		if (y == fa || vis[y]) continue;
		change(y, x, mul);
	}
	if (c[p[x]] == 1) {
		ll tmp = (ll)siz[x] * mul;
		val += tmp, col[p[x]] += tmp; 
	}
	--c[p[x]];
}

void dfs2(int x, int fa) {
	if (++c[p[x]] == 1)
		val -= col[p[x]], ++num;
	ans[x] += val + num * lst;
	for (int i=head[x]; i; i=e[i].nxt) {
		int y = e[i].to;
		if (y == fa || vis[y]) continue;
		dfs2(y, x);
	}
	if (c[p[x]]-- == 1)
		val += col[p[x]], --num;
}

void clear(int x, int fa) {
	c[p[x]] = col[p[x]] = 0;
	for (int i=head[x]; i; i=e[i].nxt) {
		int y = e[i].to;
		if (y == fa || vis[y]) continue;
		clear(y, x);
	}
}

void calc(int x) {
	dfs1(x, 0);
	ans[x] += val - col[p[x]] + siz[x];
	for (int i=head[x]; i; i=e[i].nxt) {
		int y = e[i].to;
		if (vis[y]) continue;
		++c[p[x]], val -= siz[y], col[p[x]] -= siz[y];
		change(y, x, -1);
		--c[p[x]];
		lst = siz[x] - siz[y];
		dfs2(y, x);
		++c[p[x]], val += siz[y], col[p[x]] += siz[y];
		change(y, x, 1);
		--c[p[x]];
	}
	val = num = 0;
	clear(x, 0);
}

void work(int x) {
	calc(x);
	vis[x] = true;
	for (int i=head[x]; i; i=e[i].nxt) {
		int y = e[i].to;
		if (vis[y]) continue;
		sum = siz[y], rt = 0;
		get_rt(y, x);
		work(rt);
	}
}

int main() {
	scanf("%d", &n);
	for (int i=1; i<=n; ++i) scanf("%d", &p[i]);
	for (int i=1; i<n; ++i) {
		int u, v; scanf("%d%d", &u, &v);
		add(u, v), add(v, u);
	} 
	f[0] = sum = n, rt = 0;
	get_rt(1, 0);
	get_rt(rt, 0);
	work(rt);
	for (int i=1; i<=n; ++i) printf("%lld\n", ans[i]);
	return 0;
}