Solution Set - 矩阵加速

发布时间 2023-06-02 20:03:41作者: by_chance

A[洛谷P4719]一棵树,点有权,单点修改,求最大权独立集。
B[洛谷P6021]一棵树,点有权,单点修改,求在某棵子树中选出一些点,使得所有叶子与根不连通的最小权值和。
C[洛谷P5024]一棵树,点有权,给定某两个点的选择状况,求最小权覆盖集。


动态DP:(通常在树上)用矩阵刻画DP转移。做树链剖分,然后对每个点记录轻儿子的转移矩阵之积。修改时从一个点向上跳,修改每条轻边的贡献;询问是把1所在重链的矩阵相乘。

ABC都是模板。C要注意撤销修改时的顺序。


点击查看A题代码
#include<bits/stdc++.h>
using namespace std;
const int N=1e5+5,INF=1e9;
int n,m,a[N],f[N][2];
int head[N],nxt[N<<1],ver[N<<1],tot;
void add(int u,int v){ver[++tot]=v;nxt[tot]=head[u];head[u]=tot;}
int fa[N],top[N],L[N],dfn,R[N],siz[N],son[N],id[N];
struct Matrix{
	int a[2][2];
	Matrix (){memset(a,0,sizeof(a));}
	Matrix operator *(const Matrix &b)const{
		Matrix c;
		for(int i=0;i<2;i++)
			for(int j=0;j<2;j++)
				for(int k=0;k<2;k++)
					c.a[i][j]=max(c.a[i][j],a[i][k]+b.a[k][j]);
		return c;
	}
}g[N];
struct SegmentTree{
	Matrix a[N<<2];
	#define ls p<<1
	#define rs p<<1|1
	#define mid (l+r>>1)
	void build(int p,int l,int r){
		if(l==r){a[p]=g[id[l]];return;}
		build(ls,l,mid);
		build(rs,mid+1,r);
		a[p]=a[ls]*a[rs];
	}
	Matrix query(int p,int l,int r,int L,int R){
		if(l>=L&&r<=R)return a[p];
		if(R<=mid)return query(ls,l,mid,L,R);
		if(L>mid)return query(rs,mid+1,r,L,R);
		return query(ls,l,mid,L,R)*query(rs,mid+1,r,L,R);
	}
	void modify(int p,int l,int r,int x){
		if(l==r){a[p]=g[id[l]];return;}
		if(x<=mid)modify(ls,l,mid,x);
		else modify(rs,mid+1,r,x);
		a[p]=a[ls]*a[rs];
	}
	#undef ls
	#undef rs
	#undef mid
}seg;
void dfs(int u){
	f[u][1]=a[u];siz[u]=1;
	for(int i=head[u],v;i;i=nxt[i])
		if((v=ver[i])!=fa[u]){
			fa[v]=u;
			dfs(v);siz[u]+=siz[v];
			if(siz[v]>siz[son[u]])son[u]=v;
			f[u][0]+=max(f[v][0],f[v][1]);
			f[u][1]+=f[v][0];
		}
}
void rdfs(int u,int tp){
	g[u].a[1][0]=a[u];g[u].a[1][1]=-INF;
	L[u]=R[u]=++dfn;R[tp]=dfn;id[dfn]=u;top[u]=tp;
	if(son[u])rdfs(son[u],tp);
	for(int i=head[u],v;i;i=nxt[i])
		if((v=ver[i])!=fa[u]&&v!=son[u]){
			rdfs(v,v);
			g[u].a[0][0]+=max(f[v][0],f[v][1]);
			g[u].a[1][0]+=f[v][0];
		}
	g[u].a[0][1]=g[u].a[0][0];
}
void update(int u,int val){
	g[u].a[1][0]+=val-a[u];a[u]=val;
	while(u){
		Matrix lst=seg.query(1,1,n,L[top[u]],R[top[u]]);
		seg.modify(1,1,n,L[u]);
		Matrix now=seg.query(1,1,n,L[top[u]],R[top[u]]);
		u=fa[top[u]];
		g[u].a[0][0]+=max(now.a[0][0],now.a[1][0])-max(lst.a[0][0],lst.a[1][0]);
		g[u].a[1][0]+=now.a[0][0]-lst.a[0][0];
		g[u].a[0][1]=g[u].a[0][0];
	}
}
int main(){
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;i++)scanf("%d",a+i);
	for(int i=1,u,v;i<n;i++){
		scanf("%d%d",&u,&v);
		add(u,v);add(v,u);
	}
	dfs(1);rdfs(1,1);seg.build(1,1,n);
	for(int i=1,u,val;i<=m;i++){
		scanf("%d%d",&u,&val);
		update(u,val);
		Matrix ans=seg.query(1,1,n,1,R[1]);
		printf("%d\n",max(ans.a[0][0],ans.a[1][0]));
	}
	return 0;
}
点击查看B题代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=2e5+5,INF=1e9+7;
int n,m;ll a[N],f[N];char op;
int head[N],nxt[N<<1],ver[N<<1],tot;
void add(int u,int v){ver[++tot]=v;nxt[tot]=head[u];head[u]=tot;}
struct Matrix{
	ll a[2][2];
	Matrix (){memset(a,0x3f,sizeof(a));}
	Matrix operator *(const Matrix&b)const{
		Matrix c;
		for(int i=0;i<2;i++)
			for(int j=0;j<2;j++)
				for(int k=0;k<2;k++)
					c.a[i][j]=min(c.a[i][j],a[i][k]+b.a[k][j]);
		return c;
	}
}g[N];
int fa[N],siz[N],top[N],L[N],dfn,R[N],son[N],id[N];
void dfs(int u){
	siz[u]=1;
	for(int i=head[u],v;i;i=nxt[i])
		if((v=ver[i])!=fa[u]){
			fa[v]=u;dfs(v);siz[u]+=siz[v];
			if(siz[v]>siz[son[u]])son[u]=v;
			f[u]+=min(f[v],a[v]);
		}
	if(!nxt[head[u]])f[u]=INF;
}
void rdfs(int u,int tp){
	g[u].a[0][1]=a[u];g[u].a[0][0]=g[u].a[1][1]=0;
	L[u]=++dfn;id[dfn]=u;R[tp]=dfn;top[u]=tp;
	if(son[u])rdfs(son[u],tp);
	else g[u].a[0][0]=INF;
	for(int i=head[u],v;i;i=nxt[i])
		if((v=ver[i])!=fa[u]&&v!=son[u]){
			rdfs(v,v);
			g[u].a[0][0]+=min(f[v],a[v]);
		}
}
struct SegmentTree{
	Matrix a[N<<2];
	#define ls p<<1
	#define rs p<<1|1
	#define mid (l+r>>1)
	void build(int p,int l,int r){
		if(l==r){a[p]=g[id[l]];return;}
		build(ls,l,mid);
		build(rs,mid+1,r);
		a[p]=a[ls]*a[rs];
	}
	Matrix query(int p,int l,int r,int L,int R){
		if(l>=L&&r<=R)return a[p];
		if(R<=mid)return query(ls,l,mid,L,R);
		if(L>mid)return query(rs,mid+1,r,L,R);
		return query(ls,l,mid,L,R)*query(rs,mid+1,r,L,R);
	}
	void modify(int p,int l,int r,int x){
		if(l==r){a[p]=g[id[l]];return;}
		if(x<=mid)modify(ls,l,mid,x);
		else modify(rs,mid+1,r,x);
		a[p]=a[ls]*a[rs];
	}
	#undef ls
	#undef rs
	#undef mid
}seg;
void update(int u,int val){
	g[u].a[0][1]+=val;a[u]+=val;
	while(u){
		Matrix lst=seg.query(1,1,n,L[top[u]],R[top[u]]);
		seg.modify(1,1,n,L[u]);
		Matrix now=seg.query(1,1,n,L[top[u]],R[top[u]]);
		u=fa[top[u]];
		g[u].a[0][0]+=min(now.a[0][0],now.a[0][1])-min(lst.a[0][0],lst.a[0][1]);
	}
}
int main(){
	scanf("%d",&n);
	for(int i=1;i<=n;i++)scanf("%lld",a+i);
	for(int i=1,u,v;i<n;i++){
		scanf("%d%d",&u,&v);
		add(u,v);add(v,u);
	}
	dfs(1);rdfs(1,1);seg.build(1,1,n);
	scanf("%d",&m);
	for(int i=1,u,val;i<=n;i++){
		while(op=getchar(),op!='Q'&&op!='C');
		scanf("%d",&u);
		if(op=='Q'){
			Matrix ans=seg.query(1,1,n,L[u],R[top[u]]);
			printf("%lld\n",min(ans.a[0][0],ans.a[0][1]));
		}
		else{scanf("%d",&val);update(u,val);}
	}
	return 0;
}

