C++ - 多桁乗算(Toom-Cook 法 (3-way))!

Updated:


これまで、「標準(筆算)法」や「Karatsuba 法」による多桁同士の乗算アルゴリズムの C++ への実装を紹介しました。

今回は、「Karatsuba 法」の上位にある「Toom-Cook 法」アルゴリズムを C++ で実装してみました。

0. 前提条件

  • Linux Mint 14 Nadia (64bit) での作業を想定。
  • g++ (Ubuntu/Linaro 4.7.2-2ubuntu1) 4.7.2

1. Karatsuba 法について

(数式が多いので \(\TeX\) で記載)
ちなみに、同程度の桁数同士で、3分割して考える方法を「Toom-Cook 法(3-way)」と言う。(3分割×2分割や4分割×4分割等の考え方もある)

TOOMCOOK_1 TOOMCOOK_2 TOOMCOOK_3 TOOMCOOK_4

2. C++ ソース作成

例として、以下のようにソースを作成した。概要は以下のとおり。

  • 1個の配列で1桁を扱う。
  • 計算可能な桁数は 3 のべき乗桁としている。
    (3 のべき乗以外の桁数にすると、ロジックが複雑になるため)
  • 繰り上がり処理は、最後にまとめて行う(但し、非常に大きい乗算桁数では桁あふれを起こすので注意)
  • 配列数が9個(桁数が9桁)になったら、標準(筆算)法による乗算を行う。
    (9個でなくてもよい。桁あふれしない程度で設定する)
  • 計算に使用する被乗数・乗数は、手入力は困難なため、乱数を使用している。
  • 冒頭の // #define TEST は、乗算回数をカウントしたり、処理時間を計測するテストを行うため。
    テストを行うならコメントを解除する。
    (テストの処理時間計測では clock 関数を使用しているため、あまり精度はよくない)

File: multiply_toom_cook_3.cpp

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
/*********************************************
 * 多倍長乗算 ( by Toom-Cook 法 (3-way) )
 *   - 最下位の桁を配列の先頭とする考え方
 *********************************************/
#include <cstdlib>   // for rand()
#include <iostream>  // for cout
#include <math.h>    // for pow()
#include <stdio.h>   // for printf()

//#define TEST      // テスト ( 乗算回数・処理時間 ) するならコメント解除
#define D_MAX 729  // 計算可能な最大桁数 ( 3 のべき乗 )
#define D     729  // 実際に計算する桁数 ( D_MAX 以下 )

using namespace std;

/*
 * 計算クラス
 */
class Calc
{
    int A[D];  // 被乗数配列
    int B[D];  // 乗数配列
#ifdef TEST
    int cnt_mul;       // 乗算回数
    clock_t t1, t2;    // 計算開始CPU時刻、計算終了CPU時刻
    double tt;         // 計算時間
#endif

    public:
        Calc();                                             // コンストラクタ
        void calcToomCook();                                // 計算 ( Toom-Cook 法 )

    private:
        void multiplyNormal(int *, int *, int, int *);      // 乗算 ( 標準(筆算)法 )
        void multiplyToomCook3(int *, int *, int , int *);  // 乗算 ( Toom-Cook 法 (3-way) )
        void doCarry(int *, int);                           // 繰り上がり・借り処理
        void display(int *, int *, int *);                  // 結果出力
};

/*
 * コンストラクタ
 */
Calc::Calc()
{
    /* ====================================== *
     * テストなので、被乗数・乗数は乱数を使用 *
     * ====================================== */

    int i;  // LOOP インデックス

    // 被乗数・乗数桁数設定
    for (i = 0; i < D; i++) {
        A[i] = rand() % 10;
        B[i] = rand() % 10;
    }
}

/*
 * 計算 ( Toom-Cook 法 )
 */
