浅谈莫队

发布时间 2024-01-01 14:06:41作者: 单南松

莫队

[SDOI2009] HH的项链

这道题是卡莫队的,但是确实练习莫队的好题。

首先想一下暴力:
直接暴力枚举询问,然后再枚举区间,这样是O(n^2)的;

想一下优化:
如果说询问是按照 左端点递增 && 右端点递增 的;
那么我们就可以离线排序,用线性的时间扫过去所有询问,用桶记录一下就行,同时记录答案;

但是可能没有这种好情况...可能是两个端点不能同时递增的,最多保证一个递增。
那么在这种情况下要怎么优化呢?

我们可以折中一下:(莫队思想:离线 + 分块 + 双指针)
把询问的左端点按照分块的思想分成sqrt(n)块,然后每一块里面的询问[l,r]按照右端点递增排序!!!

!!!这里我们暂停一下!!!理清楚一个关键点:
在块中:我们询问区间右端点是递增的,但是我们区间左端点不是递增的!!!
在块间:则相反,左端点是递增的,而右端点不是递增的!!!
(正因为这样才需要指针的移动 和 回滚,才有了下面的奇偶优化!)

然后我们定义两个指针i(向r靠齐),j(向l靠齐)
然后开始遍历排好序的询问,每个询问[l,r]我们的i,j就移动,分别向l,r靠齐;
同时用一个桶记录一下数字出现的个数,并更新答案就行!
#include <bits/stdc++.h>

#define int unsigned
#define rint register int
#define endl '\n'

using namespace std;

const int N = 1e6 + 5;
const int M = 3e6 + 5;

struct node
{
    int id, l, r;
} q[N];

int n, m, len = 1;
int a[N], ans[M];
int cnt[M];

int get(int i){return i / len;}

bool cmp(node a, node b)
{
    int l = get(a.l), r = get(b.l);
    if (l != r) return l < r;
    return a.r > b.r;
}

void add(int w, int &res)
{
    if (++cnt[w] == 1) res++;
}

void del(int w, int &res)
{
    if (--cnt[w] == 0) res--;
}

int read() 
{
    register int x = 0,f = 1;
	register char ch;
    ch = getchar();
    while(ch > '9' || ch < '0'){if(ch == '-') f = -f;ch = getchar();}
    while(ch <= '9' && ch >= '0'){x = x * 10 + ch - 48;ch = getchar();}
    return x * f;
}


signed main()
{
    n = read();
    
    for (rint i = 1; i <= n; i++)
    {
		a[i] = read();
	}
    
    m = read();
    unsigned lll = 1;
    len = max(lll, (int)sqrt((double)n * n / m));
    
	for (rint i = 1; i <= m; i++)
    {
        int l = read(), r = read();
        q[i] = {i, l, r};
    }
    
    sort(q + 1, q + m + 1, cmp);
    
    int res = 0;
	for (rint k = 1, i = 1, j = 0; k <= m; k++)
    {
        int id = q[k].id, l = q[k].l, r = q[k].r;
        while (j < r) add(a[++j], res);
        while (j > r) del(a[j--], res);
        while (i < l) del(a[i++], res);
        while (i > l) add(a[--i], res);
        ans[id] = res;
    }
    
    for (rint i = 1; i <= m; i++)
    {
		cout << ans[i] << endl;
	}
        
    return 0;
}

[国家集训队] 数颜色

由于增加修改操作,考虑给每一个修改操作按先后顺序编号

将一维改为二维,纵坐标为时间 $time$ ,横坐标为位置 $l,r$

当时间为 $k$ 时,表示已经经过了前 $k$ 个修改操作

若序列长度为 $n$ ,分块后每段长度为 $a$ ,共有 $\frac{n}{a}$ 块, $m$ 次询问, $t$ 次修改(根据数据范围,取合适的 $len$ )

$i$ 表示左指针, $j$ 表示右指针, $time$ 表示时间戳

排序标准为:
1. 比较左端点所在的块,从左到右排序
2. 若左端点所在块相同,则比较右端点所在块,从从左到右排序
3. 若右端点所在块也相同,则比较时间 $times$ ,从左到右排序

