zerojudge i212. 三則運算

題目在 https://zerojudge.tw/ShowProblem?problemid=i212

這題是在計算非常大的加減乘法, 大到100萬位的運算. 對於加法和減法來說, 幾位都不是問題. 用a021. 大數運算的方法都可以搞定, 我就不再寫了, 有興趣可以看我的另一篇zerojudge a021. 大數運算

困難的地方是在乘法用直式來算的話要計算太多次了, 我很懷疑這是給高中生來做的題目, 如果高中生可以自己想出快速的解法, 那他應該是個數學天才. 我也是看了數學家的方法才AC的.

看討論有人懷疑FFT也不好做, 要用NTT的方法, 差點打擊了我的信心. 最後我用FFT來解是可以成功的. 詳細的數學式我就不寫了, 網路上有很多FFT的介紹, 但是還是有很多錯誤的公式, 所以我把我看過覺得公式比較正確的連結附上. 想好好看懂的人就別走冤枉路了.

整個計算的思路如下.

  1. 大數相乘 和多項式係數表示法的 多項式相乘的過程一致.
  2. 多項式係數表示法的多項式相乘可以用多項式的點值表示法來相乘, 計算將減少很多.
  3. 用FFT可以將多項式的係數轉換成點值.
  4. 兩個多項式的點值相乘後, 可以用IFFT將點值還原成相乘後的多項式的係數, 但是IFFT也可以用FFT來做, 程式不用再多寫一個IFFT.
  5. FFT轉換後的係數若大於9, 要進位到高位, 才會和大數相乘後的積一致.

以下相乘過程大致與思路一致,

請參考 多项式101-点值表示法与系数表示法学习笔记

const double pi=acos(-1.0);

struct node{
	double x,y;
	node (double xx=0,double yy=0)
	{
		x=xx;y=yy;
	}
};

void init(int n)
{
	for (int i=0;i<n;i++)
	{
		node_omega[i]=node(cos(2.0*pi*i/n),sin(2.0*pi*i/n));
		node_a_omega[i]=node(cos(2.0*pi*i/n),-sin(2.0*pi*i/n));
	}
}

void MUL(char *a_str, char *b_str)
{
	//計算FFT需要多大才可以涵蓋所有的點數
	n=strlen(a_str)-1;
	m=strlen(b_str)-1;
	fn=1;
	while (fn<=m+n) fn<<=1;

	//取得需要的記憶體
	node_a=new node[fn];
	node_b=new node[fn];
	node_omega=new node[fn];
	node_a_omega=new node[fn];
	num = new int[fn];

	//先準備好所有的W
	init(fn);

	//將字串轉數字並放入實部
	for (int i=0;i<=n;i++) node_a[n-i].x=(double)(a_str[i]-'0');
	for (int i=0;i<=m;i++) node_b[m-i].x=(double)(b_str[i]-'0');

	//被乘數做FFT
	FFT(fn,node_a,node_omega);

	//乘數做FFT
	FFT(fn,node_b,node_omega);

	//被乘數FFT * 乘數FFT
	for (int i=0;i<=fn;i++) node_a[i]=node_a[i]*node_b[i];

	//答案做IFFT, 但是用FFT來做
	FFT(fn,node_a,node_a_omega);

	//將答案四捨五入到整數
	for (int i=0;i<=fn;i++) num[i]=(int)(node_a[i].x/fn+0.5);

	//將答案大於9的進位到下一位
	for (int i=0;i<=fn;i++)
	{
		num[i+1]+=num[i]/10;
		num[i]%=10;
	}

	//找到高位第一個非零的數字
	int len=m+n+1; while (num[len]==0 && len>0) len--;
	//印出相乘的結果
	for (int i=len;i>=0;i--) printf("%d",num[i]);

	//釋放記憶體
	delete node_a;
	delete node_b;
	delete node_omega;
	delete node_a_omega;
	delete num;
}
C++

上面的FFT怎麼做呢? 基本上就是把蝴蝶圖的計算流程寫出來. 維基百科 庫利-圖基快速傅立葉變換演算法已經寫得很清楚了.


node *node_a,*node_b,*node_omega,*node_a_omega;
int n,m;
int fn;
int *num;


node operator +(const node &a,const node &b){return node (a.x+b.x,a.y+b.y);}
node operator -(const node &a,const node &b){return node (a.x-b.x,a.y-b.y);}
node operator *(const node &a,const node &b){return node (a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);}

void FFT(int n,node *a,node *w)
{
	int i,j=0,k;
	//將係數a重排到對的位置
	for (i=0;i<n;i++)
	{
		if (i>j) swap(a[i],a[j]);
		for (int l=n>>1;(j^=l)<l;l>>=1);
	}
	//做蝴蝶圖的相乘
	for (i=2;i<=n;i<<=1)
	{
		int m=i>>1;
		for (j=0;j<n;j+=i)
			for (k=0;k<m;k++)
			{
				node z=a[j+k+m]*w[n/i*k];
				a[j+k+m]=a[j+k]-z;
				a[j+k]=a[j+k]+z;
			}
	}
}
C++

完整程式如下.


#include <cstdio>
#include <cstring>
#include <iostream>
#include <cmath>

using namespace std;

