洛谷 P5343 【XR-1】分块

发布时间 2023-03-22 21:11:52作者: zrzring

https://www.luogu.com.cn/problem/P5343

给定可以取的块长集合,求长度为 \(n\) 的块可以有几种分块的方案

考虑 dp,类似于背包计数

\[f_0 = 1, f_i = \sum_{j} f_{i - a_j} \]

发现 \(n\) 很大,但是 \(a_i\) 值域上界 \(x\) 很小,于是这个过程可以用矩阵乘法优化到 \(O(x^3 \log n)\)

#include <iostream>
#include <cstdio>
#include <cmath>
#include <vector>
#include <algorithm>
#include <cstring>
#include <functional>

#define i64 long long

inline int read() {
	bool sym = 0; int res = 0; char ch = getchar();
	while (!isdigit(ch)) sym |= (ch == '-'), ch = getchar();
	while (isdigit(ch)) res = (res << 3) + (res << 1) + (ch ^ 48), ch = getchar();
	return sym ? -res : res;
}

int m, mod = 1e9 + 7;

class MATRIX {
public:
	std::vector<std::vector<int>> mat;
	void clear() {
		mat.resize(m);
		for (int i = 0; i < m; i++) {
			mat[i].resize(m);
			for (int j = 0; j < m; j++) {
				mat[i][j] = 0;
			}
		}
	}
	void init() {
		mat.resize(m);
		for (int i = 0; i < m; i++) {
			mat[i].resize(m);
			for (int j = 0; j < m; j++) {
				mat[i][j] = i == j;
			}
		}
	}
	MATRIX operator * (const MATRIX a) {
		MATRIX b; b.mat.resize(m);
		for (int i = 0; i < m; i++) {
			b.mat[i].resize(m);
			for (int j = 0; j < m; j++) {
				b.mat[i][j] = 0;
				for (int k = 0; k < m; k++) {
					b.mat[i][j] += 1ll * mat[i][k] * a.mat[k][j] % mod;
					while (b.mat[i][j] >= mod) b.mat[i][j] -= mod;
				}
			}
		}
		return b;
	}
};

MATRIX qpow(MATRIX a, i64 x) {
	MATRIX res; res.init();
	while (x) {
		if (x & 1) res = res * a; a = a * a; x >>= 1;
	}
	return res;
}

int main() {
	i64 k; std::cin >> k;
	int dat[200] = {0}, vis[200] = {0};
	m = read();
	for (int i = 1; i <= m; i++) dat[i] = read();
	std::sort(dat + 1, dat + m + 1);
	for (int i = 1; i <= m; i++) {
		while (dat[i] == dat[i + 1]) i++; vis[dat[i]]++;
	}
	m = read();
	for (int i = 1; i <= m; i++) dat[i] = read();
	std::sort(dat + 1, dat + m + 1);
	for (int i = 1; i <= m; i++) {
		while (dat[i] == dat[i + 1]) i++; vis[dat[i]]++;
	}
	int n = 0;
	for (int i = 1; i <= 100; i++) if (vis[i] == 2) dat[++n] = i;
	m = dat[n];
	MATRIX a; a.clear(); a.mat[0][0] = 1;
	for (int i = 1; i < m; i++) {
		for (int j = 1; j <= n; j++) {
			if (i - dat[j] < 0) break;
			a.mat[0][i] += a.mat[0][i - dat[j]];
			while (a.mat[0][i] >= mod) a.mat[0][i] -= mod;
		}
	}
	if (k < m) {
		printf("%d\n", a.mat[0][k]); return 0;
	}
	MATRIX b; b.clear();
	for (int i = 1; i <= n; i++) b.mat[m - dat[i]][m - 1] = 1;
	for (int i = 1; i <= m - 1; i++) b.mat[i][i - 1] = 1;
	MATRIX res = a * qpow(b, k - m + 1);
	printf("%d", res.mat[0][m - 1]);
	return 0;
}