数論変換(NTT)を用いた多項式乗算 ① 〜離散フーリエ変換と多項式乗算〜

投稿日
🫠

はじめに

この記事は数論変換(NTT)を用いた多項式乗算について 1 から丁寧に解説することを目標としています。この記事を書こうと思ったきっかけは学校で「円周率を 1000 桁計算する」というレポートが出たからです。いわゆる多倍長演算のレポートなのですが、乗算の速度をなんとかして上げようと思い、高速フーリエ変換や数論変換について学びました。結局多倍長演算には NTT は使わなかったのですが、学んだ内容をどこかに残しておきたかったのでここに書きます。僕自身、数学はあまり得意では無いのでもしかしたら間違っているかもしれませんが、そのときは Twitter や GitHub で教えてもらえると嬉しいです。

第 1 回では離散フーリエ変換と多項式乗算の関係について解説します。

多項式乗算

多項式乗算とはその名の通り、多項式同士の乗算です。例えば 2 つの多項式f(x)=1+2xf(x) = 1 + 2xg(x)=3+x+4x2g(x) = 3 + x + 4x^2が与えられたとき、この 2 式の乗算結果は(fg)(x)=3+7x+6x2+8x3(f \cdot g)(x) = 3 + 7x + 6x^2 + 8x^3となります。プログラム上で表現する場合はそれぞれの項を配列に格納することで表現できます(配列のインデックスがxxの指数に対応します)。入力の配列の長さをそれぞれn,mn, mとすると、出力の配列の長さはn+m1n + m - 1となります。2 重 for を用いた非常に単純な実装を示します。

fn mul(f: &Vec<i64>, g: &Vec<i64>) -> Vec<i64> {
    let mut ans = vec![0; f.len() + g.len() - 1];

    for (i, ff) in f.iter().enumerate() {
        for (j, gg) in g.iter().enumerate() {
            ans[i + j] += ff * gg;
        }
    }

    ans
}

この 2 重 for を用いた実装は単純ですが計算量はO(nm)O(nm)となり、入力が大きい場合は計算に時間がかかります。どうにかしてこれを高速化できないでしょうか?ここで使用するのが高速フーリエ変換です。

高速フーリエ変換

高速フーリエ変換(FFT)は離散フーリエ変換を高速で行うことができるアルゴリズムです。計算量はO(nlogn)O(n \log n)となりO(nm)O(nm)に比べて非常に高速です。FFT をうまく使用することで多項式乗算を行うことができます。

離散フーリエ変換

FFT について解説する前にまずは離散フーリエ変換(DFT)について学びましょう。Wikipediaによると、

離散フーリエ変換とは、複素関数f(x)f(x)を複素関数f^(t)\hat f(t)に写す写像であって…(略)

だそうです。なお、多項式乗算をするために DFT を深く知る必要はないので細かい意味などは解説しません。DFT の定義は以下の通りです。

f^(t)=i=0N1f(ζNi)ti=f(ζN0)t0+f(ζN1)t1+f(ζN2)t2++f(ζNN2)tN2+f(ζNN1)tN1\begin{align*} \hat{f}(t) &= \sum^{N - 1}_{i = 0} f(\zeta^i_N) t^i \\ &= f(\zeta_N^0)t^0 + f(\zeta_N^1)t^1 + f(\zeta_N^2)t^2 + \cdots + f(\zeta_N^{N-2})t^{N-2} + f(\zeta_N^{N-1})t^{N-1} \end{align*}

見慣れないζN\zeta_Nというものが出てきました。これは「1 の原始NN乗根」と呼ばれるもので、NN乗して初めて 1 になる数のことです。複素数の範囲で表すと以下のようになります。

ζN=cos2πN+isin2πN=e2πNi\zeta_N = \cos \frac{2 \pi}{N} + i \sin \frac{2 \pi}{N} = e^{\frac{2 \pi}{N}i}

