2023湖南省赛 E.ytree (线段树)

发布时间 2023-09-24 15:12:45作者: qdhys

传送门

大致思路:

1. 将操作1拆分为两个部分x(-1)^d + kd*(-1)d。对于操作1中的x*(-1)d部分而言。我们可以对式子进行拆分,把x拆出来,我们会发现和v号点距离为奇数的点会减去x,为偶数的点会加上x,所以我们可以在线段树上用一个sum1维护应该减去的值,sum2维护加上的值即可。

2. 随即就是如何维护线段树不同结点之间的sum1和sum2了。我们将整棵树按照dfs序建树,如此一来一颗子树的dfs序是会一段连续的区间,我们在线段树上维护结点的深度最小值mn,当我们将父节点fa上的标记下传到子节点son的时候就可以根据父节点和子节点的最小深度差来下传标记,如果son.mn - fa.mn是奇数,那么son这个结点应该加上的其实是在fa减去的x的值总和,应该减去的其实是在fa加上的x的值的总和,所以就应该这样更新son.sum1 += fa.sum2, son.sum1 += fa.sum2。如果是偶数同理推导

3.再看1中的kd(-1)^d部分如何维护。维护k1和k2两个值,分别表示应该减去的k的总和 和 应该加上的k的总和。还是考虑如何下传标记。当我们将父节点fa上的标记下传到子节点son的时候同样可以根据父节点和子节点的最小深度差来下传标记。如果son.mn - fa.mn = d是奇数,那么son这个结点的sum1应该加上d * fa.k2, sum2应该加上d * fa.k1, son.k1 += fa.k2, son.k2 += fa.k1。偶数同理推导。

  在d为奇数的时候,将2和3中的两个式子合并就是son.sum1 = son.sum1 + fa.sum2 + fa.k2 * d, son.sum2 = son.sum2 + fa.sum1 + fa.k1 * d。

4. 操作2就是线段树的单点查询。

5. 操作3我们可以创建一个set<array<int, 3>>将每个操作1按照{dfs序, x, k}的顺序丢进set。当我们遇到操作3中的v的时候只需要调用set的lowerbound来查找v的dfs序第一个出现的位置,并将这个提取进行修改操作,修改完之后从set删除即可。

#include <bits/stdc++.h>

const int N = 2e5 + 10;
const int MOD = 1e9 + 7;
using ll = long long;
typedef std::array<int, 3> PII;

int n, m;
#define ls u << 1
#define rs u << 1 | 1

int w[N], h[N], e[N], ne[N], idx;
int id[N], cnt;
int dep[N], sz[N], dfn[N];

struct Node {
	int l, r;                                                                                                               
	int k1, k2;//k1奇数k的和,k2偶数k的和
	int mn;//dfs序最小的点的深度
	int sum1, sum2;//sum1奇数和, sum2偶数和
}tr[N << 2];

inline void add(int &x) {
	if (x >= MOD) x -= MOD;
	x += MOD;
	if (x >= MOD) x -= MOD;
}

inline void pushup(int u) {
	tr[u].mn = std::min(tr[ls].mn, tr[rs].mn);
}

inline void pushdown(int u) {
	int d1 = tr[ls].mn - tr[u].mn;
	if (d1 & 1) {
		tr[ls].sum1 = (1ll * tr[ls].sum1 + tr[u].sum2 + 1ll * tr[u].k2 * d1) % MOD;
		tr[ls].sum2 = (1ll * tr[ls].sum2 + tr[u].sum1 + 1ll * tr[u].k1 * d1) % MOD;
		tr[ls].k1 += tr[u].k2;
		tr[ls].k2 += tr[u].k1;
		
		add(tr[ls].sum2), add(tr[ls].sum1), add(tr[ls].k1), add(tr[ls].k2);
	} else {
		tr[ls].sum1 = (1ll * tr[ls].sum1 + tr[u].sum1 + 1ll * tr[u].k1 * d1) % MOD;
		tr[ls].sum2 = (1ll * tr[ls].sum2 + tr[u].sum2 + 1ll * tr[u].k2 * d1) % MOD;
		tr[ls].k1 += tr[u].k1;
		tr[ls].k2 += tr[u].k2;
		add(tr[ls].sum2), add(tr[ls].sum1), add(tr[ls].k1), add(tr[ls].k2);
	}
	
	int d2 = tr[rs].mn - tr[u].mn;
	
	if (d2 & 1) {
		tr[rs].sum1 = (1ll * tr[rs].sum1 + tr[u].sum2 + 1ll * tr[u].k2 * d2) % MOD;
		tr[rs].sum2 = (1ll * tr[rs].sum2 + tr[u].sum1 + 1ll * tr[u].k1 * d2) % MOD;
		tr[rs].k1 += tr[u].k2;
		tr[rs].k2 += tr[u].k1;
		add(tr[rs].sum1), add(tr[rs].sum2), add(tr[rs].k1), add(tr[rs].k2);
	} else {
		tr[rs].sum1 = (1ll * tr[rs].sum1 + tr[u].sum1 + 1ll * tr[u].k1 * d2) % MOD;
		tr[rs].sum2 = (1ll * tr[rs].sum2 + tr[u].sum2 + 1ll * tr[u].k2 * d2) % MOD;
		tr[rs].k1 += tr[u].k1;
		tr[rs].k2 += tr[u].k2;
		add(tr[rs].sum1), add(tr[rs].sum2), add(tr[rs].k1), add(tr[rs].k2);
	}
	
	tr[u].sum1 = tr[u].sum2 = tr[u].k1 = tr[u].k2 = 0;
}