string Add(string a, string b){
	int a_size=a.size();
	int b_size=b.size();

	if(a_size<b_size){
		return Add(b,a);
	}

	int a_index=a_size-1;
	int b_index=b_size-1;

	//answer先加a
	string answer=a;

	//answer加b, 先不管進位
	while(b_index>=0){
		answer[a_index--]+=(b[b_index--]-'0');
	}

	//處理answer進位, answer[0]進位不處理
	a_index=a_size-1;
	while(a_index>0){
		if(answer[a_index]>'9'){
			answer[a_index]-=10;
			answer[a_index-1]++;
		}
		a_index--;
	}

	//處理answer[0]進位
	if(answer[0]>'9'){
		answer[0]-=10;
		answer='1'+answer;
	}

	return answer;
}

string Sub(string a, string b){

	int a_size=a.size();
	int b_size=b.size();

	//處理a<b
	if(a_size<b_size){
		string answer="-"+Sub(b,a);
		return answer;
	}else if(a_size==b_size && a<b){
		string answer="-"+Sub(b,a);
		return answer;
	}

	int a_index=a_size-1;
	int b_index=b_size-1;

	string answer=a;

	//answer減b, 先不管借位
    while(b_index>=0){
		answer[a_index--]-=(b[b_index--]-'0');
	}

    a_index=a_size-1;
	//處理answer借位
	while(a_index>0){
		if(answer[a_index]<'0'){
			answer[a_index]+=10;
			answer[a_index-1]--;
		}
		a_index--;
	}
	//由左向右去找第一個不是0的數字
	a_index=0;
	while(answer[a_index]=='0' && a_index<answer.size()-1){
		a_index++;
	}
	return answer.substr(a_index);

}

const double pi=acos(-1.0);

struct node{
	double x,y;
	node (double xx=0,double yy=0)
	{
		x=xx;y=yy;
	}
};


node *node_a,*node_b,*node_omega,*node_a_omega;
int n,m;
int fn;
int *num;


node operator +(const node &a,const node &b){return node (a.x+b.x,a.y+b.y);}
node operator -(const node &a,const node &b){return node (a.x-b.x,a.y-b.y);}
node operator *(const node &a,const node &b){return node (a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);}

void init(int n)
{
	for (int i=0;i<n;i++)
	{
		node_omega[i]=node(cos(2.0*pi*i/n),sin(2.0*pi*i/n));
		node_a_omega[i]=node(cos(2.0*pi*i/n),-sin(2.0*pi*i/n));
	}
}

void FFT(int n,node *a,node *w)
{
	int i,j=0,k;
	//將係數a重排到對的位置
	for (i=0;i<n;i++)
	{
		if (i>j) swap(a[i],a[j]);
		for (int l=n>>1;(j^=l)<l;l>>=1);
	}
	//做蝴蝶圖的相乘
	for (i=2;i<=n;i<<=1)
	{
		int m=i>>1;
		for (j=0;j<n;j+=i)
			for (k=0;k<m;k++)
			{
				node z=a[j+k+m]*w[n/i*k];
				a[j+k+m]=a[j+k]-z;
				a[j+k]=a[j+k]+z;
			}
	}
}


void MUL(char *a_str, char *b_str)
{
	//計算FFT需要多大才可以涵蓋所有的點數
	n=strlen(a_str)-1;
	m=strlen(b_str)-1;
	fn=1;
	while (fn<=m+n) fn<<=1;

	//取得需要的記憶體
	node_a=new node[fn];
	node_b=new node[fn];
	node_omega=new node[fn];
	node_a_omega=new node[fn];
	num = new int[fn];

	//先準備好所有的e
	init(fn);

	//將字串轉數字並放入實部
	for (int i=0;i<=n;i++) node_a[n-i].x=(double)(a_str[i]-'0');
	for (int i=0;i<=m;i++) node_b[m-i].x=(double)(b_str[i]-'0');

	//被乘數做FFT
	FFT(fn,node_a,node_omega);

	//乘數做FFT
	FFT(fn,node_b,node_omega);

	//被乘數FFT * 乘數FFT
	for (int i=0;i<=fn;i++) node_a[i]=node_a[i]*node_b[i];

	//答案做IFFT, 但是用FFT來做

	FFT(fn,node_a,node_a_omega);

	//將答案四捨五入到整數
	for (int i=0;i<=fn;i++) num[i]=(int)(node_a[i].x/fn+0.5);

	//將答案大於9的進位到下一位
	for (int i=0;i<=fn;i++)
	{
		num[i+1]+=num[i]/10;
		num[i]%=10;
	}

	//找到高位第一個非零的數字
	int len=m+n+1; while (num[len]==0 && len>0) len--;
	//印出相乘的結果
	for (int i=len;i>=0;i--) printf("%d",num[i]);

	//釋放記憶體
	delete node_a;
	delete node_b;
	delete node_omega;
	delete node_a_omega;
	delete num;
}

char a_str[1000002];
char b_str[1000002];

int main(){

	char op;
	scanf("%s %c %s",a_str,&op,b_str);

	string answer;
	switch(op){
		case '+':answer=Add(string(a_str),string(b_str));
				break;
		case '-':answer=Sub(string(a_str),string(b_str));
				break;
		case '*':MUL(a_str,b_str);
				return 0;
	}

	printf("%s\n",answer.c_str());
}
C++