P8528 [Ynoi2003] 铃原露露

发布时间 2023-11-25 11:29:49作者: mfeitveer

一道很好的启发式合并题目。

思路

考虑一个事实。

我们想要求出对于每个点对不合法的情况。

例如现在考虑到了 \((x,y)\),它们的 \(\text{lca}\)\(z\)

有几种情况:

  • \(a_x< a_z< a_y\),那么是合法的。
  • \(a_x< a_y< a_z\),那么包含 \(a_x,a_y\) 不包含 \(a_z\) 的区间是不合法的,即 \(((1,a_x),(a_y,a_z])\)
  • \(a_z< a_x< a_y\),那么包含 \(a_x,a_y\) 不包含 \(a_z\) 的区间是不合法的,即 \(([a_z,a_x),(a_y,n))\)

我们考虑将询问离线,做一次扫描线。

考虑两种不合法的区间对于每个点来说都是连续的。

也就可以差分后做历史和。

同时考虑,对于一个相同的 \(\text{lca}\),那么我们的 \((a_x,a_y)\) 相距越近越好。

那么只要对于每一个点在同一个 \(lca\) 下找到前驱后继即可。

这一部分可以启发式合并来做,直接 \(set\) 维护做最普通的启发式合并即可,甚至不需要树上启发式合并记重儿子之类的。

而对于后面历史和部分。

我们使用线段树维护 \(0\) 的个数,历史 \(0\) 的个数就是答案。

Code

代码很好写。

/**
 * @file P8528.cpp
 * @author win114514
 * @date 2023-11-25
 * 
 * @copyright Copyright (c) 2023
 * 
 */
#include <bits/stdc++.h>
using namespace std;

#define x first
#define y second
#define mp(x, y) make_pair(x, y)
#define eb(...) emplace_back(__VA_ARGS__)
#define fro(i, x, y) for(int i = (x);i <= (y);i++)
#define pre(i, x, y) for(int i = (x);i >= (y);i--)
#define dbg cerr << "Line " << __LINE__ << ": "
#define EVAL(x) #x " = " << (x)

typedef int64_t i64;
typedef uint32_t u32;
typedef uint64_t u64;
typedef __int128_t i128;
typedef __uint128_t u128;
typedef pair<int, int> PII;

bool ed;

const int N = 200010;
const int mod = 998244353;

int n, m, a[N], sz[N], sn[N], fa[N], tn[N], dep[N];
vector<int> to[N]; set<int> son[N];
vector<tuple<int, int, int>> fin, w[N];
i64 ans[N]; vector<PII> que[N];
i64 t1[N<<1], t2[N<<1], t3[N<<1], tg1[N<<1], tg2[N<<1];

inline void push1(int p, int v)
	{ tg1[p] += v, t3[p] += v; }
inline void push2(int p, int v)
	{ tg2[p] += v, t1[p] += t2[p] * v; }
inline void pdo(int p, int mid)
{
	if(tg1[p]) push1(mid<<1, tg1[p]), push1(mid<<1|1, tg1[p]);
	if(tg2[p])
	{
		if(t3[mid<<1] <= t3[mid<<1|1]) push2(mid<<1, tg2[p]);
		if(t3[mid<<1] >= t3[mid<<1|1]) push2(mid<<1|1, tg2[p]);
	}
	tg1[p] = tg2[p] = 0;
}
inline void pup(int p, int mid)
{
	t1[p] = t1[mid<<1] + t1[mid<<1|1], t2[p] = 0;
	t2[p] += (t3[mid<<1] <= t3[mid<<1|1]) * t2[mid<<1];
	t2[p] += (t3[mid<<1] >= t3[mid<<1|1]) * t2[mid<<1|1];
	t3[p] = min(t3[mid<<1], t3[mid<<1|1]);
}
inline void build(int p, int l, int r)
{
	if(l == r) return t2[p] = 1, void();
	int mid = (l + r) >> 1;
	build(mid<<1, l, mid);
	build(mid<<1|1, mid + 1, r);
	pup(p, mid);
}
inline void upd1(int p, int l, int r, int ls, int rs, int k)
{
	if(ls <= l && r <= rs) return push1(p, k);
	int mid = (l + r) / 2; pdo(p, mid);
	if(mid >= ls) upd1(mid<<1, l, mid, ls, rs, k);
	if(mid <  rs) upd1(mid<<1|1, mid + 1, r, ls, rs, k);
	pup(p, mid);
}
inline void upd2(int p, int l, int r, int ls, int rs, int k)
{
	if(t3[p] > 0) return;
	if(ls <= l && r <= rs) return push2(p, k);
	int mid = (l + r) / 2; pdo(p, mid);
	if(mid >= ls) upd2(mid<<1, l, mid, ls, rs, k);
	if(mid <  rs) upd2(mid<<1|1, mid + 1, r, ls, rs, k);
	pup(p, mid);
}
inline i64 ask(int p, int l, int r, int ls, int rs)
{
	if(ls <= l && r <= rs) return t1[p];
	int mid = (l + r) / 2; i64 sum{}; pdo(p, mid);
	if(mid >= ls) sum += ask(mid<<1, l, mid, ls, rs);
	if(mid <  rs) sum += ask(mid<<1|1, mid + 1, r, ls, rs);
	return sum;
}
inline void merge(int x, int y)
{
	if(son[x].size() < son[y].size())
		swap(son[x], son[y]);
	for(auto i : son[y])
	{
		auto it = son[x].lower_bound(i);
		if(it != son[x].end())
			fin.eb(*it, i, a[x]);
		if(it != son[x].begin())
			fin.eb(*prev(it), i, a[x]);
	}
	for(auto i : son[y]) son[x].insert(i); son[y].clear();
}
inline void dfs(int x)
{
	son[x].insert(a[x]);
	for(auto i : to[x])
	{
		if(i == fa[x]) continue;
		dfs(i), merge(x, i);
	}
}
inline void solve()
{
	cin >> n >> m;
	fro(i, 1, n) cin >> a[i];
	fro(i, 2, n) cin >> fa[i];
	fro(i, 2, n) to[fa[i]].eb(i);
	dfs(1), build(1, 1, n);
	for(auto [x, y, z] : fin)
	{
		if(x > y) swap(x, y);
		if(x <= z && z <= y) continue;
		if(y < z)
			w[y].eb(1, x, 1),
			w[z].eb(1, x, -1);
		if(z < x)
			w[y].eb(z + 1, x, 1);
	}
	fro(i, 1, m)
	{
		int l, r;
		cin >> l >> r;
		que[r].eb(l, i);
	}
	fro(i, 1, n)
	{
		for(auto [x, y, z] : w[i])
			upd1(1, 1, n, x, y, z);
		upd2(1, 1, n, 1, i, 1);
		for(auto [l, id] : que[i])
			ans[id] = ask(1, 1, n, l, i);
	}
	fro(i, 1, m) cout << ans[i] << "\n";
}

bool st;

signed main()
{
	ios::sync_with_stdio(0), cin.tie(0);
	double Mib = fabs((&ed-&st)/1048576.), Lim = 1024;
	cerr << " Memory: " << Mib << "\n", assert(Mib<=Lim);
	solve();
	return 0;
}