点分树是什么/yiw
定义 \(s_i\) 为 \(i\) 子树内的权值和,默认 \(1\) 为根
首先考虑没有换根的解法
考虑把点权变换转化为加上一个数,即 \(val_{x}\leftarrow y\) 转化为 \(val_{x}\leftarrow val_{x} + (y - val_{x})\)
定义这个加上的数为 \(z\),考虑加上 \(z\) 对答案有什么影响
不难发现只会对 \(i \in \operatorname{path}(1, x)\) 上的 \(s_i\leftarrow s_i + z\)
所以多的贡献即为:
\(\sum\limits_{i = {path}(1, x)} (s_{i} + z)^2 - \sum\limits_{i = {path}(1, x)} s_{i}^2 = \sum\limits_{i = {path}(1, x)} (s_i^2 + 2s_i z + z^2) - \sum\limits_{i = {path}(1, x)} s_{i}^2 = \sum\limits_{i = {path}(1, x)} (2s_i z + z^2) = \operatorname{len}(\operatorname{path}(1, x)) z^2 + 2z\sum\limits_{i = {path}(1, x)} s_i\)
这个式子很明显树剖就行了
答案的初始值即为 \(\sum\limits_{i = 1}^n s_i^2\),每次修改加上对应的贡献即可
考虑有换根怎么做
则不难发现会改变的还是 \(i \in \operatorname{path}(1, x)\) 上的点
令 \(k = \operatorname{len}(\operatorname{path}(1, x))\),\(a_i,b_i\) 为 \(\operatorname{path}(1\to x)\) 上的 \(k\) 个点依次排列,其分别以 \(1,x\) 为根的 \(s\) 的值
然后能发现一个性质,\(a_{i + 1} + b_{i} = s_1 = a_1 = b_k\),因为在树上这两部分正好拼成了整个树,而 \(a_1,b_k\) 都是整个树
然后算贡献:
\(-\sum\limits_{i = 1}^k a_i^2 + \sum\limits_{i = 1}^k b_i^2 = -a_1^2 - \sum\limits_{i = 2}^k a_i^2 + \sum\limits_{i = 1}^{k - 1} (s_1 - a_{i + 1})^2 + b_k^2 = \sum\limits_{i = 2}^k a_i^2 + \sum\limits_{i = 2}^{k} (s_1 - a_i)^2 = \sum\limits_{i = 2}^k a_i^2 + \sum\limits_{i = 2}^{k} (s_1^2 - 2s_1a_i + a_i^2) = \sum\limits_{i = 2}^k (s_1^2 - 2s_1a_i) = (k - 1)s1^2 - 2s_1\sum\limits_{i = 2}^k a_i = (k + 1)s1^2 - 2s_1\sum\limits_{i = 1}^k a_i\)
所以 \(ans\) 加上这部分贡献就是答案啦
// lhzawa(https://www.cnblogs.com/lhzawa/)
#include<bits/stdc++.h>
using namespace std;
const int N = 2e5 + 10;
vector<int> ev[N];
void add(int u, int v) {
ev[u].push_back(v);
return ;
}
int dep[N], siz[N], son[N], fa[N];
long long val[N];
long long s[N];
void dfsinit(int u, int f, long long &ans) {
siz[u] = 1, fa[u] = f;
s[u] = val[u];
for (unsigned int i = 0; i < ev[u].size(); i++) {
int v = ev[u][i];
if (v == fa[u]) {
continue;
}
dfsinit(v, u, ans);
siz[u] += siz[v], s[u] += s[v], son[u] = siz[v] > siz[son[u]] ? v : son[u];
}
ans += s[u] * s[u];
return ;
}
int dfn[N], top[N], dfnt;
long long dval[N];
void dfsdfn(int u, int hd) {
top[u] = hd, dfn[u] = ++dfnt, dval[dfnt] = s[u];
if (son[u]) {
dfsdfn(son[u], hd);
}
for (unsigned int i = 0; i < ev[u].size(); i++) {
int v = ev[u][i];
if (v == fa[u] || v == son[u]) {
continue;
}
dfsdfn(v, v);
}
return ;
}
struct segnode{
int l, r, ln;
long long w, ly;
};
segnode tr[N * 4];
void pushup(int k) {
tr[k].w = tr[k << 1].w + tr[k << 1 | 1].w;
return ;
}
void pushdown(int k) {
if (tr[k].ly) {
tr[k << 1].w += tr[k].ly * tr[k << 1].ln, tr[k << 1].ly += tr[k].ly;
tr[k << 1 | 1].w += tr[k].ly * tr[k << 1 | 1].ln, tr[k << 1 | 1].ly += tr[k].ly;
tr[k].ly = 0;
}
return ;
}
void build(int k, int l, int r) {
tr[k] = {l, r, r - l + 1, 0, 0};
if (l == r) {
tr[k].w = dval[l];
return ;
}
int mid = (l + r) >> 1;
build(k << 1, l, mid), build(k << 1 | 1, mid + 1, r);
pushup(k);
return ;
}
void update(int k, int l, int r, long long x) {
if (r < tr[k].l || tr[k].r < l) {
return ;
}
if (l <= tr[k].l && tr[k].r <= r) {
tr[k].w += x * tr[k].ln, tr[k].ly += x;
return ;
}
pushdown(k);
update(k << 1, l, r, x), update(k << 1 | 1, l, r, x);
pushup(k);
return ;
}
long long query(int k, int l, int r) {
if (r < tr[k].l || tr[k].r < l) {
return 0;
}
if (l <= tr[k].l && tr[k].r <= r) {
return tr[k].w;
}
pushdown(k);
return query(k << 1, l, r) + query(k << 1 | 1, l, r);
}
int main() {
int n, q;
scanf("%d%d", &n, &q);
for (int i = 1; i < n; i++) {
int u, v;
scanf("%d%d", &u, &v);
add(u, v), add(v, u);
}
for (int i = 1; i <= n; i++) {
scanf("%lld", &val[i]);
}
long long ans = 0;
dfsinit(1, 0, ans), dfsdfn(1, 1), build(1, 1, n);
function<void (int, long long)> add = [&ans](int x, long long y) -> void {
int k = 0;
long long h = 0;
for (int i = x; i; i = fa[top[i]]) {
k += dfn[i] - dfn[top[i]] + 1;
h += query(1, dfn[top[i]], dfn[i]);
update(1, dfn[top[i]], dfn[i], y);
}
ans += y * y * k + 2 * y * h;
return ;
};
function<long long (int)> qry = [](int x) -> long long {
int k = 0;
long long h = 0, s1 = query(1, 1, 1);
for (int i = x; i; i = fa[top[i]]) {
k += dfn[i] - dfn[top[i]] + 1;
h += query(1, dfn[top[i]], dfn[i]);
}
return (k + 1) * s1 * s1 - 2 * s1 * h;
};
for (; q; q--) {
int opt;
scanf("%d", &opt);
switch (opt) {
case 1: {
int x;
long long y;
scanf("%d%lld", &x, &y);
add(x, y - val[x]), val[x] = y;
break;
}
case 2: {
int x;
scanf("%d", &x);
printf("%lld\n", ans + qry(x));
break;
}
default: {
break;
}
}
}
return 0;
}