JAXとPyTorch、どっちが速いのか検証してみた

高速化が趣味&仕事なので、最近よく目にするJAXの速度が気になってました。このため、今回は日ごろ使っているPyTorchと比較したので、その結果のまとめを紹介します。

結論

結果だけ知りたい方が多いだろうと思ったので先に結論から書くと、私のPyTorch力では力及ばず、今回の検証では

JAXのほうがPyTorchの2.2倍速い

という結果でした。ここから詳しく評価について説明します。

評価方法

今回、JAXとPyTorchを比較するにあたり、この前紹介したSmooth Smith Watermanのコードを利用しました。Smooth Smith Watermanについて知りたいという方は以下の記事をご覧ください。

この記事で紹介したJAXコードは論文の著者が頑張って高速化した結果なため、十分最適化された結果であるという認識です。このため、今回はPyTorchのコードを私が作成し、測定を行いました。

今回の検証コードはここに置いてあります。

https://github.com/shu65/blog-jax-notebook/blob/main/Smooth_Smith_Waterman_PyTorch_vs_JAX.ipynb

今回は3パターン実装したので、それぞれについて順番に紹介します。

実行はGoogle Colab上で行いました。この際、使用したGPUやライブラリのバージョンは以下の通りです。

  • GPU: K80
  • CUDA: 11.2
  • PyTorch: 1.10.0
  • JAX: 0.2.21

また、Smooth Smith Watermanは2つの配列の最大長を100, 120とした64個の配列のペアを入力に与えて測定しました。今回は少し測定誤差が入ることも考慮して10回平均で比較します。

JAXのコードをそのままPyTorchにする

JAXのコードで利用されているアルゴリズムはPyTorchでも十分速くなるようにみえました。このため、まずはそのまま適用してみました。PyTorchのコードとしては以下の通りです。

class SwTorch(nn.Module):
    def __init__(self, unroll=2, NINF=-1e30, device="cpu"):
        super(SwTorch, self).__init__()
        self.unroll = unroll
        self.NINF = torch.tensor(NINF, device=device)
        self.device = device

    def _make_mask(self, score_matrix, lengths):
        a,b = score_matrix.shape
        real_a = lengths[0]
        real_b = lengths[1]
        mask = (torch.arange(a, device=self.device) < real_a)[:,None] & (torch.arange(b, device=self.device) < real_b)[None,:]
        return mask

    def _rotate(self, score_matrix):
        a,b = score_matrix.shape
        n,m = (a+b-1),(a+b)//2
        ar = torch.flip(torch.arange(a, device=self.device), [0])[:, None]
        br = torch.arange(b, device=self.device)[None,:]
        i,j = (br-ar)+(a-1),(ar+br)//2
        rotated_score_matrix = torch.full([n,m], self.NINF, dtype=score_matrix.dtype, device=self.device)
        rotated_score_matrix[i, j] = score_matrix
        reverse_idx = (i, j)
        return rotated_score_matrix, reverse_idx

    def _step(self, prev, gap_cell_condition, rotated_score_matrix, gap, temp):
        h2,h1 = prev   # previous two rows of scoring (hij) mtx
        h1_T = self._get_prev_gap_cell_score(
            gap_cell_condition,
            torch.nn.functional.pad(h1[:-1], [1,0], value=self.NINF),
            torch.nn.functional.pad(h1[1:], [0,1], value=self.NINF),
        )
      
        a = h2 + rotated_score_matrix
        g0 = h1 + gap
        g1 = h1_T + gap
        s = rotated_score_matrix
        h0 = torch.stack([a, g0, g1, s], -1)
        h0 = self._soft_maximum(h0, temp, -1)
        return (h1,h0), h0

    def _rotate_in_reverse(self, rotated_dp_matrix, reverse_idx):
        return rotated_dp_matrix[reverse_idx]

    def _logsumexp(self, y, axis):
        y = torch.maximum(y,self.NINF)
        return torch.logsumexp(y, axis=axis)

    def _logsumexp_with_mask(self, y, axis, mask):
        y = torch.maximum(y,self.NINF)
        if axis is None:
          return torch.max(y) + torch.log(torch.sum(mask * torch.exp(y - torch.max(y))))
        else:
          return torch.max(y, axis)[0] + torch.log(torch.sum(mask * torch.exp(y - torch.max(y, axis, keepdims=True)[0]), axis=axis))

    def _soft_maximum(self, x, temp, axis=None):
        return temp*self._logsumexp(x/temp, axis)

    def _soft_maximum_with_mask(self, x, temp, mask, axis=None):
        return temp*self._logsumexp_with_mask(x/temp, axis, mask)

    def _get_prev_gap_cell_score(self, cond, true, false): 
        return cond*true + (1-cond)*false

    def forward(self, score_matrix, lengths, gap=0, temp=1.0):
      mask = self._make_mask(score_matrix, lengths)
      masked_score_matrix = score_matrix + self.NINF * (~mask)
      rotated_score_matrix, reverse_idx = self._rotate(masked_score_matrix)

      a,b = score_matrix.shape
      n,m = rotated_score_matrix.shape
      gap_cell_condition = (torch.arange(n, device=self.device)+a%2)%2
      prev = (torch.full((m,), self.NINF, device=self.device), torch.full((m,), self.NINF, device=self.device))
      rotated_hij = [None for _ in range(n)]
      for i in range(n):
          prev, h = self._step(prev, gap_cell_condition[i], rotated_score_matrix[i], gap, temp)
          rotated_hij[i] = h
      rotated_hij = torch.stack(rotated_hij)
      hij = self._rotate_in_reverse(rotated_hij, reverse_idx)
      score = self._soft_maximum_with_mask(hij, temp, mask=mask, axis=None)
      return score


