CF1797E 线段树 + 倍增 题解

发布时间 2023-04-20 18:43:14作者: Kowenxrz

Preface

有趣的一道 ds,赛后不看题解做出来了。

Solution

首先有一个性质:\(\varphi(x)\) 经过 \(\mathcal{O}(\log x)\) 次迭代后变为 \(1\)

证明:

\(x\) 为奇数,\(\varphi(x)=x\sum_{i=1}^{k}\frac{p_i-1}{p_i}\)\(p_i\) 为奇数,所以 \(p_i-1\) 为偶数,我们就能得到 \(\varphi(x)\) 也为偶数。

\(x\) 为偶数,则 \(1\)\(x-1\) 的所有偶数都与 \(x\) 不互质,所以 \(\varphi(x)\) 最大为 \(\frac{x}{2}\)

这样每次下降 \(\frac{x}{2}\)\(\mathcal{O}(\log x)\) 次后变为 \(1\)

证毕。

这样我们就可以每次倍增找两个数 \(a,b\) 经过 \(\varphi(x)\) 的迭代操作后第一次 \(a=b\) 的迭代次数,相当于是建了一棵高为 \(\log V\) 的树,\(x\) 的父亲为 \(\varphi(x)\),然后在上面找 lca,复杂度为 \(\mathcal{O}(\log \log V)\)

那么我们直接建一棵线段树,维护区间 lca、最大深度与深度之和。

对于修改操作,直接对于最大深度不为 \(1\) 的区间暴力修改,这一部分总共是均摊 \(\mathcal{O}(n\log n\log V)\) 的,再加上 pushup 里修改 lca 的 \(\mathcal{O}(\log\log V)\),复杂度为 \(\mathcal{O}(n\log n\log V\log \log V)\)

对于询问操作,合并所有点的 lca 操作为 \(\mathcal{O}(\log n\log\log V)\),答案即为区间深度和减 \(r-l+1\) 乘区间 lca 的深度,复杂度总共为 \(\mathcal{O}(m\log n\log\log V)\)

所以时间复杂度即为 \(\mathcal{O}(n\log n \log V \log \log V+m\log n \log\log V)\)

其实这个时间复杂度可以进一步优化成 \(\mathcal{O}(n\log n \log V+m\log n \log\log V)\),但是原时间复杂度已经足够通过此题,所以就没写。

Code:

#include<bits/stdc++.h>
using namespace std;
//#define int long long
#define pai 3.141592653589793238462643383279502884197169399375105820
#define pb emplace_back
#define fi first
#define se second
#define ls id << 1
#define rs id << 1 | 1
#define pc putchar
#define fre freopen("x.in", "r", stdin), freopen("x.out", "w", stdout)
#define ios ios::sync_with_stdio(false), cin.tie(0), cout.tie(0)
void upmin(int &x, int y) {x = x < y ? x : y;}
void upmax(int &x, int y) {x = x > y ? x : y;}
typedef unsigned long long ull;
typedef pair<int, int> pii;
/* --------------- fast io --------------- */ // begin
//省略fast io部分
/* --------------- fast io --------------- */ // end
const int N = 1e5 + 5, maxn = 5e6 + 5, mod = 1e9 + 7, inf = 1e9;
int n, m, va[N], dep[maxn], anc[maxn][6];
int cnt, phi[maxn], p[N << 2]; bool vis[maxn];
vector<int> e[maxn];  
set<pii> s;
void euler()
{
    phi[1] = 1;
    for (int i = 2; i < maxn; i++)
    {
        if (!vis[i]) p[++cnt] = i, phi[i] = i - 1;
        for (int j = 1; j <= cnt && 1ll * i * p[j] < maxn; j++)
        {
            vis[i * p[j]] = 1;
            if (!(i % p[j]))
            {
                phi[i * p[j]] = phi[i] * p[j];
                break;
            }
            phi[i * p[j]] = phi[i] * (p[j] - 1);
        }
    }
}
void dfs(int u, int fa)
{
	for (auto v : e[u])
	{
		if (v == fa) continue;
		dep[v] = dep[u] + 1, anc[v][0] = u;
		dfs(v, u);
	}
}
void built()
{
    for (int i = 1; i <= n; i++) 
    {
        int x = va[i];
        while (x != phi[x] && s.find(pii(x, phi[x])) == s.end()) 
        {
            e[x].pb(phi[x]), e[phi[x]].pb(x);
            s.insert(pii(x, phi[x])), x = phi[x];
        }
    }
	dep[1] = 1;
	dfs(1, 0);
	for (int j = 1; j <= 5; j++)
		for (int i = 1; i < maxn; i++)
			anc[i][j] = anc[anc[i][j - 1]][j - 1];
}
int LCA(int u, int v)
{
	if (dep[u] < dep[v]) swap(u, v);
	for (int i = 5; i >= 0; i--)
		if (dep[anc[u][i]] >= dep[v]) u = anc[u][i];
	if (u == v) return u;
	for (int i = 5; i >= 0; i--)
		if (anc[u][i] != anc[v][i]) u = anc[u][i], v = anc[v][i];
	return anc[u][0];
}
struct node {int lca, mx, sum;};
struct segt
{
    node t[N << 2];
    #define ls id << 1
    #define rs id << 1 | 1
    void pushup(int id)
    {
        t[id].lca = LCA(t[ls].lca, t[rs].lca);
        t[id].mx = max(t[ls].mx, t[rs].mx);
        t[id].sum = t[ls].sum + t[rs].sum;
    }
    void build(int id, int l, int r)
    {
        if (l == r)
        {
            t[id].lca = va[l];
            t[id].sum = t[id].mx = dep[va[l]];
            return ;
        }
        int mid = (l + r) >> 1;
        build(ls, l, mid), build(rs, mid + 1, r);
        pushup(id);
    }
    void update(int id, int l, int r, int a, int b)
    {
        if (t[id].mx == 1) return ;
        if (l == r) 
        {
            t[id].lca = phi[t[id].lca];
            t[id].sum = t[id].mx = dep[t[id].lca];
            return ;
        }
        int mid = (l + r) >> 1;
        if (a <= mid) update(ls, l, mid, a, b);
        if (b > mid) update(rs, mid + 1, r, a, b);
        pushup(id);
    }
    pii query(int id, int l, int r, int a, int b)
    {
        if (a <= l && b >= r) return pii(t[id].lca, t[id].sum);
        int mid = (l + r) >> 1;
        pii res = pii(0, 0);
        if (a <= mid) res = query(ls, l, mid, a, b);
        if (b > mid)
            if (res == pii(0, 0)) res = query(rs, mid + 1, r, a, b);
            else 
            {
                pii x = query(rs, mid + 1, r, a, b);
                res.fi = LCA(res.fi, x.fi);
                res.se = res.se + x.se;
            }
        return res;
    }
    #undef ls
    #undef rs
}t;
signed main()
{
    cin >> n >> m;
    for (int i = 1; i <= n; i++) cin >> va[i];
    euler(), built();
    t.build(1, 1, n);
    while (m--)
    {
        int op, l, r; cin >> op >> l >> r;
        if (op == 1) t.update(1, 1, n, l, r);
        else
        {
            pii x = t.query(1, 1, n, l, r);
            cout << x.se - (r - l + 1) * dep[x.fi] << endl; 
        }
    }
    return 0;
}