子集反演

发布时间 2023-06-04 21:31:41作者: Diavolo-Kuang

什么是子集反演?

当然在此之前说明子集反演是什么 : 子集反演用于在 恰好是某个子集在这个子集中钦定/钦定这个子集 之间转换。并且子集反演有两种形式。

第一种:现在有一个集合 \(A\) ,我们定义 \(f(A)\) 表示集合 \(A\) 的答案, \(g(A)\) 表示 \(A\) 的所有子集的答案。如果有 $$g(A)=\sum_{B\subseteq A}f(B)$$

那么就可以有

\[f(A)=\sum_{B \subseteq A} (-1)^{|A|-|B|}g(B) \]

第二种(和第一种反过来):现在有一个集合 \(A\) ,我们定义 \(f(A)\) 表示集合 \(A\) 的答案,\(g(A)\) 表示 \(A\) 的所有子集的答案。如果有

\[g(A)=\sum_{A\subseteq B}f(B) \]

那么就可以有

\[f(A)=\sum_{A \subseteq B} (-1)^{|B|-|A|}g(B) \]

OK,现在我们知道什么是子集反演了,那么我们看几道例题。

[ZJOI2016]小星星

题意描述

现在给你一个 \(n\) 个节点,\(m\) 条边的图和一个 \(n\) 个节点的树。你需要给树上的每一个节点标号,满足:

  • 给节树上点标的号是一个 \(1\)\(n\) 的排列
  • 标号之后的树,树上有的边,原图中也得有。

求标号方案数。

\(n \leqslant 17\)

思路点拨

可以发现 \(n\) 非常的小,并且是一道计数题,我们自然而然想到了状态压缩dp。我们定义状态 \(f_{i,j,S}\) 表示在 \(i\) 的子树中 \(i\) 映射的点是 \(j\) ,子树中使用了 \(S\) 这些点的映射的方案数(每一个图上的点只被使用一次)。\(S\) 是一个二进制数表示方案。转移十分简单,这里略去,我们只需要知道这样会TLE到爆就可以了。

接下来就要讲到这道题目最神的地方了,我们抛弃标号是排列的限制。你就惊人的发现,这样最终答案可以子集反演,并且我们需要用到的状态 \(f_{i,j,S}\) 表示在 \(i\) 的子树中 \(i\) 映射的点是 \(j\) ,子树中使用了 \(S\) 这些点的映射的方案数(每一个图上的点可能使用多次)是十分好求的。

考虑转移(为了方便我们抛弃 \(S\) 这一无用维):$$f_{i,j}=\prod_{k \in son} (\sum_{l=1}^{n} f_{k,l} [l \in S , 原来的图中j,l有连边])$$

这里给出一份代码(如果不理解可以看代码,很短):

#include<bits/stdc++.h>
#define int long long
using namespace std;
inline int read(){
	int x=0,f=1;
	char ch=getchar();
	while(ch<'0'||ch>'9'){
		if(ch=='-') f=-f;
		ch=getchar();
	}
	while(ch>='0'&&ch<='9'){
		x=x*10+ch-'0';
		ch=getchar();
	}
	return x*f;
}
const int MAXN=18;
int n,m;
bool vis[MAXN][MAXN];
vector<int> e[MAXN];
int f[MAXN][MAXN];
int s[MAXN],top;
void dfs(int x,int dad){
	for(int i=1;i<=top;i++)
		f[x][s[i]]=1;
	for(int i=0;i<e[x].size();i++){
		int to=e[x][i];
		if(to==dad) continue;
		dfs(to,x);
		for(int i=1;i<=top;i++){
			int cnt=0;
			for(int j=1;j<=top;j++)
				if(vis[s[i]][s[j]])
					cnt+=f[to][s[j]];
			f[x][s[i]]*=cnt;
		}
	}
}
signed main(){
	n=read(),m=read();
	for(int i=1;i<=m;i++){
		int u=read(),v=read();
		vis[u][v]=vis[v][u]=1;
	}
	for(int i=1;i<n;i++){
		int u=read(),v=read();
		e[u].push_back(v);
		e[v].push_back(u);
	}
	int ans=0;
	for(int i=1;i<(1<<n);i++){
		memset(f,0,sizeof(f));
		top=0;
		for(int j=0;j<n;j++)
			if(i&(1<<j))
				s[++top]=j+1;
		dfs(1,-1);
		int cnt=0;
		for(int i=1;i<=top;i++)
			cnt+=f[1][s[i]];
		if((n-top)&1) ans-=cnt;
		else ans+=cnt;
	}
	cout<<ans;
	return 0;
}