void Calc::calcToomCook()
{
    int a[D_MAX];       // 被乗数配列
    int b[D_MAX];       // 乗数配列
    int z[D_MAX * 2];  // 計算結果用配列
    int i;              // LOOPインデックス

#ifdef TEST
    t1 = clock();  // 計算開始時刻
    for (int l = 0; l < 1000; l++) {
        cnt_mul = 0;  // 乗算回数リセット
#endif
    // 配列初期設定 ( コンストラクタで設定した配列を設定 )
    for (i = 0; i < D; i++) {
        a[i] = A[i];
        b[i] = B[i];
    }

    // 最大桁に満たない部分は 0 を設定
    for (i = D; i < D_MAX; i++) {
        a[i] = 0;
        b[i] = 0;
    }

    // 乗算 ( Toom-Cook 法 (3-way) )
    multiplyToomCook3(a, b, D_MAX, z);

    // 繰り上がり・借り処理
    doCarry(z, D_MAX * 2);
#ifdef TEST
    }
    t2 = clock();  // 計算終了時刻
    tt = (double)(t2 - t1) / CLOCKS_PER_SEC;  // ==== 計算時間
#endif

    // 結果出力
    display(a, b, z);
}

/*
 * 乗算 ( 標準(筆算)法 )
 */
void Calc::multiplyNormal(int *a, int *b, int tLen, int *z)
{
    int i, j;  // ループインデックス

    // 計算結果初期化
    for(i = 0; i < tLen * 2; i++) z[i] = 0;

    // 各配列を各桁とみなして乗算
    for (j = 0; j < tLen; j++) {
        for (i = 0; i < tLen; i++) {
            z[j + i] += a[i] * b[j];
#ifdef TEST
            cnt_mul++;  // 乗算カウント
#endif
        }
    }
}

/*
 * 乗算 ( Toom-Cook 法 (3-way) )
 *   結果用配列は以下のように配置し、
 *     +----+----+----+----+----+----+----+----+----+----+
 *     |   c0    |   c2    |   c4    |   c1    |   c3    |
 *     +----+----+----+----+----+----+----+----+----+----+
 *   最後に、c1, c3 を所定の位置に加算する。
 *     +----+----+----+----+----+----+
 *     |   c0    |   c2    |   c4    |
 *     +----+----+----+----+----+----+
 *          +----+----+----+----+
 *          |   c1    |   c3    |
 *          +----+----+----+----+
 *   その他、計算に必要な配列ポインタを詳細に設定。
 */
