FFT模板

发布时间 2023-09-24 17:21:26作者: Tartaglia

FFT递归实现

#include <iostream>
#include <cmath>
using namespace std;

const int N = 2e6+10;
const double PI = acos(-1.0);

struct comp{
	double a,b;
};

comp operator +(comp a,comp b){
    return {a.a+b.a, a.b+b.b};
}
comp operator -(comp a,comp b){
    return {a.a-b.a, a.b-b.b};
}
comp operator *(comp a,comp b){
    return {a.a*b.a-a.b*b.b, a.a*b.b+a.b*b.a};
}

comp a[N],b[N];

void FFT(comp a[], int len, int flag){
    	if(len==1) return;
    	comp A1[len/2],A2[len/2];
    	for(int i = 0; i< len/2; i++)  //拆分奇偶项
    	{
    		A1[i]=a[i * 2];
    		A2[i]=a[i * 2 + 1];
    	}
    	FFT(A1, len/2, flag);
    	FFT(A2, len/2, flag);

    	comp wk={1,0}; 
    	comp w1={cos(2*PI / len), flag * sin( 2*PI / len)};
    	
    	for(int i = 0; i < len / 2; i++){
    		a[i] = A1[i] + A2[i] * wk;
    		a[i + len / 2] = A1[i] - A2[i] * wk;
    		wk = wk * w1;  //存放每次循环过后wk的值
    	}
    }
    
int main (){
	int n,m;
	scanf("%d%d",&n,&m);
	
	for(int i=0; i<=n; i++) scanf("%lf",&a[i].a);
	for(int i=0; i<=m; i++) scanf("%lf",&b[i].a);
	
	int len=1; while(len <= n+m) len <<= 1;  //求将单位圆等分的次数len

	FFT(a, len, 1);   	//求a的点值
	FFT(b, len, 1);     //求b的点值
	for(int i=0; i < len; i++) a[i]=a[i]*b[i];  	//求a*b的点值
	
	FFT(a, len, -1);  //根据点值求系数
	
	for(int i=0; i <= n+m; i++) printf("%d ",(int)(a[i].a/len + 0.5));
	
	return 0;
}

FFT迭代实现

#include <iostream> 
#include <cmath>
using namespace std;

const int N = 4e6 + 10;
const double PI = acos(-1.0);

struct comp{
	double a,b;
};
comp operator +(comp a,comp b){return {a.a+b.a,a.b+b.b};}
comp operator -(comp a,comp b){return {a.a-b.a,a.b-b.b};}
comp operator *(comp a,comp b){return {a.a*b.a-a.b*b.b,a.a*b.b+a.b*b.a};}

int R[N];
comp a[N],b[N];

void change(comp A[], int n){
    for(int i = 0; i < n; i++) R[i] = R[i/2] /2 + ( (i & 1) ? n/2 : 0);
    for(int i = 0; i < n; i++) 
        if(i < R[i])  swap(A[i], A[R[i]]);
}

void FFT(comp A[], int len, int flag)
{
	change(A, len); //位逆序变换
	
	for(int m = 2; m <= len; m <<= 1 ){  //枚举块宽
	    comp w1 ({cos(2*PI / m), flag * sin(2*PI/m)});
	    
	    for(int i = 0; i < len ; i+= m){  //枚举块数
	        comp wk({1, 0});
	        
	        for(int j = 0; j < m/2; j++){  //枚举半块
	            comp x = A[i + j];
	            comp y = A[i + j + m /2] * wk;
	            A[i + j] = x + y;
	            A[i + j + m/2] = x-y;
	            wk = wk * w1;
	        }
	    }
	}
}

int main ()
{
	int n,m; 
	scanf("%d%d",&n,&m);
	
	for(int i=0;i<=n;i++) scanf("%lf",&a[i].a);
	for(int i=0;i<=m;i++) scanf("%lf",&b[i].a);
	
	int len = 1; while(len <= n + m) len <<= 1;
	FFT(a, len, 1);
	FFT(b, len, 1);//len是多项式项数
	for(int i=0; i < len; i++) a[i]=a[i]*b[i];
	FFT(a, len, -1);
	for(int i=0; i<=n+m; i++) printf("%d ",(int)(a[i].a/len + 0.5));
	return 0;
}