<?xml version="1.0" encoding="UTF-8"?><rss version="2.0"
	xmlns:content="http://purl.org/rss/1.0/modules/content/"
	xmlns:wfw="http://wellformedweb.org/CommentAPI/"
	xmlns:dc="http://purl.org/dc/elements/1.1/"
	xmlns:atom="http://www.w3.org/2005/Atom"
	xmlns:sy="http://purl.org/rss/1.0/modules/syndication/"
	xmlns:slash="http://purl.org/rss/1.0/modules/slash/"
	>

<channel>
	<title>JAX - まったり勉強ノート</title>
	<atom:link href="https://www.mattari-benkyo-note.com/tag/jax/feed/" rel="self" type="application/rss+xml" />
	<link>https://www.mattari-benkyo-note.com</link>
	<description>shuの日々の勉強まとめ</description>
	<lastBuildDate>Tue, 16 Nov 2021 23:00:29 +0000</lastBuildDate>
	<language>ja</language>
	<sy:updatePeriod>
	hourly	</sy:updatePeriod>
	<sy:updateFrequency>
	1	</sy:updateFrequency>
	<generator>https://wordpress.org/?v=6.8.3</generator>
<site xmlns="com-wordpress:feed-additions:1">189243286</site>	<item>
		<title>JAXとPyTorch、どっちが速いのか検証してみた</title>
		<link>https://www.mattari-benkyo-note.com/2021/11/17/ssw-jax-vs-torch/</link>
					<comments>https://www.mattari-benkyo-note.com/2021/11/17/ssw-jax-vs-torch/#respond</comments>
		
		<dc:creator><![CDATA[Shuji Suzuki (shu)]]></dc:creator>
		<pubDate>Tue, 16 Nov 2021 23:00:28 +0000</pubDate>
				<category><![CDATA[プログラミング]]></category>
		<category><![CDATA[JAX]]></category>
		<category><![CDATA[pytorch]]></category>
		<guid isPermaLink="false">https://www.mattari-benkyo-note.com/?p=212</guid>

					<description><![CDATA[<p>高速化が趣味＆仕事なので、最近よく目にするJAXの速度が気になってました。このため、今回は日ごろ使っているPyTorchと比較したので、その結果のまとめを紹介します。 結論 結果だけ知りたい方が多いだろうと思ったので先に [&#8230;]</p>
<p>The post <a href="https://www.mattari-benkyo-note.com/2021/11/17/ssw-jax-vs-torch/">JAXとPyTorch、どっちが速いのか検証してみた</a> first appeared on <a href="https://www.mattari-benkyo-note.com">まったり勉強ノート</a>.</p>]]></description>
										<content:encoded><![CDATA[<p>高速化が趣味＆仕事なので、最近よく目にするJAXの速度が気になってました。このため、今回は日ごろ使っているPyTorchと比較したので、その結果のまとめを紹介します。</p>



<h2 class="wp-block-heading">結論</h2>



<p>結果だけ知りたい方が多いだろうと思ったので先に結論から書くと、私のPyTorch力では力及ばず、今回の検証では</p>



<p class="has-text-align-center"><strong>JAXのほうがPyTorchの2.2倍速い</strong></p>



<p>という結果でした。ここから詳しく評価について説明します。</p>



<h2 class="wp-block-heading">評価方法</h2>



<p>今回、JAXとPyTorchを比較するにあたり、この前紹介したSmooth Smith Watermanのコードを利用しました。Smooth Smith Watermanについて知りたいという方は以下の記事をご覧ください。</p>



<figure class="wp-block-embed is-type-wp-embed is-provider-まったり勉強ノート wp-block-embed-まったり勉強ノート"><div class="wp-block-embed__wrapper">
<blockquote class="wp-embedded-content" data-secret="pTiJItncc4"><a href="https://www.mattari-benkyo-note.com/2021/11/08/jax-smooth-sw/">JAXによる微分可能Smith Watermanアルゴリズムのパフォーマンス測定</a></blockquote><iframe class="wp-embedded-content" sandbox="allow-scripts" security="restricted"  title="&#8220;JAXによる微分可能Smith Watermanアルゴリズムのパフォーマンス測定&#8221; &#8212; まったり勉強ノート" src="https://www.mattari-benkyo-note.com/2021/11/08/jax-smooth-sw/embed/#?secret=pTiJItncc4" data-secret="pTiJItncc4" width="600" height="338" frameborder="0" marginwidth="0" marginheight="0" scrolling="no"></iframe>
</div></figure>



<p>この記事で紹介したJAXコードは論文の著者が頑張って高速化した結果なため、十分最適化された結果であるという認識です。このため、今回はPyTorchのコードを私が作成し、測定を行いました。</p>



<p>今回の検証コードはここに置いてあります。</p>



<p><a href="https://github.com/shu65/blog-jax-notebook/blob/main/Smooth_Smith_Waterman_PyTorch_vs_JAX.ipynb">https://github.com/shu65/blog-jax-notebook/blob/main/Smooth_Smith_Waterman_PyTorch_vs_JAX.ipynb</a></p>



<p>今回は3パターン実装したので、それぞれについて順番に紹介します。</p>



<p>実行はGoogle Colab上で行いました。この際、使用したGPUやライブラリのバージョンは以下の通りです。</p>



<ul class="wp-block-list"><li>GPU: K80</li><li>CUDA: 11.2</li><li>PyTorch: 1.10.0</li><li>JAX: 0.2.21</li></ul>



<p>また、Smooth Smith Watermanは2つの配列の最大長を100, 120とした64個の配列のペアを入力に与えて測定しました。今回は少し測定誤差が入ることも考慮して10回平均で比較します。</p>



<h2 class="wp-block-heading">JAXのコードをそのままPyTorchにする</h2>



<p>JAXのコードで利用されているアルゴリズムはPyTorchでも十分速くなるようにみえました。このため、まずはそのまま適用してみました。PyTorchのコードとしては以下の通りです。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="ssw-torch-v0" data-lang="Python"><code>class SwTorch(nn.Module):
    def __init__(self, unroll=2, NINF=-1e30, device=&quot;cpu&quot;):
        super(SwTorch, self).__init__()
        self.unroll = unroll
        self.NINF = torch.tensor(NINF, device=device)
        self.device = device

    def _make_mask(self, score_matrix, lengths):
        a,b = score_matrix.shape
        real_a = lengths[0]
        real_b = lengths[1]
        mask = (torch.arange(a, device=self.device) &lt; real_a)[:,None] & (torch.arange(b, device=self.device) &lt; real_b)[None,:]
        return mask

    def _rotate(self, score_matrix):
        a,b = score_matrix.shape
        n,m = (a+b-1),(a+b)//2
        ar = torch.flip(torch.arange(a, device=self.device), [0])[:, None]
        br = torch.arange(b, device=self.device)[None,:]
        i,j = (br-ar)+(a-1),(ar+br)//2
        rotated_score_matrix = torch.full([n,m], self.NINF, dtype=score_matrix.dtype, device=self.device)
        rotated_score_matrix[i, j] = score_matrix
        reverse_idx = (i, j)
        return rotated_score_matrix, reverse_idx

    def _step(self, prev, gap_cell_condition, rotated_score_matrix, gap, temp):
        h2,h1 = prev   # previous two rows of scoring (hij) mtx
        h1_T = self._get_prev_gap_cell_score(
            gap_cell_condition,
            torch.nn.functional.pad(h1[:-1], [1,0], value=self.NINF),
            torch.nn.functional.pad(h1[1:], [0,1], value=self.NINF),
        )
      
        a = h2 + rotated_score_matrix
        g0 = h1 + gap
        g1 = h1_T + gap
        s = rotated_score_matrix
        h0 = torch.stack([a, g0, g1, s], -1)
        h0 = self._soft_maximum(h0, temp, -1)
        return (h1,h0), h0

    def _rotate_in_reverse(self, rotated_dp_matrix, reverse_idx):
        return rotated_dp_matrix[reverse_idx]

    def _logsumexp(self, y, axis):
        y = torch.maximum(y,self.NINF)
        return torch.logsumexp(y, axis=axis)

    def _logsumexp_with_mask(self, y, axis, mask):
        y = torch.maximum(y,self.NINF)
        if axis is None:
          return torch.max(y) + torch.log(torch.sum(mask * torch.exp(y - torch.max(y))))
        else:
          return torch.max(y, axis)[0] + torch.log(torch.sum(mask * torch.exp(y - torch.max(y, axis, keepdims=True)[0]), axis=axis))

    def _soft_maximum(self, x, temp, axis=None):
        return temp*self._logsumexp(x/temp, axis)

    def _soft_maximum_with_mask(self, x, temp, mask, axis=None):
        return temp*self._logsumexp_with_mask(x/temp, axis, mask)

    def _get_prev_gap_cell_score(self, cond, true, false): 
        return cond*true + (1-cond)*false

    def forward(self, score_matrix, lengths, gap=0, temp=1.0):
      mask = self._make_mask(score_matrix, lengths)
      masked_score_matrix = score_matrix + self.NINF * (~mask)
      rotated_score_matrix, reverse_idx = self._rotate(masked_score_matrix)

      a,b = score_matrix.shape
      n,m = rotated_score_matrix.shape
      gap_cell_condition = (torch.arange(n, device=self.device)+a%2)%2
      prev = (torch.full((m,), self.NINF, device=self.device), torch.full((m,), self.NINF, device=self.device))
      rotated_hij = [None for _ in range(n)]
      for i in range(n):
          prev, h = self._step(prev, gap_cell_condition[i], rotated_score_matrix[i], gap, temp)
          rotated_hij[i] = h
      rotated_hij = torch.stack(rotated_hij)
      hij = self._rotate_in_reverse(rotated_hij, reverse_idx)
      score = self._soft_maximum_with_mask(hij, temp, mask=mask, axis=None)
      return score


class BatchSwTorch(nn.Module):
    def __init__(self, unroll=2, NINF=-1e30, device=&quot;cpu&quot;):
        super(BatchSwTorch, self).__init__()
        self.device = device
        self.sw = SwTorch(unroll=unroll, NINF=NINF, device=device)

    def forward(self, batch_score_matrix, batch_lengths, gap=0, temp=1.0):
        n_batches = batch_score_matrix.shape[0]
        ret = torch.empty((n_batches,), dtype=batch_score_matrix.dtype, device=self.device)
        for i in range(n_batches):
          ret[i] = self.sw(batch_score_matrix[i], batch_lengths[i], gap=gap, temp=temp) 
        return ret</code></pre></div>



<p>ちなみに最初はシンプルなコードと比較しようと思ったので、この時点ではまだ<code>torch.jit</code>は使っていません。このコードの結果は以下の通りです。</p>



<figure class="wp-block-table"><table><tbody><tr><td></td><td>平均実行時間 (sec)</td></tr><tr><td>numpy</td><td>34.5</td></tr><tr><td>JAX jit版</td><td>0.0142</td></tr><tr><td>JAXのコードをそのままPyTorchにする</td><td>7.89</td></tr></tbody></table><figcaption> JAXのコードをそのままPyTorchにした場合の結果</figcaption></figure>



<p>見ての通り、JAXが圧倒的。PyTorchもnumpyに比べて速くなってはいるのでGPUを使っている効果が出ていると考えられますが、それ以上にJAXが速い。PyTorchと比較してJAXのほうが556倍も速いという結果でした。JAXのほうがバグっているのか？とも一瞬思ったのですが、ちゃんと正しい答えを出力しているし、nsysでプロファイル結果を取ってみた限りそれっぽい時間で1回の計算が終わっているので、測定ミスでもなさそうでした。</p>



<p>というわけで、圧倒的にPyTorchがこのままでは遅いので、高速化したバージョンを作成して評価したので、次で紹介します。</p>



<h2 class="wp-block-heading">Batchの軸を行列の一番内側にもってくる + torch.jitで高速化する</h2>



<p>PyTorchを普段から使っていると気にならない部分ではありますが、CUDAの高速化のつもりで考えると、 Batchの軸を行列の一番内側にもってくるほうがCUDA的には速くなりそうな気がします。また、JAXはJITを使っているのでPyTorchもJIT使うほうがいいだろうということでJITを使いました。</p>



<p>これに伴ってコードは以下のように変更しました。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="ssw-torch-v1" data-lang="Python"><code>from typing import Tuple


def _make_batch_mask(batch_score_matrix, batch_lengths):
    a, b, batch_size = batch_score_matrix.shape
    real_a = batch_lengths[:, 0]
    real_b = batch_lengths[:, 1]
    mask_a = torch.arange(a, device=batch_score_matrix.device)[:, None].repeat(1, batch_size) &lt; real_a[None, :]
    mask_b = torch.arange(b, device=batch_score_matrix.device)[:, None].repeat(1, batch_size) &lt; real_b[None, :]
    mask = mask_a[:, None] & mask_b[None, :]
    return mask


def _logsumexp(y: torch.Tensor, axis: int, NINF: torch.Tensor) -&gt; torch.Tensor:
    y = torch.maximum(y, NINF)
    return torch.logsumexp(y, dim=axis)


def _logsumexp_with_mask(y: torch.Tensor, axis: int, mask: torch.Tensor, NINF: torch.Tensor) -&gt; torch.Tensor:
    y = torch.maximum(y, NINF)
    return torch.max(y, axis)[0] + torch.log(torch.sum(mask * torch.exp(y - torch.max(y, dim=axis, keepdim=True)[0]), dim=axis))


def _soft_maximum(x: torch.Tensor, temp: torch.Tensor, axis: int, NINF: torch.Tensor) -&gt; torch.Tensor:
    return temp*_logsumexp(x/temp, axis=axis, NINF=NINF)


def _soft_maximum_with_mask(x: torch.Tensor, temp: torch.Tensor, axis: int, mask: torch.Tensor, NINF: torch.Tensor) -&gt; torch.Tensor:
    return temp*_logsumexp_with_mask(x/temp, axis=axis, mask=mask, NINF=NINF)


def _rotate(batch_score_matrix: torch.Tensor, NINF: torch.Tensor, rotated_batch_score_matrix: torch.Tensor) -&gt; Tuple[torch.Tensor, torch.Tenso\
r, torch.Tensor]:
    a, b, batch_size = batch_score_matrix.shape
    n,m = (a+b-1),(a+b)//2
    ar = torch.flip(torch.arange(a, device=batch_score_matrix.device), [0])[:, None]
    br = torch.arange(b, device=batch_score_matrix.device)[None,:]
    i,j = (br-ar)+(a-1),(ar+br)//2      
    rotated_batch_score_matrix[:, :, :] = NINF
    rotated_batch_score_matrix[i, j, :] = batch_score_matrix                                                                                         
    return rotated_batch_score_matrix, i, j


def _rotate_in_reverse(rotated_dp_matrix, i, j):                                                                                                  
    return rotated_dp_matrix[i, j]


def _get_prev_gap_cell_score(cond, true, false):
    return cond*true + (1-cond)*false

@torch.jit.script
def _step(h2, h1, gap_cell_condition, rotated_score_matrix, gap, temp, NINF, prev_gap_cell_true, prev_gap_cell_false):
    prev_gap_cell_true[1:, :] = h1[:-1, :]
    prev_gap_cell_false[:-1, :] = h1[1:, :]
    h1_T = _get_prev_gap_cell_score(
        gap_cell_condition,
        prev_gap_cell_true,
        prev_gap_cell_false,
    )
    a = h2 + rotated_score_matrix
    g0 = h1 + gap
    g1 = h1_T + gap
    s = rotated_score_matrix
    h0 = torch.stack([a, g0, g1, s], -1)
    h0 = _soft_maximum(h0, temp, axis=-1, NINF=NINF)
    return h1, h0, h0


@torch.jit.script
def _step_loop(init_h1, init_h0, gap_cell_condition, rotated_batch_score_matrix, gap, temp, NINF, prev_gap_cell_true, prev_gap_cell_false):
    n, _, _ = rotated_batch_score_matrix.shape
    rotated_hij = torch.empty((n, init_h1.shape[0], init_h1.shape[1]), dtype=init_h1.dtype, device=init_h1.device)
    h1 = init_h1
    h0 = init_h0
    h1[:, :] = NINF
    h0[:, :] = NINF
    for i in range(n):
        h1, h0, h = _step(h1, h0, gap_cell_condition=gap_cell_condition[i], rotated_score_matrix=rotated_batch_score_matrix[i], gap=gap, temp=temp, NINF=NINF, prev_gap_cell_true=prev_gap_cell_true, prev_gap_cell_false=prev_gap_cell_false,)
        rotated_hij[i] = h                                                                                                 
    return rotated_hij

@torch.jit.script
def batch_sw_func(batch_score_matrix, batch_lengths, gap, temp, NINF, rotated_batch_score_matrix, init_h1, init_h0, prev_gap_cell_true, prev_gap_cell_false):
    transposed_batch_score_matrix = batch_score_matrix.permute(1, 2, 0)
    mask = _make_batch_mask(transposed_batch_score_matrix, batch_lengths)
    masked_batch_score_matrix = transposed_batch_score_matrix + NINF * (~mask)
    rotated_batch_score_matrix, reverse_idx_i, reverse_idx_j = _rotate(masked_batch_score_matrix, NINF=NINF, rotated_batch_score_matrix=rotated_batch_score_matrix)
    a, b, batch_size = transposed_batch_score_matrix.shape
    n, m, _ = rotated_batch_score_matrix.shape
    gap_cell_condition = (torch.arange(n, device=rotated_batch_score_matrix.device)+a%2)%2                                           
    rotated_hij = _step_loop(init_h1, init_h0, gap_cell_condition, rotated_batch_score_matrix, gap, temp, prev_gap_cell_true=prev_gap_cell_true, prev_gap_cell_false=prev_gap_cell_false, NINF=NINF)                                                                                                   
    hij = _rotate_in_reverse(rotated_hij, reverse_idx_i, reverse_idx_j)
    score = _soft_maximum_with_mask(hij.reshape(a*b,batch_size), temp=temp, mask=mask.reshape(a*b, batch_size), axis=0, NINF=NINF)
    return score
</code></pre></div>



<p>次に紹介する高速化の関係で一時領域も引数として与えていますが気にしないでください。後ほど説明します。</p>



<p>このコードの測定結果も加えると以下の通りです。</p>



<figure class="wp-block-table"><table><tbody><tr><td></td><td> 平均実行時間 (sec) </td></tr><tr><td> numpy </td><td> 34.5 </td></tr><tr><td> JAX jit版 </td><td> 0.0142 </td></tr><tr><td> JAXのコードをそのままPyTorchにする </td><td> 7.89 </td></tr><tr><td> Batchの軸を行列の一番内側にもってくる + torch.jitで高速化する </td><td>0.0655</td></tr></tbody></table><figcaption> Batchの軸を行列の一番内側にもってくる + torch.jitで高速化した結果</figcaption></figure>



<p>正直、この時点でJAXと並ぶだろうとやる前は思っていたのですが、JAXのほうがまだPyTorchの4.6倍速いという結果でした。JAX速い・・・。でも、できることはまだある！ということでもう一工夫やります。</p>



<h2 class="wp-block-heading">CUDA Graphsを使う</h2>



<p>PyTorchのコードのプロファイル結果を見るとかなり実行時間の短いCUDA Kernelが大量に実行されているという状態でした。このため、CUDA Kernelの実行のオーバーヘッドがかなり入っているのでは？と考えて、これを削減するCUDA Graphsを使ってみます。</p>



<p>CUDA Graphsが何かわからない方はこちらの記事をご覧ください。</p>



<figure class="wp-block-embed is-type-wp-embed is-provider-まったり勉強ノート wp-block-embed-まったり勉強ノート"><div class="wp-block-embed__wrapper">
<blockquote class="wp-embedded-content" data-secret="f9Rflq7N83"><a href="https://www.mattari-benkyo-note.com/2021/10/23/pytorch-cuda-graphs/">PyTorch 1.10の新機能「CUDA Graphs」のパフォーマンスを測定してみる</a></blockquote><iframe class="wp-embedded-content" sandbox="allow-scripts" security="restricted"  title="&#8220;PyTorch 1.10の新機能「CUDA Graphs」のパフォーマンスを測定してみる&#8221; &#8212; まったり勉強ノート" src="https://www.mattari-benkyo-note.com/2021/10/23/pytorch-cuda-graphs/embed/#?secret=f9Rflq7N83" data-secret="f9Rflq7N83" width="600" height="338" frameborder="0" marginwidth="0" marginheight="0" scrolling="no"></iframe>
</div></figure>



<p>さて、CUDA Graphsで実行するようのコードとしては一つ前のコードの<code>torch.jit.trace</code>でコンパイルしたものを利用します。CUDA Graphsで実行できるようにするために、一時領域を一部入力として入れていました。</p>



<p>この測定結果は以下の通りです。</p>



<figure class="wp-block-table"><table><tbody><tr><td></td><td> 平均実行時間 (sec) </td></tr><tr><td> numpy </td><td> 34.5 </td></tr><tr><td> JAX jit版 </td><td> 0.0142 </td></tr><tr><td> JAXのコードをそのままPyTorchにする </td><td> 7.89 </td></tr><tr><td> Batchの軸を行列の一番内側にもってくる + torch.jitで高速化する </td><td>0.0655</td></tr><tr><td> CUDA Graphsを使う </td><td>0.0324</td></tr></tbody></table><figcaption> Batchの軸を行列の一番内側にもってくる + torch.jitで高速化した結果</figcaption></figure>



<p>なんと、JAXのほうがPyTorchのコードよりも2.2倍速いという結果でした。CUDA Graphsでも勝てないなんて・・・。</p>



<h2 class="wp-block-heading">現状のPyTorchのコードの敗因</h2>



<p>自分なりにnsysを使ってプロファイルとって結果を見た印象では主に以下の二つの原因があると考えています。</p>



<ol class="wp-block-list"><li>PyTorchのJITがJAXに比べてあまりfuseしてくれない</li><li>そもそも入力サイズがGPUで実行するには小さすぎる</li></ol>



<p>1についてはPyTorchとJAXの二つのコードを見てみるとJAXのコードのほうがJITされて1つのCUDA Karnelの実行時間が長い印象でした。JAX、PyTorchどちらもどの関数がfuseされたかをどうやってみるのかわからないので憶測になりますが、おそらくJAXのほうがより多くの処理を1つのCUDA Karnelまとめてくれていて、結果としてCUDA Karnelの実行数が減り、CUDA Karnelの実行オーバーヘッド小さくなったためJAXのほうが速くなった、ということを考えています。</p>



<p>2に関しては、そもそもGPUで実行するには入力サイズが小さすぎる印象です。実際配列の長さや配列ペアの数を大きくしてもPyTorchのコードはあまり実行時間が増加しないことを確認しています。じゃあ、入力サイズを大きくすればいいのでは？とも思ったのですが、今回のSmooth Smith Watermanではあと2，3倍くらいにはしてもよさそうですが、どこまでできるかは問題依存なため、ひとまずそのままにしておきました。なんとなく実際に使うときもあと数倍くらいは大きくできそうだけど、100倍とかはあまりつかわなさそうだなと思っています。ただ、この辺りはいろいろ意見が分かれそうかなと思っています。</p>



<h2 class="wp-block-heading">終わりに</h2>



<p>今回は個人的に前々から気になっていた「JAXって速いの？」という問いに答えるための検証の一環で行いました。結果はまさかのPyTorchと比べてこんなに差がでるとは、という感じした。ただ、PyTorchのJITはまだ使い慣れていない感があるので、何か高速化のアイディアが浮かんだら再チャレンジしたいと思います。</p><p>The post <a href="https://www.mattari-benkyo-note.com/2021/11/17/ssw-jax-vs-torch/">JAXとPyTorch、どっちが速いのか検証してみた</a> first appeared on <a href="https://www.mattari-benkyo-note.com">まったり勉強ノート</a>.</p>]]></content:encoded>
					
					<wfw:commentRss>https://www.mattari-benkyo-note.com/2021/11/17/ssw-jax-vs-torch/feed/</wfw:commentRss>
			<slash:comments>0</slash:comments>
		
		
		<post-id xmlns="com-wordpress:feed-additions:1">212</post-id>	</item>
		<item>
		<title>JAXによる微分可能Smith Watermanアルゴリズムのパフォーマンス測定</title>
		<link>https://www.mattari-benkyo-note.com/2021/11/08/jax-smooth-sw/</link>
					<comments>https://www.mattari-benkyo-note.com/2021/11/08/jax-smooth-sw/#respond</comments>
		
		<dc:creator><![CDATA[Shuji Suzuki (shu)]]></dc:creator>
		<pubDate>Sun, 07 Nov 2021 23:07:44 +0000</pubDate>
				<category><![CDATA[プログラミング]]></category>
		<category><![CDATA[Bioinformatics]]></category>
		<category><![CDATA[JAX]]></category>
		<guid isPermaLink="false">https://www.mattari-benkyo-note.com/?p=151</guid>

					<description><![CDATA[<p>最近微分可能な Smith Waterman アルゴリズムというものとJAXのコードが公開されました。今回はこれらを参考に、JAXの勉強がてら何パターンかSmith Watermanアルゴリズムを実装して測定してみたので [&#8230;]</p>
<p>The post <a href="https://www.mattari-benkyo-note.com/2021/11/08/jax-smooth-sw/">JAXによる微分可能Smith Watermanアルゴリズムのパフォーマンス測定</a> first appeared on <a href="https://www.mattari-benkyo-note.com">まったり勉強ノート</a>.</p>]]></description>
										<content:encoded><![CDATA[<p>最近微分可能な Smith Waterman アルゴリズムというものとJAXのコードが公開されました。今回はこれらを参考に、JAXの勉強がてら何パターンかSmith Watermanアルゴリズムを実装して測定してみたので、その結果のまとめの紹介となります。</p>



<p>論文は以下のものです。</p>



<p>[1] <strong>Petti, S., Bhattacharya, N., Rao, R., Dauparas, J., Thomas, N., Zhou, J., … Ovchinnikov, S. (2021). End-to-end learning of multiple sequence alignments with differentiable Smith-Waterman. BioRxiv, 2021.10.23.465204. https://doi.org/10.1101/2021.10.23.465204</strong></p>



<p>また、著者の実装はこちらに公開されています。</p>



<p><a href="https://github.com/spetti/SMURF">https://github.com/spetti/SMURF</a></p>



<p>今回は主に私がJAXの勉強をしたかったということもあり、いくつか実装を作ってパフォーマンスを測定して、「JAXって速いの？」という疑問にある程度答えられればと思い、記事を書いています。今回の実装はすべてこちらにありますので参考にしてみてください。</p>



<p><a href="https://github.com/shu65/blog-jax-notebook/blob/main/JAX_Smooth_Smith_Waterman.ipynb">https://github.com/shu65/blog-jax-notebook/blob/main/JAX_Smooth_Smith_Waterman.ipynb</a></p>



<p>また、計算時間測定はすべてGoogle Cloab上のCPUで行っています。</p>



<h2 class="wp-block-heading">論文概要</h2>



<p>この論文では教師なし学習によるコンタクト予測において、前処理で使われるSmith Watermanアルゴリズムを微分可能なものに置き換えて、Smith Watermanアルゴリズムの中で使われるパラメータ（置換スコア）も含めて学習する手法 SMURFを提案した論文です。論文自体にはコンタクト予測の精度なども書かれていますが、微分可能なSmith Watermanの紹介をメインにしたいため、今回は割愛します。</p>



<h2 class="wp-block-heading"> 微分可能な Smith Waterman アルゴリズム「Smooth Smith Waterman」とは？ </h2>



<p>Smith Watermanアルゴリズムを微分可能にするためには、微分可能ではない関数を微分可能なものに置き換えて、近似することで実現します。まずは大本のSmith Watermanアルゴリズムの説明をしたあと、微分可能なものに変更する方法を紹介していきます。</p>



<h3 class="wp-block-heading"> Smith Watermanアルゴリズムとは</h3>



<p>Smith Watermanアルゴリズム は2つのDNAやタンパク質の配列の類似度、特にローカルアライメントのスコアと呼ばれる類似度を計算するアルゴリズムです。ローカルアラインメントとは2配列間の類似度の高い部分的な文字列を発見するときに使われます。これは以下のように行列の要素を計算する動的計画法 (Dynamic Programming, DP) により計算します。</p>



<p>$$ H_{i0} = H_{0j} = 0 \\ H_{ij} = \max\begin{cases} H_{i-1,j-1} + s(a_i,b_j), \\ H_{i-k,j} + g, \\ H_{i,j-l} +g, \\<br>0 \\<br>\end{cases}    $$</p>



<p>\(  s(a_i,b_j)  \) は 配列Aのi番目の文字と配列Bのj番目の文字の置換スコアと呼ばれるもので、同じ文字、もしくは類似度の高い文字のペアはプラス、類似度の低い文字のペアはマイナスにするのが一般的です。また、 \(  g \) はギャップペナルティと呼ばれるもので、1文字飛ばしのペナルティを表します。 </p>



<h3 class="wp-block-heading">Smith Watermanアルゴリズムを微分可能にする</h3>



<p>先ほど説明したとおり、Smith Watermanアルゴリズムではmax関数があります。この部分が微分可能ではないため、SmithWatermanアルゴリズムは微分可能ではありません。このため、このmax関数を微分可能な何等かの関数で置き換える必要があります。この論文ではmax関数を「logsumexp」で置き換えることで微分可能にします。</p>



<p> logsumexpはmax関数を滑らかに近似するための関数として使われる関数で、微分可能な関数です。このためmax関数を logsumexp に置き換えればSmith Watermanアルゴリズムの計算全体が微分可能になります。論文中ではこの微分可能なSmtth Watermanアルゴリズムを「Smooth Smith Waterman」と呼んでいます。<br>なぜlogsumexpがmax関数の近似になるかを詳しく知りたい方は、こちらのブログ記事がわかりやすかったのでお勧めです。</p>



<figure class="wp-block-embed is-type-wp-embed is-provider-hatena-blog wp-block-embed-hatena-blog"><div class="wp-block-embed__wrapper">
<iframe class="wp-embedded-content" sandbox="allow-scripts" security="restricted" title="Smooth maximumを作って遊ぼう - Corollaryは必然に。" src="https://hatenablog-parts.com/embed?url=https%3A%2F%2Fcorollary2525.hatenablog.com%2Fentry%2F2020%2F07%2F25%2F221823#?secret=g4M2CLQP1b" data-secret="g4M2CLQP1b" scrolling="no" frameborder="0"></iframe>
</div></figure>



<h3 class="wp-block-heading">numpyによるシンプルな Smooth Smith Waterman </h3>



<p>後ほどJAXの実装を示しますが、高速化したあとのJAXのコードは初見では分かりづらいため、先にシンプルなnumpyの実装を示します。この実装は著者の実装にあわせつつ、numpyとのパフォーマンス実装をするために以下のようにしています。</p>



<ol class="wp-block-list"><li>配列Aと配列Bの全文字ペアの置換スコアの行列 <code>score_matrix</code>（置換スコアの行列のサイズは|A|×|B|）と2つの配列の長さ<code>lengths</code>、その他のパラメータを入力とする</li><li>この記事では勾配を計算できないnumpyとの比較のために、著者実装では<code>score_matrix</code>の勾配を返すのに対して、今回の記事では2配列の最大スコアを返す。</li></ol>



<p>Smith Watermanアルゴリズムをご存じの方は戸惑うかもしれませんが、Smooth Smith Watermanアルゴリズムでは<code>score_matrix</code>の勾配を出力として返す関数になっています。このため、あらかじめ配列Aと配列Bの全文字ペアの置換スコアの行列を用意して入力にします。<br>このため、配列Aと配列Bは入力に出てきませんし、PAMやBLOSUMなどの置換スコアもでてきません。</p>



<p>このSmooth Smith Watermanアルゴリズムをシンプルにnumpyで実装すると以下の通りになります。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="smooth-sw-np" data-lang="Python"><code>def sw_np(NINF=-1e30):
    
    def _logsumexp(y, axis):
        y = np.maximum(y,NINF)
        return y.max(axis) + np.log(np.sum(np.exp(y - y.max(axis, keepdims=True)), axis=axis))

    def _soft_maximum(x, temp, axis=None):
        return temp*_logsumexp(x/temp, axis)

    def _sw(score_matrix, lengths, gap=0, temp=1.0):
        real_a, real_b = lengths
        hij = np.full((real_a + 1, real_b + 1), fill_value=NINF, dtype=np.float32)
        for i in range(real_a):
            for j in range(real_b):
                s = score_matrix[i, j]
                m = hij[i, j] + s
                g0 = hij[i + 1, j] + gap
                g1 = hij[i, j + 1] + gap

                h = np.stack([m, g0, g1, s], -1)
                hij[i + 1, j + 1] = _soft_maximum(h, temp=temp, axis=-1)
        hij = hij[1:, 1:]
        score = _soft_maximum(hij, temp=temp)
        return score
    return _sw</code></pre></div>



<p>こちらの実装で通常のSmith Watermanアルゴリズムと違う点は以下の2点です</p>



<ol class="wp-block-list"><li>DPの行列の要素更新のところでmax関数をlogsumexpで実装した<code>_soft_maximum()</code>という関数に置き換えている。</li><li>DPの行列の各要素を入れるところで最大値を取るところで0以下にならないようにmax関数の入力の一つとして0を入れるところを、置換スコア(<code>s</code>)を入れている。</li></ol>



<p>1が微分可能とするための改良した箇所です。一方、2に関しては私が読み飛ばしてしまった可能性がありますが、特に論文中に説明が見当たらなかった変更点です。なんとなくSmooth Smith Watermanアルゴリズムを使って深層学習のモデル更新をするときにうまく勾配が置換スコアに流れるようにするためでは？と思っているのですが、未確認な状態です。何かご存じの方がいれば教えていただければと思っています。</p>



<p>この実装を実行したときの計算時間を<code>%time</code>で測定すると以下の通りです。</p>



<pre class="wp-block-preformatted">CPU times: user 735 ms, sys: 4 ms, total: 739 ms
Wall time: 743 ms
</pre>



<h2 class="wp-block-heading"> JAXを使ったSmooth Smith アルゴリズム</h2>



<p>ここからnumpyの部分をJAXに置き換えてSmooth Smith Watermanアルゴリズムを実装し、徐々に改良していくという順番で説明していきます。まずはJAXをご存じない方のためにJAXを簡単に説明します。</p>



<h3 class="wp-block-heading">JAXってなに？</h3>



<p>JAXはPythonやnumpyの関数を微分可能なものにし、XLAというコンパイラを使ってGPUやTPUで実行で実行できるようにしたライブラリです。<br>JAXでは勾配が計算できることと、jitをはじめとした様々な高速化する仕組みが用意されているため、最近論文で利用しているケースが増えてきた印象です。特に今回紹介した論文のような、従来では微分可能でなかった計算を微分可能なものに置き換え、深層学習のモデル学習の中で利用するという手法の実装にJAXが使われるのをよく目にします。今回紹介したものの他には BRAXがあります。</p>



<p><a href="https://github.com/google/brax">https://github.com/google/brax</a></p>



<h3 class="wp-block-heading">単純なJAX実装</h3>



<p>JAXはnumpyの関数と同じAPIの関数があるので、まずはそれをそのまま利用してみます。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="smooth-sw-jax-v0" data-lang="Python"><code>def sw_v0(NINF=-1e30):
    
    def _logsumexp(y, axis):
        y = jnp.maximum(y,NINF)
        return jax.nn.logsumexp(y, axis=axis)

    def _soft_maximum(x, temp, axis=None):
        return temp*_logsumexp(x/temp, axis)

    def _sw(score_matrix, lengths, gap=0, temp=1.0):
        real_a, real_b = lengths
        hij = jnp.full((real_a + 1, real_b + 1), fill_value=NINF, dtype=jnp.float32)
        for i in range(real_a):
            for j in range(real_b):
                s = score_matrix[i, j]
                m = hij[i, j] + s
                g0 = hij[i + 1, j] + gap
                g1 = hij[i, j + 1] + gap
                h = jnp.stack([m, g0, g1, s], -1)
                hij = hij.at[i + 1, j + 1].set(_soft_maximum(h, -1))
        hij = hij[1:, 1:]
        score = _soft_maximum(hij)
        return score
    return _sw</code></pre></div>



<p>これも動くには動くのですが、あまりにも遅いため、まったく使い物になりません。このためJAXを使う際はもう少し真面目に高速に動くアルゴリズムで実装する必要があります。</p>



<h3 class="wp-block-heading">Striped Smith-Watermanベースの実装</h3>



<p>論文でも紹介されているStriped Smith-Watermanベースで実装してみます。 Smith-Waterman アルゴリズムをSIMDなどで並列化する方法として、依存関係のないDP行列の斜めのセルを同時に埋めていくという方法がしばしば取られます。詳しくはこちらをご覧ください。</p>



<p><strong>Farrar, M. (2007). Striped Smith-Waterman speeds database searches six times over other SIMD implementations. Bioinformatics (Oxford, England), 23(2), 156–161. https://doi.org/10.1093/bioinformatics/btl582</strong></p>



<p>これをJAXで実装するにあたり、著者はDP行列を回転させ、依存関係のない斜めに並んだセルを横1列に並べて計算するようにしています。</p>



<figure class="wp-block-image size-large"><img fetchpriority="high" decoding="async" width="1024" height="440" src="https://www.mattari-benkyo-note.com/wp-content/uploads/2021/11/striped_sw-1024x440.png" alt="" class="wp-image-175" srcset="https://www.mattari-benkyo-note.com/wp-content/uploads/2021/11/striped_sw-1024x440.png 1024w, https://www.mattari-benkyo-note.com/wp-content/uploads/2021/11/striped_sw-300x129.png 300w, https://www.mattari-benkyo-note.com/wp-content/uploads/2021/11/striped_sw-768x330.png 768w, https://www.mattari-benkyo-note.com/wp-content/uploads/2021/11/striped_sw-1536x660.png 1536w, https://www.mattari-benkyo-note.com/wp-content/uploads/2021/11/striped_sw-2048x880.png 2048w" sizes="(max-width: 1024px) 100vw, 1024px" /><figcaption>DP行列の回転 ([1] Fig. 7)</figcaption></figure>



<p>こうすることで内側のforループをJAXのベクトルの計算で実行できるようにしています。個人的にはここがこの論文の最大の貢献な気がしています。具体的にJAXで実装すると以下の通りです。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="smooth-sw-jax-v1" data-lang="Python"><code>def sw_v1(unroll=2, NINF=-1e30):
        
    def _rotate(score_matrix):
        a,b = score_matrix.shape
        n,m = (a+b-1),(a+b)//2
        ar,br = jnp.arange(a)[::-1,None], jnp.arange(b)[None,:]
        i,j = (br-ar)+(a-1),(ar+br)//2
        rotated_score_matrix = jnp.full([n,m],NINF).at[i,j].set(score_matrix)
        reverse_idx = (i, j)
        return rotated_score_matrix, reverse_idx
    
    def _rotate_in_reverse(rotated_dp_matrix, reverse_idx):
        return rotated_dp_matrix[reverse_idx]

    def _logsumexp(y, axis):
        y = jnp.maximum(y,NINF)
        return jax.nn.logsumexp(y, axis=axis)

    def _soft_maximum(x, temp, axis=None):
        return temp*_logsumexp(x/temp, axis)

    def _step(prev, gap_cell_condition, rotated_score_matrix, gap, temp):
        h2,h1 = prev   # previous two rows of scoring (hij) mtx
        h1_T = jax.lax.cond(
            gap_cell_condition,
            lambda x: jnp.pad(x[:-1], [1,0], constant_values=(NINF,NINF)),
            lambda x: jnp.pad(x[1:], [0,1], constant_values=(NINF,NINF)),
            h1,
        )

        a = h2 + rotated_score_matrix
        g0 = h1 + gap
        g1 = h1_T + gap
        s = rotated_score_matrix

        h0 = jnp.stack([a, g0, g1, s], -1)
        h0 = _soft_maximum(h0, temp, -1)
        return (h1,h0), h0

    def _sw(score_matrix, lengths, gap=0, temp=1.0):
        rotated_score_matrix, reverse_idx = _rotate(score_matrix)
        
        a,b = score_matrix.shape
        n,m = rotated_score_matrix.shape
        
        gap_cell_condition = (jnp.arange(n)+a%2)%2
        prev = (jnp.full(m, NINF), jnp.full(m, NINF))
        rotated_hij = []
        for i in range(n):
            prev, h = _step(prev, gap_cell_condition[i], rotated_score_matrix[i], gap, temp)
            rotated_hij.append(h)
        rotated_hij = jnp.stack(rotated_hij)
        hij = _rotate_in_reverse(rotated_hij, reverse_idx)
        score = _soft_maximum(hij, temp=temp)
        return score
    return _sw</code></pre></div>



<p>この実装では置換行列<code>score_matrix</code>を回転させて、DP行列のセルを埋めていき、そのあとDP行列元の方向に戻すということをしています。<br>回転させたときの注意点として、DP行列の列番号が偶数か奇数かでギャップペナルティのスコアを加算するセルの相対座標が変わります。このため、<code>jax.lax.cond()</code>を利用して使うセルを分岐しています。</p>



<p>この実装をそのまま実行したときとjitを利用したときの計算時間は以下の通りです。</p>



<pre class="wp-block-preformatted">jax default first call
CPU times: user 17.7 s, sys: 177 ms, total: 17.8 s
Wall time: 17.8 s
jax default second call
CPU times: user 17.6 s, sys: 153 ms, total: 17.7 s
Wall time: 17.7 s


jax jit first call
CPU times: user 2min 20s, sys: 715 ms, total: 2min 21s
Wall time: 2min 20s
jax jit second call
CPU times: user 1.98 ms, sys: 0 ns, total: 1.98 ms
Wall time: 1.81 ms


</pre>



<p>jitなしでそのまま実行するのはnumpyよりもかなり遅い印象です。またjitを使う場合も最初の呼び出しはコンパイルが走ることもあり、jitなしに比べるとさらに遅くなっています。さすがに1回目とはいえ、ここまで時間がかかると使いづらいと思われます。このため、まだ工夫する必要があります。</p>



<h3 class="wp-block-heading">外側のforループをjax.lax.scan()に置き換える</h3>



<p>1つ前の実装で遅い原因がどこか？というとforループです。これを速くする方法としてJAXのforループと類似する処理を実行するための関数を利用します。今回はforループ部分を <code>jax.lax.scan()</code> に置き換えます。</p>



<p>実装は以下の通りです。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="smooth-sw-jax-v2" data-lang="Python"><code>def sw_v2(unroll=2, NINF=-1e30):

    def _rotate(score_matrix):
        a,b = score_matrix.shape
        n,m = (a+b-1),(a+b)//2
        ar,br = jnp.arange(a)[::-1,None], jnp.arange(b)[None,:]
        i,j = (br-ar)+(a-1),(ar+br)//2
        rotated_score_matrix = jnp.full([n,m],NINF).at[i,j].set(score_matrix)
        reverse_idx = (i, j)
        return rotated_score_matrix, reverse_idx

    def _prepare_scan_inputs(score_matrix, rotated_score_matrix, gap, temp):
        def scan_f(prev, scan_xs):
            h2, h1 = prev
            h1_T = jax.lax.cond(
                scan_xs[&quot;gap_cell_condition&quot;],
                lambda x: jnp.pad(x[:-1], [1,0], constant_values=(NINF,NINF)),
                lambda x: jnp.pad(x[1:], [0,1], constant_values=(NINF,NINF)),
                h1,
            )
            a = h2 + scan_xs[&quot;rotated_score_matrix&quot;]
            g0 = h1 + gap
            g1 = h1_T + gap
            s = scan_xs[&quot;rotated_score_matrix&quot;]

            h0 = jnp.stack([a, g0, g1, s], -1)
            h0 = _soft_maximum(h0, temp, -1)
            return (h1,h0), h0
        
        a,b = score_matrix.shape
        n,m = rotated_score_matrix.shape

        scan_xs = {
            &quot;rotated_score_matrix&quot;: rotated_score_matrix,
            &quot;gap_cell_condition&quot;: (jnp.arange(n)+a%2)%2
        }
        scan_init = (jnp.full(m, NINF), jnp.full(m, NINF))
        return scan_f, scan_xs, scan_init

    def _rotate_in_reverse(rotated_dp_matrix, reverse_idx):
        return rotated_dp_matrix[reverse_idx]

    def _logsumexp(y, axis):
        y = jnp.maximum(y,NINF)
        return jax.nn.logsumexp(y, axis=axis)

    def _soft_maximum(x, temp, axis=None):
        return temp*_logsumexp(x/temp, axis)
    
    def _sw(score_matrix, lengths, gap=0, temp=1.0):
        rotated_score_matrix, reverse_idx = _rotate(score_matrix)
        scan_f, scan_xs, scan_init = _prepare_scan_inputs(score_matrix, rotated_score_matrix, gap, temp)
        rotated_hij = jax.lax.scan(scan_f, scan_init, scan_xs, unroll=unroll)[-1]
        hij = _rotate_in_reverse(rotated_hij, reverse_idx)
        score = _soft_maximum(hij, temp, axis=None)
        return score
    return _sw</code></pre></div>



<p>この実装でforループがなくなりました。実行した結果は以下の通りです。</p>



<pre class="wp-block-preformatted">jax default first call
CPU times: user 739 ms, sys: 18 ms, total: 757 ms
Wall time: 758 ms
jax default second call
CPU times: user 666 ms, sys: 1.98 ms, total: 668 ms
Wall time: 671 ms

jax jit first call
CPU times: user 1 s, sys: 5.01 ms, total: 1.01 s
Wall time: 1.01 s
jax jit second call
CPU times: user 339 µs, sys: 989 µs, total: 1.33 ms
Wall time: 1.14 ms
</pre>



<p>先ほどに比べるとjitなしでも速くなりましたが、jitありの1回目の実行もかなり速くなった印象です。これなら十分使えるのではないか？と思っています。</p>



<h3 class="wp-block-heading">jax.lax.condの置き換え</h3>



<p>著者の実装では<code> jax.lax.cond()</code>を使わずに加算と乗算だけで実装されています。試しに同様の実装にしたバージョンも示します。具体的な実装は以下の通りです。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="smooth-sw-jax-v3" data-lang="Python"><code>def sw_v3(unroll=2, NINF=-1e30):

    def _rotate(score_matrix):
        a,b = score_matrix.shape
        n,m = (a+b-1),(a+b)//2
        ar,br = jnp.arange(a)[::-1,None], jnp.arange(b)[None,:]
        i,j = (br-ar)+(a-1),(ar+br)//2
        rotated_score_matrix = jnp.full([n,m],NINF).at[i,j].set(score_matrix)
        reverse_idx = (i, j)
        return rotated_score_matrix, reverse_idx

    def _prepare_scan_inputs(score_matrix, rotated_score_matrix, gap, temp):
        def scan_f(prev, scan_xs):
            h2, h1 = prev
            h1_T = _get_prev_gap_cell_score(
                scan_xs[&quot;gap_cell_condition&quot;],
                jnp.pad(h1[:-1], [1,0], constant_values=(NINF,NINF)),
                jnp.pad(h1[1:], [0,1], constant_values=(NINF,NINF)),
            )
            a = h2 + scan_xs[&quot;rotated_score_matrix&quot;]
            g0 = h1 + gap
            g1 = h1_T + gap
            s = scan_xs[&quot;rotated_score_matrix&quot;]

            h0 = jnp.stack([a, g0, g1, s], -1)
            h0 = _soft_maximum(h0, temp, -1)
            return (h1,h0), h0
        
        a,b = score_matrix.shape
        n,m = rotated_score_matrix.shape

        scan_xs = {
            &quot;rotated_score_matrix&quot;: rotated_score_matrix,
            &quot;gap_cell_condition&quot;: (jnp.arange(n)+a%2)%2
        }
        scan_init = (jnp.full(m, NINF), jnp.full(m, NINF))
        return scan_f, scan_xs, scan_init

    def _rotate_in_reverse(rotated_dp_matrix, reverse_idx):
        return rotated_dp_matrix[reverse_idx]

    def _logsumexp(y, axis):
        y = jnp.maximum(y,NINF)
        return jax.nn.logsumexp(y, axis=axis)

    def _soft_maximum(x, temp, axis=None):
        return temp*_logsumexp(x/temp, axis)

    def _get_prev_gap_cell_score(cond, true, false): 
        return cond*true + (1-cond)*false
    
    def _sw(score_matrix, lengths, gap=0, temp=1.0):
        rotated_score_matrix, reverse_idx = _rotate(score_matrix)
        scan_f, scan_xs, scan_init = _prepare_scan_inputs(score_matrix, rotated_score_matrix, gap, temp)
        rotated_hij = jax.lax.scan(scan_f, scan_init, scan_xs, unroll=unroll)[-1]
        hij = _rotate_in_reverse(rotated_hij, reverse_idx)
        score = _soft_maximum(hij, temp, axis=None)
        return score
    return _sw</code></pre></div>



<p>この時のパフォーマンスは以下の通りです。</p>



<pre class="wp-block-preformatted">jax defaujax default first call
CPU times: user 599 ms, sys: 1.99 ms, total: 601 ms
Wall time: 608 ms
jax default second call
CPU times: user 599 ms, sys: 3.02 ms, total: 602 ms
Wall time: 607 ms

jax jit first call
CPU times: user 940 ms, sys: 2.01 ms, total: 942 ms
Wall time: 947 ms
jax jit second call
CPU times: user 4.9 ms, sys: 0 ns, total: 4.9 ms
Wall time: 3.41 ms</pre>



<p>jitなしの時は速くなっている印象ですが、jitありのときは少し遅くなっています。ただ、何度か実行してみると逆転することもあるようなので、誤差の範囲かもしれません。また、JAX特有のパフォーマンス測定のお作法をし忘れている可能性もあります。もしご存じの方があればコメントいただければと思います。</p>



<h3 class="wp-block-heading">Batch実行用の実装</h3>



<p>著者のSmooth Smith Watermanは2つの配列のペアを1つだけ実行するのではなく、複数のペアをまとめて実行することを想定されて実装してあります。ここでも同様に複数のペアをまとめて計算するのもやってみようと思います。</p>



<h4 class="wp-block-heading">簡単な実装</h4>



<p>複数のペアをまとめて実装する際、ペア毎に配列の長さが違っても動作するようにします。このため、置換スコアのうち必要な部分だけmaskするようにします。</p>



<p>実装は以下の通りです。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="smooth-sw-jax-v4" data-lang="Python"><code>def sw_v4(unroll=2, NINF=-1e30):
    
    def _make_mask(score_matrix, lengths):
        a,b = score_matrix.shape
        real_a, real_b = lengths
        mask = (jnp.arange(a) &lt; real_a)[:,None] * (jnp.arange(b) &lt; real_b)[None,:]
        return mask

    def _rotate(score_matrix):
        a,b = score_matrix.shape
        n,m = (a+b-1),(a+b)//2
        ar,br = jnp.arange(a)[::-1,None], jnp.arange(b)[None,:]
        i,j = (br-ar)+(a-1),(ar+br)//2
        rotated_score_matrix = jnp.full([n,m],NINF).at[i,j].set(score_matrix)
        reverse_idx = (i, j)
        return rotated_score_matrix, reverse_idx

    def _prepare_scan_inputs(score_matrix, rotated_score_matrix, gap, temp):
        def scan_f(prev, scan_xs):
            h2, h1 = prev
            h1_T = _get_prev_gap_cell_score(
                scan_xs[&quot;gap_cell_condition&quot;],
                jnp.pad(h1[:-1], [1,0], constant_values=(NINF,NINF)),
                jnp.pad(h1[1:], [0,1], constant_values=(NINF,NINF)),
            )
            a = h2 + scan_xs[&quot;rotated_score_matrix&quot;]
            g0 = h1 + gap
            g1 = h1_T + gap
            s = scan_xs[&quot;rotated_score_matrix&quot;]

            h0 = jnp.stack([a, g0, g1, s], -1)
            h0 = _soft_maximum(h0, temp, -1)
            return (h1,h0), h0
        
        a,b = score_matrix.shape
        n,m = rotated_score_matrix.shape

        scan_xs = {
            &quot;rotated_score_matrix&quot;: rotated_score_matrix,
            &quot;gap_cell_condition&quot;: (jnp.arange(n)+a%2)%2
        }
        scan_init = (jnp.full(m, NINF), jnp.full(m, NINF))
        return scan_f, scan_xs, scan_init

    def _rotate_in_reverse(rotated_dp_matrix, reverse_idx):
        return rotated_dp_matrix[reverse_idx]

    def _logsumexp(y, axis):
        y = jnp.maximum(y,NINF)
        return jax.nn.logsumexp(y, axis=axis)

    def _logsumexp_with_mask(y, axis, mask):
        y = jnp.maximum(y,NINF)
        return y.max(axis) + jnp.log(jnp.sum(mask * jnp.exp(y - y.max(axis, keepdims=True)), axis=axis))

    def _soft_maximum(x, temp, axis=None):
        return temp*_logsumexp(x/temp, axis)

    def _soft_maximum_with_mask(x, temp, mask, axis=None):
        return temp*_logsumexp_with_mask(x/temp, axis, mask)

    def _get_prev_gap_cell_score(cond, true, false): 
        return cond*true + (1-cond)*false
    
    def _sw(score_matrix, lengths, gap=0, temp=1.0):
        mask = _make_mask(score_matrix, lengths)
        masked_score_matrix = score_matrix + NINF * (1 - mask)
        rotated_score_matrix, reverse_idx = _rotate(masked_score_matrix)
        scan_f, scan_xs, scan_init = _prepare_scan_inputs(score_matrix, rotated_score_matrix, gap, temp)
        rotated_hij = jax.lax.scan(scan_f, scan_init, scan_xs, unroll=unroll)[-1]
        hij = _rotate_in_reverse(rotated_hij, reverse_idx)
        score = _soft_maximum_with_mask(hij, temp, mask=mask, axis=None)
        return score
    return _sw</code></pre></div>



<p>この実装をペアの数分、forループで計算していくようにします。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="batch-sw-v0" data-lang="Python"><code>def batch_sw_v0(NINF=-1e30):
    def _batch_sw(batch_score_matrix, batch_lengths, gap=0, temp=1.0):
        n_batches = batch_score_matrix.shape[0]
        sw_func = jax.jit(sw_v4())
        ret = [sw_func(batch_score_matrix[i], batch_lengths[i], gap, temp) 
               for i in range(n_batches)]
        return jnp.array(ret)
    return _batch_sw</code></pre></div>



<p>これを実行すると計算時間は以下の通りでした。</p>



<pre class="wp-block-preformatted">batch jax default first call
CPU times: user 1.31 s, sys: 13 ms, total: 1.33 s
Wall time: 1.32 s
batch jax default second call
CPU times: user 1.3 s, sys: 5.02 ms, total: 1.3 s
Wall time: 1.3 s

batch jax default first call
CPU times: user 10min 43s, sys: 2.99 s, total: 10min 46s
Wall time: 10min 45s
batch jax default second call
CPU times: user 279 ms, sys: 2 ms, total: 281 ms
Wall time: 281 ms</pre>



<p>forループでそのまま実装すると、jitありのときはやはり1度目の実行に非常に時間がかかるようです。このため、この部分を速くします。</p>



<h4 class="wp-block-heading">forループをjax.vmap()で置き換える</h4>



<p>ここではforループを<code> jax.vmap()</code> で置き換えます。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="batch-sw-v1" data-lang="Python"><code>def batch_sw_v1(unroll=2, NINF=-1e30):
    sw_func = sw_v4(unroll=unroll, NINF=NINF)
    batch_sw_func = jax.vmap(sw_func, (0, 0, None, None))
    return batch_sw_func</code></pre></div>



<p>この時の計算時間は以下の通りです。</p>



<pre class="wp-block-preformatted">batch jax default first call
CPU times: user 1.04 s, sys: 11 ms, total: 1.05 s
Wall time: 1.03 s
batch jax default second call
CPU times: user 1.04 s, sys: 7.97 ms, total: 1.04 s
Wall time: 1.01 s

batch jax default first call
CPU times: user 1.51 s, sys: 10 ms, total: 1.52 s
Wall time: 1.5 s
batch jax default second call
CPU times: user 120 ms, sys: 9 µs, total: 120 ms
Wall time: 97 ms</pre>



<p>先ほどと比べるとかなり高速化できました。ちなみにこれがほぼ著者の実装と同じものになります。</p>



<h2 class="wp-block-heading">結果まとめ</h2>



<p>ここまでの計算時間の結果をまとめると以下の通りです。</p>



<figure class="wp-block-table"><table><tbody><tr><td></td><td>jitなし1回目</td><td>jitなし2回目</td><td>jitあり1回目</td><td>jitあり2回目</td></tr><tr><td>numpy</td><td>739 ms</td><td>&#8211;</td><td>&#8211;</td><td>&#8211;</td></tr><tr><td>Striped Smith-Watermanベースの実装</td><td>17.8 s</td><td>17.7 s</td><td>2min 21s</td><td>1.98 ms</td></tr><tr><td>外側のforループをjax.lax.scan()に置き換える</td><td>757 ms</td><td>668 ms</td><td>1.01 s</td><td>1.33 ms</td></tr><tr><td>jax.lax.condの置き換え</td><td>601 ms</td><td>602 ms</td><td>942 ms</td><td>4.9 ms</td></tr></tbody></table><figcaption>Smooth Smith Watermanの計算時間まとめ</figcaption></figure>



<figure class="wp-block-table"><table><tbody><tr><td></td><td> jitなし1回目 </td><td> jitなし2回目 </td><td> jitあり1回目 </td><td>jitあり2回目 </td></tr><tr><td>簡単な実装</td><td>1.33 s</td><td>1.3 s</td><td>10min 46s</td><td>281 ms</td></tr><tr><td>forループをjax.vmap()で置き換える</td><td>1.05 s</td><td>1.04 s</td><td>1.52 s</td><td>120 ms</td></tr></tbody></table><figcaption> Batch Smooth Smith Watermanの計算時間まとめ </figcaption></figure>



<p>各実装を比較するとforループのありなしでかなり実行時間やコンパイル時間が変化していることがわかります。このためnumpyの実装をそのままJAXにすればそれだけで速くなることはまずなさそうです。また、何も考えずに実装してjitを使うと、コンパイル時間が長すぎて使い物にならないというケースが多そうな気がしています。このため、JAXを使いこなすにはどのような計算は遅いかを理解して使うことが重要そうな印象です。</p>



<h2 class="wp-block-heading">おわりに</h2>



<p>今回、初のJAX使用だったため、パフォーマンス測定や高速化にはもっとやり方があるかもしれないと思っています。もしお気づきの点がありましたら気兼ねなくコメントいただければと思っています。</p>



<p>次はできればJAXとPyTorchのjitとどちらが速いのか試せればと思っています。</p><p>The post <a href="https://www.mattari-benkyo-note.com/2021/11/08/jax-smooth-sw/">JAXによる微分可能Smith Watermanアルゴリズムのパフォーマンス測定</a> first appeared on <a href="https://www.mattari-benkyo-note.com">まったり勉強ノート</a>.</p>]]></content:encoded>
					
					<wfw:commentRss>https://www.mattari-benkyo-note.com/2021/11/08/jax-smooth-sw/feed/</wfw:commentRss>
			<slash:comments>0</slash:comments>
		
		
		<post-id xmlns="com-wordpress:feed-additions:1">151</post-id>	</item>
	</channel>
</rss>