点击查看C题代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=1e5+5;const ll INF=1e18;
int n,m,p[N];ll f[N][2];char type[5];
int head[N],nxt[N<<1],ver[N<<1],tot;
void add(int u,int v){ver[++tot]=v;nxt[tot]=head[u];head[u]=tot;}
struct Matrix{
	ll a[2][2];
	Matrix (){a[0][0]=a[0][1]=a[1][0]=a[1][1]=INF;}
	Matrix operator *(const Matrix&b)const{
		Matrix c;
		c.a[0][0]=min(a[0][0]+b.a[0][0],a[0][1]+b.a[1][0]);
		c.a[1][0]=min(a[1][0]+b.a[0][0],a[1][1]+b.a[1][0]);
		c.a[0][1]=min(a[0][0]+b.a[0][1],a[0][1]+b.a[1][1]);
		c.a[1][1]=min(a[1][0]+b.a[0][1],a[1][1]+b.a[1][1]);
		return c;
	}
}g[N];
int fa[N],top[N],siz[N],L[N],dfn,R[N],son[N],id[N];
void dfs(int u){
	f[u][1]=p[u];siz[u]=1;
	for(int i=head[u],v;i;i=nxt[i])
		if((v=ver[i])!=fa[u]){
			fa[v]=u;dfs(v);siz[u]+=siz[v];
			if(siz[v]>siz[son[u]])son[u]=v;
			f[u][0]+=f[v][1];
			f[u][1]+=min(f[v][0],f[v][1]);
		}
}
void rdfs(int u,int tp){
	g[u].a[0][0]=p[u];g[u].a[1][0]=0;
	L[u]=++dfn;id[dfn]=u;top[u]=tp;R[tp]=dfn;
	if(son[u])rdfs(son[u],tp);
	for(int i=head[u],v;i;i=nxt[i])
		if((v=ver[i])!=fa[u]&&v!=son[u]){
			rdfs(v,v);
			g[u].a[0][0]+=min(f[v][0],f[v][1]);
			g[u].a[1][0]+=f[v][1];
		}
	g[u].a[0][1]=g[u].a[0][0];
}
struct SegmentTree{
	Matrix a[N<<2];
	#define mid (l+r>>1)
	void build(int p,int l,int r){
		if(l==r){a[p]=g[id[l]];return;}
		build(p<<1,l,mid);
		build(p<<1|1,mid+1,r);
		a[p]=a[p<<1]*a[p<<1|1];
	}
	void modify(int p,int l,int r,int x){
		if(l==r){a[p]=g[id[l]];return;}
		if(x<=mid)modify(p<<1,l,mid,x);
		else modify(p<<1|1,mid+1,r,x);
		a[p]=a[p<<1]*a[p<<1|1];
	}
	Matrix query(int p,int l,int r,int L,int R){
		if(l>=L&&r<=R)return a[p];
		if(R<=mid)return query(p<<1,l,mid,L,R);
		if(L>mid)return query(p<<1|1,mid+1,r,L,R);
		return query(p<<1,l,mid,L,R)*query(p<<1|1,mid+1,r,L,R);
	}
	#undef mid
}seg;
void update(int u,int op,Matrix c){
	if(op==1)g[u].a[1][0]=INF;
	else if(op==2)g[u]=c;
	else g[u].a[0][0]=g[u].a[0][1]=INF;
	while(u){
		Matrix lst=seg.query(1,1,n,L[top[u]],R[top[u]]);
		seg.modify(1,1,n,L[u]);
		Matrix now=seg.query(1,1,n,L[top[u]],R[top[u]]);
		u=fa[top[u]];
		g[u].a[0][0]+=min(now.a[0][0],now.a[1][0])-min(lst.a[0][0],lst.a[1][0]);
		g[u].a[1][0]+=now.a[0][0]-lst.a[0][0];g[u].a[0][1]=g[u].a[0][0];
	}
}
int main(){
//	freopen("defense.in","r",stdin);
//	freopen("defense.out","w",stdout);
	scanf("%d%d",&n,&m);scanf("%s",type);
	for(int i=1;i<=n;i++)scanf("%d",p+i);
	for(int i=1,u,v;i<n;i++){
		scanf("%d%d",&u,&v);
		add(u,v);add(v,u);
	}
	dfs(1);rdfs(1,1);seg.build(1,1,n);
	for(int i=1,a,x,b,y;i<=m;i++){
		scanf("%d%d%d%d",&a,&x,&b,&y);
		Matrix tmpa=g[a];update(a,x,tmpa);
		Matrix tmpb=g[b];update(b,y,tmpb);
		Matrix ans=seg.query(1,1,n,1,R[1]);
		ll res=min(ans.a[0][0],ans.a[1][0]);
		printf("%lld\n",res>1e15?-1:res);
		update(b,2,tmpb);update(a,2,tmpa);
	}
	return 0;
}