树形DP

发布时间 2023-07-25 15:47:53作者: andy_lz

P3565 [POI2014] HOT-Hotels

\(solution 1\):


先说一下我想到的 \(O(n^2)\) 算法。

首先不难发现,如果三个点两两距离相等,那么一定存在一个中心点,使得中心点到这三个点的距离相等。

枚举每一个点作为中心节点,然后从它开始将整个树 \(DFS\) 一遍。在 \(DFS\) 的时候开一个动态数组 \(p\)\(p[i][j]\) 表示以当前中心节点为根的第 \(i\) 颗子树中,深度为 \(j\) 的节点的数量。此时的答案为 \(\sum_{d=1}\sum_{i=1}^{cnt}\sum_{j=i+1}^{cnt}\sum_{k=j+1}^{cnt}p[i][d]\times p[j][d]\times p[k][d]\)

如何计算它?可以考虑增量法。设 \(s1=\sum_{d=1}\sum_{i=1}^{cnt}p[i][d]\)\(s2=\sum_{d=1}\sum_{i=1}^{cnt}\sum_{j=i+1}^{cnt}p[i][d]\times p[j][d]\)\(s3=\sum_{d=1}\sum_{i=1}^{cnt}\sum_{j=i+1}^{cnt}\sum_{k=j+1}^{cnt}p[i][d]\times p[j][d]\times p[k][d]\) 。在外层枚举 \(d\) ,在内层枚举 \(i\) ,每当 \(i\) 加一,就令 \(s3+=p[i][d-1]\times s2\) , \(s2+=p[i][d-1]\times s1\)\(s1+=p[i][d-1]\) 。每次内层循环结束时,令 \(ans+=s3\) ,最后 \(ans\) 即为所求。

\(code:\)

void dfs(int x,int fa,int d){
	if(p[cnt].size()<d)
		p[cnt].push_back(1);
	else
		++p[cnt][d-1];
	for(int i=head[x];i;i=nxt[i])
		if(ver[i]!=fa)
			dfs(ver[i],x,d+1);
}
void work(int x){
	cnt=0;
	for(int i=head[x];i;i=nxt[i]){
		++cnt;p[cnt].clear();
		dfs(ver[i],x,1);
	}
	bool ok=1;
	for(int i=1;;++i){
		int s=0,s2=0,s3=0;
		for(int j=1;j<=cnt;++j){
			if(p[j].size()<i)
				continue;
			s3+=p[j][i-1]*s2;
			s2+=p[j][i-1]*s;
			s+=p[j][i-1];
		}
		if(!s3) break;
		ans+=s3;
	}
}
void solve(int x){
	vis[x]=1;sum=1e9;
	work(x);
	for(int i=head[x];i;i=nxt[i])
		if(!vis[ver[i]])
			solve(ver[i]);
}

\(solution 2\)


其实这个题是可以优化到 \(O(n)\) 的。

\(f[i][j]\) 表示以 \(i\) 为根的子树中,距离当前点为 \(j\) 的点数;\(g[i][j]\)表示以 \(i\) 为根的子树中,两个点到 \(LCA\) 的距离为 \(d\) ,并且他们的 \(LCA\)\(i\) 的距离为 \(d−j\) 的点对数。

状态转移方程:
\(ans+=g[i][0],ans+=g[i][j]\times f[son][j-1],f[i][j]+=f[son][j-1],g[i][j]+=g[son][j+1]\)

如果我们钦定一个儿子,那么 \(f\)\(g\) 数组是可以直接赋值的。

我们进行长链剖分,每次钦定从重儿子直接转移,那么我们还需要从轻儿子进行转移。

那么,复杂度拆分成两个部分:直接从重儿子转移\(O(1)\),从轻儿子转移\(O(\sum len)\)。发现每个点有且仅有一个父亲,因此一条重链算且仅被一个点暴力转移,而每次转移复杂度是链长。所以全局复杂度是\(\sum\)链长,也就是\(O(n)\),因此总复杂度就是\(O(n)\)

P4516 潜入行动

2023.6.18拷逝T4

树形背包好题。

\(f[x][i][0/1][0/1]\) 表示以 \(x\) 为根的子树中共放了 \(i\) 个监听装置,其中 \(x\) 点放没放装置, \(x\) 点有没有被监听到的方案数(在以 \(x\) 为根的子树中除 \(x\) 外的其它结点都被监听到了)

状态转移方程:

