https://www.luogu.com.cn/problem/P5343
给定可以取的块长集合,求长度为 \(n\) 的块可以有几种分块的方案
考虑 dp,类似于背包计数
\[f_0 = 1, f_i = \sum_{j} f_{i - a_j}
\]
发现 \(n\) 很大,但是 \(a_i\) 值域上界 \(x\) 很小,于是这个过程可以用矩阵乘法优化到 \(O(x^3 \log n)\)
#include <iostream>
#include <cstdio>
#include <cmath>
#include <vector>
#include <algorithm>
#include <cstring>
#include <functional>
#define i64 long long
inline int read() {
bool sym = 0; int res = 0; char ch = getchar();
while (!isdigit(ch)) sym |= (ch == '-'), ch = getchar();
while (isdigit(ch)) res = (res << 3) + (res << 1) + (ch ^ 48), ch = getchar();
return sym ? -res : res;
}
int m, mod = 1e9 + 7;
class MATRIX {
public:
std::vector<std::vector<int>> mat;
void clear() {
mat.resize(m);
for (int i = 0; i < m; i++) {
mat[i].resize(m);
for (int j = 0; j < m; j++) {
mat[i][j] = 0;
}
}
}
void init() {
mat.resize(m);
for (int i = 0; i < m; i++) {
mat[i].resize(m);
for (int j = 0; j < m; j++) {
mat[i][j] = i == j;
}
}
}
MATRIX operator * (const MATRIX a) {
MATRIX b; b.mat.resize(m);
for (int i = 0; i < m; i++) {
b.mat[i].resize(m);
for (int j = 0; j < m; j++) {
b.mat[i][j] = 0;
for (int k = 0; k < m; k++) {
b.mat[i][j] += 1ll * mat[i][k] * a.mat[k][j] % mod;
while (b.mat[i][j] >= mod) b.mat[i][j] -= mod;
}
}
}
return b;
}
};
MATRIX qpow(MATRIX a, i64 x) {
MATRIX res; res.init();
while (x) {
if (x & 1) res = res * a; a = a * a; x >>= 1;
}
return res;
}
int main() {
i64 k; std::cin >> k;
int dat[200] = {0}, vis[200] = {0};
m = read();
for (int i = 1; i <= m; i++) dat[i] = read();
std::sort(dat + 1, dat + m + 1);
for (int i = 1; i <= m; i++) {
while (dat[i] == dat[i + 1]) i++; vis[dat[i]]++;
}
m = read();
for (int i = 1; i <= m; i++) dat[i] = read();
std::sort(dat + 1, dat + m + 1);
for (int i = 1; i <= m; i++) {
while (dat[i] == dat[i + 1]) i++; vis[dat[i]]++;
}
int n = 0;
for (int i = 1; i <= 100; i++) if (vis[i] == 2) dat[++n] = i;
m = dat[n];
MATRIX a; a.clear(); a.mat[0][0] = 1;
for (int i = 1; i < m; i++) {
for (int j = 1; j <= n; j++) {
if (i - dat[j] < 0) break;
a.mat[0][i] += a.mat[0][i - dat[j]];
while (a.mat[0][i] >= mod) a.mat[0][i] -= mod;
}
}
if (k < m) {
printf("%d\n", a.mat[0][k]); return 0;
}
MATRIX b; b.clear();
for (int i = 1; i <= n; i++) b.mat[m - dat[i]][m - 1] = 1;
for (int i = 1; i <= m - 1; i++) b.mat[i][i - 1] = 1;
MATRIX res = a * qpow(b, k - m + 1);
printf("%d", res.mat[0][m - 1]);
return 0;
}