NTT笔记

发布时间 2023-05-04 21:19:23作者: wang-holmes

NTT 笔记

前言:
这个算法是与FFT 类似的,本片不会再从头讲起,建议先去补补课《FFT 笔记》
本文只会讲一下互相关联的地方与一些不同的地方。

建议:在电脑前放好演算纸和笔。

注:本篇文章是我这个小蒟弱写的,真正的dalao请看个玩笑便好,不必争论对错(但是欢迎指出文章存在的小错误)。

NTT 有什么用

与FFT 一样,快速数论变换(NTT)可以在\(O(n\log n)\)完成两个多项式相乘问题。

为什么要用 NTT

在一些\(O(n^2)\)会爆,但是数据很大有需要一个数的时候,显然FFT就不太适用了,就需要用到NTT。

下面补充了一些数论知识,以便你能看懂下文。

补充芝士
1.剩余系:所有整数模正整数n得到的结果组成的集合称为n的剩余系,n的剩余系即小于n的非负整数的集合,记为 \(Z_n\)

2.简化剩余系:在n的剩余系中与n互质的元素的集合,称为n的简化剩余系,记为 \(Z_n^*\)

3.欧拉函数:n的简化剩余系中元素的个数,称为欧拉函数,记为\(\phi(n)\)

4.阶:$g,n $互质,令 \(g^x\%n=1\)成立的最小的正整数 \(x\),称为 \(g\)\(n\)的阶。

5.原根:对于互质的两个正整数\(g\)\(n\),如果\(g\)\(n\)的阶为\(\phi(n)\),则称 \(g\)\(n\)的原根。换句话说,即对于 \(1 \leq j < \phi(n)\),使得 \(g^{j} \neq 1 (mod \ n)\),但 \(g^{\phi(n)}=1 (mod \ n)\)则称g为n的原根。

使用NTT时有些限制,如果多项式乘法不要求取模,则我们要找足够大的质数\(m\),并且 \(m−1=k\times2^n\),要保证 \(2^n\)大于等于多项式的次数界,\(m\)还要大于多项式的系数.

如果多项式乘法是带模乘法,则只能用NTT,不能使用FFT。此时,若\(m\)为质数,则要求\(m-1=k*2^n\),要保证\(2^n\)要大于等于多项式的次数界, 若\(m\)不为质数,则需要用中国剩余定理来做。

NTT

可以发现一些数的简化剩余系在乘法运算下构成的群与FFT当中的单位复根有相似性质。

我们为什么当初FFT要用单位根进行代入?因为单位复根满足一些特殊性质的同时,它还满足当且仅当\(i=j\)\(\omega_n^i=\omega_n^j\)成立。

如果\(m\)存在原根\(x\),则\(m\)的简化剩余系 \(Z_m^*=\{x^j (mod \ m) |1\leq j \leq \phi(n)\}\)
\(\phi(m)=k*2^n\),令\(N=2^n\)\(x_N^i=x^{\phi(m)*i/N} (mod \ m)\)\(i \leq N\)
则易知\(x_N^i\)​满足:

1.消去引理

\[x_{Nd}^{jd}​=x_N^j​(mod\ m) \]

根据定义易证。

2.折半引理

\[(x_N^{j+\frac{N}{2}})^2=x_N^{2j+N}=(x_N^j)^2\times x_N^N=(x_N^j)^2 \ (mod \ m) \]

根据消去引理易证。

3.求和引理
\(m>2\)\(k\)不是\(N\)的倍数时

\[\sum\limits_{j=0}^{N-1} (x_N^k)^j=0 \ (mod \ m) \]

根据等比数列求和公式易证。

后记&版题

本篇文章写得还算轻松,因为有很多东西可以借鉴《FFT笔记》

本篇文章借鉴了一些资料:
《快速数论变换(NTT)及蝴蝶操作构造详解》——永远在你身后
快速数论变换NTT——hefenghhhh

下面贴个板子:

code:

#include<bits/stdc++.h>
using namespace std;

#define ll long long
#define rp(i,o,p) for(ll i=o;i<=p;++i)
#define pr(i,o,p) for(ll i=o;i>=p;--i)

const ll MAXN=1e7+5,P=998244353;

ll n=1;
char s1[MAXN],s2[MAXN];
ll ans[MAXN],inv3,omg[MAXN],inv[MAXN];
ll la,lb;
ll a[MAXN],b[MAXN];

ll qpow(ll a,ll b)
{
    ll re=1;
    while(b)
    {
        if(b&1)
            re*=a,re%=P;
        a*=a,a%=P;
        b>>=1;
    }
    return re;
}

void init()
{
    inv3=qpow(3,P-2);
    for(ll i=1;i<=n;i<<=1)
    {
        omg[i]=qpow(3,(P-1)/i);
        inv[i]=qpow(inv3,(P-1)/i);
    }
}

void ntt(ll *a,ll *omg)
{
    for(ll i=0,j=0;i<n;++i)
    {
        if(i>j) swap(a[i],a[j]);
        for(ll l = n >> 1; (j^=l) < l; l >>= 1);
    }
    for(ll l=2;l<=n;l<<=1)
    {
        ll m=l>>1;
        for(ll *p=a;p!=a+n;p+=l)
        {
            ll w=1,wn=omg[l];
            rp(i,0,m-1)
            {
                ll t=w*p[i+m]%P;
                p[i+m]=(p[i]-t+P)%P;
                p[i]=(p[i]+t)%P;
                w=w*wn%P;
            }
        }
    }
}

int main()
{
    scanf("%s%s",s1,s2);
    la=strlen(s1),lb=strlen(s2);
    while(n<la+lb) n<<=1;
    rp(i,0,la-1)
        a[i]=s1[la-i-1]-'0';
    rp(i,0,lb-1)
        b[i]=s2[lb-i-1]-'0';
    
    init();
    ntt(a,omg);
    ntt(b,omg);

    rp(i,0,n-1)
        a[i]=a[i]*b[i]%P;
    
    ntt(a,inv);

    ll invn=qpow(n,P-2);
    rp(i,0,n-1)
    {
        ans[i]+=a[i]*invn%P;
        ans[i+1]+=ans[i]/10;
        ans[i]%=10;
    }
    ll i;
    for(i=la+lb-1;i>=0&&!ans[i];--i)
        if(i==0)
            putchar('0'),i=-1;
    while(i>=0)
        putchar('0'+ans[i--]);
    puts("");

    return 0;
}