void Calc::multiplyToomCook3(int *a, int *b, int tLen, int *z)
{
    // ==== 変数宣言
    int *a0 = &a[0];                // 被乗数/右側配列ポインタ
    int *a1 = &a[tLen / 3];         // 被乗数/中央配列ポインタ
    int *a2 = &a[tLen * 2/ 3];      // 被乗数/左側配列ポインタ
    int *b0 = &b[0];                // 乗数  /右側配列ポインタ
    int *b1 = &b[tLen / 3];         // 乗数  /中央配列ポインタ
    int *b2 = &b[tLen * 2/ 3];      // 乗数  /左側配列ポインタ
    int *c0 = &z[(tLen / 3) *  0];  // c0     用配列ポインタ
    int *c2 = &z[(tLen / 3) *  2];  // c2     用配列ポインタ
    int *c4 = &z[(tLen / 3) *  4];  // c4     用配列ポインタ
    int c1      [(tLen / 3) * 2];   // c1     用配列
    int c3      [(tLen / 3) * 2];   // c3     用配列
    int a_m2    [tLen / 3];         // a( -2) 用配列
    int a_m1    [tLen / 3];         // a( -1) 用配列
    int a_0     [tLen / 3];         // a(  0) 用配列
    int a_1     [tLen / 3];         // a(  1) 用配列
    int a_inf   [tLen / 3];         // a(inf) 用配列
    int b_m2    [tLen / 3];         // b( -2) 用配列
    int b_m1    [tLen / 3];         // b( -1) 用配列
    int b_0     [tLen / 3];         // b(  0) 用配列
    int b_1     [tLen / 3];         // b(  1) 用配列
    int b_inf   [tLen / 3];         // b(inf) 用配列
    int c_m2    [(tLen / 3) * 2];   // c( -2) 用配列
    int c_m1    [(tLen / 3) * 2];   // c( -1) 用配列
    int c_0     [(tLen / 3) * 2];   // c(  0) 用配列
    int c_1     [(tLen / 3) * 2];   // c(  1) 用配列
    int c_inf   [(tLen / 3) * 2];   // c(inf) 用配列
    int i;                          // LOOPインデックス

    // ==== 9 桁(配列 9 個)になった場合は標準乗算
    if (tLen <= 9) {
        multiplyNormal(a, b, tLen, z);
        return;
    }

    // ==== a(-2) = 4 * a2 - 2 * a1 + a0, b(1) = 4 * b2 - 2 * b1 + b0 (by シフト演算)
    for(i = 0; i < tLen / 3; i++) {
        a_m2[i] = (a2[i] << 2) - (a1[i] << 1) + a0[i];
        b_m2[i] = (b2[i] << 2) - (b1[i] << 1) + b0[i];
    }

    // ==== a(-1) = a2 - a1 + a0, b(1) = b2 - b1 + b0
    for(i = 0; i < tLen / 3; i++) {
        a_m1[i] = a2[i] - a1[i] + a0[i];
        b_m1[i] = b2[i] - b1[i] + b0[i];
    }

    // ==== a(0) = a0, b(0) = b0
    for(i = 0; i < tLen / 3; i++) {
        a_0[i] = a0[i];
        b_0[i] = b0[i];
    }

    // ==== a(1) = a2 + a1 + a0, b(1) = b2 + b1 + b0
    for(i = 0; i < tLen / 3; i++) {
        a_1[i] = a2[i] + a1[i] + a0[i];
        b_1[i] = b2[i] + b1[i] + b0[i];
    }

    // ==== a(inf) = a2, b(inf) = b2
    for(i = 0; i < tLen / 3; i++) {
        a_inf[i] = a2[i];
        b_inf[i] = b2[i];
    }

    // ==== c(-2) = a(-2) * b(-2)
    multiplyToomCook3(a_m2,  b_m2,  tLen / 3, c_m2 );

    // ==== c(-1) = a(-1) * b(-1)
    multiplyToomCook3(a_m1,  b_m1,  tLen / 3, c_m1 );

    // ==== c(0) = a(0) * b(0)
    multiplyToomCook3(a_0,   b_0,   tLen / 3, c_0  );

    // ==== c(1) = a(1) * b(1)
    multiplyToomCook3(a_1,   b_1,   tLen / 3, c_1  );

    // ==== c(inf) = a(inf) * b(inf)
    multiplyToomCook3(a_inf, b_inf, tLen / 3, c_inf);

    // ==== c4 = 6 * c(inf) / 6
    for(i = 0; i < (tLen / 3) * 2; i++)
        c4[i] = c_inf[i];

    // ==== c3 = -c(-2) + 3 * c(-1) - 3 * c(0) + c(1) + 12 * c(inf) / 6
    for(i = 0; i < (tLen / 3) * 2; i++) {
        c3[i]  = -c_m2[i];
        c3[i] += (c_m1[i] << 1) + c_m1[i];
        c3[i] -= (c_0[i] << 1) + c_0[i];
        c3[i] += c_1[i];
        c3[i] += (c_inf[i] << 3) + (c_inf[i] << 2);
        c3[i] /= 6;
    }

    // ==== c2 = 3 * c(-1) - 6 * c(0) + 3 * c(1) - 6 * c(inf) / 6
    for(i = 0; i < (tLen / 3) * 2; i++) {
        c2[i]  = (c_m1[i] << 1) + c_m1[i];
        c2[i] -= (c_0[i] << 2) + (c_0[i] << 1);
        c2[i] += (c_1[i] << 1) + c_1[i];
        c2[i] -= (c_inf[i] << 2) + (c_inf[i] << 1);
        c2[i] /= 6;
    }

    // ==== c1 = c(-2) - 6 * c(-1) + 3 * c(0) + 2 * c(1) - 12 * c(inf) / 6
    for(i = 0; i < (tLen / 3) * 2; i++) {
        c1[i]  = c_m2[i];
        c1[i] -= (c_m1[i] << 2) + (c_m1[i] << 1);
        c1[i] += (c_0[i] << 1) + c_0[i];
        c1[i] += (c_1[i] << 1);
        c1[i] -= (c_inf[i] << 3) + (c_inf[i] << 2);
        c1[i] /= 6;
    }

    // ==== c0 = 6 * c(0) / 6
    for(i = 0; i < (tLen / 3) * 2; i++)
        c0[i] = c_0[i];

    // ==== z = c4 * x^4 + c3 * x^3 + c2 * x^2 + c1 * x + c0
    //      (c0, c2, c4 は最初から所定の位置に格納されているので、
    //       c1, c3 のみ加算 )
    for(i = 0; i < (tLen / 3) * 2; i++) z[i + tLen / 3] += c1[i];
    for(i = 0; i < (tLen / 3) * 2; i++) z[i + (tLen / 3) * 3] += c3[i];
}

