FFT递归实现
#include <iostream>
#include <cmath>
using namespace std;
const int N = 2e6+10;
const double PI = acos(-1.0);
struct comp{
double a,b;
};
comp operator +(comp a,comp b){
return {a.a+b.a, a.b+b.b};
}
comp operator -(comp a,comp b){
return {a.a-b.a, a.b-b.b};
}
comp operator *(comp a,comp b){
return {a.a*b.a-a.b*b.b, a.a*b.b+a.b*b.a};
}
comp a[N],b[N];
void FFT(comp a[], int len, int flag){
if(len==1) return;
comp A1[len/2],A2[len/2];
for(int i = 0; i< len/2; i++) //拆分奇偶项
{
A1[i]=a[i * 2];
A2[i]=a[i * 2 + 1];
}
FFT(A1, len/2, flag);
FFT(A2, len/2, flag);
comp wk={1,0};
comp w1={cos(2*PI / len), flag * sin( 2*PI / len)};
for(int i = 0; i < len / 2; i++){
a[i] = A1[i] + A2[i] * wk;
a[i + len / 2] = A1[i] - A2[i] * wk;
wk = wk * w1; //存放每次循环过后wk的值
}
}
int main (){
int n,m;
scanf("%d%d",&n,&m);
for(int i=0; i<=n; i++) scanf("%lf",&a[i].a);
for(int i=0; i<=m; i++) scanf("%lf",&b[i].a);
int len=1; while(len <= n+m) len <<= 1; //求将单位圆等分的次数len
FFT(a, len, 1); //求a的点值
FFT(b, len, 1); //求b的点值
for(int i=0; i < len; i++) a[i]=a[i]*b[i]; //求a*b的点值
FFT(a, len, -1); //根据点值求系数
for(int i=0; i <= n+m; i++) printf("%d ",(int)(a[i].a/len + 0.5));
return 0;
}
FFT迭代实现
#include <iostream>
#include <cmath>
using namespace std;
const int N = 4e6 + 10;
const double PI = acos(-1.0);
struct comp{
double a,b;
};
comp operator +(comp a,comp b){return {a.a+b.a,a.b+b.b};}
comp operator -(comp a,comp b){return {a.a-b.a,a.b-b.b};}
comp operator *(comp a,comp b){return {a.a*b.a-a.b*b.b,a.a*b.b+a.b*b.a};}
int R[N];
comp a[N],b[N];
void change(comp A[], int n){
for(int i = 0; i < n; i++) R[i] = R[i/2] /2 + ( (i & 1) ? n/2 : 0);
for(int i = 0; i < n; i++)
if(i < R[i]) swap(A[i], A[R[i]]);
}
void FFT(comp A[], int len, int flag)
{
change(A, len); //位逆序变换
for(int m = 2; m <= len; m <<= 1 ){ //枚举块宽
comp w1 ({cos(2*PI / m), flag * sin(2*PI/m)});
for(int i = 0; i < len ; i+= m){ //枚举块数
comp wk({1, 0});
for(int j = 0; j < m/2; j++){ //枚举半块
comp x = A[i + j];
comp y = A[i + j + m /2] * wk;
A[i + j] = x + y;
A[i + j + m/2] = x-y;
wk = wk * w1;
}
}
}
}
int main ()
{
int n,m;
scanf("%d%d",&n,&m);
for(int i=0;i<=n;i++) scanf("%lf",&a[i].a);
for(int i=0;i<=m;i++) scanf("%lf",&b[i].a);
int len = 1; while(len <= n + m) len <<= 1;
FFT(a, len, 1);
FFT(b, len, 1);//len是多项式项数
for(int i=0; i < len; i++) a[i]=a[i]*b[i];
FFT(a, len, -1);
for(int i=0; i<=n+m; i++) printf("%d ",(int)(a[i].a/len + 0.5));
return 0;
}