class BatchSwTorch(nn.Module):
    def __init__(self, unroll=2, NINF=-1e30, device="cpu"):
        super(BatchSwTorch, self).__init__()
        self.device = device
        self.sw = SwTorch(unroll=unroll, NINF=NINF, device=device)

    def forward(self, batch_score_matrix, batch_lengths, gap=0, temp=1.0):
        n_batches = batch_score_matrix.shape[0]
        ret = torch.empty((n_batches,), dtype=batch_score_matrix.dtype, device=self.device)
        for i in range(n_batches):
          ret[i] = self.sw(batch_score_matrix[i], batch_lengths[i], gap=gap, temp=temp) 
        return ret

ちなみに最初はシンプルなコードと比較しようと思ったので、この時点ではまだtorch.jitは使っていません。このコードの結果は以下の通りです。

平均実行時間 (sec)
numpy34.5
JAX jit版0.0142
JAXのコードをそのままPyTorchにする7.89
JAXのコードをそのままPyTorchにした場合の結果

見ての通り、JAXが圧倒的。PyTorchもnumpyに比べて速くなってはいるのでGPUを使っている効果が出ていると考えられますが、それ以上にJAXが速い。PyTorchと比較してJAXのほうが556倍も速いという結果でした。JAXのほうがバグっているのか?とも一瞬思ったのですが、ちゃんと正しい答えを出力しているし、nsysでプロファイル結果を取ってみた限りそれっぽい時間で1回の計算が終わっているので、測定ミスでもなさそうでした。

というわけで、圧倒的にPyTorchがこのままでは遅いので、高速化したバージョンを作成して評価したので、次で紹介します。

Batchの軸を行列の一番内側にもってくる + torch.jitで高速化する

PyTorchを普段から使っていると気にならない部分ではありますが、CUDAの高速化のつもりで考えると、 Batchの軸を行列の一番内側にもってくるほうがCUDA的には速くなりそうな気がします。また、JAXはJITを使っているのでPyTorchもJIT使うほうがいいだろうということでJITを使いました。

これに伴ってコードは以下のように変更しました。

from typing import Tuple


def _make_batch_mask(batch_score_matrix, batch_lengths):
    a, b, batch_size = batch_score_matrix.shape
    real_a = batch_lengths[:, 0]
    real_b = batch_lengths[:, 1]
    mask_a = torch.arange(a, device=batch_score_matrix.device)[:, None].repeat(1, batch_size) < real_a[None, :]
    mask_b = torch.arange(b, device=batch_score_matrix.device)[:, None].repeat(1, batch_size) < real_b[None, :]
    mask = mask_a[:, None] & mask_b[None, :]
    return mask


def _logsumexp(y: torch.Tensor, axis: int, NINF: torch.Tensor) -> torch.Tensor:
    y = torch.maximum(y, NINF)
    return torch.logsumexp(y, dim=axis)


