「GYM103470G」Paimon's Tree

发布时间 2023-08-04 16:22:53作者: _kkio

树上区间dp。

由于dp转移跟左右端点有关,考虑怎样转移端点。

左右端点只有被染色了才能被转移,那就多记个两维,表示左右端点是否已经被染色就好了。

\(dp_{u,v,t,0/1,0/1}\) 表示左右端点 \(u\)\(v\) 当前已经染了 \(t\) 个点,左右端点染色四种情况的路径长度。

#include <bits/stdc++.h>
using namespace std;
#define int long long
const int maxn=155,inf=1e18;
int n;
int f[maxn][maxn][maxn][4],a[maxn];
int fp[maxn][maxn];
int siz[maxn][maxn];
vector<int> G[maxn];
void dfs(int u,int F)
{
	siz[F][u]=1;
    for(int v:G[u])
        if(v!=fp[F][u])
        {
            fp[F][v]=u;
            dfs(v,F);
            siz[F][u]+=siz[F][v];
        }
}
inline void upd(int u,int v,int t,int k,int w)
{if(f[u][v][t][k]<w)f[u][v][t][k]=w;}
void solve()
{
    scanf("%lld",&n);n++;
    for(int i=1;i<n;i++)scanf("%lld",&a[i]);
    for(int i=1;i<=n;i++)G[i].clear();
    for(int i=1;i<n;i++)
    {
        static int u,v;scanf("%lld%lld",&u,&v);
        G[u].push_back(v);G[v].push_back(u);
    }
    for(int i=1;i<=n;i++)for(int j=1;j<=n;j++)fp[i][j]=siz[i][j]=0;
    for(int i=1;i<=n;i++)dfs(i,i);
    for(int i=1;i<=n;i++)for(int j=1;j<=n;j++)for(int t=0;t<=n;t++)for(int k=0;k<=3;k++)f[i][j][t][k]=-inf;
	for(int i=1;i<=n;i++)f[i][i][0][3]=0;
	for(int t=0;t<n-1;t++)
		for(int k=3;k>=0;k--)
   			for(int u=1;u<=n;u++)
   				for(int v=1;v<=n;v++)
			   	{
					if(f[u][v][t][k]==-inf)continue;
					int fu=fp[v][u],fv=fp[u][v],w=f[u][v][t][k],g=(u==v)?0:n-siz[v][u]-siz[u][v]+__builtin_popcount(k)-1;
					if(t==n-1)continue;
					if(k==3)
					{
						for(int su:G[u])if(su!=fu)upd(su,v,t,1,w);
						for(int sv:G[v])if(sv!=fv)upd(u,sv,t,2,w);
						if(t+1<=g)upd(u,v,t+1,3,w);
					}
					if(k==2)
					{
						for(int su:G[u])if(su!=fu)upd(su,v,t,0,w);
						upd(u,v,t+1,3,w+a[t+1]);
						if(t+1<=g)upd(u,v,t+1,2,w);
					}
					if(k==1)
					{
						for(int sv:G[v])if(sv!=fv)upd(u,sv,t,0,w);
						upd(u,v,t+1,3,w+a[t+1]);
						if(t+1<=g)upd(u,v,t+1,1,w); 
					}
					if(k==0)
					{
						upd(u,v,t+1,1,w+a[t+1]);
						upd(u,v,t+1,2,w+a[t+1]);
						if(t+1<=g)upd(u,v,t+1,0,w);
					}
			   	}
	int ans=-1;
	for(int i=1;i<=n;i++)
		for(int j=1;j<=n;j++)
			ans=max(ans,f[i][j][n-1][3]);
    printf("%lld\n",ans);
    return;
}

signed main()
{
    int T;
    scanf("%lld",&T);
    while(T--)solve();
    return 0;
}