【CUDA】GPU编程实现NTT算法

发布时间 2023-06-03 21:23:30作者: tudouuuuu

怎么有人选题迟了么得FFT啊。好久没更新博客了,来水一发!

参考资料:
NTT:https://oi-wiki.org/math/poly/ntt/
CUDA实现FFT并行计算:https://blog.csdn.net/Liadrinz/article/details/106695275

无任何优化,纯实现。

观察到只需要对蝶形变化部分和factor进行改动即可,于是可得:

__device__ int get_R(int x, int bit_size) {
    int ans = 0;
    for (int i = 0; i < bit_size; i++) {
        ans <<= 1;
        ans |= x & 1;
        x >>= 1;
    }
    return ans;
}

// 蝴蝶操作, 输出结果直接覆盖原存储单元的数据, factor是旋转因子
__device__ void butterfly(long long *a, long long *b, long long factor) {
    long long a1 = ((*a) + factor * (*b) % mod) % mod;
    long long b1 = ((*a) - factor * (*b) % mod + mod) % mod;
    *a = a1;
    *b = b1;
}

__device__ long long ksm(long long base, int times) {
    long long p = 1;
    while (times) {
        if (times & 1) p = (p * base) % mod;
        base = base * base % mod;
        times >>= 1;
    }
    return p;
}

__global__ void NTT(long long nums[], long long result[], int limit, int bit_size) {
    int tid = threadIdx.x + blockDim.x * blockIdx.x;
    if (tid >= limit) return;
    for (int k = 1; tid + k < limit; k <<= 1) {
        long long Wn = ksm(mod_g, (mod - 1) / (k << 1));
        if (!(tid & k)) {
            // printf("tid=%d, k=%d Wn=%lld\n", tid, k, ksm(Wn, tid % k));
            butterfly(&nums[get_R(tid, bit_size)], &nums[get_R(tid + k, bit_size)], ksm(Wn, tid % k));
            //printf("nums[%d] = %lld, nums[%d] = %lld\n", tid, nums[tid], tid+k, nums[tid+k]);
        }
        __syncthreads();
    }
    result[tid] = nums[tid];
}

但是相较于GPU编程的FFT算法,其需要存在一个快速幂的过程,因此其复杂度应为\(O(logn * log P)\),其中\(P\)为模。