「解题报告」CF757G Can Bash Save the Day?

发布时间 2023-06-26 12:42:09作者: APJifengc

好好好。

求一个点到一个集合内点的距离和,这相当于在考虑若干条路径的长度,我们可以考虑使用点分治,建出点分树,这样每次查询时只需要对于这个点到根上的每个分治中心,找到所有经过这个中心的路径和即可。容易拆成每个点到分治中心的距离之和加上点数乘分治中心到查询点的距离,那么我们直接拿前缀和维护一下这个东西就行了。

带修我们可以直接上个数据结构维护,由于我们发现每次交换的是两个相邻的东西,如果我们做前缀和后,仅有一个位置是会改变的,那么我们考虑直接修改这一个位置即可。但是这样复杂度至少是 \(O(n \log^2 n)\),时间空间都不能接受。

能不能不对每个节点维护一个数据结构,而是直接对整个点分树进行维护?其实是可以的。我们首先把点分治改成边分治,因为点分树儿子太多不好处理。我们发现,边分树实际上就是一个二叉树结构,那么可以类比可持久化线段树来考虑,我们也可以实现一个可持久化边分树。当然由于我们要可持久化,我们就不能直接按照传统的跳父亲方法来修改查询,而是需要从根往下查询。这个我们可以先预处理出每个点到根的路径,然后修改与查询的时候按照这个路径走即可。

#include <bits/stdc++.h>
using namespace std;
const int MAXN = 400005;
int n, q;
vector<pair<int, int>> e[MAXN];
struct Graph {
#define forGraph(u, v) for (int i = fst[u], v = to[i]; i; i = nxt[i], v = to[i])
    int fst[MAXN], to[MAXN << 1], nxt[MAXN << 1], w[MAXN << 1], tot;
    int siz[MAXN];
    Graph() : tot(1) {}
    void add(int u, int v, int ww) { to[++tot] = v, nxt[tot] = fst[u], fst[u] = tot, w[tot] = ww; }
    void adde(int u, int v, int w) { add(u, v, w), add(v, u, w); }
    bool vis[MAXN << 1];
    int mn, ed, sz;
    void calcSize(int u, int pre) {
        siz[u] = 1;
        forGraph(u, v) if (v != pre && !vis[i]) {
            calcSize(v, u);
            siz[u] += siz[v];
            if (max(siz[v], sz - siz[v]) < mn) mn = max(siz[v], sz - siz[v]), ed = i;
        }
    }
    struct Node {
        int lc, rc;
        long long lcnt, rcnt, lsum, rsum;
    } t[MAXN * 33];
    int cnt;
    void calcRoot(int u, int s) {
        mn = INT_MAX, ed = 0, sz = s;
        calcSize(u, 0);
    }
    int lst[MAXN];
    int rt[MAXN];
    void dfs(int u, int pre, long long dis, int dir) {
        if (u <= n) {
            int now = ++cnt;
            if (!rt[u]) rt[u] = now;
            if (t[lst[u]].lc == -1) t[lst[u]].lc = now;
            else t[lst[u]].rc = now;
            if (!dir) t[now].lsum = dis, t[now].lcnt = 1, t[now].lc = -1;
            else t[now].rsum = dis, t[now].rcnt = 1, t[now].rc = -1;
            lst[u] = now;
        }
        forGraph(u, v) if (!vis[i] && v != pre) {
            dfs(v, u, dis + w[i], dir);
        }
    }
    void solve(int u, int s) {
        if (s == 1) {
            t[lst[u]].lc = t[lst[u]].rc = 0;
            return;
        }
        calcRoot(u, s);
        vis[ed] = vis[ed ^ 1] = 1;
        int x = to[ed ^ 1], y = to[ed];
        dfs(x, 0, 0, 0), dfs(y, 0, w[ed], 1);
        solve(x, s - siz[y]), solve(y, siz[y]);
    }
    void update(int &u, int d) {
        t[++cnt] = t[u], u = cnt;
        t[u].lsum += t[d].lsum, t[u].rsum += t[d].rsum, t[u].lcnt += t[d].lcnt, t[u].rcnt += t[d].rcnt;
        if (t[d].lcnt) update(t[u].lc, t[d].lc);
        if (t[d].rcnt) update(t[u].rc, t[d].rc);
    }
    long long query(int u, int x, int y) {
        if (t[u].lcnt) return query(t[u].lc, t[x].lc, t[y].lc) + (t[y].rsum - t[x].rsum) + (t[y].rcnt - t[x].rcnt) * t[u].lsum;
        if (t[u].rcnt) return query(t[u].rc, t[x].rc, t[y].rc) + (t[y].lsum - t[x].lsum) + (t[y].lcnt - t[x].lcnt) * t[u].rsum;
        return 0;
    }
} g;
int tot;
void dfs(int u, int pre) {
    int lst = 0;
    for (auto p : e[u]) if (p.first != pre) {
        int v = p.first, w = p.second;
        dfs(v, u);
        if (!lst) lst = u, g.adde(u, v, w);
        else g.adde(lst, ++tot, 0), g.adde(tot, v, w), lst = tot;
    }
}
int p[MAXN];
long long lstans;
int root[MAXN];
int main() {
    scanf("%d%d", &n, &q);
    tot = n;
    for (int i = 1; i <= n; i++) {
        scanf("%d", &p[i]);
    }
    for (int i = 1; i < n; i++) {
        int u, v, w; scanf("%d%d%d", &u, &v, &w);
        e[u].push_back({ v, w }), e[v].push_back({ u, w });
    }
    dfs(1, 0);
    g.solve(1, tot);
    for (int i = 1; i <= n; i++) {
        root[i] = root[i - 1];
        g.update(root[i], g.rt[p[i]]);
    }
    while (q--) {
        int op; scanf("%d", &op);
        if (op == 1) {
            int l, r, v; scanf("%d%d%d", &l, &r, &v);
            l ^= lstans % (1 << 30), r ^= lstans % (1 << 30), v ^= lstans % (1 << 30);
            printf("%lld\n", lstans = g.query(g.rt[v], root[l - 1], root[r]));
        } else {
            int x; scanf("%d", &x);
            x ^= lstans % (1 << 30);
            swap(p[x], p[x + 1]);
            root[x] = root[x - 1];
            g.update(root[x], g.rt[p[x]]);
        }
    }
    return 0;
}