ABC331G 题解

发布时间 2023-12-04 19:16:34作者: 2008verser

盒子里有 \(n\)\(m\) 种卡片,第 \(i\) 种卡片有 \(c_i\) 张。\(\sum c_i=n\)

每次均匀随机选一张,再放回去。求拿出过的卡片包含全部种类所需要的取出次数的期望。

\(998244353\) 取模。

\(1\leq n,m\leq 2e5,c_i\gt0\)

首先观察到,对于任意终止局面,最后取出的卡片一定的种类,一定是第一次被取出。

那么我们只需要知道每种卡片第一次被取出的时间 \(S_i\),剩下每个空位置都可以选前面任意一种卡片。

这个思路是行不通的。

发现答案等于“每个卡片第一次被取出的时间”之中的最大值。即 \(\max S_i\)

因为 \(S\) 是不确定的。即,求 \(S\) 里面最大值的期望。

联系 Min-Max 容斥:

\(U=\{1,2,...,m\}\)\(U\) 是全部种类的集合。

\[E(\max_{i\in U} S_i)=\sum_{T\subseteq U}(-1)^{|T|-1}E(\min_{i\in T} S_i) \]

此时我们要算的就是,任选一个种类集合 \(T\),(接着我们开始对全部卡牌进行游戏,每种卡牌都会有一个 \(S\) 值),求出这些种类里面最早被取出的那个种类的取出时间(即最小的 \(S\)的期望,求和。

当然还有系数 \((-1)^{|T|-1}\)。我们先考虑上面那个怎么求。

我们考虑计算最早取出时间为 \(i\) 的概率,与 \(i\) 相乘,然后对于 \(i=1\to \inf\) 求和。

这个东西求出来是

\[\dfrac{n}{\sum_{i\in T} c_i} \]

我们的式子变成了

\[E(\max_{i\in U} S_i)=\sum_{T\subseteq U}(-1)^{|T|-1}\times\dfrac{n}{\sum_{i\in T} c_i} \]

使用一个常见的拆贡献方法,设 \(\sum_{i\in T} c_i=x\),则原式变为

\[\sum_{x=1}^n F(x)\times\frac{n}{x} \]

\(F(x)\) 是一个负或正的整数。因为原式中全部 \(\frac{n}{x}\) 相等的项可以合并同类项。\(F(x)\) 就是合并完以后他们的系数。

如果把 \(F(x)\) 求出来就可以 \(O(n\log n)\) 计算答案了。之所以带 \(\log\) 是因为我不会线性求逆元。

现在考虑求 \(F(x)\)。因为 \(F(x)\) 表示的是若干个 \(c\) 相加为 \(x\) 的一些系数,考虑卷积。

一开始我们有 \(n\) 个数列,第 \(k\) 个数列为:

\[[0,0,...,0,1,0,0,...0,0] \]

\(c_k\) 项是 1,其余是 0。

数列第 \(x\) 项代表在 \(\{k\}\) 这一个集合的全部子集 \(T\subseteq\{k\}\) 中,\(\frac{n}{x}\) 那一项的系数,也就是 \(F(x)\),只不过我们此时的 \(U=\{k\}\)

那么显然,此时 \(T\) 只有 \(\{k\}\) 一种取法,而 \((-1)^{|T|-1}=1,\sum_{i\in T} c_i=c_k\),所以第 \(c_k\) 项是 1,其余是 0。

这一部分的关键在于,每个多项式都有隐含的一个意义:\(U\) 的元素。初始第 \(k\) 个多项式 \(U=\{k\}\)

当我们把两个数列(多项式)卷起来。设这两个数列的 \(U\) 分别等于 \(A,B\),成绩结果为 \(C\)

乘积多项式的意义应当是当 \(U=A\cup B\) 时的 \(F(x)\)

但我们发现这两个卷起来以后得到的结果是把 \(T\subseteq A\)\(T\subseteq B\) 排除在外的!

所以我们要让 \(C=A\times B+A+B\)

但我们发现,对于 \(-1\) 的处理,在 \(A\) 中一个因子是 \((-1)^{|T1|-1}\),在 \(B\) 中一个因子是 \((-1)^{|T2|-1}\),两个合起来应当为 \((-1)^{|T1|+|T2|-1}\)

但是我们做的运算是 \((-1)^{|T1|-1}\times (-1)^{|T2|-1}=(-1)^{|T1|+|T2|-1}\)

所以 \(C\) 应当等于 \(-(A\times B)+A+B\)

总结一下:

我们要求

\[\sum_{x=1}^n F(x)\times\frac{n}{x} \]

所以要求 \(F\)。每个多项式第 \(x\) 项代表对于它的 \(U\) 的全部子集 \(T\subseteq U\),对 \((-1)^{|T|-1}\times\dfrac{n}{\sum_{i\in T} c_i(=x)}\) 求和,\(\frac{n}{x}\) 那一项的系数。

此时我们把这 \(n\) 个多项式启发式卷起来。卷不是正常的卷,\(F\)\(G\) 等于 \(-(F\times G)+F+G\)

最终的多项式就是要求的 \(F(x)\)

时间 \(O(n\log^2n)\)

#include<bits/stdc++.h>
#define pb push_back
#define rg register
#define ld long double
#define ull unsigned int
#define epb emplace_back
#define getc getchar
#define putc putchar
using namespace std;
inline int re() {
	rg int x=0,p=0;rg char c=getchar();
	while(c<'0'||c>'9') (!p)?(p=c=='-'):(p=p),c=getchar();
	while('0'<=c&&c<='9') (x*=10)+=c-48,c=getchar();
	if(p) x=-x;
	return x;
}
inline void wt(rg int x) { if(x>9) wt(x/10);putc(x%10+48); }

const int N=2e5+5;
const int mod=998244353;
int n,m,c[N];
int ans;
struct poly { vector<int>f; };
bool operator<(poly a,poly b) { return a.f.size()>b.f.size(); }
priority_queue<poly>q;
const int G=3;
int qp(int a,int b=mod-2) {
	int s=1;
	while(b) {
		if(b&1) s=1ll*s*a%mod;
		b>>=1;
		a=1ll*a*a%mod;
	}
	return s;
}
const int iG=qp(G);
int gt[N<<2];
void NTT(vector<int>&a,int n,int ty) {
	for(int i=0;i<n;i++) gt[i]=(gt[i>>1]>>1)|((i&1)?(n>>1):0);
	for(int i=0;i<n;i++) if(i<gt[i]) swap(a[i],a[gt[i]]);
	for(int len=2;len<=n;len<<=1) {
		int on1=qp(ty?G:iG,(mod-1)/len);
		for(int l=0;l<n;l+=len) {
			int r=l+len-1,mid=l+r>>1;
			int yg=1;
			for(int i=l;i<=mid;i++) {
				int t=1ll*yg*a[i+len/2]%mod;
				a[i+len/2]=(a[i]-t+mod)%mod;
				a[i]=(a[i]+t)%mod;
				yg=1ll*yg*on1%mod;
			}
		}
	}
}
int main() {
	n=re(),m=re();
	for(int i=1;i<=m;i++) {
		c[i]=re();
		poly p;
		p.f.resize(c[i]+1);
		p.f[c[i]]=1;
		q.push(p);
	}
	while(q.size()>1) {
		poly f=q.top();q.pop();
		poly g=q.top();q.pop();
		poly f1=f,g1=g;
		
		int len=f.f.size()+g.f.size()-1;
		len=1<<int(log2(len-1)+1);
		int invn=qp(len);
		f.f.resize(len);g.f.resize(len);
		NTT(f.f,len,1);NTT(g.f,len,1);
		for(int i=0;i<len;i++) f.f[i]=1ll*f.f[i]*g.f[i]%mod;
		NTT(f.f,len,0);
		for(int i=0;i<len;i++) f.f[i]=1ll*f.f[i]*invn%mod;
		
		for(int i=0;i<f.f.size();i++) f.f[i]=mod-f.f[i];
		for(int i=0;i<f1.f.size();i++) f.f[i]=(f.f[i]+f1.f[i])%mod;
		for(int i=0;i<g1.f.size();i++) f.f[i]=(f.f[i]+g1.f[i])%mod;
		
		q.push(f);
	}
	vector<int>f=q.top().f;
	for(int i=1;i<=n;i++) {
		(ans+=1ll*n*qp(i)%mod*f[i]%mod)%=mod;
	}
	printf("%d",ans);
}