PyTorch 1.10の新機能「CUDA Graphs」のパフォーマンスを測定してみる

はじめに

10/21にPyTorch 1.10がリリースされ、今回も面白そうな機能が追加されました。個人的には楽しみにしていた「CUDA Graphs」のAPIのベータ版が追加されたということで早速試してみました。今回はこの試した結果の記事になります。

CUDA Graphsとは?

CUDA GraphsはCUDA 10で追加されたCUDAの機能の一つで、複数のCUDA Kernelの実行にかかるオーバーヘッドを減らすための機能です。

基本的には依存関係表すことができるグラフにCUDA Kernelを登録して、依存関係を考慮して順番にCUDA Kernelを実行するという仕組みです。このCUDA Graphsを通して実行すると普通にCUDA Knernelを実行するのに比べてCUDA Kernelの実行オーバーヘッドを減らすことができます。

詳しくはNVIDIA Developer Blogに記事があるのでご覧ください。

PyTorchでCUDA Graphsを使う

PyTorchでCUDA Graphsを使うには主に以下の2つのステップを踏みます。

  1. CUDA GraphsのStream Captureの機能を使ってグラフを構築
  2. 構築したグラフを実行

それぞれについて順番に説明します。

また、ディープラーニングにおいてすべてのレイヤーがグラフに登録できるものでなかった場合、ネットワークの一部部分だけグラフを構築する方法も用意されています。こちらは今回は触れません。詳しく知りたい方は以下のドキュメントをご覧ください。

https://pytorch.org/docs/master/notes/cuda.html#partial-network-capture

CUDA GraphsのStream Captureの機能を使ってグラフを構築

PyTorchではCUDA Graphsのグラフ構築の一つにStream Captureベースの方法が提供されています。これはtorch.cuda.graph() 以下の実行された関数を自動的にグラフに登録するというものです。
注意点としてはグラフ構築の前のwarmupでは別streamで実行したほうが良いらしいです。詳しくは参考資料の公式ドキュメントをご覧ください。

warmupも含めたグラフ構築は以下の通りです。

static_input = torch.empty((5,), device="cuda")
# Warmup before capture
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
    for _ in range(3):
        static_output = static_input * 2
torch.cuda.current_stream().wait_stream(s)

# Captures the graph
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
    static_output = static_input * 2

これで入力をstatic_input、出力をstatic_outputとし、入力を2倍にする計算のグラフgが準備できました。

構築したグラフを実行

構築されたグラフg を実行する際には入力データをstatic_input に上書きして、replay()を実行します。

static_input.copy_(torch.full((5,), 3, device="cuda"))
print("input of cuda graph", static_input)
g.replay()
# static_output holds the results
print("output of cuda graph", static_output) 

出力は以下の通りです。

input of cuda graph tensor([3., 3., 3., 3., 3.], device='cuda:0')
output of cuda graph tensor([6., 6., 6., 6., 6.], device='cuda:0')

注意事項

CUDA Graphsは簡単に使えそうですが、入力のtensorのshapeが変えられないなど制約がいくつかあります。詳しくはこちらをご覧ください。

https://pytorch.org/docs/master/notes/cuda.html#constraints

パフォーマンスの評価

使い方がわかったところで、どれくらい速くなるのか?ということが気になったので測定してみました。測定したときのnotebookは以下のところに置いておきます。

https://github.com/shu65/blog-pytorch-notebooks/blob/main/pytorch_CUDA_Graphs.ipynb

今回は気になった2つのパターンで評価しました。

  • GELU
  • シンプルなLinearとDropoutのモデルの学習

評価環境は以下の通り。

  • 実行環境:Google Colab
    • PyTorch: 1.10.0
    • CUDA: 11.1
    • GPU: K80 (たまたま取れた)

GELU

簡単な例として以下ようなGELUをCUDA Graphsで実行してみます。

def gelu(x):
    return x * 0.5 * (1.0 + torch.erf(x / 1.41421))

