Problem
lrb 有一棵树,树的每个节点有个颜色。给一个长度为 \(n\) 的颜色序列,定义 \(s(i,j)\) 为 \(i\) 到 \(j\) 的颜色数量。以及
现在他想让你求出所有的 \(sum_i\)。\(1\leq n\leq 10^5\)。
这不是黑题?
Solution
FlierKing orz
Treeloveswater orz
把点分治的外壳扔了之后,要解决的是如何 \(O(n)\) 处理出所有 lca 为 \(x\) 的链对答案作出了贡献。很遗憾没能从发明人角度想出来正解是如何得出的 (这 nm 怎么想)。果然还是我太弱了吗 QAQ
所以我能做的就是再说一遍题解……
一个必须想出来的东西就是,如果我们正在遍历到 \(y\),如果 \(color_y\) 是第一次出现在 \(x-y-\cdots\) 这条链上,那么对于链底位于 \(y\) 子树内的所有链来说,这条链上的这个 color 的贡献就算是 \(y\) 的。这样当我们在计算 \(ans_z\)(\(z\) 是 \(x\) 的另一颗不是 \(y\) 所在的子树内的任意一点)的时候,如果 \(x-z\) 这条链上没出现 \(color_y\),\(y\) 就可以给 \(ans_z\) 贡献 \(size_y\) 个答案,即如果链的另一端位于 \(y\) 的子树内,\(color_y\) 就会作出 \(1\) 的贡献,共有 \(size_y\) 个满足条件的另一端。我们就通过一遍 dfs 求出贡献数组 \(col_{color_y}=\sum size_y\)。
计算答案的时候,根据题解得知可以直接遍历并计算每个点的答案。当前遍历到 \(y\)。要统计的链都是形如 \(y-x-z\) 的链,可以分为 \(y-x,x-z\) 两段。我们以 \(y-x\) 段为主,也就是说,如果同一个颜色同时出现在了两段内,贡献就算是左边的。
对于 \(y-x\) 这段,单次作出的贡献就是这条链上的颜色数量 \(num\),这条链是固定不变的,并且作了 \(size_x-size_y\) 倍贡献(换言之,\(z\) 有 \(size_x-size_y\) 种选择,因为不能和 \(y\) 在同一子树所以要减)。所以第一段做的贡献为 \(num\times(size_x-size_y)\)。
对于 \(x-z\) 这段,假如不会有部分颜色的贡献被上一段抢了的话,贡献和就直接是 \(\sum col_i\) 就可以了?不对,还应该去除一下 \(y\) 所在子树内的贡献,因为 \(z\) 不能和 \(y\) 在同一子树内。现在来考虑怎么去除被抢的那部分颜色的贡献,也就是说 \(col_{\texttt{被抢的颜色}}\) 不能被计算,发现可以在遍历 \(y\) 的同时,如果 \(y\) 是一个新颜色,就从 \(sum=\sum col_i\) 中减去 \(col_{color_y}\) 就解决了。对于这部分,贡献就是 \(sum\)。
细节:观察到这两条链中的 \(x\) 是重复的,防止贡献重复计算,可以让 \(color_x\) 被前面的链抢走,但是在进行从 \(col\) 中去除 \(y\) 子树的贡献时,在 \(x\) 这里去除 \(size_y\) 即可。别忘了继续分治之前把贡献数组清空。
关于 \(O(n)\) 解法:不比点分治简单,正常人可以不学。
Code
码量惊人。
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int N = 100010;
int n, sum, rt, p[N], siz[N], f[N];
ll val, num, lst, c[N], col[N], ans[N];
bool vis[N];
inline void Max(int &x, int y) {if (x < y) x = y; }
struct edge{
int to, nxt;
}e[N << 1];
int head[N], cnt = 0;
inline void add(int u, int v) {
e[++cnt] = (edge) {v, head[u]}, head[u] = cnt;
}
void get_rt(int x, int fa) {
siz[x] = 1, f[x] = 0;
for (int i=head[x]; i; i=e[i].nxt) {
int y = e[i].to;
if (y == fa || vis[y]) continue;
get_rt(y, x);
siz[x] += siz[y];
Max(f[x], siz[y]);
}
Max(f[x], sum - siz[x]);
if (f[rt] > f[x]) rt = x;
}
void dfs1(int x, int fa) {
siz[x] = 1, ++c[p[x]];
for (int i=head[x]; i; i=e[i].nxt) {
int y = e[i].to;
if (y == fa || vis[y]) continue;
dfs1(y, x);
siz[x] += siz[y];
}
if (c[p[x]] == 1)
val += siz[x], col[p[x]] += siz[x];
--c[p[x]];
}
void change(int x, int fa, int mul) {
++c[p[x]];
for (int i=head[x]; i; i=e[i].nxt) {
int y = e[i].to;
if (y == fa || vis[y]) continue;
change(y, x, mul);
}
if (c[p[x]] == 1) {
ll tmp = (ll)siz[x] * mul;
val += tmp, col[p[x]] += tmp;
}
--c[p[x]];
}
void dfs2(int x, int fa) {
if (++c[p[x]] == 1)
val -= col[p[x]], ++num;
ans[x] += val + num * lst;
for (int i=head[x]; i; i=e[i].nxt) {
int y = e[i].to;
if (y == fa || vis[y]) continue;
dfs2(y, x);
}
if (c[p[x]]-- == 1)
val += col[p[x]], --num;
}
void clear(int x, int fa) {
c[p[x]] = col[p[x]] = 0;
for (int i=head[x]; i; i=e[i].nxt) {
int y = e[i].to;
if (y == fa || vis[y]) continue;
clear(y, x);
}
}
void calc(int x) {
dfs1(x, 0);
ans[x] += val - col[p[x]] + siz[x];
for (int i=head[x]; i; i=e[i].nxt) {
int y = e[i].to;
if (vis[y]) continue;
++c[p[x]], val -= siz[y], col[p[x]] -= siz[y];
change(y, x, -1);
--c[p[x]];
lst = siz[x] - siz[y];
dfs2(y, x);
++c[p[x]], val += siz[y], col[p[x]] += siz[y];
change(y, x, 1);
--c[p[x]];
}
val = num = 0;
clear(x, 0);
}
void work(int x) {
calc(x);
vis[x] = true;
for (int i=head[x]; i; i=e[i].nxt) {
int y = e[i].to;
if (vis[y]) continue;
sum = siz[y], rt = 0;
get_rt(y, x);
work(rt);
}
}
int main() {
scanf("%d", &n);
for (int i=1; i<=n; ++i) scanf("%d", &p[i]);
for (int i=1; i<n; ++i) {
int u, v; scanf("%d%d", &u, &v);
add(u, v), add(v, u);
}
f[0] = sum = n, rt = 0;
get_rt(1, 0);
get_rt(rt, 0);
work(rt);
for (int i=1; i<=n; ++i) printf("%lld\n", ans[i]);
return 0;
}