inline void build(int u, int l, int r){
	tr[u] = {l, r};
	if(l == r)	{
		tr[u].mn = dep[dfn[l]];
		return ;
	}
	int mid = l + r >> 1;
	build(u << 1, l, mid);
	build(u << 1 | 1, mid + 1, r);
	
	pushup(u);
}

inline void init(){
	memset(h, -1, sizeof h);
}

inline void add(int a, int b){
    e[idx] = b, ne[idx] = h[a], h[a] = idx ++ ;
}
int rr[N];
inline void dfs(int u, int father, int depth){
    dep[u] = depth, id[u] = ++ cnt, sz[u] = 1;
    dfn[cnt] = u;
    for (int i = h[u]; ~i; i = ne[i]){
        int j = e[i];
        if (j == father) continue;
        dfs(j, u, depth + 1);
        sz[u] += sz[j];
    }
    rr[u] = cnt;
}

inline void modify(int u, int L, int R, int x, int k, int depth) {
	if (tr[u].l >= L && tr[u].r <= R) {
		int d = tr[u].mn - depth;
		if (d & 1) {
			tr[u].sum1 = (ll(tr[u].sum1) + x + 1ll * d * k) % MOD;
			tr[u].k1 += k;
			add(tr[u].sum1);
			add(tr[u].k1);
		} else {
			tr[u].sum2 = (ll(tr[u].sum2) + x + 1ll * d * k) % MOD;
			tr[u].k2 += k;
			add(tr[u].sum2);
			add(tr[u].k2);
		}
 		return ;
	}
	
	pushdown(u);
	
	int mid = tr[u].l + tr[u].r >> 1;
	if (L <= mid) modify(ls, L, R, x, k, depth);
	if (R > mid) modify(rs, L, R, x, k, depth);
}

inline int query(int u, int x) {
	if (tr[u].l == tr[u].r) return (tr[u].sum2 - tr[u].sum1 + MOD) % MOD;
	
	pushdown(u);
	
	int mid = tr[u].l + tr[u].r >> 1;
	
	if (x <= mid) return query(ls, x);
	return query(rs, x);
}

inline void solve() {
	memset(h, -1, sizeof h);
	std::cin >> n >> m;
	
	for (int i = 2; i <= n; i ++) {
		int x;
		std::cin >> x;
		add(x, i);
	}
	
	dfs(1, -1, 1);
	
	std::multiset<PII> st;

	build(1, 1, n);
	constexpr int INF = 0x3f3f3f3f;
	auto get = [&](int sb) {
		auto it = st.lower_bound({sb, -INF, -INF});
		if (it == st.end()) {
			PII sn = {10000000, 0, 0};
			return sn;
		}
		return *it;
	};
	
	while (m --) {
		int op;
		std::cin >> op;
		if (op == 1) {
			int x, v, k;
			std::cin >> v >> x >> k;
			modify(1, id[v], id[v] + sz[v] - 1, x, k, dep[v]);
			st.insert({id[v], x, k});
		} else if (op == 2){
			int v;
			std::cin >> v;
			std::cout << query(1, id[v]) << '\n';
		} else {
			int z;
			std::cin >> z;
			for (int t = get(id[z])[0]; t <= rr[z]; t = get(t)[0]) {
				auto [dfnn, x, k] = get(t);
				modify(1, dfnn, dfnn + sz[dfn[dfnn]] - 1, -x, -k, dep[dfn[dfnn]]);
				st.erase(st.find({dfnn, x, k}));
			}
		}
	}
}

signed main(void) {
	std::ios::sync_with_stdio(false);
	std::cin.tie(nullptr);
	std::cout.tie(nullptr);
	
	int _ = 1;
	
	//std::cin >> _;
	while (_ --) solve();
	
	return 0;
}