また、この際、入力のtensorで小さい例と大きい例の2種類を使って測定してみます。

それぞれのtensorのshapeとしては以下の通りです。

  • 小さいtensor: (1, 3, 224, 224)
  • 大きいtensor: (32, 3, 224, 224)

上記のサイズのtensorそれぞれを10000回実行して平均計算時間を測定しました。結果は以下の通りです。

平均計算時間 (sec.)defaultを1とした時の速度向上率
default7.09e-051.00
CUDA Graphs6.49e-051.09
GELUの小さいtensorの評価結果
平均計算時間 (sec.)defaultを1とした時の速度向上率
default1.32e-031.00
CUDA Graphs1.34e-030.99
GELUの大きいtensorの評価結果

評価結果としては個人的には思った通りの結果という印象で、CUDA Kernelのオーバーヘッドの割合が大きい、小さいtensorの時は効果がある程度出ているが、大きいtensorの時はオーバーヘッドの割合が小さいため、ほぼ変わらないという結果になりました。

シンプルなLinearとDropoutのモデルの学習

PyTorchでCUDA Graphsの真価を発揮するのは学習のタイミングかと思いますので、公式ドキュメントにあった例の評価をしてみます。CUDA Graphsに登録する関数train_step()とモデル、各種入力は以下の通りです。

def training_step(model, loss_fn, optimizer, data, target):
    y_pred = model(data)
    loss = loss_fn(y_pred, target)
    loss.backward()
    optimizer.step()

N, D_in, H, D_out = 32, 128, 256, 16
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.Dropout(p=0.2),
    torch.nn.Linear(H, D_out),
    torch.nn.Dropout(p=0.1)
).cuda()
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

# Placeholders used for capture
static_input = torch.randn(N, D_in, device='cuda')
static_target = torch.randn(N, D_out, device='cuda')

ちなみにGoogle Colabで実行しようとしたとき、公式ドキュメントの入力サイズそのままだとcuBLASの内部でエラーが発生して実行できなかったため、サイズを小さくしてあります。

これらを10イテレーション分実行したときの評価結果は以下の通りです。

1イテレーションあたりの平均計算時間 (sec.) defaultを1とした時の速度向上率
default1.11e-031.00
CUDA Graphs4.71e-042.36
シンプルなLinearとDropoutのモデルの学習の評価結果

こちらは思ったよりも速度に差がでました。CUDA Graphsを利用できる場合は使うと効果的かもしれません。

おまけ

CUDA Graphsの制限を見ていて思いましたが、これならtorch.jit.tracetorch.jit.scriptも併用できるのでは?と思ってやってみました。以前、以下の記事で行ったように torch.jit.script + GELUを使用して評価しました。

評価結果は以下の通りです。

平均計算時間 (sec.)defaultを1とした時の速度向上率
default7.09e-051.00
CUDA Graphs6.49e-051.09
torch.jit.script3.89e-051.82
torch.jit.script + CUDA Graphs 3.56e-051.99
GELUの小さいtensorの評価結果
平均計算時間 (sec.)defaultを1とした時の速度向上率
default1.32e-031.00
CUDA Graphs1.34e-030.99
torch.jit.script 4.25e-043.11
torch.jit.script + CUDA Graphs 3.74e-04 3.53
GELUの大きいtensorの評価結果

torch.jit.script の効果が大きいですが、CUDA Graphsを使うことでさらに速くなることが確認できました。個人的には CUDA Graphs が使える状況なら torch.jit.tracetorch.jit.script も使えると思われるので併用してよいのではないかと思います。

終わりに

楽しみにしていたCUDA GraphsがPyTorchで使えるようになったということで、評価してみました。一部思った以上の効果を発揮したところもあるので、仕事でも使ってみてノウハウを貯めていこうと思います。

参考資料

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

このサイトは reCAPTCHA と Google によって保護されていますプライバシーポリシー利用規約 申し込み。

The reCAPTCHA verification period has expired. Please reload the page.