def _logsumexp_with_mask(y: torch.Tensor, axis: int, mask: torch.Tensor, NINF: torch.Tensor) -> torch.Tensor:
    y = torch.maximum(y, NINF)
    return torch.max(y, axis)[0] + torch.log(torch.sum(mask * torch.exp(y - torch.max(y, dim=axis, keepdim=True)[0]), dim=axis))


def _soft_maximum(x: torch.Tensor, temp: torch.Tensor, axis: int, NINF: torch.Tensor) -> torch.Tensor:
    return temp*_logsumexp(x/temp, axis=axis, NINF=NINF)


def _soft_maximum_with_mask(x: torch.Tensor, temp: torch.Tensor, axis: int, mask: torch.Tensor, NINF: torch.Tensor) -> torch.Tensor:
    return temp*_logsumexp_with_mask(x/temp, axis=axis, mask=mask, NINF=NINF)


def _rotate(batch_score_matrix: torch.Tensor, NINF: torch.Tensor, rotated_batch_score_matrix: torch.Tensor) -> Tuple[torch.Tensor, torch.Tenso\
r, torch.Tensor]:
    a, b, batch_size = batch_score_matrix.shape
    n,m = (a+b-1),(a+b)//2
    ar = torch.flip(torch.arange(a, device=batch_score_matrix.device), [0])[:, None]
    br = torch.arange(b, device=batch_score_matrix.device)[None,:]
    i,j = (br-ar)+(a-1),(ar+br)//2      
    rotated_batch_score_matrix[:, :, :] = NINF
    rotated_batch_score_matrix[i, j, :] = batch_score_matrix                                                                                         
    return rotated_batch_score_matrix, i, j


def _rotate_in_reverse(rotated_dp_matrix, i, j):                                                                                                  
    return rotated_dp_matrix[i, j]


def _get_prev_gap_cell_score(cond, true, false):
    return cond*true + (1-cond)*false

@torch.jit.script
def _step(h2, h1, gap_cell_condition, rotated_score_matrix, gap, temp, NINF, prev_gap_cell_true, prev_gap_cell_false):
    prev_gap_cell_true[1:, :] = h1[:-1, :]
    prev_gap_cell_false[:-1, :] = h1[1:, :]
    h1_T = _get_prev_gap_cell_score(
        gap_cell_condition,
        prev_gap_cell_true,
        prev_gap_cell_false,
    )
    a = h2 + rotated_score_matrix
    g0 = h1 + gap
    g1 = h1_T + gap
    s = rotated_score_matrix
    h0 = torch.stack([a, g0, g1, s], -1)
    h0 = _soft_maximum(h0, temp, axis=-1, NINF=NINF)
    return h1, h0, h0


@torch.jit.script
def _step_loop(init_h1, init_h0, gap_cell_condition, rotated_batch_score_matrix, gap, temp, NINF, prev_gap_cell_true, prev_gap_cell_false):
    n, _, _ = rotated_batch_score_matrix.shape
    rotated_hij = torch.empty((n, init_h1.shape[0], init_h1.shape[1]), dtype=init_h1.dtype, device=init_h1.device)
    h1 = init_h1
    h0 = init_h0
    h1[:, :] = NINF
    h0[:, :] = NINF
    for i in range(n):
        h1, h0, h = _step(h1, h0, gap_cell_condition=gap_cell_condition[i], rotated_score_matrix=rotated_batch_score_matrix[i], gap=gap, temp=temp, NINF=NINF, prev_gap_cell_true=prev_gap_cell_true, prev_gap_cell_false=prev_gap_cell_false,)
        rotated_hij[i] = h                                                                                                 
    return rotated_hij

@torch.jit.script
def batch_sw_func(batch_score_matrix, batch_lengths, gap, temp, NINF, rotated_batch_score_matrix, init_h1, init_h0, prev_gap_cell_true, prev_gap_cell_false):
    transposed_batch_score_matrix = batch_score_matrix.permute(1, 2, 0)
    mask = _make_batch_mask(transposed_batch_score_matrix, batch_lengths)
    masked_batch_score_matrix = transposed_batch_score_matrix + NINF * (~mask)
    rotated_batch_score_matrix, reverse_idx_i, reverse_idx_j = _rotate(masked_batch_score_matrix, NINF=NINF, rotated_batch_score_matrix=rotated_batch_score_matrix)
    a, b, batch_size = transposed_batch_score_matrix.shape
    n, m, _ = rotated_batch_score_matrix.shape
    gap_cell_condition = (torch.arange(n, device=rotated_batch_score_matrix.device)+a%2)%2                                           
    rotated_hij = _step_loop(init_h1, init_h0, gap_cell_condition, rotated_batch_score_matrix, gap, temp, prev_gap_cell_true=prev_gap_cell_true, prev_gap_cell_false=prev_gap_cell_false, NINF=NINF)                                                                                                   
    hij = _rotate_in_reverse(rotated_hij, reverse_idx_i, reverse_idx_j)
    score = _soft_maximum_with_mask(hij.reshape(a*b,batch_size), temp=temp, mask=mask.reshape(a*b, batch_size), axis=0, NINF=NINF)
    return score