/*
 * 繰り上がり・借り処理
 */
void Calc::doCarry(int *a, int tLen) {
    int cr;  // 繰り上がり
    int i;   // ループインデックス

    cr = 0;
    for(i = 0; i < tLen; i++) {
        a[i] += cr;
        if(a[i] < 0) {
            cr = -(-(a[i] + 1) / 10 + 1);
        } else {
            cr = a[i] / 10;
        }
        a[i] -= cr * 10;
    }

    // オーバーフロー時
    if (cr != 0) printf("[ OVERFLOW!! ] %d\n", cr);
}

/*
 * 結果出力
 */
void Calc::display(int *a, int *b, int *z)
{
    int i;  // LOOPインデックス

    // 上位桁の不要な 0 を削除するために、配列サイズを取得
    int aLen = D_MAX, bLen = D_MAX, zLen = D_MAX * 2;
    while (a[aLen - 1] == 0) if (a[aLen - 1] == 0) aLen--;
    while (b[bLen - 1] == 0) if (b[bLen - 1] == 0) bLen--;
    while (z[zLen - 1] == 0) if (z[zLen - 1] == 0) zLen--;

    // a 値
    printf("a =\n");
    for (i = aLen - 1; i >= 0; i--) {
        printf("%d", a[i]);
        if ((aLen - i) % 10 == 0) printf(" ");
        if ((aLen - i) % 50 == 0) printf("\n");
    }
    printf("\n");

    // b 値
    printf("b =\n");
    for (i = bLen - 1; i >= 0; i--) {
        printf("%d", b[i]);
        if ((bLen - i) % 10 == 0) printf(" ");
        if ((bLen - i) % 50 == 0) printf("\n");
    }
    printf("\n");

    // z 値
    printf("z =\n");
    for (i = zLen - 1; i >= 0; i--) {
        printf("%d", z[i]);
        if ((zLen - i) % 10 == 0) printf(" ");
        if ((zLen - i) % 50 == 0) printf("\n");
    }
    printf("\n\n");

#ifdef TEST
    printf("Counts of multiply / 1 loop = %d\n", cnt_mul);     // 乗算回数
    printf("Total time of all loops     = %f seconds\n", tt);  // 処理時間
#endif
}

/*
 * メイン処理
 */
int main()
{
    try
    {
        // 計算クラスインスタンス化
        Calc objCalc;

        // 乗算 ( Toom-Cook 法 )
        objCalc.calcToomCook();
    }
    catch (...) {
        cout << "例外発生!" << endl;
        return -1;
    }

    // 正常終了
    return 0;
}

3. C++ ソースコンパイル

-Wall は警告出力、-O2 最適化のオプション)

$ g++ -Wall -O2 -o multiply_toom_cook_3 multiply_toom_cook_3.cpp

何も出力されなければ成功。

4. 実行

