ucup nanjing 题解

发布时间 2023-12-03 11:45:10作者: Farmer_D

比赛链接

D

收获很大的一道题
首先考虑朴素的 \(dp\),令 \(f_{x,i}\)\(x\) 子树中的每一个叶子到 \(x\) 的距离都为 \(i\) 的最小代价
不难列出 \(dp\) 式子为:\(f_{x,i}=\min\limits_{i\in \{0,1\}}\{cost(u,i)+\sum\limits_{v\in son(u)}f_{v,x-i}\}\),其中 \(cost(u,i)\) 为把 \(u\) 变成颜色 \(0/1\) 的代价( \(black=1,red=0\)

可以发现 \(f_u\) 是下凸的
因为 \(cost_{u,0/1}\) 是下凸的,凸序列按位 \(\sum\) 仍是凸的,凸序列做 \(\min +\) 卷积也是凸的

考虑一个很妙的东西:维护下凸序列的差分序列(把正差分,负差分,\(0\) 差分分开存),是单调不降的
合并时好合并的,直接按位加即可
考虑如何添加进 \(u\) 的贡献
\(col_u=1\) 时,\(f_{v,i} \to f_{x,i}\) 需要多 \(1\) 的贡献,在凸壳上分析一下可得,在差分 \(<0\) 时,这样是优的,所以直接在负差分的后面加入 \(-1\) 即可
\(col_u=0\) 时,类似分析可得,在正差分前面加入 \(1\) 即可

来分析一下时间复杂度
显然,长度为 \(a,b\) 的凸序列合并,我们只需要保留前 \(\min(a,b)\)
考虑删除元素时的贡献,即 \(\max(a,b)\) 个元素会产生 \(a+b\) 的贡献,所以分给每个元素的平均贡献不超过 \(2\)
所以总贡献是 \(O(n)\) 级别的
时间复杂度为 \(O(n)\)

#include <bits/stdc++.h>
#define F(i,x,y) for(int i=(x);i<=(y);i++)
#define DF(i,x,y) for(int i=(x);i>=(y);i--)
#define ms(x,y) memset(x,y,sizeof(x))
#define SZ(x) (int)x.size()-1
#define pb push_back
using namespace std;
typedef long long LL;
typedef unsigned long long ull;
typedef pair<int,int> pii;
template<typename T> void chkmax(T &x,T y){ x=max(x,y);}
template<typename T> void chkmin(T &x,T y){ x=min(x,y);}
inline int read(){
    int FF=0,RR=1;
    char ch=getchar();
    for(;!isdigit(ch);ch=getchar()) if(ch=='-') RR=-1;
    for(;isdigit(ch);ch=getchar()) FF=(FF<<1)+(FF<<3)+ch-48;
    return FF*RR;
}
const int N=100100;
int ans[N];
char str[N];
vector<int> G[N];
struct Node{
    int f0,negtot,s0;
    vector<int> dif1,dif2;//dif1:负差分  dif2:正差分
}f[N];
Node merge(Node x,Node y){
    Node ret;ret.f0=x.f0+y.f0,ret.s0=0,ret.negtot=0;
    int len=min(x.dif1.size()+x.dif2.size()+x.s0,y.dif1.size()+y.dif2.size()+y.s0);
    vector<int> rec(len);
    int cur=0;
    for(int v:x.dif1){ if(cur>=len) break;rec[cur++]+=v;}
    cur=min(len,cur+x.s0);
    reverse(x.dif2.begin(),x.dif2.end());
    for(int v:x.dif2){ if(cur>=len) break;rec[cur++]+=v;}
    cur=0;
    for(int v:y.dif1){ if(cur>=len) break;rec[cur++]+=v;}
    cur=min(len,cur+y.s0);
    reverse(y.dif2.begin(),y.dif2.end());
    for(int v:y.dif2){ if(cur>=len) break;rec[cur++]+=v;}
    for(int v:rec){
        if(v<0) ret.dif1.pb(v),ret.negtot+=v;
        else if(v==0) ret.s0++;
        else ret.dif2.pb(v);
    }
    reverse(ret.dif2.begin(),ret.dif2.end());
    return ret;
}
void dfs(int u){
    if(!G[u].empty()){
        dfs(G[u].back()),swap(f[u],f[G[u].back()]),G[u].pop_back();
        for(int v:G[u]) dfs(v),f[u]=merge(f[u],f[v]);
    }
    else f[u].s0=0,f[u].negtot=f[u].f0=0,f[u].dif1.clear(),f[u].dif2.clear();
    if(str[u]=='1') f[u].f0++,f[u].negtot--,f[u].dif1.pb(-1);
    else f[u].dif2.pb(1);
    // cout<<f[u].f0<<" "<<f[u].negtot<<' '<<f[u].f0<<' '<<f[u].dif1.size()<<' '<<f[u].dif2.size()<<'\n';
    ans[u]=f[u].f0+f[u].negtot;
}
void work(){
    int n=read();scanf("%s",str+1);
    F(i,1,n) G[i].clear();
    F(i,2,n){ int fa=read();G[fa].pb(i);}
    dfs(1);
    F(i,1,n) printf("%d ",ans[i]);puts("");
}
int main(){
    int T=read();
    while(T--) work();
    fprintf(stderr,"%d ms\n",int(1e3*clock()/CLOCKS_PER_SEC));
    return 0;
}