CF840E In a Trap

发布时间 2023-09-15 11:00:26作者: ImALAS

想了一会并不是很会,主要是信息利用率实在太不牛。

考虑树分块,我们取块长 \(B=2^8\),这个块长很有深意。注意这里的树分块只是形式分块,并不是树上关键点之类。

定义 \(f_{x,i}\) 表示 \(x\) 是第 \(i\) 个块的开头,询问深度为 \(dep_x+Bi\) 的块内最优解。那么 \(f_{x,i} = \max\limits_{j=0}^{B-1}\{ a_{anc_{x,j}}\mathrm{xor} (Bi+j)\}\)。考虑怎么维护这个东西。

事实上,取块长为 \(2^8\) 是因为这样可以给深度这个并不那么“二进制”的东西提供一个 trie 树上便利的维护方法。这也和这题询问为 dep 的特殊性质有关。

因为询问深度可以被拆解位 \(Bi+j\),前 8 位和后 8 位之间没有影响,我们可以把前 8 位扔到 trie 树上,查询和 \(i\) xor 起来的最大值,后面 8 位在 trie 树上打一个 tag 记录后八位的 max 即可。查询就直接往上跳。\(\mathcal O(n\sqrt n\log n+q\sqrt n)\)

#include <bits/stdc++.h>
#define pb emplace_back
#define fir first
#define sec second

using i64 = long long;
using pii = std::pair<int, int>;

const int sig = 10, B = 256, maxn = 5e4 + 5;
namespace trie {
	int trie[B * sig + 5][2], sz, mxv[B * sig + 5];
	void init() {
		for(int i = 0;i <= sz;++ i)
			trie[i][0] = trie[i][1] = mxv[i] = 0;
		return sz = 0, void();
	}
	void insert(int x, int idx) {
		int u = 0;
		for(int i = 7;~ i;-- i) {
			int c = x >> i & 1;
			if(!trie[u][c]) trie[u][c] = ++ sz;
			u = trie[u][c];
		}
		mxv[u] = std::max(mxv[u], idx);
	}
	int query(int x) {
		int ans = 0, u = 0;
		for(int i = 7;~ i;-- i) {
			int c = x >> i & 1;
			if(trie[u][!c]) {
				u = trie[u][!c];
				ans += 1 << i;
			} else {
				u = trie[u][c];
			}
		}
		return (ans << 8) + mxv[u];
	}
}
std::vector<int> G[maxn], ver;
int a[maxn], f[maxn][B], n, m, fa[maxn], siz[maxn], dep[maxn], bel[maxn];

void dfs(int u, int ff) {
	dep[u] = dep[ff] + 1; fa[u] = ff;
	for(auto& v : G[u]) {
		if(v == ff) continue ;
		dfs(v, u);
	}
	if(dep[u] >= B) ver.pb(u);
	return ;
}

int main() {
	scanf("%d %d", &n, &m);
	for(int i = 1;i <= n;++ i)
		scanf("%d", &a[i]);
	for(int i = 1;i < n;++ i) {
		int u, v; scanf("%d %d", &u, &v);
		G[u].pb(v); G[v].pb(u);
	}
	dfs(1, 0);
	for(auto& u : ver) {
		int x = u; trie::init();
		trie::insert(a[u] >> 8, a[u] % B);
		for(int i = 1;i < B;++ i)
			u = fa[u], trie::insert(a[u] >> 8, (a[u] % B) ^ i);
		bel[x] = u;
		for(int i = 0;i <= n / B;++ i)
			f[x][i] = trie::query(i);
	}
	while(m --) {
		int u, v; scanf("%d %d", &u, &v);
		int ans = 0, blk = 0, res = 0;
		while(dep[bel[v]] >= dep[u])
			ans = std::max(ans, f[v][blk ++]), v = fa[bel[v]];
		while(dep[v] >= dep[u])
			ans = std::max(ans, a[v] ^ (blk * B + res)), ++ res, v = fa[v];
		printf("%d\n", ans);
	}
	return 0;
}