次に紹介する高速化の関係で一時領域も引数として与えていますが気にしないでください。後ほど説明します。

このコードの測定結果も加えると以下の通りです。

平均実行時間 (sec)
numpy 34.5
JAX jit版 0.0142
JAXのコードをそのままPyTorchにする 7.89
Batchの軸を行列の一番内側にもってくる + torch.jitで高速化する 0.0655
Batchの軸を行列の一番内側にもってくる + torch.jitで高速化した結果

正直、この時点でJAXと並ぶだろうとやる前は思っていたのですが、JAXのほうがまだPyTorchの4.6倍速いという結果でした。JAX速い・・・。でも、できることはまだある!ということでもう一工夫やります。

CUDA Graphsを使う

PyTorchのコードのプロファイル結果を見るとかなり実行時間の短いCUDA Kernelが大量に実行されているという状態でした。このため、CUDA Kernelの実行のオーバーヘッドがかなり入っているのでは?と考えて、これを削減するCUDA Graphsを使ってみます。

CUDA Graphsが何かわからない方はこちらの記事をご覧ください。

さて、CUDA Graphsで実行するようのコードとしては一つ前のコードのtorch.jit.traceでコンパイルしたものを利用します。CUDA Graphsで実行できるようにするために、一時領域を一部入力として入れていました。

この測定結果は以下の通りです。

平均実行時間 (sec)
numpy 34.5
JAX jit版 0.0142
JAXのコードをそのままPyTorchにする 7.89
Batchの軸を行列の一番内側にもってくる + torch.jitで高速化する 0.0655
CUDA Graphsを使う 0.0324
Batchの軸を行列の一番内側にもってくる + torch.jitで高速化した結果

なんと、JAXのほうがPyTorchのコードよりも2.2倍速いという結果でした。CUDA Graphsでも勝てないなんて・・・。

現状のPyTorchのコードの敗因

自分なりにnsysを使ってプロファイルとって結果を見た印象では主に以下の二つの原因があると考えています。

  1. PyTorchのJITがJAXに比べてあまりfuseしてくれない
  2. そもそも入力サイズがGPUで実行するには小さすぎる

1についてはPyTorchとJAXの二つのコードを見てみるとJAXのコードのほうがJITされて1つのCUDA Karnelの実行時間が長い印象でした。JAX、PyTorchどちらもどの関数がfuseされたかをどうやってみるのかわからないので憶測になりますが、おそらくJAXのほうがより多くの処理を1つのCUDA Karnelまとめてくれていて、結果としてCUDA Karnelの実行数が減り、CUDA Karnelの実行オーバーヘッド小さくなったためJAXのほうが速くなった、ということを考えています。

2に関しては、そもそもGPUで実行するには入力サイズが小さすぎる印象です。実際配列の長さや配列ペアの数を大きくしてもPyTorchのコードはあまり実行時間が増加しないことを確認しています。じゃあ、入力サイズを大きくすればいいのでは?とも思ったのですが、今回のSmooth Smith Watermanではあと2,3倍くらいにはしてもよさそうですが、どこまでできるかは問題依存なため、ひとまずそのままにしておきました。なんとなく実際に使うときもあと数倍くらいは大きくできそうだけど、100倍とかはあまりつかわなさそうだなと思っています。ただ、この辺りはいろいろ意見が分かれそうかなと思っています。

終わりに

今回は個人的に前々から気になっていた「JAXって速いの?」という問いに答えるための検証の一環で行いました。結果はまさかのPyTorchと比べてこんなに差がでるとは、という感じした。ただ、PyTorchのJITはまだ使い慣れていない感があるので、何か高速化のアイディアが浮かんだら再チャレンジしたいと思います。

コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です