P5901 [IOI2009] Regions

发布时间 2023-10-26 17:21:51作者: WrongAnswer_90

P5901 [IOI2009] Regions

更好的阅读体验

根号分治,过掉不难,但是想 \(\mathcal O(n\sqrt n)\) 还是有一些思维含量的。

首先考虑一种暴力:预处理两两颜色间的答案,\(\mathcal O(1)\) 查询。首先枚举颜色数,然后每种颜色 \(\mathcal O(n)\) 扫一遍,这样预处理的复杂度是 \(\mathcal O(n\times \text{处理的颜色数})\)。直接硬预处理显然过不去,所以可以考虑对出现次数较多的颜色预处理。设块长为 \(B\)

对于大块与大块,大块与小块,小块与大块之间的答案还是想上面一样预处理,预处理复杂度是 \(\mathcal O(n\times \dfrac{n}{B})\)

对于小块,内部点数是 \(\mathcal O(B)\) 级别的,树上的祖先后代问题用 dfn 序转化为区间问题:

\(\sum_{i\in col_x}\sum_{j\in col_y}[in_i\leq dfn_j\leq out_i]\)

直接做有 \(in,out\) 两个限制,但是一个 \(dfn\) 不可能即 \(>out_i\)\(<in_i\),所以可以容斥一下,用总的点对数减去不合法的点对数,问题变为

\(\sum_{i\in col_x}\sum_{j\in col_y}[dfn_j<in_i]\)

可以树状数组做到 \(B\log n\),但是更优秀的做法是预处理排序好的 \(in\)\(dfn\),用类似归并排序求逆序对的思路双指针解决,对于 \(>out_i\) 的部分也是同理。

总复杂度 \(\mathcal O(n\times \dfrac{n}{B}+nB)\)\(B\)\(\sqrt n\) 时复杂度为 \(\mathcal O(n\sqrt n)\)

	int n,m,q,cnt,B=360,tot,dfn[200001],nfd[200001],pos[200001],sum[200001],a[200001],head[200001],to[200001],nex[200001],siz[25001],ans1[160][25001],ans2[160][25001];
	inline void add(int x,int y){to[++cnt]=y,nex[cnt]=head[x],head[x]=cnt;}
	vector<int> larg,ve[25001],ve2[25001];
	void dfs(int k){dfn[k]=++tot;for(int i=head[k];i;i=nex[i])dfs(to[i]);nfd[k]=tot;}
	void dfs1(int k){for(int i=head[k];i;i=nex[i])dfs1(to[i]),sum[k]+=sum[to[i]];}
	void dfs2(int k){for(int i=head[k];i;i=nex[i])sum[to[i]]+=sum[k],dfs2(to[i]);}
	inline bool cmp(int x,int y){return dfn[x]<dfn[y];}
	inline bool cmp2(int x,int y){return nfd[x]<nfd[y];}
	inline void mian()
	{
		cin>>n>>m>>q>>a[1],++siz[a[1]],ve[a[1]].eb(1);int x,y;
		for(int i=2;i<=n;++i)cin>>x>>a[i],++siz[a[i]],ve[a[i]].eb(i),add(x,i);
		for(int i=1;i<=m;++i)if(siz[i]>B)pos[i]=larg.size(),larg.eb(i);
		dfs(1);
		for(int i=1;i<=m;++i)ve[i].eb(0),ve2[i]=ve[i],sort(ve2[i].begin(),ve2[i].end(),cmp2),sort(ve[i].begin(),ve[i].end(),cmp);
		for(int i=0;i<larg.size();++i)
		{
			for(auto j:ve[larg[i]])++sum[j];
			dfs1(1);
			for(int j=1;j<=n;++j)ans1[i][a[j]]+=sum[j];
			memset(sum,0,sizeof(sum));
			for(auto j:ve[larg[i]])++sum[j];
			dfs2(1);
			for(int j=1;j<=n;++j)ans2[i][a[j]]+=sum[j];
			memset(sum,0,sizeof(sum));
		}
		while(q--)
		{
			cin>>x>>y;
			if(pos[x])cout<<ans2[pos[x]][y]<<endl;
			else if(pos[y])cout<<ans1[pos[y]][x]<<endl;
			else
			{
				int ans=siz[x]*siz[y];
				for(int i=1,j=1;i<=siz[x]&&j<=siz[y];++j)
				{
					while(i<=siz[x]&&dfn[ve[x][i]]<=dfn[ve[y][j]])++i;
					ans-=siz[x]-i+1;
				}
				for(int i=siz[x],j=siz[y];i>0&&j>0;--j)
				{
					while(i>0&&nfd[ve2[x][i]]>=dfn[ve[y][j]])--i;
					ans-=i;
				}
				cout<<ans<<endl;
			}
			fflush(stdout);
		}
	}