AtCoder Beginner Contest 335 G Discrete Logarithm Problems

发布时间 2024-01-13 19:18:54作者: zltzlt

洛谷传送门

AtCoder 传送门

考虑若我们对于每个 \(a_i\) 求出来了使得 \(g^{b_i} \equiv a_i \pmod P\)\(b_i\)(其中 \(g\)\(P\) 的原根),那么 \(a_i^k \equiv a_j \pmod P\) 等价于 \(kb_i \equiv b_j \pmod{P - 1}\),有解的充要条件是 \(\gcd(b_i, P - 1) \mid b_j\)

显然我们不可能对于每个 \(a_i\) 都求出来 \(b_i\)。注意到我们只关心 \(c_i = \gcd(b_i, P - 1)\),而 \(c_i\) 为满足 \(a_i^{c_i} \equiv 1 \pmod P\) 的最小正整数。若求出 \(c_i\) 则等价于统计 \(c_i \mid c_j\) 的对数。于是问题变成求出 \(c_i\)

因为我们一定有 \(a_i^{P - 1} \equiv 1 \pmod P\),所以 \(c_i\) 一定为 \(P - 1\) 的因数。所以我们初始令 \(c_i = P - 1\),然后对 \(P - 1\) 分解质因数,依次让 \(c_i\) 试除 \(P - 1\) 的每个质因子,判断除完后是否还有 \(a_i^{c_i} \equiv 1 \pmod P\) 即可。这部分复杂度大概是 \(O(n \log^2 P)\) 的。

问题还剩下统计 \(c_i \mid c_j\) 的对数。因为 \(c_i\)\(P - 1\) 的因数,所以我们可以做一遍 Dirichlet 后缀和求出 \(f_x\) 表示满足 \(x \mid c_i\)\(i\) 的个数。最后遍历 \(c_i\) 统计即可。

总时间复杂度大概是 \(O(n \log^2 P + m \log m \log P)\),其中 \(m\)\(P - 1\) 因数个数。

code
// Problem: G - Discrete Logarithm Problems
// Contest: AtCoder - AtCoder Beginner Contest 335 (Sponsored by Mynavi)
// URL: https://atcoder.jp/contests/abc335/tasks/abc335_g
// Memory Limit: 1024 MB
// Time Limit: 5000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))

using namespace std;
typedef long long ll;
typedef __int128 lll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;

const int maxn = 200100;

ll n, m, a[maxn], tot, T, c[maxn], tt, f[maxn];
pii b[maxn];

inline ll qpow(ll b, ll p, const ll &mod) {
	ll res = 1;
	while (p) {
		if (p & 1) {
			res = (lll)res * b % mod;
		}
		b = (lll)b * b % mod;
		p >>= 1;
	}
	return res;
}

void solve() {
	scanf("%lld%lld", &n, &m);
	ll x = m - 1;
	for (ll i = 1; i * i <= x; ++i) {
		if (x % i) {
			continue;
		}
		c[++tt] = i;
		if (i * i != x) {
			c[++tt] = x / i;
		}
	}
	sort(c + 1, c + tt + 1);
	for (ll i = 2; i * i <= x; ++i) {
		if (x % i == 0) {
			ll cnt = 0;
			while (x % i == 0) {
				x /= i;
				++cnt;
			}
			b[++tot] = mkp(i, cnt);
		}
	}
	if (x > 1) {
		b[++tot] = mkp(x, 1);
	}
	for (int i = 1; i <= n; ++i) {
		scanf("%lld", &a[i]);
		ll x = m - 1;
		for (int j = 1; j <= tot; ++j) {
			for (int _ = 0; _ < b[j].scd; ++_) {
				if (qpow(a[i], x / b[j].fst, m) == 1) {
					x /= b[j].fst;
				}
			}
		}
		a[i] = x;
		++f[lower_bound(c + 1, c + tt + 1, a[i]) - c];
	}
	for (int i = 1; i <= tot; ++i) {
		for (int j = tt; j; --j) {
			if (c[j] % b[i].fst) {
				continue;
			}
			ll x = lower_bound(c + 1, c + tt + 1, c[j] / b[i].fst) - c;
			f[x] += f[j];
		}
	}
	ll ans = 0;
	for (int i = 1; i <= n; ++i) {
		ans += f[lower_bound(c + 1, c + tt + 1, a[i]) - c];
	}
	printf("%lld\n", ans);
}

int main() {
	int T = 1;
	// scanf("%d", &T);
	while (T--) {
		solve();
	}
	return 0;
}