指针 $time$ 移动次数: $\frac{n^2t}{a^2}$
指针 $i$ 移动次数: $am+n$
指针 $j$ 移动次数: $am+\frac{n^2}{a}$

#include <bits/stdc++.h>

#define rint register int
#define int long long
#define endl '\n'

using namespace std;

const int N = 1e6 + 5;
const int M = 1e7 + 5;

int n, m;
int times, idx, len;
int a[N], ans[N], cnt[M];

struct node
{
    int l, r, t, id;
} q[N];

struct Modify
{
    int p, c;
} f[N];

int get(int i){return i / len;}

bool cmp(node a, node b)
{
    int al = get(a.l), bl = get(b.l);
    int ar = get(a.r), br = get(b.r);
    if (al != bl) return al < bl;
    if (ar != br) return ar < br;
    return a.t < b.t;
}

void add(int w, int &res)
{
    if (++cnt[w] == 1) res++;
}

void del(int w, int &res)
{
    if (--cnt[w] == 0) res--;
}

signed main()
{
    cin >> n >> m;
    
    for (rint i = 1; i <= n; i++)
    {
		cin >> a[i];
	}
        
    for (rint i = 1; i <= m; i++)
    {
        char op[2];
        int a, b;
        scanf("%s%lld%lld", op, &a, &b);
        if (*op == 'Q') idx++, q[idx] = {a, b, times, idx};
        else f[++times] = {a, b};
    }
    
    len = cbrt((double)n * max(times, 1ll)) + 1;
    
    sort(q + 1, q + idx + 1, cmp);
    
    for (int k = 1, i = 1, j = 0, t = 0, res = 0; k <= idx; k++)
    {
        int l = q[k].l, r = q[k].r, tim = q[k].t, id = q[k].id;
        while (t < tim)
        {
            t++;
            if (f[t].p >= i && f[t].p <= j)
            {
                add(f[t].c, res);
                del(a[f[t].p], res);
            }
            swap(a[f[t].p], f[t].c);
        }
        while (t > tim)
        {
            if (f[t].p >= i && f[t].p <= j)
            {
                add(f[t].c, res);
                del(a[f[t].p], res);
            }
            swap(a[f[t].p], f[t].c);
            t--;
        }
        while (i < l) del(a[i++], res);
        while (i > l) add(a[--i], res);
        while (j < r) add(a[++j], res);
        while (j > r) del(a[j--], res);
        ans[id] = res;
    }
    
    for (rint i = 1; i <= idx; i++)
    {
		cout << ans[i] << endl;
	}
        
    return 0;
}

AcWing 2523. 历史研究

回滚莫队又称不删除莫队

用于维护一些不删除属性的操作

例如最大值,加入一个数后只需取一次max,删除一个数却很难维护

