<?xml version="1.0" encoding="UTF-8"?><rss version="2.0"
	xmlns:content="http://purl.org/rss/1.0/modules/content/"
	xmlns:wfw="http://wellformedweb.org/CommentAPI/"
	xmlns:dc="http://purl.org/dc/elements/1.1/"
	xmlns:atom="http://www.w3.org/2005/Atom"
	xmlns:sy="http://purl.org/rss/1.0/modules/syndication/"
	xmlns:slash="http://purl.org/rss/1.0/modules/slash/"
	>

<channel>
	<title>Bioinformatics - まったり勉強ノート</title>
	<atom:link href="https://www.mattari-benkyo-note.com/tag/bioinformatics/feed/" rel="self" type="application/rss+xml" />
	<link>https://www.mattari-benkyo-note.com</link>
	<description>shuの日々の勉強まとめ</description>
	<lastBuildDate>Mon, 19 Dec 2022 22:27:06 +0000</lastBuildDate>
	<language>ja</language>
	<sy:updatePeriod>
	hourly	</sy:updatePeriod>
	<sy:updateFrequency>
	1	</sy:updateFrequency>
	<generator>https://wordpress.org/?v=6.8.3</generator>
<site xmlns="com-wordpress:feed-additions:1">189243286</site>	<item>
		<title>Kaggleの「Open Problems &#8211; Multimodal Single-Cell Integration」の振り返り</title>
		<link>https://www.mattari-benkyo-note.com/2022/11/19/open-problems-multimodal/</link>
					<comments>https://www.mattari-benkyo-note.com/2022/11/19/open-problems-multimodal/#respond</comments>
		
		<dc:creator><![CDATA[Shuji Suzuki (shu)]]></dc:creator>
		<pubDate>Sat, 19 Nov 2022 00:05:55 +0000</pubDate>
				<category><![CDATA[プログラミング]]></category>
		<category><![CDATA[Bioinformatics]]></category>
		<category><![CDATA[kaggle]]></category>
		<category><![CDATA[python]]></category>
		<category><![CDATA[pytorch]]></category>
		<guid isPermaLink="false">https://www.mattari-benkyo-note.com/?p=1447</guid>

					<description><![CDATA[<p>今回は2022/11/15 (日本時間の2022/11/16の朝)まで行われていた「Open Problems &#8211; Multimodal Single-Cell Integration」に参加した際、どうして [&#8230;]</p>
<p>The post <a href="https://www.mattari-benkyo-note.com/2022/11/19/open-problems-multimodal/">Kaggleの「Open Problems – Multimodal Single-Cell Integration」の振り返り</a> first appeared on <a href="https://www.mattari-benkyo-note.com">まったり勉強ノート</a>.</p>]]></description>
										<content:encoded><![CDATA[<p>今回は2022/11/15 (日本時間の2022/11/16の朝)まで行われていた「<a href="https://www.kaggle.com/competitions/open-problems-multimodal/overview" target="_blank" rel="noopener" title="Open Problems - Multimodal Single-Cell Integration">Open Problems &#8211; Multimodal Single-Cell Integration</a>」に参加した際、どうして参加したのかや、参加中にやってよかったこと、課題などを忘れないようにまとめたので、せっかくなので記事にしました。</p>



<p>ちなみに、自分の手法に関してはこちらをご覧ください。</p>



<p><a href="https://www.kaggle.com/competitions/open-problems-multimodal/discussion/366961">https://www.kaggle.com/competitions/open-problems-multimodal/discussion/366961</a></p>



<h1 class="wp-block-heading">Open Problems &#8211; Multimodal Single-Cell Integration とは？</h1>



<p>コンペのサイトのoverviewに書かれていることをざっくり日本語に訳すと「骨髄幹細胞が血液細胞になるにつれて単一の細胞のDNAとRNA、タンパク質がどのように変化するかを予測するコンペ」ということになります。もう少し専門家の人がわかりやすいように説明すると以下の二つのsingle-cellのデータを使ったモデルを構築するコンペです。</p>



<ol class="wp-block-list">
<li>Chromatin accessibility(peak countをTF-IDFで変換したデータ)からRNAの発現量 (library-sizeでノーマライズされたcountデータ)の予測</li>



<li>RNAの発現量 (library-sizeでノーマライズされたcountデータ)からsurface protein levels (dsb でノーマライズされたデータ)の予測</li>
</ol>



<h1 class="wp-block-heading">そもそもなぜ参加したのか？</h1>



<p>参加してがんばろうと思ったきっかけをせっかくなので書いておくと</p>



<ol class="wp-block-list">
<li>Bioチームの同僚がこういうコンペがあるよとslackで教えてくれた (参加したきっかけ)</li>



<li>育休に入る前にやっておくべきことは？ということを会社の先輩パパさんに相談したら、「solo gold medalは取っておけ」と言われた (がんばったきっかけ)</li>
</ol>



<p>ということがあります。特に2つ目は「確かに！」と思いました。なので、今回が初ソロ参加にして、solo gold medalを取る最後のチャンスということで頑張りました。アドバイスをくれた先輩パパさんには感謝しかありません。</p>



<h1 class="wp-block-heading">やってよかったこと</h1>



<p>さて、ここからやってよかったことについて忘れないように書いておきます。</p>



<h2 class="wp-block-heading">まずはとにかく簡単な方法でいいのでsubmitする</h2>



<p>社内のkaggle強い人に前に言われた気がするので、まず意識したことがこれです。やってみて思ったのですが以下のような効果があることを実感しました。</p>



<ol class="wp-block-list">
<li>submitして順位やスコアがでるようになるとモチベーションが上がる。</li>



<li>何かベースラインがあると手法開発がしやすい</li>
</ol>



<p>特に1の効果がすごかった気がします。ちなみに私の場合、KaggleのCodeで公開されてたシンプルな手法をそのままコピペしてsubmitして一番最初のスコアを出しました。シンプルな手法だったため、最初は200位にも入れなかったと記憶していますが、それでもモチベーションはそれまでと比べてすごく上がりました。</p>



<h2 class="wp-block-heading">毎日決めた数の改良を試す</h2>



<p>これはコンペに限らず重要なことだと思いますが、とにかくコンスタントに改良を続けていき、最低限local CV scoreを出すということを意識してました。</p>



<p>今回のコンペでは3，4個の改良を毎日試すことを目標にしてやってました。私の場合、gitの1 branchが1改良になっていて、9/1からがんばり始めて11/15までに最終的には351個branchができていました。なので、通算4から5個くらいの改良を毎日試していたことになります。</p>



<p>このとき改良としてうまくいきそうなものはもちろんですが、思いついたタイミングではあまり筋が良くなさそうだけど他にやることがないというときは、筋が良くないアイディアもダメ元で試すようにしてました。</p>



<p>結果として、やってよかったと感じた理由としては、仮説を立てて試すとうまくいかなかったときの問題への理解度がすごい上がるため、とにかくいろいろ試すことで、そこからいいアイディアをひらめくということが多くあったからです。</p>



<p>今回のコンペで何度か大きくスコアを上げたタイミングがありましたが、総じて何かの失敗から気が付いたアイディアをもとにしたことが多かった印象です。ちなみにどれくらい失敗し続けたのかの体感ですが、1週間スコアが上がらないということがよくあったので、うまくいった改良というのはこれだけ試しても数えるくらいしかなかった印象です。</p>



<h2 class="wp-block-heading">ensembleはすぐに試さない</h2>



<p>これも社内のkaggle強い人に前に言われたことなのですが、ensembleをすぐに試さないというのはやってよかったなと思います。<br>今回はensembleを試さなくてよかったなと感じた理由としては、ensembleはやれば簡単にスコアが上がるのでいいように感じますが、ensembleは始めるとすごいいろいろ試せることがある一方、おそらくそこまで高いスコア上昇をしていなかったような印象があります。</p>



<p>このためensembleを頑張ることに時間使うよりも1つのモデルのスコアが上がるように頑張るという作戦でいたのですが、結果としてそれが良かった印象です。</p>



<h2 class="wp-block-heading">最後まで諦めない</h2>



<p>最後はこれです。特に今回はsolo gold medalが欲しかったので、一人で出ていたのですが、途中全然順位が上がらず何度ももうやめようかなぁと思うタイミングがありました。結果としてそのときダメもとで試した改良やそこから思いついたアイディアがうまくいってスコアを伸ばすということが何度もあったので諦めなくてよかったです。</p>



<h1 class="wp-block-heading">課題</h1>



<p>次にコンペの参加中、もしくは終わったあと振り返って感じた課題的な部分も列挙しておきます。</p>



<h2 class="wp-block-heading">モチベーションの維持が大変（特に一人のとき）</h2>



<p>「最後まで諦めない」のところに似たようなことを書きましたが、とにかくモチベーション維持が大変でした。今まではチームででていたのと、すごい応援してくれる上司がいてくれたりとこの部分はそこまで問題にならなかった印象でした。</p>



<p>今後一人ででるならこの部分はまだまだどうにかしないと最後まで戦うのは難しい気がしています。</p>



<h2 class="wp-block-heading">public scoreがどうしても気になってしまう</h2>



<p>今回のコンペはデータセットの説明を読んだ段階でpublic scoreとprivate scoreのギャップが激しそうだなぁということを思ってたので、それほどpublic scoreを気にしないほうがいいかも？ということを最初思ってました。ただ、それでも最後はpublic scoreを気にしてモデルの改良をしてしまっていました。</p>



<p>コンペのサイトのデータセットの説明のところに詳しくかかれていますが、今回public scoreは4人中一人の最終日一つ前までのデータを、private scoreは4人全員の最後の日のデータを予測するというものでした。この説明を見ると予測する対象が全然違うことがわかります。実際、public scoreの上位陣が軒並みprivate scoreでは順位を落としていました。なので、後から振り返るとやっぱり最初に思った通りpublic scoreをそれほど気にせずモデルの改良をしていて正解だったと思います。ただ、そうはいっても、public scoreの順位が気になってしまって、結局最後はpublic scoreを気にして最終submittionを決めていました。ただ、ふたを開けたらlocal CVがベストなものがprivate scoreも一番よかったので、public scoreを気にしすぎたなぁと思っていました。</p>



<p>ただ、この部分はふたを開けてみないとわからないところなので、難しいポイントな気がしています。</p>



<h2 class="wp-block-heading">論文を読んで最新研究をコンペの期間中に試すことが心理的に難しい</h2>



<p>期間の前半ならまだましですが、どうしても後半になればなるほど、心理的に追い込まれていきます。なので、試すアイディアがないと思いつつも、自分の全然読んだことない論文の手法を試すということが難しかったです。この結果、特に後半はアイディアが前に試したもののちょっと変更したものばかりになり、結果としてそれほど精度が向上しないということが起きました。ただ、この部分一人でやっていると意識しても難しい部分な気がしています。</p>



<p>このため、日ごろから論文を読んでいろいろ手法を勉強し、できれば実装を動かして感触を確かめておくことが重要かなと思いました。実際、今回のコンペではReactomeのpathwayデータを使ったのですが、これは論文の追試のついでにいろいろ試したことがあったからこそできたことだったと思っています。</p>



<p>ちなみにその時の記事はこちらにあります。</p>



<figure class="wp-block-embed is-type-wp-embed is-provider-まったり勉強ノート wp-block-embed-まったり勉強ノート"><div class="wp-block-embed__wrapper">
<blockquote class="wp-embedded-content" data-secret="2XpReJuJFD"><a href="https://www.mattari-benkyo-note.com/2022/05/06/reactome-gene-pathway-data/">ReactomeからPathwayの階層構造とPathwayに関連するGeneのデータを取得する</a></blockquote><iframe class="wp-embedded-content" sandbox="allow-scripts" security="restricted"  title="&#8220;ReactomeからPathwayの階層構造とPathwayに関連するGeneのデータを取得する&#8221; &#8212; まったり勉強ノート" src="https://www.mattari-benkyo-note.com/2022/05/06/reactome-gene-pathway-data/embed/#?secret=1JxmvC3IrQ#?secret=2XpReJuJFD" data-secret="2XpReJuJFD" width="600" height="338" frameborder="0" marginwidth="0" marginheight="0" scrolling="no"></iframe>
</div></figure>



<h2 class="wp-block-heading">ensembleの準備をせずにいると最後の土壇場で困る</h2>



<p>ensembleをすぐに試さないことを意識していた結果、実は今回期間最後のほうでensembleをしようとしたときにいろいろ準備できてないことに気が付いてすごい慌てました。具体的には以下の二つが締め切り1週間前の段階でできていませんでした。</p>



<ol class="wp-block-list">
<li>予測結果をどのように集計して最終的なファイルを作るか？</li>



<li>ensembleした結果をどのように評価するか？</li>
</ol>



<p>結局、1はギリギリできたのですが、2の評価方法はとても間に合いそうになかったのでpublic scoreを見てうまくいっているかいっていないか？を判断していました。ただ、これは評価としては微妙なので、次からはそんなに頑張らないにしても中盤くらいにはensembleの準備はしておこうと思います。</p>



<h1 class="wp-block-heading">最後に</h1>



<p>折角なので個人的な振り返りをblog記事にしました。仕事ではない状態でのkaggle参加は初めてだったのですが、思ったよりも疲れたのと、無事子供がコンペ締め切りの前日に生まれて、これから子育てもあるのででしばらくはコンペにでない気がします。ただ、次出たときに今回のコンペで何を思ったのか忘れないようにしておければと思っていたので、記事にできてよかったです。</p>



<p>これが他の人の参考になれば幸いです。</p><p>The post <a href="https://www.mattari-benkyo-note.com/2022/11/19/open-problems-multimodal/">Kaggleの「Open Problems – Multimodal Single-Cell Integration」の振り返り</a> first appeared on <a href="https://www.mattari-benkyo-note.com">まったり勉強ノート</a>.</p>]]></content:encoded>
					
					<wfw:commentRss>https://www.mattari-benkyo-note.com/2022/11/19/open-problems-multimodal/feed/</wfw:commentRss>
			<slash:comments>0</slash:comments>
		
		
		<post-id xmlns="com-wordpress:feed-additions:1">1447</post-id>	</item>
		<item>
		<title>UniProtのWeb APIを使ってほしい遺伝子のタンパク質配列を取ってくる</title>
		<link>https://www.mattari-benkyo-note.com/2022/05/07/uniprot-web-api-python/</link>
					<comments>https://www.mattari-benkyo-note.com/2022/05/07/uniprot-web-api-python/#respond</comments>
		
		<dc:creator><![CDATA[Shuji Suzuki (shu)]]></dc:creator>
		<pubDate>Sat, 07 May 2022 06:14:02 +0000</pubDate>
				<category><![CDATA[プログラミング]]></category>
		<category><![CDATA[Bioinformatics]]></category>
		<guid isPermaLink="false">https://www.mattari-benkyo-note.com/?p=1430</guid>

					<description><![CDATA[<p>はじめに UniProtで検索した配列がほしくなることが度々あるのですが、数が多いとブラウザで検索して配列をコピペするというのが面倒になってきます。このため、UniProtのWeb APIを使ってまとめて取ってくる方法を [&#8230;]</p>
<p>The post <a href="https://www.mattari-benkyo-note.com/2022/05/07/uniprot-web-api-python/">UniProtのWeb APIを使ってほしい遺伝子のタンパク質配列を取ってくる</a> first appeared on <a href="https://www.mattari-benkyo-note.com">まったり勉強ノート</a>.</p>]]></description>
										<content:encoded><![CDATA[<h2 class="wp-block-heading">はじめに</h2>



<p>UniProtで検索した配列がほしくなることが度々あるのですが、数が多いとブラウザで検索して配列をコピペするというのが面倒になってきます。このため、UniProtのWeb APIを使ってまとめて取ってくる方法を調べたので、今回はそのまとめの記事になります。</p>



<p>今回はUniProtが提供しているWeb APIの中で、クエリを使ってほしいタンパク質の情報を持ってきてPythonで簡単に使えるようにpandasのDataFrameにするところまでを紹介します。</p>



<p>今回のコードはこちらにあります。</p>



<p><a href="https://github.com/shu65/uniprot_web_api_python_example/blob/main/UniProt_Result_to_Pandas_DataFrame.ipynb">https://github.com/shu65/uniprot_web_api_python_example/blob/main/UniProt_Result_to_Pandas_DataFrame.ipynb</a></p>



<h2 class="wp-block-heading">UniProt Web APIについて</h2>



<p>UniProtはタンパク質配列や機能に関するデータを集めた有名なデータベースです。このUniProtではブラウザで検索するのはもちろん、Web APIも提供されています。詳細はここに書かれています。</p>



<p><a href="https://www.uniprot.org/help/api">https://www.uniprot.org/help/api</a></p>



<p>クエリを使った方法はここにまとめられています。詳しくは次で紹介します。</p>



<p><a href="https://www.uniprot.org/help/api_queries">https://www.uniprot.org/help/api_queries</a></p>



<h2 class="wp-block-heading">UniProt Web APIの簡単な使い方</h2>



<p>説明がわかりやすくなるようにまず具体的な例を示します。ここではヒトのTP53をUniProtKB/Swiss-Protから検索してタブ区切りのフォーマットでidとgeneの名前、アミノ酸配列を受け取ってみます。これを実行するURLとしては以下のようになります。</p>



<pre class="wp-block-preformatted">https://www.uniprot.org/uniprot/?query=reviewed:yes+AND+organism:%22Homo%20sapiens%22+AND+gene_exact:TP53&amp;format=tab&amp;columns=id,genes(PREFERRED),sequence</pre>



<p>順番に与えているパラメータを見ていきます。</p>



<h3 class="wp-block-heading">query</h3>



<p>検索するqueryを文字列にして与えます。使えるフィールドは以下に一覧が載っています。</p>



<p><a href="https://www.uniprot.org/help/query-fields" target="_blank" rel="noreferrer noopener">https://www.uniprot.org/help/query-fields</a></p>



<p>普通は複数のフィールドを使ってand検索かor検索を使いたくなるかと思います。and検索、or検索をする場合は以下のようにします。</p>



<ul class="wp-block-list"><li>and検索例:  human かつ antigenのand検索<ol><li>https://www.uniprot.org/uniprot/?query=human%20antigen</li><li>https://www.uniprot.org/uniprot/?query=human%20AND%20antigen</li></ol></li><li>or検索例: human または mouseのor検索<ol><li>https://www.uniprot.org/uniprot/?query=human%20OR%20mouse</li></ol></li></ul>



<p>それ以外のandやor以外についてもいろいろ使えます。詳細はこちらに書かれています。</p>



<p><a href="https://www.uniprot.org/help/text-search">https://www.uniprot.org/help/text-search</a></p>



<h3 class="wp-block-heading">format</h3>



<p>結果のフォーマットを指定するパラメータです。今回示したタブ区切りの場合は<code>tab</code>を指定します。それ以外のフォーマットは以下のページの[Parameter]→[format]に書かれています。</p>



<p><a href="https://www.uniprot.org/help/api_queries">https://www.uniprot.org/help/api_queries</a></p>



<h3 class="wp-block-heading">columns</h3>



<p>デフォルトではUniProtで取得できる一部のデータしか取得できませんが、columnsを指定するとほしいデータを取得できます。上の例ではidとgeneの名前、アミノ酸配列を指定するために以下のようにしていました。</p>



<p><code>columns=id,genes(PREFERRED),sequence</code></p>



<p>カンマでほしいカラム名を区切ることで複数のカラムの情報を得ることができます。他に使えるカラムに関しては以下のページをご覧ください。</p>



<p><a href="https://www.uniprot.org/help/uniprotkb_column_names">https://www.uniprot.org/help/uniprotkb_column_names</a></p>



<p>このページの[Column names as displayed in URL]側を指定するようにします。</p>



<h3 class="wp-block-heading">それ以外のパラメータ</h3>



<p>それ以外に圧縮をするかどうかなどがパラメータとして指定できます。詳しくはこのページの[Parameter]をご覧ください</p>



<p><a href="https://www.uniprot.org/help/api_queries">https://www.uniprot.org/help/api_queries</a></p>



<h2 class="wp-block-heading">UniProtのWeb APIを使って検索した結果をPandasのデータフレームにする</h2>



<p>ここまでわかっていれば簡単にできてしまいますが、UniProtのWeb APIを使って上で示した結果をPandasの<code>DataFrame</code>にするところまでのコードを示します。</p>



<p>requestsを使ってUniProtから検索結果をタブ形式で取得します。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="uniprot_api" data-lang="Python"><code>import requests

url=&quot;https://www.uniprot.org/uniprot/&quot;

params={
    &quot;query&quot;:&quot;reviewed:yes AND organism:\&quot;Homo sapiens\&quot; AND gene_exact:TP53&quot;,
    &quot;format&quot;:&quot;tab&quot;,
    &quot;columns&quot;:&quot;id,genes(PREFERRED),sequence&quot;,
}

response = requests.get(url=url, params=params)
response.raise_for_status()</code></pre></div>



<p>これを以下のようにpandasで読み込むだけで、pandasの<code>DataFrame</code>形式の結果を取得できます。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="uniprot_df" data-lang="Python"><code>import io

import pandas as pd


df = pd.read_csv(io.StringIO(response.text), sep=&quot;\t&quot;)</code></pre></div>



<p>できたpandasの<code>DataFrame</code>をprintした結果は以下の通りです。</p>



<pre class="wp-block-preformatted">    Entry Gene names  (primary )  \
0  P04637                   TP53   

                                            Sequence  
0  MEEPQSDPSVEPPLSQETFSDLWKLLPENNVLSPLPSQAMDDLMLS...  </pre>



<p>Python以外でweb apiの結果を取得するまでの例はこちらにも書かれています。</p>



<p><a href="https://www.uniprot.org/help/api_idmapping">https://www.uniprot.org/help/api_idmapping</a></p>



<p>他の言語でやりたい方は参考にしてみてください。</p>



<h2 class="wp-block-heading">終わりに</h2>



<p>今回はUniProtのWeb APIを使って検索結果をpandasのDataFrameにするところまでの例を紹介しました。当初はほしいアミノ酸配列を取得するのにfastaファイルをがんばってパースしようかな？とおも思ったのですが、調べてみると簡単にWeb APIが使えそうなことがわかったので、試してみました。</p>



<p>機会があれば他の有名データベースのWeb APIもしらべてみようかなと思います。</p><p>The post <a href="https://www.mattari-benkyo-note.com/2022/05/07/uniprot-web-api-python/">UniProtのWeb APIを使ってほしい遺伝子のタンパク質配列を取ってくる</a> first appeared on <a href="https://www.mattari-benkyo-note.com">まったり勉強ノート</a>.</p>]]></content:encoded>
					
					<wfw:commentRss>https://www.mattari-benkyo-note.com/2022/05/07/uniprot-web-api-python/feed/</wfw:commentRss>
			<slash:comments>0</slash:comments>
		
		
		<post-id xmlns="com-wordpress:feed-additions:1">1430</post-id>	</item>
		<item>
		<title>ReactomeからPathwayの階層構造とPathwayに関連するGeneのデータを取得する</title>
		<link>https://www.mattari-benkyo-note.com/2022/05/06/reactome-gene-pathway-data/</link>
					<comments>https://www.mattari-benkyo-note.com/2022/05/06/reactome-gene-pathway-data/#respond</comments>
		
		<dc:creator><![CDATA[Shuji Suzuki (shu)]]></dc:creator>
		<pubDate>Fri, 06 May 2022 07:32:35 +0000</pubDate>
				<category><![CDATA[プログラミング]]></category>
		<category><![CDATA[Bioinformatics]]></category>
		<guid isPermaLink="false">https://www.mattari-benkyo-note.com/?p=1418</guid>

					<description><![CDATA[<p>はじめに だいぶ前ですが以下の論文を読み、ReactomeからどのようにしてPathwayデータを取得するのか？というのが気になっていました。 Elmarakeby, H.A., Hwang, J., Arafeh, R [&#8230;]</p>
<p>The post <a href="https://www.mattari-benkyo-note.com/2022/05/06/reactome-gene-pathway-data/">ReactomeからPathwayの階層構造とPathwayに関連するGeneのデータを取得する</a> first appeared on <a href="https://www.mattari-benkyo-note.com">まったり勉強ノート</a>.</p>]]></description>
										<content:encoded><![CDATA[<h2 class="wp-block-heading">はじめに</h2>



<p>だいぶ前ですが以下の論文を読み、ReactomeからどのようにしてPathwayデータを取得するのか？というのが気になっていました。</p>



<p>Elmarakeby, H.A., Hwang, J., Arafeh, R. et al. Biologically informed deep neural network for prostate cancer discovery. Nature 598, 348–352 (2021). https://doi.org/10.1038/s41586-021-03922-4</p>



<p>こちらの論文自体はReactomeというPathwayデータベースのデータを組み合わせてP-NETという深層学習モデルを構築して前立腺がんの患者のデータに対して適用した研究になります。</p>



<p>この深層学習のモデルを構築する際、ReactomeからPathwayに関するデータを取得して利用しているのですが、ReactomeからどうPathwayのデータを取得して、それをどう加工すればよさそうか？が論文を読んだだけでは分からず、コードを読んで調べたので本日の記事はそのまとめになります。</p>



<p>今回の記事で利用したコードはこちらに置いてあります。</p>



<p><a href="https://github.com/shu65/reactome-example/blob/main/Reactome_gene_pathway_hierarchy_relationship.ipynb">https://github.com/shu65/reactome-example/blob/main/Reactome_gene_pathway_hierarchy_relationship.ipynb</a></p>



<h2 class="wp-block-heading">P-Netで利用されているReactomeのデータ</h2>



<p>P-NetではPathwayの階層情報に基づいて深層学習のモデルを構築しています。この際、利用されているデータは以下のものになります。</p>



<ol class="wp-block-list"><li>Pathway毎に関連するGeneの集合</li><li>Pathwayの階層構造</li><li>Pathwayの名前と種</li></ol>



<p>これらをReactomeから取得する方法を紹介していきます。わかりやすいようにReactomeのリンクやページのスクリーンショットも合わせのせています。これらは2022/05/06時点のものになります。アクセスする時期によってはこれらが変わっている可能性もあるので注意してください。</p>



<h2 class="wp-block-heading">Reactomeから大本のデータを取得する</h2>



<p>Reactomeは各種データを様々なフォーマットでダウンロードできるようになっています。それらはここにまとまっています。</p>



<p><a href="https://reactome.org/download-data">https://reactome.org/download-data</a></p>



<p>今回利用するデータはそれぞれ以下のファイルに書かれた情報から抽出することができます。</p>



<h3 class="wp-block-heading">Pathway毎に関連するGeneの集合</h3>



<p>Pathway毎に関連するGeneの集合のデータは[Specialized data formats]→[Reactome Pathways Gene Set.]から落とすことができます。</p>



<div class="wp-block-image"><figure class="aligncenter size-full is-resized"><img fetchpriority="high" decoding="async" src="https://www.mattari-benkyo-note.com/wp-content/uploads/2022/05/gene_set.png" alt="" class="wp-image-1420" width="425" height="287" srcset="https://www.mattari-benkyo-note.com/wp-content/uploads/2022/05/gene_set.png 789w, https://www.mattari-benkyo-note.com/wp-content/uploads/2022/05/gene_set-300x203.png 300w, https://www.mattari-benkyo-note.com/wp-content/uploads/2022/05/gene_set-768x519.png 768w" sizes="(max-width: 425px) 100vw, 425px" /><figcaption>Gene Setのリンク</figcaption></figure></div>



<p>こちらはGMTフォーマットで書かれたファイルになります。具体的には行ごとに1つのPathwayについて書かれており、カラムはPathwayの名前、Reactome Stable identifiers (ST_ID)が続き、それ以降は登場するGeneが並んでいます。</p>



<h3 class="wp-block-heading">Pathwayの階層構造</h3>



<p>Pathwayの階層構造のデータは[Pathways] → [Pathways hierarchy relationship]から落とすことができます。</p>



<p> </p>



<div class="wp-block-image"><figure class="aligncenter size-full is-resized"><img decoding="async" src="https://www.mattari-benkyo-note.com/wp-content/uploads/2022/05/pathways_hierarchy.png" alt="" class="wp-image-1421" width="392" height="195" srcset="https://www.mattari-benkyo-note.com/wp-content/uploads/2022/05/pathways_hierarchy.png 829w, https://www.mattari-benkyo-note.com/wp-content/uploads/2022/05/pathways_hierarchy-300x149.png 300w, https://www.mattari-benkyo-note.com/wp-content/uploads/2022/05/pathways_hierarchy-768x382.png 768w" sizes="(max-width: 392px) 100vw, 392px" /><figcaption>Pathwayの階層構造</figcaption></figure></div>



<p>[Read more] のリンクに詳しいフォーマットが書かれていますが、タブ区切りで最初のカラムが親のPathwayのST_ID、2つ目が子のST_IDになります。</p>



<h3 class="wp-block-heading">Pathwayの名前と種</h3>



<p>Pathwayの名前と種に関するデータはPathwayの階層構造のデータは[Pathways] → [Complete List of Pathways]から落とすことができます。</p>



<div class="wp-block-image"><figure class="aligncenter size-full is-resized"><img loading="lazy" decoding="async" src="https://www.mattari-benkyo-note.com/wp-content/uploads/2022/05/pathway_name.png" alt="" class="wp-image-1422" width="430" height="214" srcset="https://www.mattari-benkyo-note.com/wp-content/uploads/2022/05/pathway_name.png 829w, https://www.mattari-benkyo-note.com/wp-content/uploads/2022/05/pathway_name-300x149.png 300w, https://www.mattari-benkyo-note.com/wp-content/uploads/2022/05/pathway_name-768x382.png 768w" sizes="auto, (max-width: 430px) 100vw, 430px" /><figcaption>Pathwayの名前と種</figcaption></figure></div>



<p>[Read more] のリンクに詳しいフォーマットが書かれていますが、タブ区切りで最初のカラムがPathwayのReactome Stable identifiers (ST_ID)、2つ目がPathwayの名前、3つ目が種になります。</p>



<h2 class="wp-block-heading">Reactomeから落としたデータをPythonで読み込み</h2>



<p>ここまででほしいデータがどこから落とすことができるか説明しました。ここからは実際に使う際に利用しやすいようにPythonで読み込むパーサーを書いたのでその紹介をします。</p>



<h3 class="wp-block-heading">Pathway毎に関連するGeneの集合のデータのパース</h3>



<p>Reactomeから落としてきたGMTファイルをパースするスクリプトは以下の通りです。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="gmt_parser" data-lang="Python"><code>import pandas as pd


def read_reactome_gmt(file_path):
  data_dict_list = []
  with open(file_path) as f:
      for i, line in enumerate(f):
          values = line.strip().split(&quot;\t&quot;)
          st_id = values[1]
          genes = values[3:]
          for gene in genes:
              data_dict_list.append({&#39;st_id&#39;: st_id, &#39;gene&#39;: gene})
  df = pd.DataFrame(data_dict_list)
  return df</code></pre></div>



<p>以下のようにようにファイルのパスを渡すとpandasのデータフレームでファイルから読み込んだ結果を返すようにしてあります。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="use_gmt_parser" data-lang="Python"><code>pathway_gene_df = read_reactome_gmt(&quot;ReactomePathways.gmt&quot;)
print(pathway_gene_df)</code></pre></div>



<p>出力は以下の通りです。</p>



<pre class="wp-block-preformatted">               st_id   gene
0       R-HSA-164843  HMGA1
1       R-HSA-164843   LIG4
2       R-HSA-164843  PSIP1
3       R-HSA-164843  XRCC4
4       R-HSA-164843  XRCC5
...              ...    ...
121252  R-HSA-192905     NP
121253  R-HSA-192905     NS
121254  R-HSA-192905     PA
121255  R-HSA-192905    PB1
121256  R-HSA-192905    PB2

[121257 rows x 2 columns]</pre>



<h3 class="wp-block-heading">Pathwayの階層構造の読み込み</h3>



<p>Reactomeから落としたPathwayの階層構造のデータを読み込むパーサーは以下の通りです。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="pathway_hierarchy_relationship_parser" data-lang="Python"><code>import pandas as pd


def read_reactome_pathway_hierarchy_relationship(file_path):
  df = pd.read_csv(file_path, sep=&#39;\t&#39;)
  df.columns = [&#39;parent_st_id&#39;, &#39;child_st_id&#39;]
  return df</code></pre></div>



<p>Reactomeから落としたファイルパスを渡すと以下のようなpandasのデータフレームを返します。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="use_pathway_hierarchy_relationship_parser" data-lang="Python"><code>pathway_hierarchy_df = read_reactome_pathway_hierarchy_relationship(&quot;ReactomePathwaysRelation.txt&quot;)
print(pathway_hierarchy_df)</code></pre></div>



<p>出力は以下の通りです。</p>



<pre class="wp-block-preformatted">       parent_st_id    child_st_id
0      R-BTA-109581  R-BTA-5357769
1      R-BTA-109581    R-BTA-75153
2      R-BTA-109582   R-BTA-140877
3      R-BTA-109582   R-BTA-202733
4      R-BTA-109582   R-BTA-418346
...             ...            ...
21521  R-XTR-983705  R-XTR-5690714
21522  R-XTR-983705   R-XTR-983695
21523  R-XTR-983712  R-XTR-2672351
21524  R-XTR-983712   R-XTR-936837
21525  R-XTR-991365   R-XTR-997272

[21526 rows x 2 columns]</pre>



<h3 class="wp-block-heading">Pathwayの名前と種の読み込み</h3>



<p>ReactomeのPathwayの名前と種のパーサーは以下の通りです。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="complete_list_parser" data-lang="Python"><code>import pandas as pd


def read_reactome_complete_list(file_path):
  df = pd.read_csv(file_path, sep=&#39;\t&#39;)
  df.columns = [&#39;st_id&#39;, &#39;pathway_name&#39;, &#39;species&#39;]
  return df</code></pre></div>



<p>こちらも以下のようにしてファイルパスを指定するとpandasのデータフレームを返します。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="use_complete_list_parser" data-lang="Python"><code>complete_list = read_reactome_complete_list(&quot;ReactomePathways.txt&quot;)
print(complete_list)</code></pre></div>



<pre class="wp-block-preformatted">complete_list = read_reactome_complete_list("ReactomePathways.txt")
print(complete_list)
               st_id                                       pathway_name  \
0      R-BTA-1971475  A tetrasaccharide linker sequence is required ...   
1      R-BTA-1369062              ABC transporters in lipid homeostasis   
2       R-BTA-382556             ABC-family proteins mediated transport   
3      R-BTA-9033807                       ABO blood group biosynthesis   
4       R-BTA-418592          ADP signalling through P2Y purinoceptor 1   
...              ...                                                ...   
21417   R-XTR-193639                           p75NTR signals via NF-kB   
21418   R-XTR-111995                               phospho-PLA2 pathway   
21419   R-XTR-191859                                     snRNP Assembly   
21420   R-XTR-379724                                tRNA Aminoacylation   
21421   R-XTR-199992                trans-Golgi Network Vesicle Budding   

                  species  
0              Bos taurus  
1              Bos taurus  
2              Bos taurus  
3              Bos taurus  
4              Bos taurus  
...                   ...  
21417  Xenopus tropicalis  
21418  Xenopus tropicalis  
21419  Xenopus tropicalis  
21420  Xenopus tropicalis  
21421  Xenopus tropicalis  

[21422 rows x 3 columns]</pre>



<h2 class="wp-block-heading">Reactomeのデータを組み合わせて使う</h2>



<p>これらのデータを組み合わせて利用する例として、ヒトのPathwayの中で一番上の親のPathwayとその一つ下のPathway、2つの階層のPathwayの子のリストを取得し、その子の一つのPathwayの名前とGeneのリストを取得するコードを示します。</p>



<p>まずはヒトのPathwayを抽出してPythonのNetworkXを利用して有向グラフを作り、一番上の親を<code>root</code>というノードにつなげたグラフを作ります。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="human_pathway" data-lang="Python"><code>import networkx as nx

human_pathway_ids = complete_list[complete_list[&quot;species&quot;] == &#39;Homo sapiens&#39;][&quot;st_id&quot;]
human_pathway_hierarchy_df = pathway_hierarchy_df[pathway_hierarchy_df[&quot;parent_st_id&quot;].isin(human_pathway_ids) & pathway_hierarchy_df[&quot;child_st_id&quot;].isin(human_pathway_ids)]
human_pathway_graph = nx.from_pandas_edgelist(human_pathway_hierarchy_df, source=&quot;parent_st_id&quot;, target=&quot;child_st_id&quot;, create_using=nx.DiGraph())
root_pathways = [n for n, d in human_pathway_graph.in_degree() if d==0] 
root_edges = [(&quot;root&quot;, n) for n in root_pathways] 
human_pathway_graph.add_edges_from(root_edges)</code></pre></div>



<p>これでヒトのPathwayの階層構造を示した有向グラフができました。これを使えば<code>root</code>からの最小距離が2以下のnodeを列挙することで、一番上の親のPathwayとその一つ下のPathway、2階層のPathwayが取得できます。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-lang="Python"><code>selected_ids = nx.single_source_shortest_path_length(human_pathway_graph, source=&quot;root&quot;, cutoff=2)</code></pre></div>



<p><code>selected_ids</code>をprintするとこのような形になります。</p>



<pre class="wp-block-preformatted">length 1 {'R-HSA-109581': 2,
 'R-HSA-109582': 1,
 'R-HSA-112307': 2,
 'R-HSA-112315': 2,
 'R-HSA-112316': 1,
 'R-HSA-1181150': 2,
 'R-HSA-1187000': 2,
 'R-HSA-1266738': 1,
...</pre>



<p>ここから<code>successors()</code>を使って各Pathwayの子のPathwayを取得します。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="print_children" data-lang="Python"><code>for pathway_id, shortest_path_length in selected_ids.items():
  if shortest_path_length &gt; 0:
    children = list(human_pathway_graph.successors(pathway_id))
    print(&quot;length&quot;, shortest_path_length, &quot;parent&quot;, pathway_id, &quot;children&quot;, children)</code></pre></div>



<p>ちなみに、<code>shortest_path_length=0</code>は<code>root</code>だけなのでskipしています。出力としては以下のようになります。</p>



<pre class="wp-block-preformatted">length 1 parent R-HSA-1852241 children ['R-HSA-1592230', 'R-HSA-5617833']
length 1 parent R-HSA-5357801 children ['R-HSA-109581', 'R-HSA-5218859']
length 1 parent R-HSA-1266738 children ['R-HSA-1181150', 'R-HSA-186712', 'R-HSA-381340', 'R-HSA-452723', 'R-HSA-525793', 'R-HSA-5619507', 'R-HSA-5682910', 'R-HSA-6805567', 'R-HSA-9616222', 'R-HSA-9675108', 'R-HSA-9690406']
length 1 parent R-HSA-4839726 children ['R-HSA-3247509']
length 1 parent R-HSA-9709957 children ['R-HSA-2187338', 'R-HSA-381753', 'R-HSA-9659379', 'R-HSA-9717189']
length 1 parent R-HSA-1474244 children ['R-HSA-1474228', 'R-HSA-1474290', 'R-HSA-1566948', 'R-HSA-1566977', 'R-HSA-216083', 'R-HSA-3000157', 'R-HSA-3000171', 'R-HSA-3000178', 'R-HSA-8941237']
length 1 parent R-HSA-9612973 children ['R-HSA-1632852', 'R-HSA-9613829', 'R-HSA-9615710']
length 1 parent R-HSA-397014 children ['R-HSA-390522', 'R-HSA-445355', 'R-HSA-5576891']
...</pre>



<p>ここで試しにR-HSA-1592230の名前を出してみます。コードとしては以下のようになります。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="print_complete_list" data-lang="Python"><code>print(complete_list[complete_list[&quot;st_id&quot;] == &quot;R-HSA-1592230&quot;])</code></pre></div>



<p>出力は以下の通りです。</p>



<pre class="wp-block-preformatted">               st_id              pathway_name       species
11395  R-HSA-1592230  Mitochondrial biogenesis  Homo sapiens</pre>



<p>また、R-HSA-1592230の関連するgeneは以下のようにして出力できます。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="print_genes" data-lang="Python"><code>print(pathway_gene_df[pathway_gene_df[&quot;st_id&quot;] == &quot;R-HSA-1592230&quot;])</code></pre></div>



<p>出力は以下の通りです。</p>



<pre class="wp-block-preformatted">               st_id    gene
65977  R-HSA-1592230   ALAS1
65978  R-HSA-1592230    APOO
65979  R-HSA-1592230   APOOL
65980  R-HSA-1592230    ATF2
65981  R-HSA-1592230   ATP5B
...              ...     ...
66066  R-HSA-1592230    TFAM
66067  R-HSA-1592230   TFB1M
66068  R-HSA-1592230   TFB2M
66069  R-HSA-1592230    TGS1
66070  R-HSA-1592230  TMEM11

[94 rows x 2 columns]</pre>



<h2 class="wp-block-heading">終わりに</h2>



<p>Reactomeで公開されているデータを読み込んでPathwayの階層構造やPathwayに関連するGeneのデータを読み込む方法を紹介しました。最初はGraph Databaseから頑張って抽出しないといけないのかと思っていましたが、調べてみると簡単なことがわかりました。ただ、何も知らない状態ではどうしていいのかわからないことが多かったので記事にまとめてみました。同じように悩んでいる方の参考になれば幸いです。</p><p>The post <a href="https://www.mattari-benkyo-note.com/2022/05/06/reactome-gene-pathway-data/">ReactomeからPathwayの階層構造とPathwayに関連するGeneのデータを取得する</a> first appeared on <a href="https://www.mattari-benkyo-note.com">まったり勉強ノート</a>.</p>]]></content:encoded>
					
					<wfw:commentRss>https://www.mattari-benkyo-note.com/2022/05/06/reactome-gene-pathway-data/feed/</wfw:commentRss>
			<slash:comments>0</slash:comments>
		
		
		<post-id xmlns="com-wordpress:feed-additions:1">1418</post-id>	</item>
		<item>
		<title>JAXによる微分可能Smith Watermanアルゴリズムのパフォーマンス測定</title>
		<link>https://www.mattari-benkyo-note.com/2021/11/08/jax-smooth-sw/</link>
					<comments>https://www.mattari-benkyo-note.com/2021/11/08/jax-smooth-sw/#respond</comments>
		
		<dc:creator><![CDATA[Shuji Suzuki (shu)]]></dc:creator>
		<pubDate>Sun, 07 Nov 2021 23:07:44 +0000</pubDate>
				<category><![CDATA[プログラミング]]></category>
		<category><![CDATA[Bioinformatics]]></category>
		<category><![CDATA[JAX]]></category>
		<guid isPermaLink="false">https://www.mattari-benkyo-note.com/?p=151</guid>

					<description><![CDATA[<p>最近微分可能な Smith Waterman アルゴリズムというものとJAXのコードが公開されました。今回はこれらを参考に、JAXの勉強がてら何パターンかSmith Watermanアルゴリズムを実装して測定してみたので [&#8230;]</p>
<p>The post <a href="https://www.mattari-benkyo-note.com/2021/11/08/jax-smooth-sw/">JAXによる微分可能Smith Watermanアルゴリズムのパフォーマンス測定</a> first appeared on <a href="https://www.mattari-benkyo-note.com">まったり勉強ノート</a>.</p>]]></description>
										<content:encoded><![CDATA[<p>最近微分可能な Smith Waterman アルゴリズムというものとJAXのコードが公開されました。今回はこれらを参考に、JAXの勉強がてら何パターンかSmith Watermanアルゴリズムを実装して測定してみたので、その結果のまとめの紹介となります。</p>



<p>論文は以下のものです。</p>



<p>[1] <strong>Petti, S., Bhattacharya, N., Rao, R., Dauparas, J., Thomas, N., Zhou, J., … Ovchinnikov, S. (2021). End-to-end learning of multiple sequence alignments with differentiable Smith-Waterman. BioRxiv, 2021.10.23.465204. https://doi.org/10.1101/2021.10.23.465204</strong></p>



<p>また、著者の実装はこちらに公開されています。</p>



<p><a href="https://github.com/spetti/SMURF">https://github.com/spetti/SMURF</a></p>



<p>今回は主に私がJAXの勉強をしたかったということもあり、いくつか実装を作ってパフォーマンスを測定して、「JAXって速いの？」という疑問にある程度答えられればと思い、記事を書いています。今回の実装はすべてこちらにありますので参考にしてみてください。</p>



<p><a href="https://github.com/shu65/blog-jax-notebook/blob/main/JAX_Smooth_Smith_Waterman.ipynb">https://github.com/shu65/blog-jax-notebook/blob/main/JAX_Smooth_Smith_Waterman.ipynb</a></p>



<p>また、計算時間測定はすべてGoogle Cloab上のCPUで行っています。</p>



<h2 class="wp-block-heading">論文概要</h2>



<p>この論文では教師なし学習によるコンタクト予測において、前処理で使われるSmith Watermanアルゴリズムを微分可能なものに置き換えて、Smith Watermanアルゴリズムの中で使われるパラメータ（置換スコア）も含めて学習する手法 SMURFを提案した論文です。論文自体にはコンタクト予測の精度なども書かれていますが、微分可能なSmith Watermanの紹介をメインにしたいため、今回は割愛します。</p>



<h2 class="wp-block-heading"> 微分可能な Smith Waterman アルゴリズム「Smooth Smith Waterman」とは？ </h2>



<p>Smith Watermanアルゴリズムを微分可能にするためには、微分可能ではない関数を微分可能なものに置き換えて、近似することで実現します。まずは大本のSmith Watermanアルゴリズムの説明をしたあと、微分可能なものに変更する方法を紹介していきます。</p>



<h3 class="wp-block-heading"> Smith Watermanアルゴリズムとは</h3>



<p>Smith Watermanアルゴリズム は2つのDNAやタンパク質の配列の類似度、特にローカルアライメントのスコアと呼ばれる類似度を計算するアルゴリズムです。ローカルアラインメントとは2配列間の類似度の高い部分的な文字列を発見するときに使われます。これは以下のように行列の要素を計算する動的計画法 (Dynamic Programming, DP) により計算します。</p>



<p>$$ H_{i0} = H_{0j} = 0 \\ H_{ij} = \max\begin{cases} H_{i-1,j-1} + s(a_i,b_j), \\ H_{i-k,j} + g, \\ H_{i,j-l} +g, \\<br>0 \\<br>\end{cases}    $$</p>



<p>\(  s(a_i,b_j)  \) は 配列Aのi番目の文字と配列Bのj番目の文字の置換スコアと呼ばれるもので、同じ文字、もしくは類似度の高い文字のペアはプラス、類似度の低い文字のペアはマイナスにするのが一般的です。また、 \(  g \) はギャップペナルティと呼ばれるもので、1文字飛ばしのペナルティを表します。 </p>



<h3 class="wp-block-heading">Smith Watermanアルゴリズムを微分可能にする</h3>



<p>先ほど説明したとおり、Smith Watermanアルゴリズムではmax関数があります。この部分が微分可能ではないため、SmithWatermanアルゴリズムは微分可能ではありません。このため、このmax関数を微分可能な何等かの関数で置き換える必要があります。この論文ではmax関数を「logsumexp」で置き換えることで微分可能にします。</p>



<p> logsumexpはmax関数を滑らかに近似するための関数として使われる関数で、微分可能な関数です。このためmax関数を logsumexp に置き換えればSmith Watermanアルゴリズムの計算全体が微分可能になります。論文中ではこの微分可能なSmtth Watermanアルゴリズムを「Smooth Smith Waterman」と呼んでいます。<br>なぜlogsumexpがmax関数の近似になるかを詳しく知りたい方は、こちらのブログ記事がわかりやすかったのでお勧めです。</p>



<figure class="wp-block-embed is-type-wp-embed is-provider-hatena-blog wp-block-embed-hatena-blog"><div class="wp-block-embed__wrapper">
<iframe class="wp-embedded-content" sandbox="allow-scripts" security="restricted" title="Smooth maximumを作って遊ぼう - Corollaryは必然に。" src="https://hatenablog-parts.com/embed?url=https%3A%2F%2Fcorollary2525.hatenablog.com%2Fentry%2F2020%2F07%2F25%2F221823#?secret=g4M2CLQP1b" data-secret="g4M2CLQP1b" scrolling="no" frameborder="0"></iframe>
</div></figure>



<h3 class="wp-block-heading">numpyによるシンプルな Smooth Smith Waterman </h3>



<p>後ほどJAXの実装を示しますが、高速化したあとのJAXのコードは初見では分かりづらいため、先にシンプルなnumpyの実装を示します。この実装は著者の実装にあわせつつ、numpyとのパフォーマンス実装をするために以下のようにしています。</p>



<ol class="wp-block-list"><li>配列Aと配列Bの全文字ペアの置換スコアの行列 <code>score_matrix</code>（置換スコアの行列のサイズは|A|×|B|）と2つの配列の長さ<code>lengths</code>、その他のパラメータを入力とする</li><li>この記事では勾配を計算できないnumpyとの比較のために、著者実装では<code>score_matrix</code>の勾配を返すのに対して、今回の記事では2配列の最大スコアを返す。</li></ol>



<p>Smith Watermanアルゴリズムをご存じの方は戸惑うかもしれませんが、Smooth Smith Watermanアルゴリズムでは<code>score_matrix</code>の勾配を出力として返す関数になっています。このため、あらかじめ配列Aと配列Bの全文字ペアの置換スコアの行列を用意して入力にします。<br>このため、配列Aと配列Bは入力に出てきませんし、PAMやBLOSUMなどの置換スコアもでてきません。</p>



<p>このSmooth Smith Watermanアルゴリズムをシンプルにnumpyで実装すると以下の通りになります。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="smooth-sw-np" data-lang="Python"><code>def sw_np(NINF=-1e30):
    
    def _logsumexp(y, axis):
        y = np.maximum(y,NINF)
        return y.max(axis) + np.log(np.sum(np.exp(y - y.max(axis, keepdims=True)), axis=axis))

    def _soft_maximum(x, temp, axis=None):
        return temp*_logsumexp(x/temp, axis)

    def _sw(score_matrix, lengths, gap=0, temp=1.0):
        real_a, real_b = lengths
        hij = np.full((real_a + 1, real_b + 1), fill_value=NINF, dtype=np.float32)
        for i in range(real_a):
            for j in range(real_b):
                s = score_matrix[i, j]
                m = hij[i, j] + s
                g0 = hij[i + 1, j] + gap
                g1 = hij[i, j + 1] + gap

                h = np.stack([m, g0, g1, s], -1)
                hij[i + 1, j + 1] = _soft_maximum(h, temp=temp, axis=-1)
        hij = hij[1:, 1:]
        score = _soft_maximum(hij, temp=temp)
        return score
    return _sw</code></pre></div>



<p>こちらの実装で通常のSmith Watermanアルゴリズムと違う点は以下の2点です</p>



<ol class="wp-block-list"><li>DPの行列の要素更新のところでmax関数をlogsumexpで実装した<code>_soft_maximum()</code>という関数に置き換えている。</li><li>DPの行列の各要素を入れるところで最大値を取るところで0以下にならないようにmax関数の入力の一つとして0を入れるところを、置換スコア(<code>s</code>)を入れている。</li></ol>



<p>1が微分可能とするための改良した箇所です。一方、2に関しては私が読み飛ばしてしまった可能性がありますが、特に論文中に説明が見当たらなかった変更点です。なんとなくSmooth Smith Watermanアルゴリズムを使って深層学習のモデル更新をするときにうまく勾配が置換スコアに流れるようにするためでは？と思っているのですが、未確認な状態です。何かご存じの方がいれば教えていただければと思っています。</p>



<p>この実装を実行したときの計算時間を<code>%time</code>で測定すると以下の通りです。</p>



<pre class="wp-block-preformatted">CPU times: user 735 ms, sys: 4 ms, total: 739 ms
Wall time: 743 ms
</pre>



<h2 class="wp-block-heading"> JAXを使ったSmooth Smith アルゴリズム</h2>



<p>ここからnumpyの部分をJAXに置き換えてSmooth Smith Watermanアルゴリズムを実装し、徐々に改良していくという順番で説明していきます。まずはJAXをご存じない方のためにJAXを簡単に説明します。</p>



<h3 class="wp-block-heading">JAXってなに？</h3>



<p>JAXはPythonやnumpyの関数を微分可能なものにし、XLAというコンパイラを使ってGPUやTPUで実行で実行できるようにしたライブラリです。<br>JAXでは勾配が計算できることと、jitをはじめとした様々な高速化する仕組みが用意されているため、最近論文で利用しているケースが増えてきた印象です。特に今回紹介した論文のような、従来では微分可能でなかった計算を微分可能なものに置き換え、深層学習のモデル学習の中で利用するという手法の実装にJAXが使われるのをよく目にします。今回紹介したものの他には BRAXがあります。</p>



<p><a href="https://github.com/google/brax">https://github.com/google/brax</a></p>



<h3 class="wp-block-heading">単純なJAX実装</h3>



<p>JAXはnumpyの関数と同じAPIの関数があるので、まずはそれをそのまま利用してみます。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="smooth-sw-jax-v0" data-lang="Python"><code>def sw_v0(NINF=-1e30):
    
    def _logsumexp(y, axis):
        y = jnp.maximum(y,NINF)
        return jax.nn.logsumexp(y, axis=axis)

    def _soft_maximum(x, temp, axis=None):
        return temp*_logsumexp(x/temp, axis)

    def _sw(score_matrix, lengths, gap=0, temp=1.0):
        real_a, real_b = lengths
        hij = jnp.full((real_a + 1, real_b + 1), fill_value=NINF, dtype=jnp.float32)
        for i in range(real_a):
            for j in range(real_b):
                s = score_matrix[i, j]
                m = hij[i, j] + s
                g0 = hij[i + 1, j] + gap
                g1 = hij[i, j + 1] + gap
                h = jnp.stack([m, g0, g1, s], -1)
                hij = hij.at[i + 1, j + 1].set(_soft_maximum(h, -1))
        hij = hij[1:, 1:]
        score = _soft_maximum(hij)
        return score
    return _sw</code></pre></div>



<p>これも動くには動くのですが、あまりにも遅いため、まったく使い物になりません。このためJAXを使う際はもう少し真面目に高速に動くアルゴリズムで実装する必要があります。</p>



<h3 class="wp-block-heading">Striped Smith-Watermanベースの実装</h3>



<p>論文でも紹介されているStriped Smith-Watermanベースで実装してみます。 Smith-Waterman アルゴリズムをSIMDなどで並列化する方法として、依存関係のないDP行列の斜めのセルを同時に埋めていくという方法がしばしば取られます。詳しくはこちらをご覧ください。</p>



<p><strong>Farrar, M. (2007). Striped Smith-Waterman speeds database searches six times over other SIMD implementations. Bioinformatics (Oxford, England), 23(2), 156–161. https://doi.org/10.1093/bioinformatics/btl582</strong></p>



<p>これをJAXで実装するにあたり、著者はDP行列を回転させ、依存関係のない斜めに並んだセルを横1列に並べて計算するようにしています。</p>



<figure class="wp-block-image size-large"><img loading="lazy" decoding="async" width="1024" height="440" src="https://www.mattari-benkyo-note.com/wp-content/uploads/2021/11/striped_sw-1024x440.png" alt="" class="wp-image-175" srcset="https://www.mattari-benkyo-note.com/wp-content/uploads/2021/11/striped_sw-1024x440.png 1024w, https://www.mattari-benkyo-note.com/wp-content/uploads/2021/11/striped_sw-300x129.png 300w, https://www.mattari-benkyo-note.com/wp-content/uploads/2021/11/striped_sw-768x330.png 768w, https://www.mattari-benkyo-note.com/wp-content/uploads/2021/11/striped_sw-1536x660.png 1536w, https://www.mattari-benkyo-note.com/wp-content/uploads/2021/11/striped_sw-2048x880.png 2048w" sizes="auto, (max-width: 1024px) 100vw, 1024px" /><figcaption>DP行列の回転 ([1] Fig. 7)</figcaption></figure>



<p>こうすることで内側のforループをJAXのベクトルの計算で実行できるようにしています。個人的にはここがこの論文の最大の貢献な気がしています。具体的にJAXで実装すると以下の通りです。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="smooth-sw-jax-v1" data-lang="Python"><code>def sw_v1(unroll=2, NINF=-1e30):
        
    def _rotate(score_matrix):
        a,b = score_matrix.shape
        n,m = (a+b-1),(a+b)//2
        ar,br = jnp.arange(a)[::-1,None], jnp.arange(b)[None,:]
        i,j = (br-ar)+(a-1),(ar+br)//2
        rotated_score_matrix = jnp.full([n,m],NINF).at[i,j].set(score_matrix)
        reverse_idx = (i, j)
        return rotated_score_matrix, reverse_idx
    
    def _rotate_in_reverse(rotated_dp_matrix, reverse_idx):
        return rotated_dp_matrix[reverse_idx]

    def _logsumexp(y, axis):
        y = jnp.maximum(y,NINF)
        return jax.nn.logsumexp(y, axis=axis)

    def _soft_maximum(x, temp, axis=None):
        return temp*_logsumexp(x/temp, axis)

    def _step(prev, gap_cell_condition, rotated_score_matrix, gap, temp):
        h2,h1 = prev   # previous two rows of scoring (hij) mtx
        h1_T = jax.lax.cond(
            gap_cell_condition,
            lambda x: jnp.pad(x[:-1], [1,0], constant_values=(NINF,NINF)),
            lambda x: jnp.pad(x[1:], [0,1], constant_values=(NINF,NINF)),
            h1,
        )

        a = h2 + rotated_score_matrix
        g0 = h1 + gap
        g1 = h1_T + gap
        s = rotated_score_matrix

        h0 = jnp.stack([a, g0, g1, s], -1)
        h0 = _soft_maximum(h0, temp, -1)
        return (h1,h0), h0

    def _sw(score_matrix, lengths, gap=0, temp=1.0):
        rotated_score_matrix, reverse_idx = _rotate(score_matrix)
        
        a,b = score_matrix.shape
        n,m = rotated_score_matrix.shape
        
        gap_cell_condition = (jnp.arange(n)+a%2)%2
        prev = (jnp.full(m, NINF), jnp.full(m, NINF))
        rotated_hij = []
        for i in range(n):
            prev, h = _step(prev, gap_cell_condition[i], rotated_score_matrix[i], gap, temp)
            rotated_hij.append(h)
        rotated_hij = jnp.stack(rotated_hij)
        hij = _rotate_in_reverse(rotated_hij, reverse_idx)
        score = _soft_maximum(hij, temp=temp)
        return score
    return _sw</code></pre></div>



<p>この実装では置換行列<code>score_matrix</code>を回転させて、DP行列のセルを埋めていき、そのあとDP行列元の方向に戻すということをしています。<br>回転させたときの注意点として、DP行列の列番号が偶数か奇数かでギャップペナルティのスコアを加算するセルの相対座標が変わります。このため、<code>jax.lax.cond()</code>を利用して使うセルを分岐しています。</p>



<p>この実装をそのまま実行したときとjitを利用したときの計算時間は以下の通りです。</p>



<pre class="wp-block-preformatted">jax default first call
CPU times: user 17.7 s, sys: 177 ms, total: 17.8 s
Wall time: 17.8 s
jax default second call
CPU times: user 17.6 s, sys: 153 ms, total: 17.7 s
Wall time: 17.7 s


jax jit first call
CPU times: user 2min 20s, sys: 715 ms, total: 2min 21s
Wall time: 2min 20s
jax jit second call
CPU times: user 1.98 ms, sys: 0 ns, total: 1.98 ms
Wall time: 1.81 ms


</pre>



<p>jitなしでそのまま実行するのはnumpyよりもかなり遅い印象です。またjitを使う場合も最初の呼び出しはコンパイルが走ることもあり、jitなしに比べるとさらに遅くなっています。さすがに1回目とはいえ、ここまで時間がかかると使いづらいと思われます。このため、まだ工夫する必要があります。</p>



<h3 class="wp-block-heading">外側のforループをjax.lax.scan()に置き換える</h3>



<p>1つ前の実装で遅い原因がどこか？というとforループです。これを速くする方法としてJAXのforループと類似する処理を実行するための関数を利用します。今回はforループ部分を <code>jax.lax.scan()</code> に置き換えます。</p>



<p>実装は以下の通りです。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="smooth-sw-jax-v2" data-lang="Python"><code>def sw_v2(unroll=2, NINF=-1e30):

    def _rotate(score_matrix):
        a,b = score_matrix.shape
        n,m = (a+b-1),(a+b)//2
        ar,br = jnp.arange(a)[::-1,None], jnp.arange(b)[None,:]
        i,j = (br-ar)+(a-1),(ar+br)//2
        rotated_score_matrix = jnp.full([n,m],NINF).at[i,j].set(score_matrix)
        reverse_idx = (i, j)
        return rotated_score_matrix, reverse_idx

    def _prepare_scan_inputs(score_matrix, rotated_score_matrix, gap, temp):
        def scan_f(prev, scan_xs):
            h2, h1 = prev
            h1_T = jax.lax.cond(
                scan_xs[&quot;gap_cell_condition&quot;],
                lambda x: jnp.pad(x[:-1], [1,0], constant_values=(NINF,NINF)),
                lambda x: jnp.pad(x[1:], [0,1], constant_values=(NINF,NINF)),
                h1,
            )
            a = h2 + scan_xs[&quot;rotated_score_matrix&quot;]
            g0 = h1 + gap
            g1 = h1_T + gap
            s = scan_xs[&quot;rotated_score_matrix&quot;]

            h0 = jnp.stack([a, g0, g1, s], -1)
            h0 = _soft_maximum(h0, temp, -1)
            return (h1,h0), h0
        
        a,b = score_matrix.shape
        n,m = rotated_score_matrix.shape

        scan_xs = {
            &quot;rotated_score_matrix&quot;: rotated_score_matrix,
            &quot;gap_cell_condition&quot;: (jnp.arange(n)+a%2)%2
        }
        scan_init = (jnp.full(m, NINF), jnp.full(m, NINF))
        return scan_f, scan_xs, scan_init

    def _rotate_in_reverse(rotated_dp_matrix, reverse_idx):
        return rotated_dp_matrix[reverse_idx]

    def _logsumexp(y, axis):
        y = jnp.maximum(y,NINF)
        return jax.nn.logsumexp(y, axis=axis)

    def _soft_maximum(x, temp, axis=None):
        return temp*_logsumexp(x/temp, axis)
    
    def _sw(score_matrix, lengths, gap=0, temp=1.0):
        rotated_score_matrix, reverse_idx = _rotate(score_matrix)
        scan_f, scan_xs, scan_init = _prepare_scan_inputs(score_matrix, rotated_score_matrix, gap, temp)
        rotated_hij = jax.lax.scan(scan_f, scan_init, scan_xs, unroll=unroll)[-1]
        hij = _rotate_in_reverse(rotated_hij, reverse_idx)
        score = _soft_maximum(hij, temp, axis=None)
        return score
    return _sw</code></pre></div>



<p>この実装でforループがなくなりました。実行した結果は以下の通りです。</p>



<pre class="wp-block-preformatted">jax default first call
CPU times: user 739 ms, sys: 18 ms, total: 757 ms
Wall time: 758 ms
jax default second call
CPU times: user 666 ms, sys: 1.98 ms, total: 668 ms
Wall time: 671 ms

jax jit first call
CPU times: user 1 s, sys: 5.01 ms, total: 1.01 s
Wall time: 1.01 s
jax jit second call
CPU times: user 339 µs, sys: 989 µs, total: 1.33 ms
Wall time: 1.14 ms
</pre>



<p>先ほどに比べるとjitなしでも速くなりましたが、jitありの1回目の実行もかなり速くなった印象です。これなら十分使えるのではないか？と思っています。</p>



<h3 class="wp-block-heading">jax.lax.condの置き換え</h3>



<p>著者の実装では<code> jax.lax.cond()</code>を使わずに加算と乗算だけで実装されています。試しに同様の実装にしたバージョンも示します。具体的な実装は以下の通りです。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="smooth-sw-jax-v3" data-lang="Python"><code>def sw_v3(unroll=2, NINF=-1e30):

    def _rotate(score_matrix):
        a,b = score_matrix.shape
        n,m = (a+b-1),(a+b)//2
        ar,br = jnp.arange(a)[::-1,None], jnp.arange(b)[None,:]
        i,j = (br-ar)+(a-1),(ar+br)//2
        rotated_score_matrix = jnp.full([n,m],NINF).at[i,j].set(score_matrix)
        reverse_idx = (i, j)
        return rotated_score_matrix, reverse_idx

    def _prepare_scan_inputs(score_matrix, rotated_score_matrix, gap, temp):
        def scan_f(prev, scan_xs):
            h2, h1 = prev
            h1_T = _get_prev_gap_cell_score(
                scan_xs[&quot;gap_cell_condition&quot;],
                jnp.pad(h1[:-1], [1,0], constant_values=(NINF,NINF)),
                jnp.pad(h1[1:], [0,1], constant_values=(NINF,NINF)),
            )
            a = h2 + scan_xs[&quot;rotated_score_matrix&quot;]
            g0 = h1 + gap
            g1 = h1_T + gap
            s = scan_xs[&quot;rotated_score_matrix&quot;]

            h0 = jnp.stack([a, g0, g1, s], -1)
            h0 = _soft_maximum(h0, temp, -1)
            return (h1,h0), h0
        
        a,b = score_matrix.shape
        n,m = rotated_score_matrix.shape

        scan_xs = {
            &quot;rotated_score_matrix&quot;: rotated_score_matrix,
            &quot;gap_cell_condition&quot;: (jnp.arange(n)+a%2)%2
        }
        scan_init = (jnp.full(m, NINF), jnp.full(m, NINF))
        return scan_f, scan_xs, scan_init

    def _rotate_in_reverse(rotated_dp_matrix, reverse_idx):
        return rotated_dp_matrix[reverse_idx]

    def _logsumexp(y, axis):
        y = jnp.maximum(y,NINF)
        return jax.nn.logsumexp(y, axis=axis)

    def _soft_maximum(x, temp, axis=None):
        return temp*_logsumexp(x/temp, axis)

    def _get_prev_gap_cell_score(cond, true, false): 
        return cond*true + (1-cond)*false
    
    def _sw(score_matrix, lengths, gap=0, temp=1.0):
        rotated_score_matrix, reverse_idx = _rotate(score_matrix)
        scan_f, scan_xs, scan_init = _prepare_scan_inputs(score_matrix, rotated_score_matrix, gap, temp)
        rotated_hij = jax.lax.scan(scan_f, scan_init, scan_xs, unroll=unroll)[-1]
        hij = _rotate_in_reverse(rotated_hij, reverse_idx)
        score = _soft_maximum(hij, temp, axis=None)
        return score
    return _sw</code></pre></div>



<p>この時のパフォーマンスは以下の通りです。</p>



<pre class="wp-block-preformatted">jax defaujax default first call
CPU times: user 599 ms, sys: 1.99 ms, total: 601 ms
Wall time: 608 ms
jax default second call
CPU times: user 599 ms, sys: 3.02 ms, total: 602 ms
Wall time: 607 ms

jax jit first call
CPU times: user 940 ms, sys: 2.01 ms, total: 942 ms
Wall time: 947 ms
jax jit second call
CPU times: user 4.9 ms, sys: 0 ns, total: 4.9 ms
Wall time: 3.41 ms</pre>



<p>jitなしの時は速くなっている印象ですが、jitありのときは少し遅くなっています。ただ、何度か実行してみると逆転することもあるようなので、誤差の範囲かもしれません。また、JAX特有のパフォーマンス測定のお作法をし忘れている可能性もあります。もしご存じの方があればコメントいただければと思います。</p>



<h3 class="wp-block-heading">Batch実行用の実装</h3>



<p>著者のSmooth Smith Watermanは2つの配列のペアを1つだけ実行するのではなく、複数のペアをまとめて実行することを想定されて実装してあります。ここでも同様に複数のペアをまとめて計算するのもやってみようと思います。</p>



<h4 class="wp-block-heading">簡単な実装</h4>



<p>複数のペアをまとめて実装する際、ペア毎に配列の長さが違っても動作するようにします。このため、置換スコアのうち必要な部分だけmaskするようにします。</p>



<p>実装は以下の通りです。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="smooth-sw-jax-v4" data-lang="Python"><code>def sw_v4(unroll=2, NINF=-1e30):
    
    def _make_mask(score_matrix, lengths):
        a,b = score_matrix.shape
        real_a, real_b = lengths
        mask = (jnp.arange(a) &lt; real_a)[:,None] * (jnp.arange(b) &lt; real_b)[None,:]
        return mask

    def _rotate(score_matrix):
        a,b = score_matrix.shape
        n,m = (a+b-1),(a+b)//2
        ar,br = jnp.arange(a)[::-1,None], jnp.arange(b)[None,:]
        i,j = (br-ar)+(a-1),(ar+br)//2
        rotated_score_matrix = jnp.full([n,m],NINF).at[i,j].set(score_matrix)
        reverse_idx = (i, j)
        return rotated_score_matrix, reverse_idx

    def _prepare_scan_inputs(score_matrix, rotated_score_matrix, gap, temp):
        def scan_f(prev, scan_xs):
            h2, h1 = prev
            h1_T = _get_prev_gap_cell_score(
                scan_xs[&quot;gap_cell_condition&quot;],
                jnp.pad(h1[:-1], [1,0], constant_values=(NINF,NINF)),
                jnp.pad(h1[1:], [0,1], constant_values=(NINF,NINF)),
            )
            a = h2 + scan_xs[&quot;rotated_score_matrix&quot;]
            g0 = h1 + gap
            g1 = h1_T + gap
            s = scan_xs[&quot;rotated_score_matrix&quot;]

            h0 = jnp.stack([a, g0, g1, s], -1)
            h0 = _soft_maximum(h0, temp, -1)
            return (h1,h0), h0
        
        a,b = score_matrix.shape
        n,m = rotated_score_matrix.shape

        scan_xs = {
            &quot;rotated_score_matrix&quot;: rotated_score_matrix,
            &quot;gap_cell_condition&quot;: (jnp.arange(n)+a%2)%2
        }
        scan_init = (jnp.full(m, NINF), jnp.full(m, NINF))
        return scan_f, scan_xs, scan_init

    def _rotate_in_reverse(rotated_dp_matrix, reverse_idx):
        return rotated_dp_matrix[reverse_idx]

    def _logsumexp(y, axis):
        y = jnp.maximum(y,NINF)
        return jax.nn.logsumexp(y, axis=axis)

    def _logsumexp_with_mask(y, axis, mask):
        y = jnp.maximum(y,NINF)
        return y.max(axis) + jnp.log(jnp.sum(mask * jnp.exp(y - y.max(axis, keepdims=True)), axis=axis))

    def _soft_maximum(x, temp, axis=None):
        return temp*_logsumexp(x/temp, axis)

    def _soft_maximum_with_mask(x, temp, mask, axis=None):
        return temp*_logsumexp_with_mask(x/temp, axis, mask)

    def _get_prev_gap_cell_score(cond, true, false): 
        return cond*true + (1-cond)*false
    
    def _sw(score_matrix, lengths, gap=0, temp=1.0):
        mask = _make_mask(score_matrix, lengths)
        masked_score_matrix = score_matrix + NINF * (1 - mask)
        rotated_score_matrix, reverse_idx = _rotate(masked_score_matrix)
        scan_f, scan_xs, scan_init = _prepare_scan_inputs(score_matrix, rotated_score_matrix, gap, temp)
        rotated_hij = jax.lax.scan(scan_f, scan_init, scan_xs, unroll=unroll)[-1]
        hij = _rotate_in_reverse(rotated_hij, reverse_idx)
        score = _soft_maximum_with_mask(hij, temp, mask=mask, axis=None)
        return score
    return _sw</code></pre></div>



<p>この実装をペアの数分、forループで計算していくようにします。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="batch-sw-v0" data-lang="Python"><code>def batch_sw_v0(NINF=-1e30):
    def _batch_sw(batch_score_matrix, batch_lengths, gap=0, temp=1.0):
        n_batches = batch_score_matrix.shape[0]
        sw_func = jax.jit(sw_v4())
        ret = [sw_func(batch_score_matrix[i], batch_lengths[i], gap, temp) 
               for i in range(n_batches)]
        return jnp.array(ret)
    return _batch_sw</code></pre></div>



<p>これを実行すると計算時間は以下の通りでした。</p>



<pre class="wp-block-preformatted">batch jax default first call
CPU times: user 1.31 s, sys: 13 ms, total: 1.33 s
Wall time: 1.32 s
batch jax default second call
CPU times: user 1.3 s, sys: 5.02 ms, total: 1.3 s
Wall time: 1.3 s

batch jax default first call
CPU times: user 10min 43s, sys: 2.99 s, total: 10min 46s
Wall time: 10min 45s
batch jax default second call
CPU times: user 279 ms, sys: 2 ms, total: 281 ms
Wall time: 281 ms</pre>



<p>forループでそのまま実装すると、jitありのときはやはり1度目の実行に非常に時間がかかるようです。このため、この部分を速くします。</p>



<h4 class="wp-block-heading">forループをjax.vmap()で置き換える</h4>



<p>ここではforループを<code> jax.vmap()</code> で置き換えます。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="batch-sw-v1" data-lang="Python"><code>def batch_sw_v1(unroll=2, NINF=-1e30):
    sw_func = sw_v4(unroll=unroll, NINF=NINF)
    batch_sw_func = jax.vmap(sw_func, (0, 0, None, None))
    return batch_sw_func</code></pre></div>



<p>この時の計算時間は以下の通りです。</p>



<pre class="wp-block-preformatted">batch jax default first call
CPU times: user 1.04 s, sys: 11 ms, total: 1.05 s
Wall time: 1.03 s
batch jax default second call
CPU times: user 1.04 s, sys: 7.97 ms, total: 1.04 s
Wall time: 1.01 s

batch jax default first call
CPU times: user 1.51 s, sys: 10 ms, total: 1.52 s
Wall time: 1.5 s
batch jax default second call
CPU times: user 120 ms, sys: 9 µs, total: 120 ms
Wall time: 97 ms</pre>



<p>先ほどと比べるとかなり高速化できました。ちなみにこれがほぼ著者の実装と同じものになります。</p>



<h2 class="wp-block-heading">結果まとめ</h2>



<p>ここまでの計算時間の結果をまとめると以下の通りです。</p>



<figure class="wp-block-table"><table><tbody><tr><td></td><td>jitなし1回目</td><td>jitなし2回目</td><td>jitあり1回目</td><td>jitあり2回目</td></tr><tr><td>numpy</td><td>739 ms</td><td>&#8211;</td><td>&#8211;</td><td>&#8211;</td></tr><tr><td>Striped Smith-Watermanベースの実装</td><td>17.8 s</td><td>17.7 s</td><td>2min 21s</td><td>1.98 ms</td></tr><tr><td>外側のforループをjax.lax.scan()に置き換える</td><td>757 ms</td><td>668 ms</td><td>1.01 s</td><td>1.33 ms</td></tr><tr><td>jax.lax.condの置き換え</td><td>601 ms</td><td>602 ms</td><td>942 ms</td><td>4.9 ms</td></tr></tbody></table><figcaption>Smooth Smith Watermanの計算時間まとめ</figcaption></figure>



<figure class="wp-block-table"><table><tbody><tr><td></td><td> jitなし1回目 </td><td> jitなし2回目 </td><td> jitあり1回目 </td><td>jitあり2回目 </td></tr><tr><td>簡単な実装</td><td>1.33 s</td><td>1.3 s</td><td>10min 46s</td><td>281 ms</td></tr><tr><td>forループをjax.vmap()で置き換える</td><td>1.05 s</td><td>1.04 s</td><td>1.52 s</td><td>120 ms</td></tr></tbody></table><figcaption> Batch Smooth Smith Watermanの計算時間まとめ </figcaption></figure>



<p>各実装を比較するとforループのありなしでかなり実行時間やコンパイル時間が変化していることがわかります。このためnumpyの実装をそのままJAXにすればそれだけで速くなることはまずなさそうです。また、何も考えずに実装してjitを使うと、コンパイル時間が長すぎて使い物にならないというケースが多そうな気がしています。このため、JAXを使いこなすにはどのような計算は遅いかを理解して使うことが重要そうな印象です。</p>



<h2 class="wp-block-heading">おわりに</h2>



<p>今回、初のJAX使用だったため、パフォーマンス測定や高速化にはもっとやり方があるかもしれないと思っています。もしお気づきの点がありましたら気兼ねなくコメントいただければと思っています。</p>



<p>次はできればJAXとPyTorchのjitとどちらが速いのか試せればと思っています。</p><p>The post <a href="https://www.mattari-benkyo-note.com/2021/11/08/jax-smooth-sw/">JAXによる微分可能Smith Watermanアルゴリズムのパフォーマンス測定</a> first appeared on <a href="https://www.mattari-benkyo-note.com">まったり勉強ノート</a>.</p>]]></content:encoded>
					
					<wfw:commentRss>https://www.mattari-benkyo-note.com/2021/11/08/jax-smooth-sw/feed/</wfw:commentRss>
			<slash:comments>0</slash:comments>
		
		
		<post-id xmlns="com-wordpress:feed-additions:1">151</post-id>	</item>
	</channel>
</rss>