$ ./multiply_toom_cook_3
a =
2591143596 2346775392 0937646531 6568615864 8398725493 
4087130219 1596917334 0177683992 0813223745 9779754352 
8944932316 9629947339 9445091835 0388960192 8955061632 
3738412276 1433199925 0315713085 9378419329 3659409294 
2480119500 7197101549 2212451932 2704588466 4812771903 
8901977387 5802786582 6405523129 4325764724 6720143000 
1277681909 6692814478 7319151423 1290842279 3038342092 
1065637955 8549310037 0533764082 9436637253 9515438679 
7766720886 1618471741 7221329703 7458017793 3575966087 
0590224271 1768660254 4076013016 1293279873 5945959708 
9946080467 5263339808 7033368282 4246141390 0394299488 
0252206204 7509191362 6987956604 6039990313 8920242824 
0815106928 5729454200 4798463368 3859845729 7374431491 
6860803289 6677917104 5612777454 6962438673 4424366456 
2613584191 3929722712 030296373
b =
5640936430 9010018506 2074049213 9225420531 7120391806 
3626925169 5258572474 6042646617 5789516910 0220359949 
3394237741 5504199300 2689317010 0173030204 0908343875 
7276728921 5639684944 4001086363 5835539015 9644066815 
4976237121 6087234637 2062179441 7355273885 5857697267 
0889167023 0551830095 0630610627 7679851212 5513317149 
5213794389 4091568806 9338677524 8328634542 4070207416 
2850176096 3013514856 8183157295 7343340228 8454195001 
4934592104 9265001117 5102743998 1388994837 8389707468 
3149905673 0009246688 7647591722 9004589105 9497241813 
2372874109 8652730769 6483888898 1629437542 5694561003 
2950241948 0871495217 0922867240 6583613401 1096478656 
6501028381 4742690593 3678808349 1837517816 5164400885 
1539589914 3536751799 9106221809 6280914880 4755795751 
0306047932 6782530986 669712556
z =
1461647630 9696028535 6604503489 0007526041 7343634144 
9012424224 8203055051 5188762825 4647083693 9946556987 
1776705770 8951386873 3310520107 5767293276 2918257300 
9738034682 1299762223 5923897458 8934279861 7465049365 
8843967704 2679161515 9650724804 3745565830 2404529184 
8153750484 9136923387 1070699742 2427487328 8754504356 
0612755617 5184216545 5162560665 5429356048 4383287541 
4349677390 3284617894 6977135150 5866219438 8254075770 
4288514897 9756324873 9356425815 0617270838 9035401309 
6240946056 6087222538 2876110571 6041803886 8836295631 
6066052571 8502983404 6640486292 8470602933 3206771844 
7507565069 1263592396 6290321271 2071370365 4580134837 
4120524514 9596367407 7615557930 1753573529 9672002811 
5021333258 2795047488 9411440436 6051342406 6652384845 
5125046986 6299991328 7858252156 3928407926 0968969006 
3788059606 7048907697 3132905363 4969631212 6879889153 
7202723230 4021160144 7598686317 9619939624 8796424549 
6233736832 4323370236 2032503373 8424220248 9204990020 
4555937281 9315727632 6895093765 7046148316 3323171219 
9024915269 7062742069 7796088716 2920824597 2157070237 
6571127379 0423318203 7392124983 1535487696 9993725368 
6009296714 8076417516 5151775265 7836148991 1043480618 
3435401050 4007027779 1974470200 0895993617 3597893980 
3172290388 4700169184 9026298285 0863879579 5031583805 
5355823672 0136388135 2633831672 8029491896 2705418710 
5843171306 8493018874 8709677353 7863144237 7967352645 
5041214394 2329737718 2084198684 3461291462 7753947528 
5224644934 3464361873 3201519986 6919424192 2331637031 
3569475240 6078406677 2868569397 5160753082 4809755113 
99359388

「標準(筆算)法」や「Karatsuba 法」で桁数を指定して計算した結果と同じ結果になった。

5. 検証

上記のコードを利用して、乗算回数が何回になるか、計算にどれくらい時間がかかるかを検証してみた。
参考のため「Karatsuba 法」での乗算テスト結果も掲載している。計算桁数が異なるため単純には比較できないが、1万桁を超えると乗算回数は「Karatsuba 法」より微減し、計算に要する時間も少しずつ短くなっているのが分かる。

ちなみに、「Karatsuba 法」の計算量は \(O(n ^{1.585})\) で、「Toom-Cook 法(3-way)」の計算量は \(O(n ^{1.465})\) である。

MULTIPLY_TOOMCOOK_CPP


「Karatsuba 法」(=「Toom-Cook 法」を2分割×2分割で考える方法)に比べ、「Toom-Cook 法(3-way)」はそれほど大きな効果はありませんが、分割数を調整して試してみてみるのもよいかも知れません。

今回は1桁を1個の配列で扱ったが、乗算する桁数があらかじめ分かっているのなら、1個の配列で4桁を扱ったりすると速度が向上する。(但し、桁あふれに注意)
また、今回は繰り上がり処理を最後にまとめて行なっているが、乗算の都度行うとさらに速度が向上する。(但し、繰り上がりにより配列の個数が変動するので厄介)

このアルゴリズムをどこで使用できるかは、今のところ不明だが、知っておいて悪くないでしょう。
(ちなみに、少し前に当ブログでも紹介した Arctan 系公式を使用した円周率の計算でも、使用する場面は無い)

以上。





 

Sponsored Link

 

Comments