【大联盟】20230713 T1 方向矩阵(rect) 题解 CF1666A 【Admissible Map】

发布时间 2023-07-25 11:10:32作者: zhaohaikun

题目描述

here

题解

赛时得分:60/100。

想到了正解,但调不出来,就改写暴力了。。。

首先,我们把问题转化成每个点都入度为 \(1\)

我们考虑合法子串只有两种形式:

注意到 UD,要么同时出现,要么同时不出现,因为如果存在 U,就说明 U 所在这一行得到度数减少了,一定需要上一行 D 来弥补。

  1. 不存在 UD。答案形如 RLRL...RLRL,这是好统计的。

  2. 存在 UD。考虑第一个 U 一定会匹配第一个未被 LR 得到入度的点。这是因为,首先,由于有 UD,则一定会有未被 LR 得到入度的点,那这个点肯定会被 U 相连,因为如果被 D 相连,那这个 D 之前肯定会有点未被 LR 得到入度,因为边数 < 点数,与我们定义的第一个未被 LR 得到入度的点相矛盾。

    现在,我们就得到了宽度 \(L\)。然后,我们考虑哈希来判断是否合法(赛时想到的是 bitset 巨难写……)。

    然后,我们考虑按 \(L\) 根号分治:

    1. 对于 \(L\le\sqrt{n}\),由于不超过 \(\sqrt{n}\) 种,所以我们预处理出哈希值,然后使用 unordered_map 求答案。

    2. 对于 \(L>\sqrt{n}\),由于答案不超过 \(\sqrt{n}\),所以我们可以暴力往后跳,判断答案。

时间复杂度 \(O(n\sqrt{n})\)

代码

由于 CF 上 \(n\le 2\times 10^4\),模拟赛虽然 \(n\le 10^5\),但听说数据很水,所以直接写了个 \(>\sqrt{n}\) 的部分就摆了。

#include <bits/stdc++.h>
#define SZ(x) (int) x.size() - 1
#define ms(x, y) memset(x, y, sizeof x)
#define all(x) x.begin(), x.end()
#define F(i, x, y) for (int i = (x); i <= (y); ++i)
#define DF(i, x, y) for (int i = (x); i >= (y); --i)
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
template <typename T> inline void chkmax(T &x, T y) { x = max(x, y); }
template <typename T> inline void chkmin(T &x, T y) { x = min(x, y); }
template <typename T> inline void read(T &x) {
	x = 0; int f = 1; char ch = getchar();
	for (; !isdigit(ch); ch = getchar()) if (ch == '-') f = -1;
	for (; isdigit(ch); ch = getchar()) x = (x << 1) + (x << 3) + (ch ^ 48);
	x *= f;
}
const int N = 1e5 + 10, base = 312349, MOD = 1000000021;
int n, lst, pw[N], sum[4][N], ss[N];
ll ans;
int power(int x, int y = MOD - 2) {
	int ans = 1;
	for (; y; x = (ll) x * x % MOD, y >>= 1)
		if (y & 1) ans = (ll) ans * x % MOD;
	return ans;
}
void add(int &x, int y) { if ((x += y) >= MOD) x -= MOD; }
signed main() {
	// freopen("rect.in", "r", stdin);
	// freopen("rect.out", "w", stdout);
	string st; cin >> st;
	n = st.size(); st = ' ' + st;
	int invbase = power(base);
	pw[0] = 1;
	F(i, 1, n) {
		pw[i] = (ll) pw[i - 1] * base % MOD;
		ss[i] = (ss[i - 1] + pw[i]) % MOD;
		F(j, 0, 3) sum[j][i] = sum[j][i - 1];
		if (st[i] == 'U') add(sum[0][i], pw[i]);
		if (st[i] == 'D') add(sum[1][i], pw[i]);
		if (st[i] == 'L') add(sum[2][i], pw[i]);
		if (st[i] == 'R') add(sum[3][i], pw[i]);
	}
	for (int i = 2; i <= n; i += 2) {
		if (st[i] == 'L' && st[i - 1] == 'R') lst++;
		else lst = 0;
		ans += lst;
	}
	lst = 0;
	for (int i = 3; i <= n; i += 2) {
		if (st[i] == 'L' && st[i - 1] == 'R') lst++;
		else lst = 0;
		ans += lst;
	}
	int pos = 1, pp = 1;
	F(i, 1, n) {
		chkmax(pos, i), chkmax(pp, i);
		while (pos <= n && st[pos] != 'U') pos++;
		int tp = - 1;
		if (st[i + 1] == 'L') {
			tp = pp;
			pp = i;
		}
		while (pp <= n && (st[pp + 1] == 'L' || (pp != i && st[pp - 1] == 'R'))) pp++;
		if (pos > n) break;
		if (pos == i) continue;
		if (pp >= pos) continue;
		int len = pos - pp;
		// if (len > B) {
			int inv = power(invbase, len), pw = power(base, len);
			for (int l = i, r = i + len - 1; r <= n; l += len, r += len) {
				if (st[l] == 'L' || st[r] == 'R') break;
				int val = (((ll) (sum[0][r] - sum[0][i - 1]) * inv + (ll) (sum[1][r] - sum[1][i - 1]) * pw + (ll) (sum[2][r] - sum[2][i - 1]) * invbase + (ll) (sum[3][r] - sum[3][i - 1]) * base) % MOD + MOD) % MOD;
				if (r >= pos && val == (ss[r] - ss[i - 1] + MOD) % MOD) ans++;
			}
		// }
		if (~tp) pp = tp;
	}
	cout << ans;
	return 0;
}