题解 - Luogu P3676 小清新数据结构题

发布时间 2023-06-01 22:02:45作者: lhzawa

点分树是什么/yiw

定义 \(s_i\)\(i\) 子树内的权值和,默认 \(1\) 为根

首先考虑没有换根的解法
考虑把点权变换转化为加上一个数,即 \(val_{x}\leftarrow y\) 转化为 \(val_{x}\leftarrow val_{x} + (y - val_{x})\)
定义这个加上的数为 \(z\),考虑加上 \(z\) 对答案有什么影响
不难发现只会对 \(i \in \operatorname{path}(1, x)\) 上的 \(s_i\leftarrow s_i + z\)
所以多的贡献即为:
\(\sum\limits_{i = {path}(1, x)} (s_{i} + z)^2 - \sum\limits_{i = {path}(1, x)} s_{i}^2 = \sum\limits_{i = {path}(1, x)} (s_i^2 + 2s_i z + z^2) - \sum\limits_{i = {path}(1, x)} s_{i}^2 = \sum\limits_{i = {path}(1, x)} (2s_i z + z^2) = \operatorname{len}(\operatorname{path}(1, x)) z^2 + 2z\sum\limits_{i = {path}(1, x)} s_i\)

这个式子很明显树剖就行了
答案的初始值即为 \(\sum\limits_{i = 1}^n s_i^2\),每次修改加上对应的贡献即可

考虑有换根怎么做
则不难发现会改变的还是 \(i \in \operatorname{path}(1, x)\) 上的点
\(k = \operatorname{len}(\operatorname{path}(1, x))\)\(a_i,b_i\)\(\operatorname{path}(1\to x)\) 上的 \(k\) 个点依次排列,其分别以 \(1,x\) 为根的 \(s\) 的值

然后能发现一个性质,\(a_{i + 1} + b_{i} = s_1 = a_1 = b_k\),因为在树上这两部分正好拼成了整个树,而 \(a_1,b_k\) 都是整个树
然后算贡献:
\(-\sum\limits_{i = 1}^k a_i^2 + \sum\limits_{i = 1}^k b_i^2 = -a_1^2 - \sum\limits_{i = 2}^k a_i^2 + \sum\limits_{i = 1}^{k - 1} (s_1 - a_{i + 1})^2 + b_k^2 = \sum\limits_{i = 2}^k a_i^2 + \sum\limits_{i = 2}^{k} (s_1 - a_i)^2 = \sum\limits_{i = 2}^k a_i^2 + \sum\limits_{i = 2}^{k} (s_1^2 - 2s_1a_i + a_i^2) = \sum\limits_{i = 2}^k (s_1^2 - 2s_1a_i) = (k - 1)s1^2 - 2s_1\sum\limits_{i = 2}^k a_i = (k + 1)s1^2 - 2s_1\sum\limits_{i = 1}^k a_i\)

所以 \(ans\) 加上这部分贡献就是答案啦

// lhzawa(https://www.cnblogs.com/lhzawa/)
#include<bits/stdc++.h>
using namespace std;
const int N = 2e5 + 10;
vector<int> ev[N];
void add(int u, int v) {
	ev[u].push_back(v);
	return ;
}
int dep[N], siz[N], son[N], fa[N];
long long val[N];
long long s[N];
void dfsinit(int u, int f, long long &ans) {
	siz[u] = 1, fa[u] = f;
	s[u] = val[u];
	for (unsigned int i = 0; i < ev[u].size(); i++) {
		int v = ev[u][i];
		if (v == fa[u]) {
			continue;
		}
		dfsinit(v, u, ans);
		siz[u] += siz[v], s[u] += s[v], son[u] = siz[v] > siz[son[u]] ? v : son[u]; 
	}
	ans += s[u] * s[u];
	return ;
}
int dfn[N], top[N], dfnt;
long long dval[N];
void dfsdfn(int u, int hd) {
	top[u] = hd, dfn[u] = ++dfnt, dval[dfnt] = s[u];
	if (son[u]) {
		dfsdfn(son[u], hd);
	}
	for (unsigned int i = 0; i < ev[u].size(); i++) {
		int v = ev[u][i];
		if (v == fa[u] || v == son[u]) {
			continue;
		}
		dfsdfn(v, v);
	}
	return ;
}
struct segnode{
	int l, r, ln;
	long long w, ly;
};
segnode tr[N * 4];
void pushup(int k) {
	tr[k].w = tr[k << 1].w + tr[k << 1 | 1].w;
	return ;
}
void pushdown(int k) {
	if (tr[k].ly) {
		tr[k << 1].w += tr[k].ly * tr[k << 1].ln, tr[k << 1].ly += tr[k].ly;
		tr[k << 1 | 1].w += tr[k].ly * tr[k << 1 | 1].ln, tr[k << 1 | 1].ly += tr[k].ly;
		tr[k].ly = 0;
	}
	return ; 
}
void build(int k, int l, int r) {
	tr[k] = {l, r, r - l + 1, 0, 0};
	if (l == r) {
		tr[k].w = dval[l];
		return ;
	}
	int mid = (l + r) >> 1;
	build(k << 1, l, mid), build(k << 1 | 1, mid + 1, r);
	pushup(k);
	return ;
}
void update(int k, int l, int r, long long x) {
	if (r < tr[k].l || tr[k].r < l) {
		return ;
	}
	if (l <= tr[k].l && tr[k].r <= r) {
		tr[k].w += x * tr[k].ln, tr[k].ly += x;
		return ;
	}
	pushdown(k);
	update(k << 1, l, r, x), update(k << 1 | 1, l, r, x);
	pushup(k);
	return ;
}
long long query(int k, int l, int r) {
	if (r < tr[k].l || tr[k].r < l) {
		return 0;
	}
	if (l <= tr[k].l && tr[k].r <= r) {
		return tr[k].w;
	}
	pushdown(k);
	return query(k << 1, l, r) + query(k << 1 | 1, l, r);
}
int main() {
	int n, q;
	scanf("%d%d", &n, &q);
	for (int i = 1; i < n; i++) {
		int u, v;
		scanf("%d%d", &u, &v);
		add(u, v), add(v, u);
	}
	for (int i = 1; i <= n; i++) {
		scanf("%lld", &val[i]);
	}
	long long ans = 0;
	dfsinit(1, 0, ans), dfsdfn(1, 1), build(1, 1, n);
	function<void (int, long long)> add = [&ans](int x, long long y) -> void {
		int k = 0;
		long long h = 0;
		for (int i = x; i; i = fa[top[i]]) {
			k += dfn[i] - dfn[top[i]] + 1;
			h += query(1, dfn[top[i]], dfn[i]);
			update(1, dfn[top[i]], dfn[i], y);
		}
		ans += y * y * k + 2 * y * h;
		return ;
	};
	function<long long (int)> qry = [](int x) -> long long {
		int k = 0;
		long long h = 0, s1 = query(1, 1, 1);
		for (int i = x; i; i = fa[top[i]]) {
			k += dfn[i] - dfn[top[i]] + 1;
			h += query(1, dfn[top[i]], dfn[i]);
		}
		return (k + 1) * s1 * s1 - 2 * s1 * h;
	};
	for (; q; q--) {
		int opt;
		scanf("%d", &opt);
		switch (opt) {
			case 1: {
				int x;
				long long y;
				scanf("%d%lld", &x, &y);
				add(x, y - val[x]), val[x] = y;
				break;
			}
			case 2: {
				int x;
				scanf("%d", &x);
				printf("%lld\n", ans + qry(x));
				break;
			}
			default: {
				break;
			}
		}
	}
	return 0;
}