PyTorch 2.0の新機能「torch.compile」使ってみた

今回は3/16についに出たPyTorch 2.0の目玉機能である「torch.comple」について実際に動かしてみて計算時間を測定してみたので、そのまとめになります。

時間計測の部分で測定に使ったコードはここにあげてあります。

https://github.com/shu65/pytorch_2_compile_example/blob/main/torch_2_0_compile.ipynb

torch.compileとは?

torch.compileはPyTorch 2.0の新機能で、PyTorchの複数の機能を組み合わせて使い関数や深層学習のモデルを実行時に最適化して、その後の呼び出して高速に実行できるようにする機能です。

torch.compileの中身の詳しい説明はここにかかれています。

https://pytorch.org/get-started/pytorch-2.0/#technology-overview

簡単に説明するとtorch.compileの中身としては以下の3つで構成されています。

  1. Graph acquisition: 計算グラフの構築
  2. Graph lowering: PyTorchのオペレーションをバックエンドのデバイス(CPUやGPU)に特化した細かい命令に分解
  3. Graph compilation: バックエンドのデバイス特化の命令を呼び出し

これらのステップを経ることで、より効率よく計算リソースを使えるようにし、高速化を実現しています。

また、この機能のすばらしいところは使い方も非常に簡単であるというものがあります。以下にデコレータで使う方法とtorch.compileの関数を呼び出して使う方法を示します。

デコレータで使うやり方

