P3352 [ZJOI2016] 线段树 思考--zhengjun

发布时间 2023-07-23 13:59:50作者: A_zjzj

有一个显然的 \(O(n^3q)\) 的做法:

  • \(f_{i,l,r,x}\) 表示 \(i\) 次操作过后,区间 \([l,r]\) 的数 \(\le x\)\(a_{l-1},a_{r+1}>x\) 的方案数。

  • 转移:$$f_{i,l,r,x}=f_{i-1,l,r,x}\times g_{l,r}+\sum\limits_{j<l}f_{i-1,j,r,x}\times(j-1)+\sum\limits_{j>r}f_{i-1,l,j,x}\times(n-j)$$

  • 其中 \(g_{l,r}=\frac{(r-l+1)(r-l+2)}{2}+\frac{l(l-1)}{2}+\frac{(n-r)(n-r+1)}{2}\),为无效操作的方案数。

  • 最后算答案只需差分一下即可:$$ans_i=\sum\limits_{l\le i\le r}b_x\times \sum \limits_{x}f_{q,l,r,x}-f_{q,l,r,x-1}$$

  • 其中,\(b_x\) 为离散化过后的值域数组。

当然,有一个数据随机的限制,复杂度实际上为期望 \(O(n^2q)\),可过。

但是,其实可以做到任意数据下 \(O(n^2q)\),需要一点技巧。

开始推狮子:

\[ans_i=\sum\limits_{l\le i\le r}b_x\times \sum \limits_{x}f_{q,l,r,x}-f_{q,l,r,x-1}\\ =\sum\limits_{l\le i\le r}\sum\limits_{x}f_{q,l,r,x}\times (b_x-b_{x+1}) \]

发现转移其实和 \(x\) 没有关系,所以考虑优化这一维。

\(f'_{i,l,r}=\sum\limits_{x}f_{i,l,r,x}\times(b_x-b_{x+1})\)

那么 \(ans_i=\sum\limits_{l\le i\le r}f'_{q,l,r}\)

转移变成了:

\[f'_{i,l,r}=f'_{i-1,l,r}\times g_{l,r}+\sum\limits_{j<l}f'_{i-1,j,r}\times(j-1)+\sum\limits_{j>r}f'_{i-1,l,j}\times(n-j) \]

复杂度即可降为 \(O(n^2q)\)

代码

#include<bits/stdc++.h>
using namespace std;
using ll=long long;
const int N=4e2+10,mod=1e9+7;
int n,m,q,a[N],b[N];
int mx[N][N],f[2][N][N],g[N][N],h[N][N];
int calc(int x){
	return x*(x+1)/2;
}
int main(){
	freopen(".in","r",stdin);
	//freopen(".out","w",stdout);
	cin>>n>>q;
	for(int i=1;i<=n;i++)cin>>a[i];
	copy(a,a+1+n,b),sort(b+1,b+1+n),m=unique(b+1,b+1+n)-b-1;
	for(int i=1;i<=n;i++)a[i]=lower_bound(b+1,b+1+m,a[i])-b;
	for(int i=1;i<=n;i++){
		mx[i][i]=a[i];
		for(int j=i+1;j<=n;j++)mx[i][j]=max(mx[i][j-1],a[j]);
	}
	a[0]=a[n+1]=m+1;
	for(int i=1;i<=n;i++){
		for(int j=i;j<=n;j++){
			if(mx[i][j]<min(a[i-1],a[j+1]))
				f[0][i][j]=(b[mx[i][j]]-b[min(a[i-1],a[j+1])]+mod)%mod;
		}
	}
	for(int i=1,now=1,las=0;i<=q;i++,swap(now,las)){
		for(int l=1;l<=n;l++)for(int r=n;r>=l;r--){
			g[l][r]=(g[l-1][r]+f[las][l][r]*(l-1ll))%mod;
			h[l][r]=(h[l][r+1]+1ll*f[las][l][r]*(n-r))%mod;
		}
		for(int l=1;l<=n;l++)for(int r=l;r<=n;r++){
			f[now][l][r]=(1ll*f[las][l][r]*(calc(r-l+1)+calc(l-1)+calc(n-r))+g[l-1][r]+h[l][r+1])%mod;
		}
	}
	for(int i=1;i<=n;i++){
		int ans=0;
		for(int l=1;l<=i;l++){
			for(int r=i;r<=n;r++){
				(ans+=f[q&1][l][r])%=mod;
			}
		}
		cout<<ans<<' ';
	}
	return 0;
}