このζN\zeta_Nには重要な 3 つの性質があります。

  • ζNi=ζNN+i\zeta_N^i = \zeta_N^{N + i}

  • i=0N1ζNi(jk)={Nif jkmodN0otherwize\sum^{N - 1}_{i = 0} \zeta_N^{i(j - k)} = \left\{ \begin{array}{ll} N & \text{if}\ j \equiv k \mod N \\ 0 & \text{otherwize} \end{array} \right.

  • 上記はζN\zeta_NζN1\zeta_N^{-1}と置き換えても成り立つ

実はこのζN\zeta_Nは複素数でなくてもよく、この 3 つの性質を満たしていれば何でも良いのですが、それはまた後ほど。

DFT の話に戻りますが、もう一つ離散フーリエ逆変換(IDFT)というものもあります。定義は以下の通りです。

f(x)=1Ni=0N1f^(ζNi)xif(x) = \frac{1}{N} \sum^{N - 1}_{i = 0} \hat{f}(\zeta^{-i}_N) x^i

この IDFT を使うと DFT した式を復元することができます。また見たとおり DFT の定義と大して変わらないため、DFT が実装できれば IDFT もすぐに実装できます。

DFT と IDFT を使った多項式乗算

さて、ここでf(x)f(x)g(x)g(x)の積に DFT を適用してみましょう。

fg^(x)=i=0N1(fg)(ζNi)ti=i=0N1f(ζNi)g(ζNi)ti\begin{align*} \widehat{f \cdot g} (x) &= \sum^{N - 1}_{i = 0} (f \cdot g) (\zeta_N^{i}) t^i \\ &= \sum^{N - 1}_{i = 0} f (\zeta_N^{i}) g (\zeta_N^{i}) t^i \\ \end{align*}

このことからfg^(x)\widehat{f \cdot g}(x)f(x)f(x)g(x)g(x)それぞれに DFT を適用し、各項を乗算すれば求まることがわかります。IDFT を使えば元の(fg)(x)(f \cdot g)(x)を求めることもできます。

つまり多項式乗算は以下の通りに計算することで求めることができます。

  1. f(x),g(x)f(x), g(x)に DFT を適用し、f^(x),g^(x)\hat f(x), \hat g(x)を求める

  2. f^(x),g^(x)\hat f(x), \hat g(x)を各項で乗算し、fg^(x)\widehat{f \cdot g}(x)を求める

  3. fg^(x)\widehat{f \cdot g}(x)に IDFT を適用し、(fg)(x)(f \cdot g)(x)を求める

さて計算量について考えましょう。1, 3 の計算量は FFT を用いるとO(nlogn)O(n \log n)です。2 はただの for で実装できるので計算量はO(n)O(n)です。よって DFT、IDFT を用いた多項式乗算の計算量はO(nlogn)O(n \log n)になります。

Rust での DFT, IDFT の実装

Rust での DFT, IDFT の単純実装を示します。複素数計算のためにnumクレートを使用していますがご了承ください。

dftは DFT と IDFT を兼ねています。inversefalseのときは DFT、trueのときは IDFT を行います。FFT ではないので計算量はO(n2)O(n^2)のままです(むしろ定数倍が増えています)。

use num::Complex;
use std::f64::consts::PI;

fn main() {
    let f = vec![1, 2];
    let g = vec![3, 1, 4];

    let ans = mul_by_dft(&f, &g);
    println!("{:?}", ans);
}

fn mul_by_dft(f: &Vec<u32>, g: &Vec<u32>) -> Vec<u32> {
    let mut comp_f: Vec<Complex<f64>> = f.iter().map(|&x| Complex::new(x as f64, 0.0)).collect();
    let mut comp_g: Vec<Complex<f64>> = g.iter().map(|&x| Complex::new(x as f64, 0.0)).collect();

    let size = f.len() + g.len() - 1;
    comp_f.resize(size, Complex::new(0.0, 0.0));
    comp_g.resize(size, Complex::new(0.0, 0.0));

    let dft_f = dft(&comp_f, false);
    let dft_g = dft(&comp_g, false);

    let mut mul = vec![Complex::new(0.0, 0.0); size];
    for i in 0..size {
        mul[i] = dft_f[i] * dft_g[i];
    }

    let idft_mul = dft(&mul, true);

    idft_mul.iter().map(|x| x.re as u32).collect()
}

fn dft(f: &Vec<Complex<f64>>, inverse: bool) -> Vec<Complex<f64>> {
    let size = f.len();

    let mut res = vec![Complex::new(0.0, 0.0); size];

    for i in 0..size {
        let zeta = if !inverse {
            Complex::from_polar(1.0, 2.0 * PI * i as f64 / size as f64)
        } else {
            Complex::from_polar(1.0, -2.0 * PI * i as f64 / size as f64)
        };
        let mut now = Complex::new(1.0, 0.0);
        for j in 0..size {
            res[i] += f[j] * now;
            now *= zeta;
        }
    }

    if inverse {
        for i in 0..size {
            res[i] /= Complex::new(size as f64, 0.0);
        }
    }

    res
}

まとめ

  • 多項式乗算は普通に計算するとO(n2)O(n^2)かかる

  • 離散フーリエ変換を使うと多項式計算ができる

  • 離散フーリエ変換は高速フーリエ変換を使うとO(nlogn)O(n \log n)で計算できる

次回は高速フーリエ変換の概要と実装について解説します。

Read next