「解题报告」CF768G The Winds of Winter

发布时间 2023-06-03 19:38:29作者: APJifengc

真的不难,为啥是 3300*。还是模拟赛 T3,很气啊,为什么不先看这个题。

首先贪心很容易发现一定是将当前子树大小最大的那棵树的某个子树移动到最小的那个树内。那么我们记移动的这个子树的大小为 \(x\),所有树中最小的树大小为 \(a\),最大的为 \(c\),次大的为 \(b\),那么我们就是在最小化 \(\max\{a + x, b, c - x\}\)。这东西显然当 \(x = \frac{c - a}{2}\) 的时候取得最小值,但是 \(x\) 并不是什么都能取的,必须是这棵树中的某个子树的大小。那么我们数据结构维护一下某个子树内的所有树大小,那么我们就只需要找前驱后继即可。具体维护就直接用线段树合并即可。

然后就没东西了。具体维护可能有点细节,懒得写了。

#include <bits/stdc++.h>
using namespace std;
const int MAXN = 100005;
int n;
vector<int> e[MAXN];
int siz[MAXN];
int rt;
void dfs1(int u, int pre) {
    siz[u] = 1;
    for (int v : e[u]) if (v != pre) {
        dfs1(v, u);
        siz[u] += siz[v];
    }
}
struct SegmentTree {
    struct Node {
        int cnt, lc, rc;
    } t[MAXN * 250];
    int tot;
    void pushUp(int p) {
        t[p].cnt = t[t[p].lc].cnt + t[t[p].rc].cnt;
    }
    void insert(int d, int &p, int l = 1, int r = n) {
        if (!p) p = ++tot, t[p].cnt = 1;
        else t[++tot] = t[p], p = tot, t[p].cnt++;
        if (l == r) return;
        int mid = (l + r) >> 1;
        if (d <= mid) insert(d, t[p].lc, l, mid);
        else insert(d, t[p].rc, mid + 1, r);
        pushUp(p);
    }
    int merge(int x, int y, int l = 1, int r = n) {
        if (!x || !y) return x + y;
        if (l == r) {
            int z = ++tot;
            t[z].cnt = t[x].cnt + t[y].cnt;
            return z;
        }
        int mid = (l + r) >> 1;
        int z = ++tot;
        t[z].lc = merge(t[x].lc, t[y].lc, l, mid);
        t[z].rc = merge(t[x].rc, t[y].rc, mid + 1, r);
        pushUp(z);
        return z;
    }
    int pre(int v, int p, int l, int r) {
        if (v < l) return 0;
        if (!t[p].cnt) return 0;
        if (r <= v) {
            if (l == r) return l;
            int mid = (l + r) >> 1;
            return pre(v, t[p].rc, mid + 1, r) ?: pre(v, t[p].lc, l, mid);
        }
        int mid = (l + r) >> 1;
        if (v <= mid) return pre(v, t[p].lc, l, mid);
        return pre(v, t[p].rc, mid + 1, r) ?: pre(v, t[p].lc, l, mid);
    }
    int pre(int v, int p1, int p2, int p3, int l, int r) {
        if (v < l) return 0;
        if (!(t[p1].cnt - t[p2].cnt - t[p3].cnt)) return 0;
        if (r <= v) {
            if (l == r) return l;
            int mid = (l + r) >> 1;
            return pre(v, t[p1].rc, t[p2].rc, t[p3].rc, mid + 1, r) ?: pre(v, t[p1].lc, t[p2].lc, t[p3].lc, l, mid);
        }
        int mid = (l + r) >> 1;
        if (v <= mid) return pre(v, t[p1].lc, t[p2].lc, t[p3].lc, l, mid);
        return pre(v, t[p1].rc, t[p2].rc, t[p3].rc, mid + 1, r) ?: pre(v, t[p1].lc, t[p2].lc, t[p3].lc, l, mid);
    }
    int suf(int v, int p, int l, int r) {
        if (v > r) return 0;
        if (!t[p].cnt) return 0;
        if (v <= l) {
            if (l == r) return l;
            int mid = (l + r) >> 1;
            return suf(v, t[p].lc, l, mid) ?: suf(v, t[p].rc, mid + 1, r);
        }
        int mid = (l + r) >> 1;
        if (v > mid) return suf(v, t[p].rc, mid + 1, r);
        return suf(v, t[p].lc, l, mid) ?: suf(v, t[p].rc, mid + 1, r);
    }
    int suf(int v, int p1, int p2, int p3, int l, int r) {
        if (v > r) return 0;
        if (!(t[p1].cnt - t[p2].cnt - t[p3].cnt)) return 0;
        if (v <= l) {
            if (l == r) return l;
            int mid = (l + r) >> 1;
            return suf(v, t[p1].lc, t[p2].lc, t[p3].lc, l, mid) ?: suf(v, t[p1].rc, t[p2].rc, t[p3].rc, mid + 1, r);
        }
        int mid = (l + r) >> 1;
        if (v > mid) return suf(v, t[p1].rc, t[p2].rc, t[p3].rc, mid + 1, r);
        return suf(v, t[p1].lc, t[p2].lc, t[p3].lc, l, mid) ?: suf(v, t[p1].rc, t[p2].rc, t[p3].rc, mid + 1, r);
    }
} st;
int root[MAXN];
void dfs2(int u, int pre) {
    st.insert(siz[u], root[u]);
    for (int v : e[u]) if (v != pre) {
        dfs2(v, u);
        root[u] = st.merge(root[u], root[v]);
    }
}
int ans[MAXN];
void dfs3(int u, int pre, int rt) {
    vector<pair<int, int>> x; 
    if (pre) x.push_back({ n - siz[u], 0 });
    for (int v : e[u]) if (v != pre) {
        x.push_back({ siz[v], v });
    }
    sort(x.begin(), x.end());
    if (x.size() == 1 || x.back().first == x[x.size() - 2].first) {
        ans[u] = x.back().first;
    } else {
        int a = x.front().first, b = x[x.size() - 2].first, c = x.back().first;
        int p = (c - a) / 2;
        ans[u] = c;
        if (x.back().second == 0) {
            int l = st.pre(p, root[::rt], root[u], rt, 1, n);
            int r = st.suf(p, root[::rt], root[u], rt, 1, n);
            int x = st.pre(p + siz[u], rt, 1, n);
            int y = st.suf(p + siz[u], rt, 1, n);
            if (l) ans[u] = min(ans[u], max({ a + l, b, c - l }));
            if (r) ans[u] = min(ans[u], max({ a + r, b, c - r }));
            if (x && x - siz[u] >= 0) ans[u] = min(ans[u], max({ a + (x - siz[u]), b, c - (x - siz[u]) }));
            if (y && y - siz[u] >= 0) ans[u] = min(ans[u], max({ a + (y - siz[u]), b, c - (y - siz[u]) }));
        } else {
            int l = st.pre(p, root[x.back().second], 1, n);
            int r = st.suf(p, root[x.back().second], 1, n);
            if (l) ans[u] = min(ans[u], max({ a + l, b, c - l }));
            if (r) ans[u] = min(ans[u], max({ a + r, b, c - r }));
        }
    }
    st.insert(siz[u], rt);
    for (int v : e[u]) if (v != pre) {
        dfs3(v, u, rt);
    }
}
int main() {
    scanf("%d", &n);
    if (n == 1) {
        printf("0\n");
        return 0;
    }
    for (int i = 1; i <= n; i++) {
        int u, v; scanf("%d%d", &u, &v);
        if (!u || !v) rt = u | v;
        else e[u].push_back(v), e[v].push_back(u);
    }
    dfs1(rt, 0), dfs2(rt, 0);
    dfs3(rt, 0, 0);
    for (int i = 1; i <= n; i++) {
        printf("%d\n", ans[i]);
    }
    return 0;
}