①:\(x\)没有被监听,也没有放装置。此时子节点一定不能放装置。\(f[x][i+j][0][0]= \sum f[x][i][0][0]\times f[v][j][0][1]\)

②:\(x\)没有被监听,但放了装置。此时子节点同样一定不能放装置,但它有没有被以它为根的子树内的节点监听有无所谓了。

\(f[x][i+j][1][0]= \sum f[x][i][1][0]\times (f[v][j][0][0]+f[v][j][0][1])\)

③:\(x\)没放装置,但被监听了。

如果\(x\)在这之前已经被监听,那么\(v\)放不放装置无所谓;而如果\(x\)在这之前没有被监听,那么\(v\)处必须放装置。因为\(x\)没放装置,所以\(v\)必须被以它为根的子树内的节点监听。

\(f[x][i+j][0][1]= \sum f[x][i][0][1]\times (f[v][j][0][1]+f[v][j][1][1])+f[x][i][0][0]\times f[v][j][1][1]\)

④:\(x\)放了装置,也被监听了。

如果\(x\)在这之前已经被监听,那么\(v\)随意;而如果\(x\)在这之前没有被监听,那么\(v\)处必须放装置。因为\(x\)放了装置,所以\(v\)是否被以它为根的子树内的节点监听无所谓。

\(f[x][i+j][1][1]= \sum f[x][i][1][0]\times (f[v][j][1][0]+f[v][j][1][1])+f[x][i][1][1]\times (f[v][j][1][1]+f[v][j][1][0]+f[v][j][0][1]+f[v][j][0][0])\)

\(code:\)

#include<iostream>
#include<cstdio>
using namespace std;
const int mod=1e9+7,l=1e5+5;
int n,k,a,b,tot,head[l<<1],ver[l<<1],nxt[l<<1];
int dp[l][105][2][2],tmp[l][2][2],siz[l];
void add(int x,int y){
	nxt[++tot]=head[x];head[x]=tot;ver[tot]=y;
}
void dfs(int x,int fa){
	siz[x]=dp[x][0][0][0]=dp[x][1][1][0]=1;
	for(int t=head[x];t;t=nxt[t]){
		int v=ver[t];
		if(v!=fa){
			dfs(v,x);
			for(int i=0;i<=min(k,siz[x]);++i){
				tmp[i][0][0]=dp[x][i][0][0];dp[x][i][0][0]=0;
				tmp[i][0][1]=dp[x][i][0][1];dp[x][i][0][1]=0;
				tmp[i][1][0]=dp[x][i][1][0];dp[x][i][1][0]=0;
				tmp[i][1][1]=dp[x][i][1][1];dp[x][i][1][1]=0;
			}
			for(int i=0;i<=min(k,siz[x]);++i)
				for(int j=0;j<=min(k-i,siz[v]);++j){
					dp[x][i+j][0][0]=(dp[x][i+j][0][0]+1ll*tmp[i][0][0]*dp[v][j][0][1]%mod)%mod;
					dp[x][i+j][0][1]=(dp[x][i+j][0][1]+1ll*tmp[i][0][1]*(dp[v][j][0][1]+dp[v][j][1][1])%mod)%mod;
					dp[x][i+j][0][1]=(dp[x][i+j][0][1]+1ll*tmp[i][0][0]*dp[v][j][1][1]%mod)%mod;
					dp[x][i+j][1][0]=(dp[x][i+j][1][0]+1ll*tmp[i][1][0]*(dp[v][j][0][0]+dp[v][j][0][1])%mod)%mod;
					dp[x][i+j][1][1]=(dp[x][i+j][1][1]+1ll*tmp[i][1][0]*(dp[v][j][1][1]+dp[v][j][1][0])%mod)%mod;
					dp[x][i+j][1][1]=(dp[x][i+j][1][1]+1ll*tmp[i][1][1]*(1ll*(dp[v][j][0][0]+dp[v][j][0][1])+1ll*(dp[v][j][1][0]+dp[v][j][1][1]))%mod)%mod;
				}
			siz[x]+=siz[v];
		}
	}
	return ;
}
int main(){
	scanf("%d%d",&n,&k);
	for(int i=1;i<n;++i)
		scanf("%d%d",&a,&b),add(a,b),add(b,a);
	dfs(1,0);
	printf("%d\n",(dp[1][k][1][1]+dp[1][k][0][1])%mod);
	fclose(stdin);fclose(stdout);
	return 0;
}