NTT(快速数论变换)学习

发布时间 2023-09-19 21:54:36作者: Isakovsky

回顾:FFT

FFT(快速傅立叶变换)学习 - Isakovsky - 博客园 (cnblogs.com)

目的:将多项式的系数表示法形式转换为点值表示法形式,或者说,快速计算出多项式在若干个点上的值.

中心思想:适当地选取自变量,使得自变量两两互为相反数,求出的多项式值可重复利用,减少运算次数

例如上面那篇博客中,以八个点举例,取

$x_0=1$

$x_1=\frac{\sqrt{2}}{2}+\frac{\sqrt{2}}{2}i$

$x_2=i$

$x_3=-\frac{\sqrt{2}}{2}+\frac{\sqrt{2}}{2}i$

$x_4=-1$

$x_5=-\frac{\sqrt{2}}{2}-\frac{\sqrt{2}}{2}i$

$x_6=-i$

$x_7=\frac{\sqrt{2}}{2}-\frac{\sqrt{2}}{2}i$

以保证如下性质

$x_0=x_0^2=x_4^2$

$x_2=x_1^2=x_3^2$

$x_4=x_2^2=x_6^2$

$x_6=x_3^2=x_7^2$

问题:需要引入三角函数,浮点类型存储无理数会导致精度误差,而且浮点数运算会拖慢程序运行时间

有无其他的数学概念也存在这样的性质?

回顾:密码协议学习笔记(1):密码协议引论与密码学基础 - Isakovsky - 博客园 (cnblogs.com) 中关于原根的概念.

设素数$p$,以$p=998244353$,$\varphi(p)=p-1=998244352=2^{23}\cdot 7 \cdot 17$为例,存在一个在模$p$素数域下的原根$g=3$,假设还是需要求八个点的值,在模$p$域中取如下几个点:

$x_0=1$

$x_1=g^\frac{\varphi(p)}{8}=3^{124780544}=372528824$

$x_2=g^{2\cdot\frac{\varphi(p)}{8}}=3^{249561088}=911660635$

$x_3=g^{3\cdot\frac{\varphi(p)}{8}}=3^{374341632}=488723995$

$x_4=g^{4\cdot\frac{\varphi(p)}{8}}=3^{499122176}=998244352$

$x_5=g^{5\cdot\frac{\varphi(p)}{8}}=3^{623902720}=625715529$

$x_6=g^{6\cdot\frac{\varphi(p)}{8}}=3^{748683264}=86583718$

$x_7=g^{7\cdot\frac{\varphi(p)}{8}}=3^{873463808}=509520358$

也满足

$x_0=x_0^2=x_4^2$

$x_2=x_1^2=x_3^2$

$x_4=x_2^2=x_6^2$

$x_6=x_3^2=x_7^2$

的性质,从而也能进行类似FFT的变换.此过程不需要进行浮点运算,相比FFT快得多,而且,在最终运算结果的系数不超过$998244353$,项数不超过$2^{23}$时,该方法可以无误差地计算卷积.

进行从点值表示法到系数表示法的逆变换时,只需要用$g^{-1}=332748118$替代$g$(显然$g^{-1}$也是原根)作为底数算出自变量,然后在运算的最后每一项乘上自变量个数的逆元$n^{-1}$即可.

$p$的大小太小了怎么办?找到更大的质数$p'$,记$p'=n'\cdot k'$,$n'$为使得等式成立的,最大的$2$的整数次幂,这时,便可在系数不超过$p'$,项数不超过$n'$时无误差地计算卷积.

例如$p=27\cdot 2^{56}+1, g=5$,$p=29\cdot 2^{57}+1, g=3$

运算中存在负数怎么办?使用补码表示法,在保证系数的绝对值不超过$\lceil \frac{p-1}{2} \rceil$的前提下,用$[1,\lceil \frac{p-1}{2}\rceil]$表示正数.

代码:

def qpow(base,exp,mod): #快速幂
    ans = 1
    while(exp>0):
        if(exp%2 == 1):
            ans = ans*base%mod
        exp = exp//2
        base = base*base%mod
    return ans

def rev(a,mod):
    return qpow(a,mod-2,mod)

def FFT(arr,n,logn,g,mod,inv):
    arr1=[]
    for i in range(n): #蝴蝶操作
        id1=i
        id2=0
        for j in range(logn):
            id2=id2*2+id1%2
            id1=id1//2
        arr1.append(arr[id2]%mod)
    i=1
    while(i<n):  #待合并的多项式的次数
        if(inv==-1):
            wn=qpow(rev(g,mod),(mod-1)//(2*i),mod)
        else:
            wn=qpow(g,(mod-1)//(2*i),mod) 
        
        j=0
        while(j<n):  #枚举具体区间,j也就是区间右端点
            w=1
            for k in range(i): #合并
                x=arr1[j + k]
                y=(w*arr1[i + j + k])%mod
                #不把xy的值保存下来,会导致后面的计算改变了变量值而出错
                arr1[j + k] = (x+y)%mod
                arr1[i + j + k] = (x-y+mod)%mod
                w=w*wn
            j=j+i*2
        i=i*2

    if(inv==-1):
        for i in range(n):
            arr1[i]=(arr1[i]*rev(n,mod))%mod
    #不要忘了逆变换的时候除掉常数
    return arr1

mod=998244353
g=3

a=[1,2,3,4,5,6,7,8]
b=[8,7,6,5,4,3,2,1]
l1=max(len(a),len(b))
n=1
logn=0
while(n*2<=l1*2):
    n=n*2
    logn=logn+1
while(len(a)<n):
    a.append(0)
while(len(b)<n):
    b.append(0)

a=FFT(a,n,logn,g,mod,1)
b=FFT(b,n,logn,g,mod,1)
c=[]
for i in range(len(a)):
    c.append(a[i]*b[i]%mod)
c=FFT(c,n,logn,g,mod,-1)

print(c[:-1])