SIMD(AVX-512)を使って競プロの問題を解いてみる
この記事は Competitive Programming (2) Advent Calendar 2018 18日目の記事です。
定数倍高速化の方法としてSIMDを用いたものが存在します。SIMDを用いて競プロの問題を解いている例としては以下の記事がありました。
SIMDの命令セットにも色々種類があるようですが、今回はその中でもAVX-512を用いて競プロの問題を解いてみます。
AVX-512やそもそものSIMDについては以下の記事が参考になると思います。
SIMDプログラミング入門(AVX-512から始める編) - Qiita
※注意
・現時点でAVX-512を使えるコンテストサイトはないと思います(たぶん)。なので今回は手元の環境で実行してます。
・対象とした問題のテストケースが公開されていなかったので、自分で最大ケースを作って試してます。おそらく大丈夫だとは思うのですが、コードに誤りがあった場合は指摘してくださると嬉しいです。
問題(ABC106 D - AtCoder Express 2)
自明な解法
問題に書いてある通りのことをやればよいですね。こんな感じのコードになると思います。
#include <cstdio> using namespace std; int main() { const int m_max = 200000; const int q_max = 100000; int N, M, Q; scanf("%d%d%d", &N, &M, &Q); int L[m_max], R[m_max]; for (int i = 0; i < M; i++) { scanf("%d%d", &L[i], &R[i]); } int ans[q_max]; int p[q_max], q[q_max]; for (int i = 0; i < Q; i++) { scanf("%d%d", &p[i], &q[i]); } for (int i = 0; i < Q; i++) { ans[i] = 0; } for (int i = 0; i < Q; i++) { for (int j = 0; j < M; j++) { if (p[i] <= L[j] && R[j] <= q[i]) { ans[i]++; } } } for (int i = 0; i < Q; i++) { printf("%d\n", ans[i]); } }
これだと計算量は O(QM) で、QMが最大でになるのでちょっと厳しそうです。実際に実行してみましょう。
$ g++-7 -O3 -march=native AtCoderExpress2.cpp -o AtCoderExpress2 $ time ./AtCoderExpress2 < testcase.txt > output.txt real 1m5.404s user 1m5.388s sys 0m0.004s
TLは3秒なので、大幅にオーバーしてます。
SIMDで並列化
早速上記のコードを並列化してみます。二重forの外側と内側どちらを並列化するかですが、今回は内側のループを並列化することにします。(そっちを先に思いついたので。)
#include <cstdio> #include <cstdint> #include <immintrin.h> using namespace std; int main() { const int m_max = 200000; const int q_max = 100000; int N, M, Q; scanf("%d%d%d", &N, &M, &Q); alignas(64) int16_t L[m_max], R[m_max]; for (int i = 0; i < M; i++) { scanf("%hd%hd", &L[i], &R[i]); } int ans[q_max]; int16_t short p[q_max], q[q_max]; for (int i = 0; i < Q; i++) { scanf("%hd%hd", &p[i], &q[i]); } for (int i = 0; i < Q; i++) { ans[i] = 0; } const int M1 = M / 32 * 32; for (int i = 0; i < Q; i++) { const __m512i pp = _mm512_set1_epi16(p[i]); const __m512i qq = _mm512_set1_epi16(q[i]); for (int j = 0; j < M1; j += 32) { const __m512i LL = _mm512_load_si512(L + j); const __m512i RR = _mm512_load_si512(R + j); const __mmask32 cond1 = _mm512_cmple_epi16_mask(pp, LL); const __mmask32 cond2 = _mm512_cmple_epi16_mask(RR, qq); ans[i] += _mm_popcnt_u32(_kand_mask32(cond1, cond2)); } for (int j = M1; j < M; j++) { if (p[i] <= L[j] && R[j] <= q[i]) { ans[i]++; } } } for (int i = 0; i < Q; i++) { printf("%d\n", ans[i]); } }
$ g++-7 -O3 -march=native AtCoderExpress2AVX512.cpp -o AtCoderExpress2AVX512 $ time ./AtCoderExpress2AVX512 < testcase.txt > output_avx512.txt real 0m0.995s user 0m0.992s sys 0m0.000s $ diff output.txt output_avx512.txt $
めっちゃ速くなりましたね。1秒を切ることに成功しました。
まとめ
対象とした問題が単純だったというのはあるのですが、思ったよりも簡単にSIMDで高速化出来ました。問題がもっと複雑な場合でも高速化できるものはあると思うので、計算量を落とす方法は思いつかないけど定数倍高速化すれば通りそう、みたいな時は少し考えてみてもよいかもしれません。
今回使用したAVX-512の命令は単純なものだけでしたが、色々機能が追加されているらしいので、コンテストサイト(特にAtCoder)で使えるようになることを願いつつ終わりたいと思います。
オチ
実は元のコードとSIMDを用いたコードは処理が若干異なっています。後者の方に合わせるとコードは以下のようになります。ついでに変数の型も合わせました。
#include <cstdio> #include <cstdint> using namespace std; int main() { const int m_max = 200000; const int q_max = 100000; int N, M, Q; scanf("%d%d%d", &N, &M, &Q); int16_t L[m_max], R[m_max]; for (int i = 0; i < M; i++) { scanf("%hd%hd", &L[i], &R[i]); } int ans[q_max]; int16_t p[q_max], q[q_max]; for (int i = 0; i < Q; i++) { scanf("%hd%hd", &p[i], &q[i]); } for (int i = 0; i < Q; i++) { ans[i] = 0; } for (int i = 0; i < Q; i++) { for (int j = 0; j < M; j++) { ans[i] += (p[i] <= L[j] & R[j] <= q[i]); } } for (int i = 0; i < Q; i++) { printf("%d\n", ans[i]); } }
if文が消えて演算子が`&&` から `&`に変わりました。
この状態で実行してみます。
$ g++-7 -O3 -march=native AtCoderExpress2.cpp -o AtCoderExpress2 $ time ./AtCoderExpress2 < testcase.txt > output.txt real 0m1.085s user 0m1.080s sys 0m0.004s
SIMDを使ったのと同じくらいの実行時間になりました。これはどういうことかというと、条件分岐が消えたのも速くなった理由の一つではありそうですが、自動ベクトル化が効いたことが主な理由だと考えられます。コンパイラがコードを解析した結果、SIMDが使えると判断した際には(明示的に書かれていなくても)自動的に使うように最適化されることがあります。
実際にアセンブラを出力して確認してみると、AVX-512で使用されるzmmレジスタが使われていることなどから、AVX-512が使われている=SIMD命令が使われていることが確認できました。
AtCoderは128bitのSIMD幅の命令は使えるそうなので(AVX-512は512bit)単純に考えると4倍くらいの実行時間になって、それでもTL3秒に間に合いそうな気がしてきます。そのままだと通らなかったので適当にコードをいじって投げると通っちゃいました。
Submission #3820425 - AtCoder Beginner Contest 106
以上のことから、 SIMDを自分で明示的に書く前に、すでに自動ベクトル化がされていないか、されていない場合には自動ベクトル化を阻害するようなコードになっていないか等を確認するとよいかもしれません。