まずデコレータで使う方法です。これは以下のようになります (このチュートリアルの例:https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html#basic-usage)

@torch.compile
def opt_foo2(x, y):
    a = torch.sin(x)
    b = torch.cos(x)
    return a + b
opt_foo2(torch.randn(10, 10), torch.randn(10, 10))

torch.jit.scriptを使ったことがある方は、それと同じ感覚で使えるというと使い方がイメージしやすいかもしれません。

torch.compileの関数を呼び出して使うやり方

torch.compileの関数を呼び出してコンパイルする場合は以下のようにやります。(このチュートリアルの例:https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html#basic-usage)

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = torch.nn.Linear(100, 10)

    def forward(self, x):
        return torch.nn.functional.relu(self.lin(x))

mod = MyModule()
opt_mod = torch.compile(mod)
opt_mod(torch.randn(10, 100))

こちらもtorch.jit.scriptのときと同じような使い方だと思います。

torch.compileによるパフォーマンスの評価

次にtorch.compileを実際に使ってみたときの計算時間を計測したので、その紹介です。今回は以下の二つのGPUで測定しました。

  1. T4
  2. V100

T4はTuringなので公式のドキュメントでtorch.compileのサポートが書かれてないものになっています。ただ、やってみたら少し早くなったので、測定結果を載せています。GitHubにあげたコードはT4で測定したほうです。

また、CUDAのバージョンはどちらのケースも12.0利用し、測定に使ったモデルはチュートリアルにあったtorchvisionのResNet18を使用しました。

また、torch.compileにはモードが以下の3つあります。

  1. デフォルト
  2. reduce-overhead
  3. max-autotune

これらと何もしてない場合も含めて合計4つパターンの測定をしています。

具体的な測定方法が分かりやすいようにコードの一部を紹介します(torch.compleのデフォルトの場合)。

import time 

import torch
import torchvision.models as models
import torch._dynamo

batch_size = 64
n_warmup_iters = 10
n_iters = 500

x = torch.randn(batch_size, 3, 224, 224).cuda()

def get_mode():
    return models.resnet18()

torch._dynamo.reset()

model = get_mode().cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# compile
compiled_model = torch.compile(model)
for i in range(n_warmup_iters):
    optimizer.zero_grad()
    torch.cuda.synchronize()
    start = time.time()
    out = compiled_model(x)
    torch.cuda.synchronize()
    forward_elapsed_time = time.time() - start
    torch.cuda.synchronize()
    start = time.time()
    out.sum().backward()
    backward_elapsed_time = time.time() - start
    print(f"with compile {i} iter forward: {forward_elapsed_time/1000:.3e} msec., backward: {backward_elapsed_time/1000:.3e} msec.")
    optimizer.step()

print("-"*10)

torch.cuda.synchronize()
start = time.time()
for i in range(n_iters):
    optimizer.zero_grad()
    out = compiled_model(x)
    out.sum().backward()
    optimizer.step()
torch.cuda.synchronize()
elapsed_time = time.time() - start

print(f"with compile total:{elapsed_time:.3e} sec. {batch_size*n_iters/elapsed_time:.3e} imgs/sec.")

最初に、モデルの入力とモデルを作ったあと、コンパイルする場合はtorch.compile(model)でコンパイルします。このときコンパイルのモードを変える場合は引数のmodeにモードの名前を渡します。

その後、最初の数回はforward、backwardの呼び出し時にコンパイルなどのオーバーヘッドが入って遅いので、あらかじめ何度か呼びます。そして最後に実際に時間を計測します。今回は10回あらかじめforwardとbackwardを呼んでおいて、その後500回イテレーションを回したときの時間を測定しています。バッチサイズに関しては変化させると高速化率が変化することはわかっていますが、今回固定で64で実行しています。

T4, V100ともに同様の方法でtorch.compileのありなし等を測定しています。

では、時間計測の結果です。500回イテレーションを回したときの実際の計算時間を順番に示していきます。まずはT4の場合です。

計算時間 (sec.)torch.compileなしからの高速化率
torch.compileなし78.681.00
torch.compile (default)73.371.07
torch.compile (reduce-overhead)77.521.01
torch.compile (max-autotune)73.351.07
T4を使ったResNet18の結果

T4はtorch.compileのサポートが書かれてない世代のGPUなので、効果が全くでないのかと思ったのですが、そんなことはなかったです。ただ、10%は満たない高速化にとどまっているという印象です。ちなみにT4を使ったケースではtorch.compileのmodeをmax-autotuneに変えると以下のようにサポートされてないGPUであると警告がでてきます。

[2023-03-17 18:31:06,314] torch._inductor.utils: [WARNING] not enough cuda cores to use max_autotune mode

次にV100のResNet18の結果です。

計算時間 (sec.)torch.compileなしからの高速化率
torch.compileなし26.61.00
torch.compile (default)24.71.08
torch.compile (reduce-overhead)24.21.10
torch.compile (max-autotune)24.11.10
V100を使ったResNet18の結果

V100のほうはtorch.compileのサポートされていると書かれているGPUです。実際、V100はtorch.compileのmodeをmax-autotuneに変えると確かにより速くなり、高速化率も最大値は10%台に入っています。

現状のtorch.compileの注意点

最後にtorch.compileの注意したほうがよさそうな点を書いておきます。

まず、公式で書かれいたものの紹介です。基本的な注意点はこのドキュメントに書いてあります。

https://pytorch.org/get-started/pytorch-2.0/#pytorch-2x-faster-more-pythonic-and-as-dynamic-as-ever

重要なものとして、現在提供されているtorch.compileの機能を最大限活かせるのはCPU、NVIDIAのVoltaとAmpere世代のGPUのみになっています。他のGPUでは使おうとすると警告が出てきます。ただ、私が試した範囲では警告がでるだけで現状では使えないわけではなさそうです。

また、私が使ったときに感じた注意点としては

  1. おそらくforwardとbackwardで別々にコンパイルが走るので、forward、backwardの両方とも最初は遅い
  2. 実行が遅いのは最初の1回目だけでなく、最初の数回の呼び出しが遅いケースがある
  3. Google ColabなどでCellの実行を一度止めて再度実行しようとするとエラーがでて、ランタイムの再起動をしないと復帰できないケースがある

1と2は時間計測をしようとしたときにはまったポイントです。まず、1に関してです。torch.compileの直後の呼び出しはコンパイルが走るので、遅いというのはドキュメントにも書かれています。ただ、forwadだけがおそいのかな?と思ってました。ただ、torch.compileの説明をちゃんと読めば想像できると思いますが、backwardも最初の実行のときは遅いです。なので、時間を計測するときは、forwardとbackwardの両方が遅いことを考慮して測定する必要があります。

次に2です。これに関しては私が見逃してなければドキュメントに明示的に説明が書いてあるわけではないのですが、チュートリアルの時間計測の結果や実際に測定してみるとどうやら遅いのは最初の1回目の呼び出しだけではなく、そのあと数回遅いケースが存在しているようです。このため、計算時間の測定の際、最初に数回呼び出してから測定しないとtorch.compileを使ったときよりも遅いみたいな誤った結果になるので注意してください。

最後に3です。これは何度かはまったのですが、どこかにキャッシュか何か残っているのか変なところで止めるとコード的には問題ないはずなのに、エラーがでるようになるときがあります。調べても解決方法が分からなかったので、エラーがでるようになったらランタイムごと再起動するということを何度かやりました。Google Colabでやるときは注意してください。

終わりに

今回はtorch.compileについて使ってみたのでまとめを書きました。去年発表があったときから楽しみにしていましたが、期待通りのものとなっていました。なにより使い方が非常に簡単なことには驚きました。

今回はT4とV100の測定結果でしたが、A100だとどうなるのかも今度測定しようかなと思っています。

この記事がみなさんのお役に立てば幸いです。

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

このサイトは reCAPTCHA と Google によって保護されていますプライバシーポリシー利用規約 申し込み。

The reCAPTCHA verification period has expired. Please reload the page.