设序列长度为$n$,每块长度为$a$,共有$\frac{n}{a}$块,$m$次询问

  1. 循环$1\leq i \leq \frac{n}{a}$,找到所有左端点在第$i$块的询问$l$,$r$
  2. 若$r$也在第$i$块,那么就暴力求,时间复杂度$O(am)$
  3. 否则,右端点用指针$j$维护,左端点每次回到第i块的右端,暴力求,时间复杂度$O(am+\frac{n^2}{a})$
  4. #include <bits/stdc++.h>
    
    #define rint register int
    #define int long long
    #define endl '\n'
    
    using namespace std;
    
    const int N = 1e6 + 5;
    
    int n, m, len;
    int w[N], cnt[N];
    int ans[N];
    vector<int> nums;
    
    struct node
    {
        int id, l, r;
    } q[N];
    
    int get(int i){return i / len;}
    
    bool cmp(node a, node b)
    {
        int l = get(a.l), r = get(b.l);
        if (l != r) return l < r;
        return a.r < b.r;
    }
    
    void add(int x, int &res)
    {
        cnt[x]++;
        res = max(res, nums[x] * cnt[x]);
    }
    
    signed main()
    {
        cin >> n >> m;
        
        len = sqrt(n);
    
        for (rint i = 1; i <= n; i++)
        {
    		cin >> w[i];
    		nums.push_back(w[i]);
    	}
            
        sort(nums.begin(), nums.end());
        nums.erase(unique(nums.begin(), nums.end()), nums.end());
        
        for (rint i = 1; i <= n; i++)
        {
            w[i] = lower_bound(nums.begin(), nums.end(), w[i]) - nums.begin();		
    	}
    
        for (rint i = 0; i < m; i++)
        {
            int l, r;
    		cin >> l >> r;
            q[i] = {i, l, r};
        }
    
        sort(q, q + m, cmp);
    
        for (rint x = 0; x < m;)
        {
            int y = x;
            while (y < m && get(q[x].l) == get(q[y].l)) y++;
            int right = get(q[x].l) * len + len - 1;
            while (x < y && q[x].r <= right)
            {
                int res = 0;
                int id = q[x].id, l = q[x].l, r = q[x].r;
                for (rint i = l; i <= r; i++) add(w[i], res);
                ans[id] = res;
                for (rint i = l; i <= r; i++) cnt[w[i]]--;
                x++;
            }
            int res = 0;
            int i = right + 1, j = right;
            while (x < y)
            {
                int id = q[x].id, l = q[x].l, r = q[x].r;
                while (j < r) add(w[++j], res);
                int backup = res;
                while (i > l) add(w[--i], res);
                ans[id] = res;
                while (i < right + 1) cnt[w[i++]]--;
                res = backup;
                x++;
            }
            memset(cnt, 0, sizeof cnt);
        }
    
        for (rint i = 0; i < m; i++)
        {
    		cout << ans[i] << endl;
    	}
    
        return 0;
    }  
    

    SP10707

    先将整棵树的欧拉序求出来,记录每个点第一次出现的位置 $first[i]$ 和最后一次出现的位置 $last[i]$,然后观察树中的路径 $[l,r](first[l]<first[r])$ 可以发现两种情况:

    1. 如果路径是一条从上往下的直链,则其所有点对应欧拉序中 $first[l]$ 到 $first[r]$ 中出现一次的点
    2. 否则其所有点对应欧拉序中 $first[l]$ 到 $last[r]$ 中出现一次的点加上 $lca(l,r)$

    理解一下会发现的确这样,然后问题就转化为普通莫队问题了

      #include <bits/stdc++.h>
      
      #define rint register int
      #define int long long
      #define endl '\n'
      #define queue queue__
      
      using namespace std;
      
      const int N = 2e6 + 5;
      const int M = 1e7 + 5;
      
      int n, m, len;
      int h[N], e[M], ne[M], idx;
      int w[N], seq[N], first[N], last[N], top;
      int queue[N], dep[N], fa[N][25], cnt[N];
      int ans[N];
      bool st[N];
      vector<int> nums;
      
      struct node
      {
          int id, l, r, p;
      } q[N];
      
      void add(int a, int b)
      {
          e[++idx] = b, ne[idx] = h[a], h[a] = idx;
      }
      
      int get(int i){return i / len;}
      
      bool cmp(node a, node b)
      {
          int l = get(a.l), r = get(b.l);
          if (l != r) return l < r;
          return a.r < b.r;
      }
      
      void dfs(int x, int father)
      {
          seq[++top] = x;
          first[x] = top;
          for (rint i = h[x]; i; i = ne[i])
          {
              int y = e[i];
              if (y != father)
              {
                  dfs(y, x);			
      		}
          }
          seq[++top] = x;
          last[x] = top;
      }
      
      void bfs(int s)
      {
          memset(dep, -1, sizeof dep);
          int hh = 0, tt = 0;
          queue[0] = s, dep[s] = 0;
          while (hh <= tt)
          {
              int x = queue[hh++];
              for (rint i = h[x]; i; i = ne[i])
              {
                  int y = e[i];
                  if (dep[y] == -1)
                  {
                      dep[y] = dep[x] + 1;
                      fa[y][0] = x;
      				queue[++tt] = y;
                      for (rint j = 1; j <= 20; j++)
                      {
                          fa[y][j] = fa[fa[y][j - 1]][j - 1];					
      				}
                  }
              }
          }
      }
      
      int lca(int a, int b)
      {
          if (dep[a] < dep[b]) swap(a, b);
          for (rint i = 20; i >= 0; i--)
              if (dep[fa[a][i]] >= dep[b])
                  a = fa[a][i];
          if (a == b) return a;
          for (rint i = 20; i >= 0; i--)
              if (fa[a][i] != fa[b][i])
                  a = fa[a][i], b = fa[b][i];
          return fa[a][0];
      }
      
      void change(int x, int &res)
      {
          st[x] ^= 1;
          if (st[x] == 0)
          {
              cnt[w[x]]--;
              if (!cnt[w[x]]) res--;
          }
          else
          {
              if (!cnt[w[x]]) res++;
              cnt[w[x]]++;
          }
      }
      
      signed main()
      {
          cin >> n >> m;
          
          for (rint i = 1; i <= n; i++)
          {
      		cin >> w[i];
      		nums.push_back(w[i]);
      	}
      	
          sort(nums.begin(), nums.end());
          nums.erase(unique(nums.begin(), nums.end()), nums.end());
          
          for (rint i = 1; i <= n; i++)
          {
              w[i] = lower_bound(nums.begin(), nums.end(), w[i]) - nums.begin();		
      	}
      
          for (rint i = 1; i < n; i++)
          {
              int a, b;
              cin >> a >> b;
              add(a, b);
      		add(b, a);
          }
          
          bfs(1);
      	dfs(1, 1);
          
      	len = sqrt(top);
      	
          for (rint i = 1; i <= m; i++)
          {
              int x, y;
              cin >> x >> y;
              if (first[x] > first[y]) swap(x, y);
              int p = lca(x, y);
              if (x == p) q[i] = {i, first[x], first[y]};
              else q[i] = {i, last[x], first[y], p};
          }
      
          sort(q + 1, q + m + 1, cmp);
      
          for (rint k = 1, i = 1, j = 0, res = 0; k <= m; k++)
          {
              int id = q[k].id, l = q[k].l, r = q[k].r, p = q[k].p;
              while (i < l) change(seq[i++], res);
              while (i > l) change(seq[--i], res);
              while (j < r) change(seq[++j], res);
              while (j > r) change(seq[j--], res);
              if (p) change(p, res);
              ans[id] = res;
              if (p) change(p, res);
          }
      
          for (rint i = 1; i <= m; i++)
          {
      		cout << ans[i] << endl;
      	}
      
          return 0;
      }  
      

      AcWing 2535. 二次离线莫队

      本题有若干个询问,每个询问都要求出某个区间中异或和在二进制表示中有 \(k\)\(1\) 的数对个数。

      我们规定,如果某两个数的异或和在二进制表示中有 \(k\)\(1\),我们就称这两个数是配对的,因此每个询问就变成了求某个区间中有多少数对是配对的。

      本题需要用到二次离线莫队来做,而二次离线莫队就是一共需要离线两次来做,在做莫队时,我们每次都会对一段区间去查询一个数,然后我们都会去对两个端点进行移动,然后在新的维护区间中去求我们的询问。

      而对于二次离线莫队,就是当我们每次更新完维护区间之后,对于区间的询问很难算,所以我们需要在每次更新完维护区间之后,再把当前询问单独拎出来再重新离线求当前询问的值。二次离线算法思维难度一般不高,但是代码实现中的细节非常多,且每道题都需要重新思考,可以说是非常恶心。

      要想使用莫队算法来求,那么每次我们都需要从上一个询问区间 \([l, r]\) 的信息快速得到当前询问区间 \([L, R]\) 的信息。以右端点 \(r\) 为例,当 \(r\) 右移后,我们需要将 \(w_{r+1}\) 加入到维护区间中,那么我们就需要考虑将 \(w_{r+1}\) 加入到维护区间中后,维护区间的信息该怎么去维护。当我们将 \(w_{r+1}\) 加入后,需要求一下它对于配对的数量有什么样的影响,显然配对的数量只会增加,至于增加的数量,就是 \([l, r]\) 中和 \(w_{r+1}\) 配对的数的个数,这一步可以用前缀和来求。

      设 $S_i$ 表示 $w_1 \sim w_i$ 中有多少个数和 $w_{r+1}$ 配对,此时对于 $[l, r]$ 中和 $w_{r+1}$ 配对的数的个数就是 $S_r-S_{l-1}$。接下来就是要求 $S_r$ 和 $S_{l-1}$,对于 $S_r$,其实是问 $w_1 \sim w_r$ 中有多少个数和 $w_{r+1}$ 配对,可以发现要询问的数就是区间的后一个数,而 $S_{l-1}$ 则没有这么好的性质,$w_{l-1}$ 和 $w_{r+1}$ 之间的距离是非常随机的,毫无规律可循,因此 $S_r$ 和 $S_{l-1}$ 其实是两类询问,分两种情况来考虑。

      首先对于 \(S_r\),由于这一部分是非常有规律的,所以可以提前预处理,设 \(f_i\) 表示 \(w_1 \sim w_i\) 中与 \(w_{i+1}\) 配对的数个数,而 \(S_r\) 显然就是 \(f_i\),因此我们需要快速的预处理出 \(f_i\),可以用一个 \(g_x\) 表示前 \(i\) 个数中有多少个数与 \(x\) 配对。当我们把 \(g_x\) 预处理出来,则 \(f_i = g_{w_{i+1}}\)

      因此我们现在就是需要求出 \(i\) 阶段的 \(g_x\),假设当前 \(g_x\) 表示前 \(i-1\) 个数中有多少个数与 \(x\) 配对,这里我们可以先预处理出 \(0 \sim 2^{14}-1\) 中所有有 \(k\)\(1\) 的数 \(y_i\),我们想从前 \(i-1\) 个数的 \(g_x\) 变成前 \(i\) 个数的 \(g_x\),相当于是加入了一个新的数 \(w_i\),此时我们只要找出所有和 \(w_i\) 配对的 \(x\),令 \(g_x+1\),最终就能得到前 \(i\) 个数的 \(g_x\),而我们要找的所有 \(x\) 则必须满足 \(w_i~xor~x=y_i\),这个条件等价于 \(x=w_i~xor~y_i\),因此我们可以枚举所有不同的 \(y_i\),通过 \(y_i~xor~w_i\) 计算出所有的 \(x\),因为 \(y_i\) 不同,所以得出的 \(x\) 也不同。

      综上所述,对于前 \(i-1\) 个数的 \(g_x\),我们枚举所有 \(y_i\),令 \(g_{y_i~xor~w_i}+1\),最终就能得到前 \(i\) 个数的 \(g_x\),然后再用 \(g_x\) 计算 \(f_i\) 即可。这样我们就能用 \(g_x\) 作为辅助递推得出所有的 \(f_i\)。这一部分预处理一共只需要做一次,由于 \(k\) 比较小,\(y_i\) 最多只有三千多个,因此预处理的计算量最多只有三千多万。

      接下来我们需要想办法求出 $S_{l-1}$,$S_r$ 通过我们刚才的分析,我们可以在做莫队的过程中在线求出来,但是 $S_{l-1}$ 并不能马上求出来,因此我们只能先将所有要求 $S_{l-1}$ 的问题先找出来,然后我们再离线把所有要求的 $S_{l-1}$ 求出来,最后我们才能求出 $S_r - S_{l-1}$。

      要求 \(S_{l-1}\),其实就是求 \(w_1 \sim w_{l-1}\) 中有多少个数是和 \(w_{r+1}\) 配对的,以此类推,在从 \(r\) 移动到 \(R\) 的过程中,我们会将 \(r+1,r+2,&hellip;,R\) 都加入到维护区间中,因此对于 \(\forall x \in [r+1,R]\),都要求出 \(w_1 \sim w_{l-1}\) 中和 \(w_x\) 配对的数的个数,可以发现这些问题都是问某个固定前缀中,某个区间的每个数和这个前缀中有多少个数配对。我们可以将这些询问全部找出来,然后按照从前往后的顺序计算所有询问,我们先算一下所有前缀是 \(1\) 的询问,再算一下所有前缀是 \(2\) 的询问,以此类推,直到我们算完所有前缀是 \(l-1\) 的询问后,我们就将所有的询问都处理完了。

      由于我们是从前往后做所有询问,因此每一次前缀中只会增加一个数,因此这里我们同样可以用一个 \(g_x\) 数组来作为辅助,表示的内容和上面相同,同样表示前 \(i\) 个数中与 \(x\) 配对的数的个数,因此对于 \(\forall x \in [r+1,R]\),我们想求的 \(w_1 \sim w_{l-1}\) 中和 \(x\) 配对的数的个数恰好就是 \(g_x\),直接按照上面更新 \(g_x\) 的思路依次往后求即可。而这一部分的询问数量应该取决于两个指针移动的次数,这个在基础莫队中就已经证明过是 \(O(\sqrt{n})\) 级别的,因此我们就能用一个 \(O(n\sqrt{n})\) 的离线做法求出所有的 \(S_{l-1}\),然后就能把这部分在莫队中无法解决的问题统一计算出来。

      到此我们就能将前一个询问到当前询问的增量 \(S_r-S_{l-1}\) 求出来,但是这并不是当前询问的答案,如果我们想求某一个询问的结果的话,还需要将前面求出来的所有增量累加成前缀和才是最终答案。

      注意,上面我们推导了 \(r\) 向右移动到 \(R\) 这一种情况,实际上两个指针一共有四种情况,而其他三种情况都按照上面同样的形式去分析即可,代码实现时需要根据每种情况的区别做一些更细致的处理,这里就不过多赘述,直接体现在代码中。

      #include <bits/stdc++.h>
      
      #define rint register int
      #define int long long
      #define endl '\n'
      
      using namespace std;
      
      const int N = 1e5 + 5;
      
      int n, m, k, len;
      int w[N], g[N], f[N];
      int ans[N];
      struct node
      {
          int id, l, r, t;
          int res;
      } q[N];
      vector<node> range[N];
      
      int get(int i){return i / len;}
      
      bool cmp(node a, node b)
      {
          int l = get(a.l), r = get(b.l);
          if (l != r) return l < r;
          return a.r < b.r;
      }
      
      bool count(int i)
      {
          int res = 0;
          for (rint j = 0; j < 14; j++)
              if (i >> j & 1)
                  res++;
          return res == k;
      }
      
      signed main()
      {
          cin >> n >> m >> k;
          
          for (rint i = 1; i <= n; i++)
          {
      		cin >> w[i];
      	}
      	
          for (rint i = 1; i <= m; i++)
          {
              int l, r;
              cin >> l >> r;
              q[i] = {i, l, r};
          }
      
          vector<int> nums;
          for (rint i = 0; i < (1 << 14); i++)
          {
              if (count(i))
              {
      			nums.push_back(i);
      		}
      	}
                  
          for (rint i = 1; i <= n; i++)
          {
              for (auto y : nums) g[w[i] ^ y]++;
              f[i] = g[w[i + 1]];
          }
      
          len = sqrt(n);
          sort(q + 1, q + m + 1, cmp);
      
          for (rint i = 1, L = 1, R = 0; i <= m; i++)
          {
              int l = q[i].l, r = q[i].r;
              
              if (R < r) range[L - 1].push_back({i, R + 1, r, -1});
              while (R < r) q[i].res += f[R++];
              
              if (R > r) range[L - 1].push_back({i, r + 1, R, 1});
              while (R > r) q[i].res -= f[--R];
              
              if (L < l) range[R].push_back({i, L, l - 1, -1});
              while (L < l) q[i].res += f[L - 1] + !k, L++;
              
              if (L > l) range[R].push_back({i, l, L - 1, 1});
              while (L > l) q[i].res -= f[L - 2] + !k, L--;
          }
      
          memset(g, 0, sizeof g);
          
          for (rint i = 1; i <= n; i++)
          {
              for (auto y : nums) g[w[i] ^ y]++;
              for (auto &rg : range[i])
              {
                  int id = rg.id, l = rg.l, r = rg.r, t = rg.t;
                  for (rint x = l; x <= r; x++)
                  {
                      q[id].res += t * g[w[x]];				
      			}
              }
          }
      
          for (rint i = 2; i <= m; i++)
          {
              q[i].res += q[i - 1].res;		
      	}
      
          for (rint i = 1; i <= m; i++)
          {
              ans[q[i].id] = q[i].res;
      	}
      	
          for (rint i = 1; i <= m; i++)
          {
      		cout << ans[i] << endl;
      	}
      
          return 0;
      }