<?xml version="1.0" encoding="UTF-8"?><rss version="2.0"
	xmlns:content="http://purl.org/rss/1.0/modules/content/"
	xmlns:wfw="http://wellformedweb.org/CommentAPI/"
	xmlns:dc="http://purl.org/dc/elements/1.1/"
	xmlns:atom="http://www.w3.org/2005/Atom"
	xmlns:sy="http://purl.org/rss/1.0/modules/syndication/"
	xmlns:slash="http://purl.org/rss/1.0/modules/slash/"
	>

<channel>
	<title>pytorch - まったり勉強ノート</title>
	<atom:link href="https://www.mattari-benkyo-note.com/tag/pytorch/feed/" rel="self" type="application/rss+xml" />
	<link>https://www.mattari-benkyo-note.com</link>
	<description>shuの日々の勉強まとめ</description>
	<lastBuildDate>Wed, 12 Feb 2025 23:22:17 +0000</lastBuildDate>
	<language>ja</language>
	<sy:updatePeriod>
	hourly	</sy:updatePeriod>
	<sy:updateFrequency>
	1	</sy:updateFrequency>
	<generator>https://wordpress.org/?v=6.8.3</generator>
<site xmlns="com-wordpress:feed-additions:1">189243286</site>	<item>
		<title>小型LLM PLaMo 2 1BをGoogle ColabでSFTしてみる</title>
		<link>https://www.mattari-benkyo-note.com/2025/02/13/plamo-2-1b-sft/</link>
					<comments>https://www.mattari-benkyo-note.com/2025/02/13/plamo-2-1b-sft/#respond</comments>
		
		<dc:creator><![CDATA[Shuji Suzuki (shu)]]></dc:creator>
		<pubDate>Wed, 12 Feb 2025 23:30:00 +0000</pubDate>
				<category><![CDATA[プログラミング]]></category>
		<category><![CDATA[CUDA]]></category>
		<category><![CDATA[llm]]></category>
		<category><![CDATA[python]]></category>
		<category><![CDATA[pytorch]]></category>
		<category><![CDATA[機械学習]]></category>
		<guid isPermaLink="false">https://www.mattari-benkyo-note.com/?p=3210</guid>

					<description><![CDATA[<p>今回はPreferred Networksとその子会社のPreferred Elementsが共同で開発した1Bサイズの小型のLLM、PLaMo 2 1Bに対してSFTをするコードの紹介になります。 Google Col [&#8230;]</p>
<p>The post <a href="https://www.mattari-benkyo-note.com/2025/02/13/plamo-2-1b-sft/">小型LLM PLaMo 2 1BをGoogle ColabでSFTしてみる</a> first appeared on <a href="https://www.mattari-benkyo-note.com">まったり勉強ノート</a>.</p>]]></description>
										<content:encoded><![CDATA[<p>今回はPreferred Networksとその子会社のPreferred Elementsが共同で開発した1Bサイズの小型のLLM、<a href="https://huggingface.co/pfnet/plamo-2-1b" title="PLaMo 2 1B">PLaMo 2 1B</a>に対してSFTをするコードの紹介になります。</p>



<p>Google Colabの無料枠で推論を回す方法は前回記事にしましたので、そもそもPLaMo 2 1Bって何と思った方や推論を回してみたいという方はそちらをご覧ください。</p>



<figure class="wp-block-embed is-type-wp-embed"><div class="wp-block-embed__wrapper">
<blockquote class="wp-embedded-content" data-secret="ZZZgUkUP2R"><a href="https://www.mattari-benkyo-note.com/2025/02/12/plamo-2-1b-infernece/">小型LLM PLaMo 2 1BをGoogle Colabの無料枠の範囲で使ってみる</a></blockquote><iframe class="wp-embedded-content" sandbox="allow-scripts" security="restricted"  title="&#8220;小型LLM PLaMo 2 1BをGoogle Colabの無料枠の範囲で使ってみる&#8221; &#8212; まったり勉強ノート" src="https://www.mattari-benkyo-note.com/2025/02/12/plamo-2-1b-infernece/embed/#?secret=V6iPyeE1qa#?secret=ZZZgUkUP2R" data-secret="ZZZgUkUP2R" width="600" height="338" frameborder="0" marginwidth="0" marginheight="0" scrolling="no"></iframe>
</div></figure>



<p>また、今回説明に使うコードはこちらに置いてありますので、適宜参照してください。</p>



<p><a href="https://github.com/shu65/plamo-2-1b-sft-example">https://github.com/shu65/plamo-2-1b-sft-example</a></p>



<p>Google Colabにおける一連の実行に関してはJupyter Notebookにまとめてありますので、細かい実行方法がわからないという方はこちらをご覧ください</p>



<p><a href="https://github.com/shu65/plamo-2-1b-sft-example/blob/main/run_sft_google_colab.ipynb">https://github.com/shu65/plamo-2-1b-sft-example/blob/main/run_sft_google_colab.ipynb</a></p>



<h2 class="wp-block-heading">Supervised Fine-Tuning(SFT)とは？</h2>



<p>SFTを知らない方に簡単に説明すると、SFTは指示と想定されている回答のペアを用意し、LLMに対して学習を行い、指示に従いやすいモデルを作る方法になります。</p>



<p>特にPLaMo 2 1Bのような事前学習モデルでは、特に指示に従うように学習されていないケースもあり、そのまま利用した際、余計なことをだらだらと出力し続けたり、頓珍漢な回答が返ってきたりという問題が発生することがあります。</p>



<p>このため指示に適切にこたえてもらうための技術がいろいろあるのですが、そのうちの一つにSFTというものがあります。</p>



<h2 class="wp-block-heading">Google ColabでPLaMo 2 1BをSFTする</h2>



<p>それでは本題のGoogle ColabでPLaMo 2 1BをSFTする方法について説明します。今回はGPUメモリの関係上、おそらく無料で使えるT4だと無改造では実行できない気がするのでL4を使った説明をします。</p>



<h3 class="wp-block-heading">L4 GPUの利用</h3>



<p>まず、Google ColabでL4が使えるように、課金が必要になります。</p>



<p>課金についてはこちらをご覧ください。</p>



<p><a href="https://colab.research.google.com/signup?hl=ja">https://colab.research.google.com/signup?hl=ja</a></p>



<p>今回のコードを動かすだけであれば「Pay As You Go」で100 コンピューティング ユニットを購入すれば十分です。この記事を執筆時点では1200円に満たない程度で購入できます。</p>



<p>課金が済んだら、メニューバーから「ランタイム」→「ランタイムのタイプを変更」をクリックします。すると無料枠では選択できないL4 GPUが選択できるようになっていると思うので、L4 GPUを選択します。</p>



<p>これでGPUを使う準備ができました。</p>



<h3 class="wp-block-heading">実行環境準備</h3>



<p>L4を利用するようにしたら、実行するコードのダウンロードやPythonパッケージのインストールを行います。</p>



<p>まずGithubよりコードをcloneしてきます</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-plain"><code>!git clone https://github.com/shu65/plamo-2-1b-sft-example.git</code></pre></div>



<p>次に、PyTorchのバージョンを現在の最新版よりも前の以下のものに変更します。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-plain"><code>!pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu124</code></pre></div>



<p>この後は以下のようにPyTorch以外のPLaMo 2 1Bの実行に必要なパッケージやSFTに必要なパッケージなどをインストールします。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-plain"><code>!pip install -r plamo-2-1b-sft-example/requirements.txt</code></pre></div>



<p>ここまで実行すると2025/02/12現在以下のようなバージョンがインストールされました。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-plain"><code>causal-conv1d                      1.5.0.post8
fastrlock                          0.8.3
mamba-ssm                          2.2.4
numba                              0.61.0
numba-cuda                         0.0.17.1
sentence-transformers              3.4.1
torch                              2.4.1+cu124
torchaudio                         2.4.1+cu124
torchsummary                       1.5.1
torchvision                        0.19.1+cu124
transformers                       4.48.2
trl                                0.14.0</code></pre></div>



<p>これであとはSFTのコードを実行すれば、SFTをすることができます。このSFTの中身に関しては次で紹介していきます。</p>



<h3 class="wp-block-heading">PLaMo 2 1BをSFTする</h3>



<p>SFTをする部分は<code>sft.py</code>　というスクリプトにまとめてあります。このスクリプトの重要な部分について簡単にですが説明していきます。</p>



<p>まず、今回はすぐに実行が終わるように少量の質問と回答のペアのデータを用います。</p>



<p>今回は日本語の指示学習でよく使われる<a href="https://huggingface.co/datasets/kunishou/databricks-dolly-15k-ja" target="_blank" rel="noopener" title="kunishou/databricks-dolly-15k-ja">kunishou/databricks-dolly-15k-ja</a>というデータセットのうち、<code>input</code> がなく<code>instruction</code> と<code>output</code> のペアになっているデータのみを取り出しその一部だけを利用します。一つ例を見せると以下のようなデータを利用します。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-lang="Python"><code>{
  &quot;output&quot;: &quot;イコクエイラクブカ&quot;,
  &quot;input&quot;: &quot;&quot;,
  &quot;index&quot;: &quot;1&quot;,
  &quot;category&quot;: &quot;classification&quot;,
  &quot;instruction&quot;: &quot;魚の種類はどっち？イコクエイラクブカとロープ&quot;
}</code></pre></div>



<p>一部だけ取り出すコードは以下の通りです。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-lang="Python"><code>    dataset = datasets.load_dataset(&quot;kunishou/databricks-dolly-15k-ja&quot;)
    train_dataset = dataset[&quot;train&quot;].filter(lambda data: data[&quot;input&quot;] == &quot;&quot;)</code></pre></div>



<p>次に<code>SFTConfig</code> というSFTの実行の設定のクラスのインスタンスを用意します。具体的には以下の通りです。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-lang="Python"><code>    sft_args = SFTConfig(
        output_dir=&quot;./outputs&quot;,
        evaluation_strategy=&quot;no&quot;,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        learning_rate=5e-5,
        num_train_epochs=0.1,
        lr_scheduler_type=&quot;cosine&quot;,
        warmup_ratio=0.3,
        logging_steps=10,
        save_strategy=&quot;epoch&quot;,
        report_to=&quot;tensorboard&quot;,
        bf16=True,
        max_seq_length=1024,
        gradient_checkpointing=True,
        deepspeed=&#39;./deepspeed_config.json&#39;,
    )</code></pre></div>



<p>重要なこととして、今回はGPUのメモリが少ないため、DeepSpeedのStage 3という学習時に一部のデータをCPU側に置いておくモードを利用します。</p>



<p>これによりGPUメモリが少ない環境でもSFTを回すことができます。</p>



<p>DeepSpeed周りの設定は<code>deepspeed_config.json</code> に書いてありますので気になる方はご覧ください。</p>



<p>また、今回は学習データの10%だけを利用するようにしています。これはこの学習を早く終わらせるためであり、本来はもっと回す必要があると考えられますので、本気でSFTをする場合は注意してください。</p>



<p>次にデータをどのようなフォーマットでLLMに入力するかを指定する<code>formatting_func</code> という関数を用意します。今回は以下のようにしました。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-lang="Python"><code>INSTRUCTION_TEMPLATE = string.Template(
    &quot;&quot;&quot;### Question:
${input}

### Answer:
${response}&lt;|plamo:eos|&gt;
&quot;&quot;&quot;
)


def formatting_func(examples):
    output_texts = []
    for i in range(len(examples[&#39;instruction&#39;])):
        text = INSTRUCTION_TEMPLATE.substitute(input=examples[&#39;instruction&#39;][i], response=examples[&#39;output&#39;][i])
        output_texts.append(text)
    return output_texts</code></pre></div>



<p><code>INSTRUCTION_TEMPLATE</code> が今回のフォーマットで、<code>### Question:\n</code> の後に指示、<code>### Answer:\n</code> のあとに回答が続き、最後にend of sequenceである<code>&lt;|plamo:eos|&gt;</code> が来るようになっています。</p>



<p>また、学習時には回答部分だけを学習してほしいので、どこからが回答かがわかるように<code>‎DataCollatorForCompletionOnlyLM</code> のインスタンスも用意します。これは以下の通りです。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-lang="Python"><code>    data_collator = DataCollatorForCompletionOnlyLM(
        response_template=tokenizer.encode(&quot; Answer:\n&quot;, add_special_tokens=False),
        tokenizer=tokenizer
    )
</code></pre></div>



<p><code>response_template</code> のところで回答前の部分がどのようなtoken idになるかを指定する部分があるので、上記のように指定します。前後の文字の影響で指定したtoken idが出現しないケースがあるので、その時はいろいろ<code>response_template</code> に指定する文字列を調整してみてください。</p>



<p>最後にSFTを実行するためのクラスの<code>‎SFTTrainer</code> を以下のように用意します。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-lang="Python"><code>    trainer = SFTTrainer(
        model=model,
        args=sft_args,
        data_collator=data_collator,
        train_dataset=train_dataset,
        tokenizer=tokenizer,
        formatting_func=formatting_func,
    )</code></pre></div>



<p>そして、以下のように実行し、結果を保存します。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-lang="Python"><code>    trainer.train()
    trainer.save_model()</code></pre></div>



<p>これで学習が終わると<code>SFTConfig</code> の<code>output_dir</code> で指定した<code>./outputs</code> に結果が出力されます。試しに私がGoogle Colabで実行した際は13分程度で学習が終わりました。コンピューティングユニットとしてはパッケージなどのインストールも含めて4だけ消費しました。</p>



<h3 class="wp-block-heading">SFTされたモデルで推論してみる</h3>



<p>最後にSFTされたモデルで推論するというのを行います。</p>



<p>これはPLaMo 2 1Bのexampleとほぼ同じでpromptだけ少し変えたものを例として用います。コードとしては以下の通りです。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-plain"><code>from transformers import AutoTokenizer, AutoModelForCausalLM
import torch


model_name = &quot;./plamo-2-1b-sft-example/outputs&quot;

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)


# プロンプトの準備
prompt = &quot;### Question:\n埼玉の県庁所在地は何市？\n\n### Answer:\n&quot;

# 推論の実行
inputs = tokenizer(prompt, return_tensors=&quot;pt&quot;)
generated_tokens = model.generate(
    **inputs,
    max_new_tokens=64,
    pad_token_id=tokenizer.pad_token_id,
)[0]
generated_text = tokenizer.decode(generated_tokens)
print(generated_text)</code></pre></div>



<p>出力結果は以下のようになります。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-plain"><code>&lt;|plamo:bos|&gt;### Question:
埼玉の県庁所在地は何市？

### Answer:
埼玉県の県庁所在地はさいたま市です。&lt;|plamo:eos|&gt;</code></pre></div>



<p>ちゃんと学習で指定されたように<code>### Answer:\n</code> の後に質問に対する回答をし、その後<code>&lt;|plamo:eos|&gt;</code> を出力するということができています。</p>



<p>ちなみにSFTしていないモデルではどうなるかというと、以下のように余計なことを出力するうえ、出力が止まらないという状態になっています。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-plain"><code>&lt;|plamo:bos|&gt;### Question:
埼玉の県庁所在地は何市？

### Answer:
さいたま市

### 解説
「県庁所在地」とは、都道府県庁が置かれている都市のことです。
「さいたま市」は埼玉県の県庁所在地です。

### 関連記事
### 取り急ぎお知らせ
「埼玉の県庁所在地は何市？」の解説は以上です。
「埼玉の県庁所在地は何市？」の解説は以上です。</code></pre></div>



<p>このため、SFTでうまくフォーマットに従うよう学習できたと考えられます。</p>



<h2 class="wp-block-heading">終わりに</h2>



<p>今回はPLaMo 2 1Bを使ってSFTをする例を示しました。今回示したように簡単なSFTなら十分Google Colabで実行することができます。みなさんもぜひいろいろ試していただければと思います。</p><p>The post <a href="https://www.mattari-benkyo-note.com/2025/02/13/plamo-2-1b-sft/">小型LLM PLaMo 2 1BをGoogle ColabでSFTしてみる</a> first appeared on <a href="https://www.mattari-benkyo-note.com">まったり勉強ノート</a>.</p>]]></content:encoded>
					
					<wfw:commentRss>https://www.mattari-benkyo-note.com/2025/02/13/plamo-2-1b-sft/feed/</wfw:commentRss>
			<slash:comments>0</slash:comments>
		
		
		<post-id xmlns="com-wordpress:feed-additions:1">3210</post-id>	</item>
		<item>
		<title>[書評] 機械学習エンジニアのためのTransformers ー 自然言語のTransformerについてより知りたい人向けな一冊</title>
		<link>https://www.mattari-benkyo-note.com/2023/05/08/transformers_book_review/</link>
					<comments>https://www.mattari-benkyo-note.com/2023/05/08/transformers_book_review/#respond</comments>
		
		<dc:creator><![CDATA[Shuji Suzuki (shu)]]></dc:creator>
		<pubDate>Sun, 07 May 2023 22:13:57 +0000</pubDate>
				<category><![CDATA[書評]]></category>
		<category><![CDATA[llm]]></category>
		<category><![CDATA[pytorch]]></category>
		<category><![CDATA[transformers]]></category>
		<category><![CDATA[書籍]]></category>
		<guid isPermaLink="false">https://www.mattari-benkyo-note.com/?p=2805</guid>

					<description><![CDATA[<p>今回は毎週月曜日恒例の書評回です。今回は「機械学習エンジニアのためのTransformers ―最先端の自然言語処理ライブラリによるモデル開発」を読んだところなので、この本についての記事になります。 どんな内容の本か？  [&#8230;]</p>
<p>The post <a href="https://www.mattari-benkyo-note.com/2023/05/08/transformers_book_review/">[書評] 機械学習エンジニアのためのTransformers ー 自然言語のTransformerについてより知りたい人向けな一冊</a> first appeared on <a href="https://www.mattari-benkyo-note.com">まったり勉強ノート</a>.</p>]]></description>
										<content:encoded><![CDATA[<p>今回は毎週月曜日恒例の書評回です。今回は「<a href="https://amzn.to/3LFWDuS" target="_blank" rel="noopener" title="機械学習エンジニアのためのTransformers ―最先端の自然言語処理ライブラリによるモデル開発">機械学習エンジニアのためのTransformers ―最先端の自然言語処理ライブラリによるモデル開発</a>」を読んだところなので、この本についての記事になります。</p>



<div style="text-align: center;">
<a href="https://www.amazon.co.jp/%E6%A9%9F%E6%A2%B0%E5%AD%A6%E7%BF%92%E3%82%A8%E3%83%B3%E3%82%B8%E3%83%8B%E3%82%A2%E3%81%AE%E3%81%9F%E3%82%81%E3%81%AETransformers-%E2%80%95%E6%9C%80%E5%85%88%E7%AB%AF%E3%81%AE%E8%87%AA%E7%84%B6%E8%A8%80%E8%AA%9E%E5%87%A6%E7%90%86%E3%83%A9%E3%82%A4%E3%83%96%E3%83%A9%E3%83%AA%E3%81%AB%E3%82%88%E3%82%8B%E3%83%A2%E3%83%87%E3%83%AB%E9%96%8B%E7%99%BA-Lewis-Tunstall/dp/4873119952?&#038;linkCode=li3&#038;tag=shu65-22&#038;linkId=4274a2b40be985f83809c4757dac41f5&#038;language=ja_JP&#038;ref_=as_li_ss_il" target="_blank" rel="noopener"><img decoding="async" border="0" src="//ws-fe.amazon-adsystem.com/widgets/q?_encoding=UTF8&#038;ASIN=4873119952&#038;Format=_SL250_&#038;ID=AsinImage&#038;MarketPlace=JP&#038;ServiceVersion=20070822&#038;WS=1&#038;tag=shu65-22&#038;language=ja_JP" ></a><img decoding="async" src="https://ir-jp.amazon-adsystem.com/e/ir?t=shu65-22&#038;language=ja_JP&#038;l=li3&#038;o=9&#038;a=4873119952" width="1" height="1" border="0" alt="" style="border:none !important; margin:0px !important;" />
</div>



<h2 class="wp-block-heading">どんな内容の本か？</h2>



<p>この本を一言でまとめると「Transformersを使った推論、学習など幅広くまとめた本」という感じかと思います。<br>「Transformer&#8221;s&#8221;」って何?という方向けに説明すると、ChatGPTなどで使われているTrasnfomerというモデルを扱いやすくしたPythonライブラリです。おそらく、この記事を執筆している現在、自然言語系のタスク向けにTransfomerのモデルを使って学習したり、推論したりしようと思ったら多分使うことになるライブラリかと思います。</p>



<p>このTransformersについて開発しているHugging Faceの人たちが自ら解説した本がこの本になります。扱っているテーマは幅広く、Transformerの仕組みから、Transformersを使ったテキスト分類などいくつかの応用タスクを実際に実行する方法、Transformersの高速化、学習などが書かれています。Transformersについて知りたいと思ったら、このを本をまず読んでみると全体を俯瞰できてよいかと思います。</p>



<h2 class="wp-block-heading">どんな人にお勧めか？</h2>



<p>この本は以下のような人に向いている本かなと思っています。</p>



<ol class="wp-block-list">
<li>Transformerの自然言語応用について幅広く勉強したい人</li>



<li>Transformersを使ったコードについていろいろ知りたい人</li>
</ol>



<p>特にTransformerの自然言語応用について知りたい方はちょうどよい本かと思います。一方、Transfomerの言語以外の応用、例えば画像なんかについては簡単な紹介はありますが、詳しくは書かれていません。このため、自然言語以外について知りたい人には向かない本だと思います。</p>



<h2 class="wp-block-heading">個人的に良かった点</h2>



<p>個人的には以下の点が良かったです。</p>



<ol class="wp-block-list">
<li>Transformersを使ったpretrainingについてちゃんと書いてある</li>



<li>備考的なことについてもいろいろ言及があり、しかも参考文献がしっかりついているので、詳しく知りたい場合は論文にあたりやすい</li>
</ol>



<p>Transformers + 自然言語については最近話題なこともあり、何冊か本が出ています。私自身、数冊読んだのですが、どれも応用よりなことが多く、pretrainingなどまで書いてない、もしくは書いてあったとしてもちょっとしかないみたいな本が多い印象です。この点、この本はpretrainingのやり方までちゃんと具体例を示しながら説明してあって良かったです。</p>



<p>また、単純にTransformersの使い方の説明にとどまらず、例えばデータセットの課題やTokenizerごとの違いについても簡単な言及がちゃんと書かれています。また、これらにちゃんとどの論文に書かれているのか示されているので、より詳しく知りたい場合は論文を読んで勉強するということができるようになっています。</p>



<h2 class="wp-block-heading">終わりに</h2>



<p>今回はTransfomrersについて書かれた「<a href="https://amzn.to/3LFWDuS" target="_blank" rel="noopener" title="機械学習エンジニアのためのTransformers ―最先端の自然言語処理ライブラリによるモデル開発">機械学習エンジニアのためのTransformers ―最先端の自然言語処理ライブラリによるモデル開発</a>」について紹介する記事を書きました。</p>



<p>今後もこのように読んだ本の紹介を毎週月曜日に投稿しようと思いますので、興味がある方は見に来てみてください。</p>



<p>この記事が皆様の役に立てば幸いです。</p><p>The post <a href="https://www.mattari-benkyo-note.com/2023/05/08/transformers_book_review/">[書評] 機械学習エンジニアのためのTransformers ー 自然言語のTransformerについてより知りたい人向けな一冊</a> first appeared on <a href="https://www.mattari-benkyo-note.com">まったり勉強ノート</a>.</p>]]></content:encoded>
					
					<wfw:commentRss>https://www.mattari-benkyo-note.com/2023/05/08/transformers_book_review/feed/</wfw:commentRss>
			<slash:comments>0</slash:comments>
		
		
		<post-id xmlns="com-wordpress:feed-additions:1">2805</post-id>	</item>
		<item>
		<title>[勉強ノート] 「拡散モデル　データ生成技術の数理」 3.1-3.5のVE-SDE部分について</title>
		<link>https://www.mattari-benkyo-note.com/2023/04/13/diffusion_model_book_3_ve_sde/</link>
					<comments>https://www.mattari-benkyo-note.com/2023/04/13/diffusion_model_book_3_ve_sde/#respond</comments>
		
		<dc:creator><![CDATA[Shuji Suzuki (shu)]]></dc:creator>
		<pubDate>Wed, 12 Apr 2023 21:59:03 +0000</pubDate>
				<category><![CDATA[未分類]]></category>
		<category><![CDATA[pytorch]]></category>
		<category><![CDATA[拡散モデル]]></category>
		<category><![CDATA[書籍]]></category>
		<guid isPermaLink="false">https://www.mattari-benkyo-note.com/?p=2391</guid>

					<description><![CDATA[<p>先日紹介した「拡散モデル　データ生成技術の数理」をちゃんと理解するために数式を改めて追ったり、説明されているアルゴリズムを実装をしたりしたものをまとめた記事の第4弾です。今回は3章の分散発散型確率微分方程式 (VE-SD [&#8230;]</p>
<p>The post <a href="https://www.mattari-benkyo-note.com/2023/04/13/diffusion_model_book_3_ve_sde/">[勉強ノート] 「拡散モデル　データ生成技術の数理」 3.1-3.5のVE-SDE部分について</a> first appeared on <a href="https://www.mattari-benkyo-note.com">まったり勉強ノート</a>.</p>]]></description>
										<content:encoded><![CDATA[<p>先日紹介した「<a href="https://amzn.to/3SC3LMc" target="_blank" rel="noreferrer noopener">拡散モデル　データ生成技術の数理</a>」をちゃんと理解するために数式を改めて追ったり、説明されているアルゴリズムを実装をしたりしたものをまとめた記事の第4弾です。今回は3章の分散発散型確率微分方程式 (VE-SDE)の部分のコードを書いたのでVE-SDEの式の簡単な説明とコードの解説記事になります。</p>



<p>今回の記事はスコアベースモデル (SBM)はすでに理解している前提で説明していきます。もしスコアベースモデルがよくわからないという方はこちらに簡単な解説を書いたので参考にしてください。</p>



<figure class="wp-block-embed is-type-wp-embed is-provider-まったり勉強ノート wp-block-embed-まったり勉強ノート"><div class="wp-block-embed__wrapper">
<blockquote class="wp-embedded-content" data-secret="hky94vpgHi"><a href="https://www.mattari-benkyo-note.com/2023/03/08/diffusion_model_book_2_2/">[勉強ノート] 「拡散モデル　データ生成技術の数理」 2.2 スコアベースモデル</a></blockquote><iframe class="wp-embedded-content" sandbox="allow-scripts" security="restricted"  title="&#8220;[勉強ノート] 「拡散モデル　データ生成技術の数理」 2.2 スコアベースモデル&#8221; &#8212; まったり勉強ノート" src="https://www.mattari-benkyo-note.com/2023/03/08/diffusion_model_book_2_2/embed/#?secret=2YF8VdBkdb#?secret=hky94vpgHi" data-secret="hky94vpgHi" width="600" height="338" frameborder="0" marginwidth="0" marginheight="0" scrolling="no"></iframe>
</div></figure>



<p>今回のコードは以下のところにあげてありますので、コード全体を見たい方はこちらをご覧ください。</p>



<p><a href="https://github.com/shu65/diffusion-model-book/blob/main/diffusion_model_book_3_VE_SDE.ipynb">https://github.com/shu65/diffusion-model-book/blob/main/diffusion_model_book_3_VE_SDE.ipynb</a></p>



<h2 class="wp-block-heading">分散発散型確率微分方程式(VE-SDE)とは？</h2>



<p>この本の3章の前半部分で、2章で紹介しているスコアベースモデル (SBM)とデノイジング拡散確率モデル (DDPM) を確率微分方程式 (SDE) とみなすことができるという説明をしています。このうち、SBMのほうをSDE表現してでてくるものが分散発散型確率微分方程式 (VE-SDE)です。</p>



<h2 class="wp-block-heading">確率微分方程式(SDE)</h2>



<p>確率微分方程式（SDE; Stochastic Differential Equations）は次の式で与えられます。</p>



<p>$$ \begin{align} \text{d}\boldsymbol{x} = \boldsymbol{f}(\boldsymbol{x}, t)\text{d}t + \boldsymbol{G}(\boldsymbol{x}, t)\text{d}\boldsymbol{w} \tag{3.1} \end{align} $$</p>



<p>この式において\(\text{d}\boldsymbol{x} \)は\(\boldsymbol{x}\)の変化量です。この変化量は決定的に変化量である\(\boldsymbol{f}(\boldsymbol{x}, t)\text{d}t\)とランダムに変化する量である\(\boldsymbol{G}(\boldsymbol{x}, t)\text{d}\boldsymbol{w}\)の和で構成されています。</p>



<p>ここで、\(\boldsymbol{w}\)は標準ウィーナー過程またはブラウン運動ともよばれ、\(\text{d}\boldsymbol{w}\)は微小時間間隔\(\tau\)において平均が0、分散が\(\tau\)の正規分布とみなすことができます。</p>



<p>この確率微分方程式において\(\boldsymbol{f}(\cdot, t)\)はドリフト係数、\(\boldsymbol{G}(\cdot, t)\)は拡散係数と呼びます。</p>



<p>ただし、一般に拡散モデルで扱う確率微分方程式以下のようにドリフト係数が時間のみに依存する関数\(\boldsymbol{f}(t)\)と入力\(\boldsymbol{x}\)の積、拡散係数は時間のみに依存してスカラ値を出力する\(g(t)\)を使った確率微分方程式が利用されます。</p>



<p>$$ \begin{align} \text{d}\boldsymbol{x} = f(t)\boldsymbol{x}\text{d}t +g(t)\text{d}\boldsymbol{w} \tag{3.2} \end{align} $$</p>



<h2 class="wp-block-heading">スコアベースモデルの拡散過程をSDEで表現する</h2>



<p>スコアベースモデル(SBM)の拡散過程は以下のようになっていました。</p>



<p>$$ \begin{align} q(\boldsymbol{x}_i | \boldsymbol{x}) = \mathcal{N}(\boldsymbol{x}, \sigma_i^2\boldsymbol{I}) \tag{3.3} \end{align} $$</p>



<p>ここで\(i = 0,&#8230;, N\)です。この場合の拡散過程の1ステップは次のようになります。</p>



<p>$$ \begin{align} q(\boldsymbol{x}_i | \boldsymbol{x}_{i-1}) = \mathcal{N}(\boldsymbol{x}_i;\boldsymbol{x}_{i-1}, (\sigma_i^2 &#8211; \sigma_{i-1}^2)\boldsymbol{I}) \tag{3.4}<br>\end{align} $$</p>



<p>式(3.3), (3.4)は2章のほうで説明されています。この拡散過程の1ステップは変数変換を使うと以下のようになります。</p>



<p>$$ \begin{align} <br>\boldsymbol{x}_i &amp;=  \boldsymbol{x}_{i-1} + \sqrt{\sigma_i^2 &#8211; \sigma_{i-1}^2}\boldsymbol{z}_{i-1} \tag{3.5} \\<br>\boldsymbol{z}_{i-1} &amp;\sim  \mathcal{N}(0, \boldsymbol{I})  \tag{3.6} <br>\end{align} $$</p>



<p>ここで簡略化のために\(\sigma_0 = 0\) として考えます。</p>



<p>ここから\(N \rightarrow \infty\) とした極限を考えていきます。この時、\(i\)の代わりに\(t\)を用いて、\({\boldsymbol{x}_i}_{i=1}^N\)を連続的な確率過程\({\boldsymbol{x}_t}_{t=0}^1\)とし、\(\sigma_i\)を関数\(\sigma(t)\)、\(\boldsymbol{z}_{i}\)は\(\boldsymbol{z}(t)\)とします。</p>



<p>また、\(\Delta t=1/N\)とし、\(t \in \left\{0, \frac{1}{N},&#8230;, \frac{N-1}{N} \right\}\)とします。</p>



<p>この時式(3.5)の式は以下のようになります。</p>



<p>$$ \begin{align} \boldsymbol{x}(t + \Delta t) = \boldsymbol{x}(t) + \sqrt{\sigma(t + \Delta t)^2 &#8211; \sigma(t)^2}\boldsymbol{z}_{i-1} \tag{3.7} \end{align} $$</p>



<p>ここで\(\sigma(t + \Delta t)^2 &#8211; \sigma(t)^2\)の部分で1次近似を利用して式変形します。1次近似は以下の近似を指します。</p>



<p>$$ \begin{align} f(x + \Delta x) \approx \frac{\text{d}f(x)}{\text{d}x} \Delta x + f(x) \tag{3.8} \end{align} $$</p>



<p>この1次近似の式において\(f(x)\)の部分を\(\sigma(t)^2\)として置き換えると以下のようになります。</p>



<p>$$ \begin{align} \sigma(t + \Delta t)^2 \approx \frac{\text{d}[\sigma(t)^2]}{\text{d}t} \Delta t + \sigma(t)^2 \tag{3.9} \end{align} $$</p>



<p>この式の両辺を\(\sigma(t)^2\)で引くと以下のようになります。</p>



<p>$$ \begin{align} \sigma(t + \Delta t)^2 &#8211; \sigma(t)^2 \approx \frac{\text{d}[\sigma(t)^2]}{\text{d}t} \Delta t \tag{3.10} \end{align} $$</p>



<p>この式(3.10)を式(3.7)に代入すると以下のようになります。</p>



<p>$$ \begin{align} \boldsymbol{x}(t + \Delta t) = \boldsymbol{x}(t) + \sqrt{\frac{\text{d}[\sigma(t)^2]}{\text{d}t} \Delta t}\boldsymbol{z}(t) \tag{3.11} \end{align} $$</p>



<p>このあとの説明のために以下のように少し式変形をします。</p>



<p>$$ \begin{align}<br>\boldsymbol{x}(t + \Delta t) &amp;= \boldsymbol{x}(t) + \sqrt{\frac{\text{d}[\sigma(t)^2]}{\text{d}t} \Delta t}\boldsymbol{z}(t) \\<br>\boldsymbol{x}(t + \Delta t) &#8211; \boldsymbol{x}(t) &amp;= \sqrt{\frac{\text{d}[\sigma(t)^2]}{\text{d}t} \Delta t}\boldsymbol{z}(t) \\<br>\boldsymbol{x}(t + \Delta t) &#8211; \boldsymbol{x}(t) &amp;= \sqrt{\frac{\text{d}[\sigma(t)^2]}{\text{d}t}} \left(\sqrt{\Delta t}\boldsymbol{z}(t) \right) \tag{3.12}<br>\end{align} $$</p>



<p>さて、ここから\(\Delta t \to 0\) にしたときのことを考えます。式(3.12)の左辺のほうは以下のようになります。</p>



<p>$$ \begin{align}<br>\lim_{\Delta t \to 0} \boldsymbol{x}(t + \Delta t) &#8211; \boldsymbol{x}(t) = \text{d}\boldsymbol{x} \tag{3.13}<br>\end{align} $$</p>



<p>問題は右辺の\(\sqrt{\Delta t}\boldsymbol{z}(t) \)の部分です。これは結果的には以下のようになります。</p>



<p>$$ \begin{align}<br>\lim_{\Delta t \to 0} \sqrt{\Delta t}\boldsymbol{z}(t) = \text{d}\boldsymbol{w} \tag{3.14}<br>\end{align} $$</p>



<p>この部分ですが本も元論文[3] のほうにもこの式変形のところで言及がないのでわかりにくいので少し説明します。</p>



<p>まず、そもそも\(\text{d}\boldsymbol{w}\)は何であったか？ですが、これは最初に説明した通り標準ウィーナー過程またはブラウン運動ともよばれ、\(\text{d}\boldsymbol{w}\)は微小時間間隔\(\tau\)において平均が0、分散が\(\tau\)の正規分布とみなすことができます。このことから以下のように表すことができます。</p>



<p>$$ \begin{align}<br>\text{d}\boldsymbol{w} \sim \mathcal{N}(0,  \tau \boldsymbol{I}) \tag{3.15}<br>\end{align} $$</p>



<p>ここで\(\boldsymbol{z}(t)\)は </p>



<p>$$ \begin{align}<br>\boldsymbol{z}(t) \sim \mathcal{N}(0, \boldsymbol{I}) \tag{3.16} \\<br>\end{align} $$</p>



<p>なので、\(\text{d}\boldsymbol{w}\)は以下のようになります。</p>



<p>$$ \begin{align}<br>\text{d}\boldsymbol{w} = \sqrt{\tau} \boldsymbol{z}(t) \tag{3.17}<br>\end{align} $$</p>



<p>\(\tau\)が微小時間間隔なので式(3.14)と式(3.17)を見比べるとなんとなく式(3.14)が成り立ちそうだなぁと思います。ただ、極限を素直に考えると以下のようになるのでは？とずっと思ってました。</p>



<p>$$ \begin{align}<br>\lim_{\Delta t \to 0} \sqrt{\Delta t}\boldsymbol{z}(t) = 0<br>\end{align} $$</p>



<p>この部分、私は気になってしょうがなかったので、少し調べました。結論からいうとこの部分の式変形に関してはウィーナー過程の条件から導出できそうだということがわかりました。詳しくは以下のサイトが分かりやすかったので、詳しく知りたい方はご覧ください。</p>



<p><a href="http://takashiyoshino.random-walk.org/memo/keikaku_ensyu/node4.html" target="_blank" rel="noreferrer noopener">http://takashiyoshino.random-walk.org/memo/keikaku_ensyu/node4.html</a></p>



<p>ここでは簡単に説明します。まずウィーナー過程 \(\boldsymbol{w}(t)\)を考えます。ウィナー過程の条件より以下が成り立ちます。</p>



<p>$$ \begin{align}<br>\boldsymbol{w}(t + \Delta t) &#8211; \boldsymbol{w}(t) \sim \mathcal{N}(0, \Delta t \boldsymbol{I}) \tag{3.18}<br>\end{align} $$</p>



<p>ここで式(3.18)を右辺を見ると平均０、分散\(\Delta t\)の正規分布です。このため、式(3.18)は左辺は以下のように表すこともできます。</p>



<p>$$ \begin{align}<br>\boldsymbol{w}(t + \Delta t) &#8211; \boldsymbol{w}(t) = \sqrt{\Delta t}\boldsymbol{z}(t) \tag{3.19} \\<br>\end{align} $$</p>



<p>この式(3.19)の右辺は式(3.14)の左辺の\(\lim_{\Delta t \to 0}\)の中と同じになります。また式(3.19)の左辺は\(\Delta t \to 0\)のとき以下のようになります。</p>



<p>$$ \begin{align}<br>\lim_{\Delta t \to 0} \left( \boldsymbol{w}(t + \Delta t) &#8211; \boldsymbol{w}(t) \right) &amp;= \text{d}\boldsymbol{w} \tag{3.20}<br>\end{align} $$</p>



<p>よって式(3.14)は式(3.19)と(3.20)を使うと以下のようになります。</p>



<p>$$ \begin{align}<br>\lim_{\Delta t \to 0} \sqrt{\Delta t}\boldsymbol{z}(t) &amp;= \lim_{\Delta t \to 0} \left( \boldsymbol{w}(t + \Delta t) &#8211; \boldsymbol{w}(t) \right) \\<br>&amp;= \text{d}\boldsymbol{w} \tag{3.21}<br>\end{align} $$</p>



<p>この式変形なら個人的には納得できました。よって最終的に式(3.12)で\(\Delta t \to 0\) を考えると式(3.13)と式(3.21)より以下のようになります。</p>



<p>$$ \begin{align}<br>\text{d}\boldsymbol{x} &amp;= \lim_{\Delta t \to 0} \boldsymbol{x}(t + \Delta t) &#8211; \boldsymbol{x}(t) \\<br>&amp;= \lim_{\Delta t \to 0} \sqrt{\frac{\text{d}[\sigma(t)^2]}{\text{d}t}} \left(\sqrt{\Delta t}\boldsymbol{z}(t) \right) \\<br>&amp;= \sqrt{\frac{\text{d}[\sigma(t)^2]}{\text{d}t}} \boldsymbol{w}(t) \tag{3.22}<br>\end{align} $$</p>



<p>この式(3.22)を見るとドリフト係数\(f(t)\) と拡散係数\(g(t)\)が以下のようなSDEであることが分かります。</p>



<p>$$ \begin{align*}<br>f(t) &amp;= 0 \tag{3.23} \\<br>g(t) &amp;= \sqrt{\frac{\text{d}[\sigma(t)^2]}{\text{d}t}} \tag{3.24} \\<br>\end{align*} $$</p>



<p>これでSBMをSDEで表現することができました。このSBMの式から導出したSDEを分散発散型確率微分方程式 (VE-SDE)と呼びます。</p>



<h2 class="wp-block-heading">VE-SDEの学習</h2>



<p>VE-SDEの各時刻\(t\)のスコアを学習するあために、次の条件付き確率（拡散カーネル）を知る必要があります。</p>



<p>$$ \begin{align*}<br>p_{0t}(\boldsymbol{x}(t)|\boldsymbol{x}(0)) \tag{3.25}<br>\end{align*} $$</p>



<p>ここで\(p_{0t}\)は\(\boldsymbol{x}(0)\)を条件付けしたときの\(\boldsymbol{x}(t)\)の確率を表しています。</p>



<p>ここでSDEが以下の形として考えていきます。</p>



<p>$$ \begin{align} \text{d}\boldsymbol{x} = f(t)\boldsymbol{x}\text{d}t + g(t)\text{d}\boldsymbol{w} \end{align} \tag{3.26}$$</p>



<p>この場合、式(3.26)の条件付き確率は以下のような正規分布で表すことができます[1, 2]。</p>



<p>$$ \begin{align} <br>p_{0t}(\boldsymbol{x}(t)|\boldsymbol{x}(0)) =&amp; \mathcal{N}(s(t)\boldsymbol{x}(0), s(t)^2\sigma^{\prime}(t)^2\boldsymbol{I}) \tag{3.27} \\<br>s(t) =&amp; \text{exp}\left(\int_0^tf(\xi)\text{d}\xi\right) \tag{3.28} \\<br>\sigma^{\prime}(t) =&amp; \sqrt{\int_0^t \frac{g(\xi)^2}{s(\xi)^2}\text{d}\xi} \tag{3.29} \\<br> \end{align} $$</p>



<p>本のほうでは式(3.27)と式(3.29) の\(\sigma^{\prime}(t)\)の部分は\(\sigma(t)\)という表記になっています。ただ、VE-SDEのほうにも\(\sigma(t)\)があって区別ができないので、この記事では式(3.27)と(3.29)に登場する\(\sigma(t)\)を\(\sigma^{\prime}(t)\)として説明していきます。</p>



<p>VE-SDEの場合はこの式を使うと簡単に\(p_{0t}(\boldsymbol{x}(t)|\boldsymbol{x}(0))\)の形がわかるので、以下に示していきます。</p>



<p>まず、\(s(t)\)の部分ですが、VE-SDEの場合、式(3.23)から以下のようになります。</p>



<p>$$ \begin{align} <br>s(t) &amp;= \text{exp}\left(\int_0^tf(\xi)\text{d}\xi\right) \\<br>&amp;= \text{exp}\left(\int_0^t 0 \text{d}\xi\right) \\<br>&amp;= \text{exp}\left(0 \right) \\<br>&amp;= 1 \tag{3.30} \\<br> \end{align} $$</p>



<p>次に\(\sigma^{\prime}(t)\)に関してです。まず式(3.26)を使って式変形します。</p>



<p>$$ \begin{align} <br>\sigma^{\prime}(t) &amp;= \sqrt{\int_0^t \frac{g(\xi)^2}{s(\xi)^2}\text{d}\xi} \\<br>&amp;= \sqrt{\int_0^t \frac{g(\xi)^2}{1^2}\text{d}\xi} \\<br>&amp;= \sqrt{\int_0^t g(\xi)^2\text{d}\xi}  \tag{3.31} <br> \end{align} $$</p>



<p>ここでVE-SDEの\(g(t)\)は式(3.24)で分かっているのでこれを利用してさらに式変形します。</p>



<p>$$ \begin{align} <br>\sigma^{\prime}(t) &amp;= \sqrt{\int_0^t g(\xi)^2\text{d}\xi} \\<br>&amp;= \sqrt{\int_0^t \left( \sqrt{\frac{\text{d}[\sigma(\xi)^2]}{\text{d}\xi}} \right)^2\text{d}\xi } \\<br>&amp;= \sqrt{\int_0^t \frac{\text{d}[\sigma(\xi)^2]}{\text{d}\xi} \text{d}\xi } \\<br>&amp;= \sqrt{\sigma(t)^2 &#8211; \sigma(0)^2} \tag{3.32} <br> \end{align} $$</p>



<p>式変形した式(3.30)、(3.32)を式(3.27)に代入すると最終的には以下のようになります。</p>



<p>$$ \begin{align} <br>p_{0t}(\boldsymbol{x}(t)|\boldsymbol{x}(0)) &amp;= \mathcal{N}(s(t)\boldsymbol{x}(0), s(t)^2\sigma^{\prime}(t)^2\boldsymbol{I}) \\<br>&amp;= \mathcal{N}(\boldsymbol{x}(0), \left[\sigma(t)^2 &#8211; \sigma(0)^2\right]\boldsymbol{I}) \tag{3.33} <br>\end{align} $$</p>



<p>これによりVE-SDEの拡散過程の条件付き確率の式がわかりました。</p>



<p>本の説明では\(\sigma(t)\)が具体的にどのような式を使うのかまでは示してないため、式変形はここまでになっています。</p>



<p>一方、このブログではコードに落とすところまでをやるため、ここからさらに式変形していきます。ここから元論文の[3]を参考にして式変形していきます。</p>



<p>[3]の論文で使われている\(\sigma(t)\)と同じものを用いて説明していきます。[3]では以下のものが使われています。</p>



<p>$$ \begin{align} <br>\sigma(t) &amp;= \sigma_{min}\left( \frac{\sigma_{max}}{\sigma_{min}} \right)^t, &amp; \ t &amp;\in (0, 1]  \\<br>\sigma(0) &amp;= 0, &amp; \ t &amp;= 0 \\<br>\tag{3.34}<br>\end{align} $$</p>



<p>ここで\(\sigma_{min}\)と\(\sigma_{max}\)はハイパーパラメータです。</p>



<p>これを使って式(3.24)の\(g(t)\)と式(3.33)の条件付き確率の式変形をしていきます。</p>



<p>まず、式(3.24)の\(g(t)\)に関してです。</p>



<p>$$ \begin{align*}<br>g(t) &amp;= \sqrt{\frac{\text{d}[\sigma(t)^2]}{\text{d}t}} \\<br>&amp;= \sqrt{\frac{\text{d}}{\text{d}t} \left( \sigma_{min}\left( \frac{\sigma_{max}}{\sigma_{min}} \right)^t  \right)^2} \\<br>&amp;= \sqrt{\frac{\text{d}}{\text{d}t} \sigma_{min}^2\left( \frac{\sigma_{max}}{\sigma_{min}} \right)^{2t} } \\<br>&amp;= \sqrt{\sigma_{min}^2 \frac{\text{d}}{\text{d}t} \left( \frac{\sigma_{max}}{\sigma_{min}} \right)^{2t} } \tag{3.35}<br>\end{align*} $$</p>



<p>ここで\( \frac{\text{d}}{\text{d}t} \left( \frac{\sigma_{max}}{\sigma_{min}} \right)^{2t} \)の部分に注目します。以下のような指数関数の微分公式を利用します。</p>



<p>$$ \begin{align*}<br>\frac{\text{d}}{\text{d}x} a^x = a^x \log a \tag{3.36}<br>\end{align*} $$</p>



<p>(参考：<a href="https://manabitimes.jp/math/1112">https://manabitimes.jp/math/1112</a>)</p>



<p>この公式を利用すると以下のようになります。</p>



<p>$$ \begin{align*}<br>\frac{\text{d}}{\text{d}t} \left(\frac{\sigma_{max}}{\sigma_{min}} \right)^{2t} &amp;= \left( \frac{\sigma_{max}}{\sigma_{min}} \right)^{2t} \log \left(\frac{\sigma_{max}}{\sigma_{min}} \right)^2<br>\tag{3.37}<br>\end{align*} $$</p>



<p>この式(3.37)を式(3.35)に代入して式変形していくと以下のようになります。</p>



<p>$$ \begin{align*}<br>g(t) &amp;= \sqrt{\sigma_{min}^2 \frac{\text{d}}{\text{d}t} \left( \frac{\sigma_{max}}{\sigma_{min}} \right)^{2t} } \\<br>&amp;= \sqrt{\sigma_{min}^2 \left( \frac{\sigma_{max}}{\sigma_{min}} \right)^{2t} \log \left(\frac{\sigma_{max}}{\sigma_{min}} \right)^2 } \\<br>&amp;= \sigma_{min} \left( \frac{\sigma_{max}}{\sigma_{min}} \right)^{t} \sqrt{\log \left(\frac{\sigma_{max}}{\sigma_{min}} \right)^2 } \\<br>&amp;= \sigma_{min} \left( \frac{\sigma_{max}}{\sigma_{min}} \right)^{t} \sqrt{2 \log \left(\frac{\sigma_{max}}{\sigma_{min}} \right)} \tag{3.38}<br>\end{align*} $$</p>



<p>次に式(3.33)の条件付き確率のほうを式変形していきます。この式には分散のほうにだけ\(\sigma(t)\)が登場するので、この部分だけ注目します。この分散に式(3.34)の\(\sigma(t)\)を代入して式変形していくと以下のようになります。</p>



<p>$$ \begin{align} <br>\sigma(t)^2 &#8211; \sigma(0)^2 &amp;= \left[\sigma_{min}\left( \frac{\sigma_{max}}{\sigma_{min}} \right)^t \right]^2 &#8211; 0^2 \\<br>&amp;= \sigma_{min}^2\left(\frac{\sigma_{max}}{\sigma_{min}} \right)^{2t} \\ \tag{3.39}<br>\end{align} $$</p>



<p>よって式(3.33)の条件付き確率は以下のようになります。</p>



<p>$$ \begin{align} <br>p_{0t}(\boldsymbol{x}(t)|\boldsymbol{x}(0)) &amp;= \mathcal{N}(\boldsymbol{x}(0), \left[\sigma(t)^2 &#8211; \sigma(0)^2\right]\boldsymbol{I}) \\<br>&amp;= \mathcal{N}\left(\boldsymbol{x}(0), \sigma_{min}^2\left(\frac{\sigma_{max}}{\sigma_{min}} \right)^{2t}\boldsymbol{I}\right) \tag{3.40} <br>\end{align} $$</p>



<p>これらを用いてデノイジングスコアマッチングをロス関数としてスコア関数\(s_{\theta}\)を学習します。VE-SDEの場合のデノイジングスコアマッチングの関数はSBMのときと同じ形になります。具体的には以下のようになります。（変数はVE-SDEに合わせています。）</p>



<p>$$ \begin{align} <br>L(\theta) :=&amp; <br>E_t \left[ \lambda(t) E_{\boldsymbol{x}(0) \sim p_{data}(\boldsymbol{x}),\boldsymbol{x}(t) \sim p_{0t}(\boldsymbol{x}(t)|\boldsymbol{x}(0))} \left\{ \right. \right. \\ <br>&amp; \quad \left. \left. \left| \nabla_{\boldsymbol{x}(t)} \log p_{0t}(\boldsymbol{x}(t)|\boldsymbol{x}(0)) &#8211; s_{\theta}(\boldsymbol{x}(t), t) \right|^2 \right\} \right] \tag{3.41} <br>\end{align} $$</p>



<p>ここで、\(\lambda(t)\)は各\(t\)における重みづけです。</p>



<p>これを実装するために、SBMのときと同じようにスコア \( \nabla_{\boldsymbol{x}(t)} \log p_{0t}(\boldsymbol{x}(t)|\boldsymbol{x}(0)) \)の部分を式変形します。これはSBMのときと同じなので本の２章と以前私が書いたSBMの解説の記事をご覧ください。</p>



<figure class="wp-block-embed is-type-wp-embed is-provider-まったり勉強ノート wp-block-embed-まったり勉強ノート"><div class="wp-block-embed__wrapper">
<blockquote class="wp-embedded-content" data-secret="hky94vpgHi"><a href="https://www.mattari-benkyo-note.com/2023/03/08/diffusion_model_book_2_2/">[勉強ノート] 「拡散モデル　データ生成技術の数理」 2.2 スコアベースモデル</a></blockquote><iframe class="wp-embedded-content" sandbox="allow-scripts" security="restricted"  title="&#8220;[勉強ノート] 「拡散モデル　データ生成技術の数理」 2.2 スコアベースモデル&#8221; &#8212; まったり勉強ノート" src="https://www.mattari-benkyo-note.com/2023/03/08/diffusion_model_book_2_2/embed/#?secret=2YF8VdBkdb#?secret=hky94vpgHi" data-secret="hky94vpgHi" width="600" height="338" frameborder="0" marginwidth="0" marginheight="0" scrolling="no"></iframe>
</div></figure>



<p>結果として以下のようになります。</p>



<p>$$ \begin{align} <br>\nabla_{\boldsymbol{x}(t)} \log p_{0t}(\boldsymbol{x}(t)|\boldsymbol{x}(0)) &amp;=  \frac{-\epsilon}{\sigma_{min}^2\left(\frac{\sigma_{max}}{\sigma_{min}} \right)^{2t}} \tag{3.42} \\<br>\epsilon &amp;\sim \mathcal{N}\left(0, \sigma_{min}^2\left(\frac{\sigma_{max}}{\sigma_{min}} \right)^{2t}\boldsymbol{I}\right) \tag{3.43}<br>\end{align} $$</p>



<p>式(3.41)を式(3.42)、(3.43)を使って変形すると以下のようになります。（式が長すぎるので\(\boldsymbol{x}(0), \boldsymbol{x}(t), \epsilon\)の分布を省略してます。）</p>



<p>$$ \begin{align} <br>L(\theta) :=&amp; <br>E_t \left[ \lambda(t) E_{\boldsymbol{x}(0),\boldsymbol{x}(t)} \left\{ \left| \nabla_{\boldsymbol{x}(t)} \log p_{0t}(\boldsymbol{x}(t)|\boldsymbol{x}(0)) &#8211; s_{\theta}(\boldsymbol{x}(t), t) \right|^2 \right\} \right] \\<br>=&amp; E_t \left[ \lambda(t) E_{\boldsymbol{x}(0),\epsilon} \left\{ \left| \frac{-\epsilon}{\sigma_{min}^2\left(\frac{\sigma_{max}}{\sigma_{min}} \right)^{2t}} &#8211; s_{\theta}(\boldsymbol{x}(t), t) \right|^2 \right\} \right] \tag{3.44} <br>\end{align} $$</p>



<p>これをPyTorchを使ってコードにすると以下のようになります。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="loss_ve_sed.py" data-lang="Python"><code>def sigma(t, sigma_min=sigma_min, sigma_max=sigma_max):
  return sigma_min * (sigma_max / sigma_min) ** t 

def ve_sde_marginal_prob_statistics(x, t, sigma_min, sigma_max):
  mean = x
  std = sigma(t=t, sigma_min=sigma_min, sigma_max=sigma_max)
  return mean, std

def ve_sde_drift(t, sigma_min, sigma_max):
  drift = torch.zeros_like(t)
  return drift

def ve_sde_diffusion(t, sigma_min, sigma_max):
  std = sigma(t=t, sigma_min=sigma_min, sigma_max=sigma_max)
  diffusion = std * torch.sqrt(2 * (torch.log(sigma_max) - torch.log(sigma_min))) # (30)
  return diffusion

def dsm_loss(score_model, samples, sigma_min, sigma_max):
  eps = 1.0e-8
  t = torch.distributions.uniform.Uniform(torch.tensor([eps], device=samples.device), torch.tensor([1], device=samples.device)).sample([samples.shape[0]]) 
  z = torch.randn_like(samples)
  mean, std = ve_sde_marginal_prob_statistics(x=samples, t=t, sigma_min=sigma_min, sigma_max=sigma_max)
  noise = z * std
  perturbed_samples = mean + z * std
  scores = score_model(perturbed_samples, t)
  target = - 1 / (std ** 2) * noise
  
  target = target.view(target.shape[0], -1)
  scores = scores.view(scores.shape[0], -1)
  g = ve_sde_diffusion(t=t, sigma_min=sigma_min, sigma_max=sigma_max)
  lmd = g ** 2
  loss = torch.sqrt(((scores - target) ** 2).sum(dim=-1)) * lmd
  return loss.mean()</code></pre></div>



<p>ここで本によると\(\lambda(t)=g(t)^2\)のときにスコアマッチングの目的関数は負の対数尤度の上限となっていることが証明できるそうです。このため、上記のコードでは\(\lambda(t)=g(t)^2\)を利用しています。</p>



<h2 class="wp-block-heading">VE-SDEのサンプリング</h2>



<p>VE-SDEのサンプリングをするためには拡散過程を逆にたどる逆算過程を知る必要があります。</p>



<p>拡散過程のSDEは式(3.1)で与えらえるとするとこの逆算過程は以下のようになります。</p>



<p>$$ \begin{align} \text{d}\boldsymbol{x} =&amp; \left\{f(\boldsymbol{x}, t) &#8211; \nabla \left[ \boldsymbol{G}(\boldsymbol{x}, t) \boldsymbol{G}(\boldsymbol{x}, t)^\text{T} \right] \right. \\<br>&amp; \quad \left. &#8211; \left[ \boldsymbol{G}(\boldsymbol{x}, t) \boldsymbol{G}(\boldsymbol{x}, t)^\text{T} \right] \nabla_{\boldsymbol{x}} \log p_t(\boldsymbol{x})\right\} \text{d}t \\<br>&amp; \quad+ \boldsymbol{G}(\boldsymbol{x}, t)\text{d}\bar{\boldsymbol{w}} \tag{3.45} \end{align} $$</p>



<p>ただし、\(\text{d}\bar{\boldsymbol{w}}\)は時刻Tから0まで客向きに辿ったときの標準ウィーナー過程です。</p>



<p>ただし、一般的に拡散もモデルで使われる確率微分方程式は式(3.2)の形だそうです。このため式(3.2)で使われている\(f(t), g(t)\)で式(3.45)を書き直すと以下のようになります。</p>



<p>$$ \begin{align} \text{d}\boldsymbol{x} =&amp; \left[f(t) &#8211; g(t)^2\nabla \log p_t(\boldsymbol{x})\right] \text{d}t + g(t)\text{d}\bar{\boldsymbol{w}} \tag{3.46} \end{align} $$</p>



<p>式(3.45)と式(3.46)の式変形の説明も本当はやろうと思ったのですが、かなり長い式変形になるのと、本の付録のほうに詳しい説明があるのでこの記事では省略します。</p>



<p>この式(3.46)に基づいて拡散モデルのサンプリングをする方法としてオイラー・丸山先生によるサンプリングが本で紹介されています。疑似コードは以下の通りです。(「拡散モデル　データ生成技術の数理」Algorithm 3.1の引用)</p>



<ol class="wp-block-list">
<li>\(\boldsymbol{x} \sim \mathcal{N}(0, \boldsymbol{I})\))</li>



<li>for \(i=T,&#8230;,1\) do</li>



<li>\(\quad \boldsymbol{z}_i \sim \mathcal{N}(0, \boldsymbol{I})\)</li>



<li>\(\quad \boldsymbol{x} :=  \boldsymbol{x}  &#8211; \left[f(t_i) &#8211; g(t_i)^2 s_{\theta}(\boldsymbol{x}, t_i)\right] \Delta t_i + g(t)\sqrt{|\Delta t_i|} \boldsymbol{z}_i \)</li>



<li>end for</li>



<li>return \(\boldsymbol{x}\)</li>
</ol>



<p>これをPyTorchで実装すると以下のようになります。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="ve_sde_sampling" data-lang="Python"><code>def euler_maruyama_sample(n_samples, score_model, device=device, n=1000):
  with torch.no_grad():
    x = torch.randn(n_samples, 2, device=device)
    dt = torch.tensor(1.0 / n, device=x.device)
    for t in range(n, 0, -1):
      t_tensor = torch.full((n_samples, 1), t/n, device=device)
      z = torch.randn(n_samples, 2)
      f = ve_sde_drift(t_tensor, score_model.sigma_min, score_model.sigma_max)
      g = ve_sde_diffusion(t_tensor, score_model.sigma_min, score_model.sigma_max)
      g2 = g ** 2
      score = score_model(x, t_tensor)
      x = x - (f*x - g2 * score) * dt + g * torch.sqrt(dt) * z
    return x</code></pre></div>



<h2 class="wp-block-heading">コードの実行例</h2>



<p>ここでは先ほど紹介したロス関数とサンプリング関数を利用して実際にVE-SDEでスコア関数のパラメータを学習し、サンプリングした例を示します。</p>



<p>参考例として入力となる\(\boldsymbol{x}\)のサンプリングする分布の確率密度関数は以下のように平均が違うガウス分布二つの混合分布とし、サンプリングしたデータを正規化して使用します。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="generate_dataset" data-lang="Python"><code>n_samples = int(1e6)
sigma = 0.01

dist0 = torch.distributions.MultivariateNormal(torch.tensor([-2, -2], dtype=torch.float).to(device), sigma*torch.eye(2, dtype=torch.float).to(device))
samples0 = dist0.sample(torch.Size([n_samples//2]))
    
dist1 = torch.distributions.MultivariateNormal(torch.tensor([2, 2], dtype=torch.float).to(device), sigma*torch.eye(2, dtype=torch.float).to(device))
samples1 = dist1.sample(torch.Size([n_samples//2]))
samples = torch.vstack((samples0, samples1))

mean = torch.mean(samples, dim=0)
std = torch.std(samples, dim=0)

normalized_samples = (samples - mean[None, :])/std[None, :]</code></pre></div>



<p>使用する\(\boldsymbol{x}\)を2Dのヒストグラムで可視化すると以下のようになります。</p>


<div class="wp-block-image">
<figure class="aligncenter size-full"><img loading="lazy" decoding="async" width="266" height="264" src="https://www.mattari-benkyo-note.com/wp-content/uploads/2023/03/sample_density.png" alt="sample density " class="wp-image-2091" srcset="https://www.mattari-benkyo-note.com/wp-content/uploads/2023/03/sample_density.png 266w, https://www.mattari-benkyo-note.com/wp-content/uploads/2023/03/sample_density-150x150.png 150w" sizes="auto, (max-width: 266px) 100vw, 266px" /><figcaption class="wp-element-caption">使用するデータの可視化結果</figcaption></figure></div>


<p>次にスコア関数のモデルと学習コードです。基本的には先ほど紹介したロス関数を使ってモデルを学習形になります。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="model_train_loop" data-lang="Python"><code>import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

class ScoreModel(nn.Module):
  def __init__(self, sigma_min, sigma_max, n_channels=2):
    super(ScoreModel, self).__init__()
    self.sigma_min = sigma_min
    self.sigma_max = sigma_max
    self.model = nn.Sequential(
        nn.Linear(n_channels, 2*n_channels),
        nn.ELU(),
        nn.Linear(2*n_channels, 16*n_channels),
        nn.ELU(),
        nn.Linear(16*n_channels, 2*n_channels),
        nn.ELU(),
        nn.Linear(2*n_channels, n_channels),
    )

  def forward(self, x, t):
    y = self.model(x)
    sigma_t = sigma(t=t, sigma_min=self.sigma_min, sigma_max=self.sigma_max)
    return y/sigma_t

batch_size = 512
n_steps = 100000

dataloader = torch.utils.data.DataLoader(dataset, batch_size=512, shuffle=True, num_workers=0)
dataloader_iter = iter(dataloader)

score_model = ScoreModel(sigma_min=sigma_min, sigma_max=sigma_max).to(device)

optimizer = torch.optim.Adam(score_model.parameters())
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.001, total_steps=n_steps)

for i in range(n_steps):
  try:
    x = next(dataloader_iter)[0]
  except StopIteration:
    dataloader_iter = iter(dataloader)
    x = next(dataloader_iter)[0]
  x = x.to(device)

  optimizer.zero_grad()
  loss = dsm_loss(score_model, x, sigma_min=sigma_min, sigma_max=sigma_max)
  loss.backward()
  optimizer.step()
  lr_scheduler.step()
  if (i % 1000) == 0:
    print(f&quot;{i} steps loss:{loss}&quot;)</code></pre></div>



<p>学習が終わったら最後に以下のようにサンプリングする関数を呼び出してサンプリングします。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="sampling" data-lang="Python"><code>samples_pred = euler_maruyama_sample(n_samples=100000, score_model=score_model)</code></pre></div>



<p>サンプリングされたデータの2Dのヒストグラムは以下の通りです。</p>


<div class="wp-block-image">
<figure class="aligncenter size-full"><img loading="lazy" decoding="async" width="266" height="264" src="https://www.mattari-benkyo-note.com/wp-content/uploads/2023/04/predicted_sample.png" alt="サンプリングデータの可視化結果" class="wp-image-2660" srcset="https://www.mattari-benkyo-note.com/wp-content/uploads/2023/04/predicted_sample.png 266w, https://www.mattari-benkyo-note.com/wp-content/uploads/2023/04/predicted_sample-150x150.png 150w" sizes="auto, (max-width: 266px) 100vw, 266px" /><figcaption class="wp-element-caption">サンプリングデータの可視化結果</figcaption></figure></div>


<p>ほぼ元の分布と同じサンプリングが得られることが確認できました。</p>



<h2 class="wp-block-heading">終わりに</h2>



<p>今回は「<a href="https://amzn.to/3SC3LMc" target="_blank" rel="noreferrer noopener">拡散モデル　データ生成技術の数理</a>」の中で紹介されている分散発散型確率微分方程式 (VE-SDE)の部分を紹介しました。コードは先月の中旬にはできていたのですが、今回紹介する部分の式変形でぱっと見てわからないところがいくつかあり、それを調べていたらだいぶ時間がかかりました。また、説明のために必要な式の打ち込みにもかなり時間がかかってしまいました。</p>



<p>ただ、頑張ったおかげでかなりVE-SDEの部分の理解が進んだので記事にまとめてよかったです。</p>



<p>今後に関してはVP-SDEに関してもやろうと思っていますが、先に最近流行りのChatGPT, LLM, LangChainあたりに関していろいろ調べてみようと思うのでそちらの記事をいくつか書いてからになると思います。</p>



<p>この記事が他の方の役に少しでもなれば幸いです。</p>



<h2 class="wp-block-heading">参考文献</h2>



<ol class="wp-block-list">
<li><a href="https://amzn.to/405EQn0" target="_blank" rel="noopener" title="確率微分方程式　入門から応用まで">確率微分方程式　入門から応用まで</a></li>



<li>Särkkä, S., &amp; Solin, A. (2019).&nbsp;Applied Stochastic Differential Equations&nbsp;(Institute of Mathematical Statistics Textbooks). Cambridge: Cambridge University Press. doi:10.1017/9781108186735</li>



<li>Song, Y., Sohl-Dickstein, J.N., Kingma, D.P., Kumar, A., Ermon, S., &amp; Poole, B. (2020). Score-Based Generative Modeling through Stochastic Differential Equations.&nbsp;ArXiv, abs/2011.13456.</li>
</ol><p>The post <a href="https://www.mattari-benkyo-note.com/2023/04/13/diffusion_model_book_3_ve_sde/">[勉強ノート] 「拡散モデル　データ生成技術の数理」 3.1-3.5のVE-SDE部分について</a> first appeared on <a href="https://www.mattari-benkyo-note.com">まったり勉強ノート</a>.</p>]]></content:encoded>
					
					<wfw:commentRss>https://www.mattari-benkyo-note.com/2023/04/13/diffusion_model_book_3_ve_sde/feed/</wfw:commentRss>
			<slash:comments>0</slash:comments>
		
		
		<post-id xmlns="com-wordpress:feed-additions:1">2391</post-id>	</item>
		<item>
		<title>PyTorch 2.0の新機能「torch.compile」使ってみた</title>
		<link>https://www.mattari-benkyo-note.com/2023/03/18/torch-compile/</link>
					<comments>https://www.mattari-benkyo-note.com/2023/03/18/torch-compile/#respond</comments>
		
		<dc:creator><![CDATA[Shuji Suzuki (shu)]]></dc:creator>
		<pubDate>Fri, 17 Mar 2023 22:20:45 +0000</pubDate>
				<category><![CDATA[プログラミング]]></category>
		<category><![CDATA[GPU]]></category>
		<category><![CDATA[python]]></category>
		<category><![CDATA[pytorch]]></category>
		<guid isPermaLink="false">https://www.mattari-benkyo-note.com/?p=2261</guid>

					<description><![CDATA[<p>今回は3/16についに出たPyTorch 2.0の目玉機能である「torch.comple」について実際に動かしてみて計算時間を測定してみたので、そのまとめになります。 時間計測の部分で測定に使ったコードはここにあげてあ [&#8230;]</p>
<p>The post <a href="https://www.mattari-benkyo-note.com/2023/03/18/torch-compile/">PyTorch 2.0の新機能「torch.compile」使ってみた</a> first appeared on <a href="https://www.mattari-benkyo-note.com">まったり勉強ノート</a>.</p>]]></description>
										<content:encoded><![CDATA[<p>今回は3/16についに出たPyTorch 2.0の目玉機能である「torch.comple」について実際に動かしてみて計算時間を測定してみたので、そのまとめになります。</p>



<p>時間計測の部分で測定に使ったコードはここにあげてあります。</p>



<p><a href="https://github.com/shu65/pytorch_2_compile_example/blob/main/torch_2_0_compile.ipynb">https://github.com/shu65/pytorch_2_compile_example/blob/main/torch_2_0_compile.ipynb</a></p>



<h2 class="wp-block-heading">torch.compileとは？</h2>



<p>torch.compileはPyTorch 2.0の新機能で、PyTorchの複数の機能を組み合わせて使い関数や深層学習のモデルを実行時に最適化して、その後の呼び出して高速に実行できるようにする機能です。</p>



<p>torch.compileの中身の詳しい説明はここにかかれています。</p>



<p><a href="https://pytorch.org/get-started/pytorch-2.0/#technology-overview">https://pytorch.org/get-started/pytorch-2.0/#technology-overview</a></p>



<p>簡単に説明するとtorch.compileの中身としては以下の３つで構成されています。</p>



<ol class="wp-block-list">
<li>Graph acquisition: 計算グラフの構築</li>



<li>Graph lowering: PyTorchのオペレーションをバックエンドのデバイス（CPUやGPU）に特化した細かい命令に分解</li>



<li>Graph compilation: バックエンドのデバイス特化の命令を呼び出し</li>
</ol>



<p>これらのステップを経ることで、より効率よく計算リソースを使えるようにし、高速化を実現しています。</p>



<p>また、この機能のすばらしいところは使い方も非常に簡単であるというものがあります。以下にデコレータで使う方法とtorch.compileの関数を呼び出して使う方法を示します。</p>



<h3 class="wp-block-heading">デコレータで使うやり方</h3>



<p>まずデコレータで使う方法です。これは以下のようになります (このチュートリアルの例：<a href="https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html#basic-usage" target="_blank" rel="noopener" title="">https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html#basic-usage</a>)</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="torch_compile_decorate " data-lang="Python"><code>@torch.compile
def opt_foo2(x, y):
    a = torch.sin(x)
    b = torch.cos(x)
    return a + b
opt_foo2(torch.randn(10, 10), torch.randn(10, 10))</code></pre></div>



<p>torch.jit.scriptを使ったことがある方は、それと同じ感覚で使えるというと使い方がイメージしやすいかもしれません。</p>



<h3 class="wp-block-heading">torch.compileの関数を呼び出して使うやり方</h3>



<p>torch.compileの関数を呼び出してコンパイルする場合は以下のようにやります。(このチュートリアルの例：<a href="https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html#basic-usage" target="_blank" rel="noopener" title="">https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html#basic-usage</a>)</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="torch_compile" data-lang="Python"><code>class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = torch.nn.Linear(100, 10)

    def forward(self, x):
        return torch.nn.functional.relu(self.lin(x))

mod = MyModule()
opt_mod = torch.compile(mod)
opt_mod(torch.randn(10, 100))</code></pre></div>



<p>こちらもtorch.jit.scriptのときと同じような使い方だと思います。</p>



<h2 class="wp-block-heading">torch.compileによるパフォーマンスの評価</h2>



<p>次にtorch.compileを実際に使ってみたときの計算時間を計測したので、その紹介です。今回は以下の二つのGPUで測定しました。</p>



<ol class="wp-block-list">
<li>T4</li>



<li>V100</li>
</ol>



<p>T4はTuringなので公式のドキュメントでtorch.compileのサポートが書かれてないものになっています。ただ、やってみたら少し早くなったので、測定結果を載せています。GitHubにあげたコードはT4で測定したほうです。</p>



<p>また、CUDAのバージョンはどちらのケースも12.0利用し、測定に使ったモデルはチュートリアルにあったtorchvisionのResNet18を使用しました。</p>



<p>また、torch.compileにはモードが以下の３つあります。</p>



<ol class="wp-block-list">
<li>デフォルト</li>



<li>reduce-overhead</li>



<li>max-autotune</li>
</ol>



<p>これらと何もしてない場合も含めて合計４つパターンの測定をしています。</p>



<p>具体的な測定方法が分かりやすいようにコードの一部を紹介します（torch.compleのデフォルトの場合）。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="performance_measure" data-lang="Python"><code>import time 

import torch
import torchvision.models as models
import torch._dynamo

batch_size = 64
n_warmup_iters = 10
n_iters = 500

x = torch.randn(batch_size, 3, 224, 224).cuda()

def get_mode():
    return models.resnet18()

torch._dynamo.reset()

model = get_mode().cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# compile
compiled_model = torch.compile(model)
for i in range(n_warmup_iters):
    optimizer.zero_grad()
    torch.cuda.synchronize()
    start = time.time()
    out = compiled_model(x)
    torch.cuda.synchronize()
    forward_elapsed_time = time.time() - start
    torch.cuda.synchronize()
    start = time.time()
    out.sum().backward()
    backward_elapsed_time = time.time() - start
    print(f&quot;with compile {i} iter forward: {forward_elapsed_time/1000:.3e} msec., backward: {backward_elapsed_time/1000:.3e} msec.&quot;)
    optimizer.step()

print(&quot;-&quot;*10)

torch.cuda.synchronize()
start = time.time()
for i in range(n_iters):
    optimizer.zero_grad()
    out = compiled_model(x)
    out.sum().backward()
    optimizer.step()
torch.cuda.synchronize()
elapsed_time = time.time() - start

print(f&quot;with compile total:{elapsed_time:.3e} sec. {batch_size*n_iters/elapsed_time:.3e} imgs/sec.&quot;)</code></pre></div>



<p>最初に、モデルの入力とモデルを作ったあと、コンパイルする場合は<code>torch.compile(model)</code>でコンパイルします。このときコンパイルのモードを変える場合は引数の<code>mode</code>にモードの名前を渡します。</p>



<p>その後、最初の数回はforward、backwardの呼び出し時にコンパイルなどのオーバーヘッドが入って遅いので、あらかじめ何度か呼びます。そして最後に実際に時間を計測します。今回は10回あらかじめforwardとbackwardを呼んでおいて、その後500回イテレーションを回したときの時間を測定しています。バッチサイズに関しては変化させると高速化率が変化することはわかっていますが、今回固定で64で実行しています。</p>



<p>T4, V100ともに同様の方法でtorch.compileのありなし等を測定しています。</p>



<p>では、時間計測の結果です。500回イテレーションを回したときの実際の計算時間を順番に示していきます。まずはT4の場合です。</p>



<figure class="wp-block-table"><table><tbody><tr><td></td><td>計算時間 (sec.)</td><td>torch.compileなしからの高速化率</td></tr><tr><td>torch.compileなし</td><td>78.68</td><td>1.00</td></tr><tr><td>torch.compile (default)</td><td>73.37</td><td>1.07</td></tr><tr><td>torch.compile (reduce-overhead)</td><td>77.52</td><td>1.01</td></tr><tr><td>torch.compile (max-autotune)</td><td>73.35</td><td>1.07</td></tr></tbody></table><figcaption class="wp-element-caption">T4を使ったResNet18の結果</figcaption></figure>



<p>T4はtorch.compileのサポートが書かれてない世代のGPUなので、効果が全くでないのかと思ったのですが、そんなことはなかったです。ただ、10％は満たない高速化にとどまっているという印象です。ちなみにT4を使ったケースではtorch.compileのmodeをmax-autotuneに変えると以下のようにサポートされてないGPUであると警告がでてきます。</p>



<pre class="wp-block-preformatted">[2023-03-17 18:31:06,314] torch._inductor.utils: [WARNING] not enough cuda cores to use max_autotune mode</pre>



<p>次にV100のResNet18の結果です。</p>



<figure class="wp-block-table"><table><tbody><tr><td></td><td>計算時間 (sec.)</td><td>torch.compileなしからの高速化率</td></tr><tr><td>torch.compileなし</td><td>26.6</td><td>1.00</td></tr><tr><td>torch.compile (default)</td><td>24.7</td><td>1.08</td></tr><tr><td>torch.compile (reduce-overhead)</td><td>24.2</td><td>1.10</td></tr><tr><td>torch.compile (max-autotune)</td><td>24.1</td><td>1.10</td></tr></tbody></table><figcaption class="wp-element-caption">V100を使ったResNet18の結果</figcaption></figure>



<p>V100のほうはtorch.compileのサポートされていると書かれているGPUです。実際、V100はtorch.compileのmodeをmax-autotuneに変えると確かにより速くなり、高速化率も最大値は10%台に入っています。</p>



<h2 class="wp-block-heading">現状のtorch.compileの注意点</h2>



<p>最後にtorch.compileの注意したほうがよさそうな点を書いておきます。</p>



<p>まず、公式で書かれいたものの紹介です。基本的な注意点はこのドキュメントに書いてあります。</p>



<p><a href="https://pytorch.org/get-started/pytorch-2.0/#pytorch-2x-faster-more-pythonic-and-as-dynamic-as-ever">https://pytorch.org/get-started/pytorch-2.0/#pytorch-2x-faster-more-pythonic-and-as-dynamic-as-ever</a></p>



<p>重要なものとして、現在提供されているtorch.compileの機能を最大限活かせるのはCPU、NVIDIAのVoltaとAmpere世代のGPUのみになっています。他のGPUでは使おうとすると警告が出てきます。ただ、私が試した範囲では警告がでるだけで現状では使えないわけではなさそうです。</p>



<p>また、私が使ったときに感じた注意点としては</p>



<ol class="wp-block-list">
<li>おそらくforwardとbackwardで別々にコンパイルが走るので、forward、backwardの両方とも最初は遅い</li>



<li>実行が遅いのは最初の１回目だけでなく、最初の数回の呼び出しが遅いケースがある</li>



<li>Google ColabなどでCellの実行を一度止めて再度実行しようとするとエラーがでて、ランタイムの再起動をしないと復帰できないケースがある</li>
</ol>



<p>1と２は時間計測をしようとしたときにはまったポイントです。まず、１に関してです。torch.compileの直後の呼び出しはコンパイルが走るので、遅いというのはドキュメントにも書かれています。ただ、forwadだけがおそいのかな？と思ってました。ただ、torch.compileの説明をちゃんと読めば想像できると思いますが、backwardも最初の実行のときは遅いです。なので、時間を計測するときは、forwardとbackwardの両方が遅いことを考慮して測定する必要があります。</p>



<p>次に２です。これに関しては私が見逃してなければドキュメントに明示的に説明が書いてあるわけではないのですが、<a href="https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html#torch-compile-tutorial" target="_blank" rel="noopener" title="チュートリアル">チュートリアル</a>の時間計測の結果や実際に測定してみるとどうやら遅いのは最初の１回目の呼び出しだけではなく、そのあと数回遅いケースが存在しているようです。このため、計算時間の測定の際、最初に数回呼び出してから測定しないとtorch.compileを使ったときよりも遅いみたいな誤った結果になるので注意してください。</p>



<p>最後に３です。これは何度かはまったのですが、どこかにキャッシュか何か残っているのか変なところで止めるとコード的には問題ないはずなのに、エラーがでるようになるときがあります。調べても解決方法が分からなかったので、エラーがでるようになったらランタイムごと再起動するということを何度かやりました。Google Colabでやるときは注意してください。</p>



<h2 class="wp-block-heading">終わりに</h2>



<p>今回はtorch.compileについて使ってみたのでまとめを書きました。去年発表があったときから楽しみにしていましたが、期待通りのものとなっていました。なにより使い方が非常に簡単なことには驚きました。</p>



<p>今回はT4とV100の測定結果でしたが、A100だとどうなるのかも今度測定しようかなと思っています。</p>



<p>この記事がみなさんのお役に立てば幸いです。</p><p>The post <a href="https://www.mattari-benkyo-note.com/2023/03/18/torch-compile/">PyTorch 2.0の新機能「torch.compile」使ってみた</a> first appeared on <a href="https://www.mattari-benkyo-note.com">まったり勉強ノート</a>.</p>]]></content:encoded>
					
					<wfw:commentRss>https://www.mattari-benkyo-note.com/2023/03/18/torch-compile/feed/</wfw:commentRss>
			<slash:comments>0</slash:comments>
		
		
		<post-id xmlns="com-wordpress:feed-additions:1">2261</post-id>	</item>
		<item>
		<title>[勉強ノート] 「拡散モデル　データ生成技術の数理」 2.3 デノイジング拡散確率モデル</title>
		<link>https://www.mattari-benkyo-note.com/2023/03/16/diffusion_model_book_2_3/</link>
					<comments>https://www.mattari-benkyo-note.com/2023/03/16/diffusion_model_book_2_3/#respond</comments>
		
		<dc:creator><![CDATA[Shuji Suzuki (shu)]]></dc:creator>
		<pubDate>Wed, 15 Mar 2023 22:09:48 +0000</pubDate>
				<category><![CDATA[未分類]]></category>
		<category><![CDATA[pytorch]]></category>
		<category><![CDATA[拡散モデル]]></category>
		<category><![CDATA[書籍]]></category>
		<category><![CDATA[機械学習]]></category>
		<guid isPermaLink="false">https://www.mattari-benkyo-note.com/?p=2116</guid>

					<description><![CDATA[<p>先日紹介した「拡散モデル　データ生成技術の数理」をちゃんと理解するために数式を改めて追ったり、説明されているアルゴリズムを実装したりしています。 その第3弾として「2.3 デノイジング拡散確率モデル」で説明されているデノ [&#8230;]</p>
<p>The post <a href="https://www.mattari-benkyo-note.com/2023/03/16/diffusion_model_book_2_3/">[勉強ノート] 「拡散モデル　データ生成技術の数理」 2.3 デノイジング拡散確率モデル</a> first appeared on <a href="https://www.mattari-benkyo-note.com">まったり勉強ノート</a>.</p>]]></description>
										<content:encoded><![CDATA[<p>先日紹介した「<a href="https://amzn.to/3SC3LMc" target="_blank" rel="noopener" title="拡散モデル　データ生成技術の数理">拡散モデル　データ生成技術の数理</a>」をちゃんと理解するために数式を改めて追ったり、説明されているアルゴリズムを実装したりしています。</p>



<p>その第3弾として「2.3 デノイジング拡散確率モデル」で説明されているデノイジング拡散確率モデル（DDPM）の学習とそれを使ったサンプリングについてPython(深層学習部分はPytorch)でコードを書いて試したのでそのまとめになります。今回の記事ではDDPMの細かい数式を説明すると記事の量がすごいことになりそうなので、重要な部分だけ説明していきます。</p>



<p>また、この本を買うか迷っている方は私が読んだ感想をこちらの記事に書いてますので参考にしてみてください。</p>



<figure class="wp-block-embed is-type-wp-embed is-provider-まったり勉強ノート wp-block-embed-まったり勉強ノート"><div class="wp-block-embed__wrapper">
<blockquote class="wp-embedded-content" data-secret="a1lYO13NFR"><a href="https://www.mattari-benkyo-note.com/2023/02/23/diffusion_model_book_review/">[書評] 拡散モデル データ生成技術の数理 ー 目覚ましい画像生成の発展の裏側を知りたい人へ</a></blockquote><iframe loading="lazy" class="wp-embedded-content" sandbox="allow-scripts" security="restricted"  title="&#8220;[書評] 拡散モデル データ生成技術の数理 ー 目覚ましい画像生成の発展の裏側を知りたい人へ&#8221; &#8212; まったり勉強ノート" src="https://www.mattari-benkyo-note.com/2023/02/23/diffusion_model_book_review/embed/#?secret=anhJj8X2Xf#?secret=a1lYO13NFR" data-secret="a1lYO13NFR" width="600" height="338" frameborder="0" marginwidth="0" marginheight="0" scrolling="no"></iframe>
</div></figure>



<p>また、この記事で紹介したコードは以下にあげてありますので、コード全体を確認したい方はこちらをご覧ください。</p>



<p><a href="https://github.com/shu65/diffusion-model-book/blob/main/diffusion_model_book_2_3_ddpm.ipynb" target="_blank" rel="noopener" title="">https://github.com/shu65/diffusion-model-book/blob/main/diffusion_model_book_2_3_ddpm.ipynb</a></p>



<h2 class="wp-block-heading">デノイジング拡散確率モデルとは</h2>



<p>デノイジング拡散確率モデル（DDPM）はデータに対して徐々にノイズを加えていく拡散過程を逆向きに辿っていく逆拡散過程によってデータ生成を行います。図でみると分かりやすいと思うので拡散過程と逆拡散過程の関係図を以下に示します。</p>


<div class="wp-block-image">
<figure class="aligncenter size-large is-resized"><img loading="lazy" decoding="async" src="https://www.mattari-benkyo-note.com/wp-content/uploads/2023/03/ddpm_flow-1-1024x627.png" alt="デノイジング拡散確率モデルの流れ" class="wp-image-2215" width="768" height="470" srcset="https://www.mattari-benkyo-note.com/wp-content/uploads/2023/03/ddpm_flow-1-1024x627.png 1024w, https://www.mattari-benkyo-note.com/wp-content/uploads/2023/03/ddpm_flow-1-300x184.png 300w, https://www.mattari-benkyo-note.com/wp-content/uploads/2023/03/ddpm_flow-1-768x470.png 768w, https://www.mattari-benkyo-note.com/wp-content/uploads/2023/03/ddpm_flow-1-1536x940.png 1536w, https://www.mattari-benkyo-note.com/wp-content/uploads/2023/03/ddpm_flow-1.png 1992w" sizes="auto, (max-width: 768px) 100vw, 768px" /><figcaption class="wp-element-caption">デノイジング拡散確率モデル</figcaption></figure></div>


<p>まずは上の図の上段の拡散過程についてです。ここでは\(\boldsymbol{x}_0\)がオリジナルのデータで、これに徐々にノイズを加えていき、\(\boldsymbol{x}_1, \boldsymbol{x}_2, \boldsymbol{x}_3, &#8230;, \boldsymbol{x}_T\)といいうデータを作っていきます。これを確率密度関数で表すと以下のようになります。</p>



<p>$$ \begin{align*}<br>q(\boldsymbol{x}_{1:T}|\boldsymbol{x}_0) :=&amp; \prod_{t=1}^T q(\boldsymbol{x}_t|\boldsymbol{x}_{t-1}) \tag{2.1.1} \\<br>q(\boldsymbol{x}_{t}|\boldsymbol{x}_{t-1}) :=&amp; \mathcal{N}(\boldsymbol{x}_t; \sqrt{\alpha_t}\boldsymbol{x}_{t-1}, \beta_t \boldsymbol{I}) \tag{2.1.2} \\<br>\end{align*} $$</p>



<p>ここで\(\beta_t\) は分散の大きさを制御するパラメータで、\(0&lt;\beta_1&lt;\beta_2&lt;&#8230;&lt;\beta_T&lt;1\)です。また、\(\alpha_t := 1 &#8211; \beta_t\) で、\(\alpha_t, \beta_t\)を合わせてノイズスケジュールと呼びます。</p>



<p>ここで、\(\mathcal{N}(\boldsymbol{x}_t; \sqrt{\alpha_t}\boldsymbol{x}_{t-1}, \beta_t \boldsymbol{I})\) について詳しく見ていきます。この拡散過程を繰り返していくと、\(\beta_t\) は徐々に大きくなります。結果として、ノイズ成分は大きくなっていきます。</p>



<p>一方、\(\beta_t\)が大きくなるということは\(\alpha_t := 1 &#8211; \beta_t\) なので\(\boldsymbol{x}_{t-1}\)の係数の\(\sqrt{\alpha_t}\)はどんどん小さくなっていきます。結果として、拡散過程を繰り返していくと任意の\(\boldsymbol{x}_{0}\)に対して、以下のような近似ができるようになります。</p>



<p>$$ \begin{align*}<br>q(\boldsymbol{x}_{T}|\boldsymbol{x}_{0}) :=&amp; \mathcal{N}(\boldsymbol{x}_T; 0,  \boldsymbol{I}) \tag{2.1.3} \\<br>\end{align*} $$</p>



<p>次に、図の下段の逆拡散過程です。これは\(\mathcal{N}(\boldsymbol{x}_T; 0,  \boldsymbol{I}) \)からスタートして拡散過程を逆向きに辿っていく処理になります。</p>



<p>この逆拡散過程の各ステップを正規分布で表し、この正規分布の平均と共分散行列を一つ前のステップの変数\(\boldsymbol{x}_{t}\)と時刻\(t\)を入力としてパラメータ\(\theta\)を使ったモデル(\(\mu_{\theta}(\boldsymbol{x}_{t}, t), \Sigma(\boldsymbol{x}_{t}, t))\))として表します。これを使うと逆拡散過程は以下のような式で表すことができます。</p>



<p>$$ \begin{align*}<br>p_{\theta}(\boldsymbol{x}_{0:T}) :=&amp; p(\boldsymbol{x}_T)\prod_{t=1}^T p_{\theta}(\boldsymbol{x}_t-1|\boldsymbol{x}_{t}) \tag{2.1.4} \\<br>p_{\theta}(\boldsymbol{x}_{t-1}|\boldsymbol{x}_{t}) :=&amp; \mathcal{N}(\boldsymbol{x}_{t-1}; \mu_{\theta}(\boldsymbol{x}_{t}, t), \Sigma_{\theta}(\boldsymbol{x}_{t}, t)) \tag{2.1.5} \\<br>p(\boldsymbol{x}_{T}) =&amp; \mathcal{N}(\boldsymbol{x}_T; 0,  \boldsymbol{I}) \tag{2.1.6}<br>\end{align*} $$</p>



<p>この\(\mu_{\theta}(\boldsymbol{x}_{t}, t), \Sigma_{\theta}(\boldsymbol{x}_{t}, t)\)は後ほど示す通りニューラルネットワークなどを使ってモデル化します。これによって逆拡散過程を実現しています。</p>



<p>では次から実際にこのモデルのパラメータ\(\theta\)の学習方法とそれを使ったサンプリングを見ていきます。</p>



<h2 class="wp-block-heading">デノイジング拡散確率モデルの学習</h2>



<p>デノイジング拡散確率モデルのモデルの学習方法の説明を本来はしたいのですが、この説明はすごく長いものになります。このため、詳しい説明は本を見ていただくとして、ここでは学習を回すうえで重要な変数と式に関する簡単な説明にとどめておきます。</p>



<p>まず、先ほど示した式(2.1.2)などを利用すると以下の式が導けます。</p>



<p>$$ \begin{align*}<br>q(\boldsymbol{x}_{t}|\boldsymbol{x}_{0}) :=&amp; \mathcal{N}(\boldsymbol{x}_t; \sqrt{\bar{\alpha}_t}\boldsymbol{x}_{0}, \bar{\beta}_t \boldsymbol{I}) \tag{2.1.7} \\<br>\bar{\alpha}_t := \prod_{s=1}^t \alpha_s \tag{2.1.8} \\<br>\bar{\beta}_t := 1 &#8211; \bar{\alpha}_t \tag{2.1.9} \\<br>\end{align*} $$</p>



<p>式(2.1.7)から式(2.1.9)の導出の証明に関しては本の式(2.1)の下に証明がありますので詳しく知りたい方はそちらをご覧ください。</p>



<p>これにより、わざわざ式(2.1.2)に従って、\(\boldsymbol{x}_{0}\)から徐々に\(t\)を大きくして\(\boldsymbol{x}_{t}\)を生成しなくても、式(2.1.7)を使うことで正規分布のサンプリングを１度すれば任意の\(t\)の\(\boldsymbol{x}_{t}\)のデータを生成できることになります。学習ではこれを利用して高速にランダムな\(t\)のデータを生成して学習に利用します。</p>



<p>次にデノイジング拡散確率モデルの学習で用いるロス関数についてです。これに関しては前の記事でも紹介したスコアベースモデルと同じようにデノイジングスコアマッチングを利用します。また、ロス関数の導出は先ほど説明した通り、ちゃんと説明しようとするとすごく長いので本で示されている最終的な結果を以下に示します。</p>



<p>$$ \begin{align*}<br>L_{\gamma}(\theta) =&amp; \sum_{t=1}^T w_t E_{\boldsymbol{x}_0, \epsilon} \left\{ \left\| \epsilon &#8211; \epsilon_{\theta}(\sqrt{\bar{\alpha}_t}\boldsymbol{x}_0 + \sqrt{\bar{\beta}_t}\epsilon, t) \right\|^2 \right\} \tag{2.1.10} \\<br>\gamma =&amp; \left\{ w_1, w_2, &#8230;, w_T \right\} \tag{2.1.11} <br>\end{align*} $$</p>



<p>この式の\(w_t\)に関しては本によると\(w_t = 1\) がよくつかわれるとのことなので、このあとの実装のコードでも\(w_t = 1\)としています。</p>



<p>ここで、式の導出を省略して関係が分かりにくくなっていため、逆拡散過程の説明で出てきた式(2.1.5)の中の\(\mu_{\theta}(\boldsymbol{x}_{t}, t), \Sigma_{\theta}(\boldsymbol{x}_{t}, t)\)と式(2.1.10)の中で出てきている\(\epsilon_{\theta}(\sqrt{\bar{\alpha}_t}\boldsymbol{x}_0 + \sqrt{\bar{\beta}_t}\epsilon, t)\) との関係を説明しておきます。</p>



<p>逆拡散過程を行う上で\(\mu_{\theta}(\boldsymbol{x}_{t}, t), \Sigma_{\theta}(\boldsymbol{x}_{t}, t)\)のパラメータ\(\theta\)を学習する必要があります。ここで、本の説明によると\(\Sigma_{\theta}(\boldsymbol{x}_{t}, t)\)に関してはパラメータ\(\theta\)に依存しない固定の\(\Sigma_{\theta}(\boldsymbol{x}_{t}, t)) = \sigma_t^2 \boldsymbol{I}\)を使うことが多いそうです。先ほど示したロス関数も\(\Sigma_{\theta}(\boldsymbol{x}_{t}, t)) = \sigma_t^2 \boldsymbol{I}\)として式変形しています。</p>



<p>次に\(\mu_{\theta}(\boldsymbol{x}_{t}, t)\)の部分です。こちらはロス関数の導出の過程で結局は\(t\)の時点で加えられたノイズを予測できるモデルに置き換えることができます。このため、\(\mu_{\theta}(\boldsymbol{x}_{t}, t)\)ではなく、ノイズを予測する\(\epsilon_{\theta}(\boldsymbol{x}_{t}, t)\)がロス関数の中で登場しています。</p>



<p>また、\(\boldsymbol{x}_{t}\) は式(2.1.7)から以下のようになります。</p>



<p>$$ \begin{align*}<br>\boldsymbol{x}_{t} =&amp; \sqrt{\bar{\alpha}_t}\boldsymbol{x}_0 + \sqrt{\bar{\beta}_t}\epsilon \\<br>\epsilon \sim&amp; \mathcal{N}(0, \boldsymbol{I}) \\<br>\tag{2.1.12}<br>\end{align*} $$</p>



<p>本のほうには丁寧にこのロス関数の導出が書かれているので詳細を知りたい方はぜひ本を読んでください。</p>



<p>この式(2.1.10)をロス関数として利用したデノイジング拡散確率モデルの学習の疑似コードは以下の通りです。(「拡散モデル　データ生成技術の数理」Algorithm 2.2の引用)</p>



<ol class="wp-block-list">
<li>repeat</li>



<li>\(\quad \boldsymbol{x}_0 \sim p_{\text{data}}(\boldsymbol{x}_0)\)</li>



<li>\(\quad t \sim \text{Uniform}({1, &#8230;, T})\)</li>



<li>\(\quad \epsilon \sim \mathcal{N}(0, \boldsymbol{I})\)</li>



<li>\(\quad g := \nabla_{\theta} w_t \left\| \epsilon &#8211; \epsilon_{\theta}(\sqrt{\bar{\alpha_t}} \boldsymbol{x}_0 + \sqrt{\bar{\beta_{t}}} \epsilon, t)\right\|^2 \)</li>



<li>\(\quad \theta := \theta &#8211; \alpha g \)</li>



<li>until converged</li>
</ol>



<p>ここで本では特に説明されてないですが、式(2.1.10)の最初の\(\sum_{t=1}^T\)の部分は少し変更して、\(t=1\)から\(t=T\)ランダムに\(t\)を選んで使用するように変更されています。</p>



<p>このアルゴリズムでは、まず最初にデータ\(\boldsymbol{x}_0 \)と\(t\)を選び、それらを用いて式(2.1.10)のロス関数を計算します。その後、勾配を計算してパラメータのアップデートをするということを繰り返します。6行目の勾配を使ったパラメータのアップデートは深層学習の基本的なパラメータ更新の確率的勾配降下法を利用したコードになっています。この部分は確率的勾配降下法以外のもの、例えばAdamなどでも問題ありません。</p>



<p>この疑似コードを基にPyTorchで実装するとこのようになります。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="train_ddpm.py" data-lang="Python"><code>batch_size = 512
n_steps = 100000

dataloader = torch.utils.data.DataLoader(dataset, batch_size=512, shuffle=True, num_workers=0)
dataloader_iter = iter(dataloader)

model = Model().to(device)

optimizer = torch.optim.Adam(model.parameters())
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.001, total_steps=n_steps)
loss_func = torch.nn.MSELoss(reduction=&quot;none&quot;)

for i in range(n_steps):
  try:
    x0 = next(dataloader_iter)[0]
  except StopIteration:
    dataloader_iter = iter(dataloader)
    x0 = next(dataloader_iter)[0]
  x0 = x0.to(device)

  optimizer.zero_grad()

  noise = torch.randn_like(x0)
  t = torch.randint(0, len(x0), (x0.shape[0],), device=device)
  x = torch.sqrt(alpha_bars[t])[:, None] * x0 + torch.sqrt(beta_bars[t])[:, None] * noise
  noise_pred = model(x, t)
  w = 1.0
  losses = w * loss_func(noise_pred, noise)
  loss = losses.mean()

  loss.backward()
  optimizer.step()
  lr_scheduler.step()</code></pre></div>



<p>基本的には疑似コードとほぼ同じですが、ロス関数周りで違うところがあるので、簡単に説明します。まず、深層学習なのでミニバッチを使った学習に置き換えています。このため、\(\boldsymbol{x}_0 \)も一つだけでなく、バッチサイズ分ランダムに選んでいます。これに伴って、\(t\)もバッチサイズ分ランダムに選んで利用します。そして、最終的には各データ点のロス関数の値の平均を計算して勾配を計算するという形に置き換えています。</p>



<p>また、パラメータの最適化の部分は確率勾配降下法ではなくAdamを利用しています。</p>



<h2 class="wp-block-heading">デノイジング拡散確率モデルを使ったサンプリング</h2>



<p>ここから先ほど紹介した方法で学習したモデルを利用してどのようにサンプリングしていくか、について説明します。</p>



<p>基本的には式(2.1.5)に従って逆拡散過程のステップを繰り返すことで実現します。</p>



<p>ここで、説明を省略してしまいましたが、式(2.1.5)の中にでてくる\(\mu_{\theta}(\boldsymbol{x}_{t}, t)\)を学習したノイズを予測するモデル\(\epsilon_{\theta}(\boldsymbol{x}_{t}, t)\)を使って表すと以下のようになります。</p>



<p>$$ \begin{align*}<br>\mu_{\theta}(\boldsymbol{x}_{t}, t) =&amp; \frac{1}{\sqrt{\bar{\alpha}}} \left( \boldsymbol{x}_{t} &#8211; \frac{\beta_t}{\sqrt{\bar{\beta_t}}} \epsilon_{\theta}(\boldsymbol{x}_{t}, t) \right) \tag{2.1.13}<br>\end{align*} $$</p>



<p>また、一方、\(\Sigma_{\theta}(\boldsymbol{x}_{t}, t))\)は先ほど説明した通り、\(\Sigma_{\theta}(\boldsymbol{x}_{t}, t)) = \sigma_t^2 \boldsymbol{I}\)です。</p>



<p>これらと式(2.1.5)に従ったサンプリングの疑似コードは以下の通りです。(「拡散モデル　データ生成技術の数理」Algorithm 2.3の引用)</p>



<ol class="wp-block-list">
<li>\(\boldsymbol{x}_T \sim \mathcal{N}(0, \boldsymbol{I})\)</li>



<li>for \(t=T, &#8230;, 1\) do</li>



<li>\(\quad \boldsymbol{u}_t \sim \mathcal{N}(0, \boldsymbol{I})\)</li>



<li>\(\quad\) if \(t=1\) then \(\boldsymbol{u}_t := 0\)</li>



<li>\(\quad \boldsymbol{x}_{t-1} := \frac{1}{\sqrt{\bar{\alpha}}} \left\{ \boldsymbol{x}_{t} &#8211; \frac{\beta_t}{\sqrt{\bar{\beta_t}}} \epsilon_{\theta}(\boldsymbol{x}_{t}, t) \right\} + \sigma_t  \boldsymbol{u}_t \)</li>



<li>end for</li>



<li>return \(\boldsymbol{x}_0\)</li>
</ol>



<p>基本的には徐々にノイズを取り除くことで目的のデータをサンプリングするという流れです。</p>



<p>PyTorchのコードとしては以下のようになります。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="ddpm_sample" data-lang="Python"><code>def ddpm_sample(n_samples, model, alphas, betas, beta_bars):
    xt = torch.randn(n_samples, 2)
    T = len(alphas)
    for t in range(T -1, -1, -1):
      print(f&quot;t:{t}&quot;)
      ut = torch.randn(n_samples, 2)
      if t == 0:
        ut[:, :] = 0.0
      with torch.no_grad():
        noise_pred = model(xt, t*torch.ones(n_samples, dtype=xt.dtype))
        sigma_t = torch.sqrt(betas[t])
        xt = 1 / torch.sqrt(alphas[t]) * (xt - betas[t]/torch.sqrt(beta_bars[t])*noise_pred) + sigma_t*ut
    return xt</code></pre></div>



<p>ここで、<code>n_samples</code>がサンプリングするサンプル数、<code>model</code>が\(\epsilon_{\theta}(\boldsymbol{x}_{t}, t)\)、<code>alphas</code>、<code>betas</code>, <code>beta_bars</code>がそれぞれ\(\alpha_t, \beta_t, \bar{\beta}_t\) のリストです。</p>



<h2 class="wp-block-heading">実行例</h2>



<p>先ほど紹介したPythonコードを実際に動かした例も示しておきます。参考例として入力となる\(\boldsymbol{x}\)のサンプリングする分布の確率密度関数は以下のように平均が違うガウス分布二つの混合分布とし、サンプリングしたデータを正規化して使用します。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="generate_dataset" data-lang="Python"><code>n_samples = int(1e6)
sigma = 0.01

dist0 = torch.distributions.MultivariateNormal(torch.tensor([-2, -2], dtype=torch.float).to(device), sigma*torch.eye(2, dtype=torch.float).to(device))
samples0 = dist0.sample(torch.Size([n_samples//2]))
    
dist1 = torch.distributions.MultivariateNormal(torch.tensor([2, 2], dtype=torch.float).to(device), sigma*torch.eye(2, dtype=torch.float).to(device))
samples1 = dist1.sample(torch.Size([n_samples//2]))
samples = torch.vstack((samples0, samples1))

mean = torch.mean(samples, dim=0)
std = torch.std(samples, dim=0)

normalized_samples = (samples - mean[None, :])/std[None, :]</code></pre></div>



<p>使用する\(\boldsymbol{x}\)を2Dのヒストグラムで可視化すると以下のようになります。</p>


<div class="wp-block-image">
<figure class="aligncenter size-full"><img loading="lazy" decoding="async" width="266" height="264" src="https://www.mattari-benkyo-note.com/wp-content/uploads/2023/03/sample_density-1.png" alt="使用するデータの可視化結果" class="wp-image-2211" srcset="https://www.mattari-benkyo-note.com/wp-content/uploads/2023/03/sample_density-1.png 266w, https://www.mattari-benkyo-note.com/wp-content/uploads/2023/03/sample_density-1-150x150.png 150w" sizes="auto, (max-width: 266px) 100vw, 266px" /><figcaption class="wp-element-caption">使用するデータの可視化結果</figcaption></figure></div>


<p>このデータを再現できるようにデノイジング拡散確率モデルを学習します。コードとしては先ほど示した通りです。</p>



<p>学習が終わったら次は以下のようにサンプリングを行います。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="run_ddpm_sample" data-lang="Python"><code>samples_pred = ddpm_sample(n_samples=100000, model=model, alphas=alphas, betas=betas, beta_bars=beta_bars)</code></pre></div>



<p>サンプリングされたデータの2Dのヒストグラムは以下の通りです。</p>


<div class="wp-block-image">
<figure class="aligncenter size-full"><img loading="lazy" decoding="async" width="266" height="264" src="https://www.mattari-benkyo-note.com/wp-content/uploads/2023/03/predicted_sample_density-3.png" alt="デノイジング拡散確率モデルによるサンプリングデータの可視化結果" class="wp-image-2214" srcset="https://www.mattari-benkyo-note.com/wp-content/uploads/2023/03/predicted_sample_density-3.png 266w, https://www.mattari-benkyo-note.com/wp-content/uploads/2023/03/predicted_sample_density-3-150x150.png 150w" sizes="auto, (max-width: 266px) 100vw, 266px" /><figcaption class="wp-element-caption">デノイジング拡散確率モデルによるサンプリングデータの可視化結果</figcaption></figure></div>


<p>可視化結果をみると元の分布の平均の近くにデータが集中しているので、うまくいっていると考えられます。</p>



<h2 class="wp-block-heading">終わりに</h2>



<p>今回は「<a href="https://amzn.to/3SC3LMc" target="_blank" rel="noopener" title="拡散モデル　データ生成技術の数理">拡散モデル　データ生成技術の数理</a>」の2.3のデノイジング拡散確率モデルの簡単な説明とコードを書いたのでそのまとめの記事になります。先日スコアベースモデルのコードを用意したことで、そのコードを参考に今回のデノイジング拡散確率モデルをすぐに作ることができたのですが、説明はすごい大変でした。</p>



<p>スコアベースモデルのほうも気になるという方はこちらをご覧ください。</p>



<figure class="wp-block-embed is-type-wp-embed is-provider-まったり勉強ノート wp-block-embed-まったり勉強ノート"><div class="wp-block-embed__wrapper">
<blockquote class="wp-embedded-content" data-secret="QbbiRh4C0Y"><a href="https://www.mattari-benkyo-note.com/2023/03/08/diffusion_model_book_2_2/">[勉強ノート] 「拡散モデル　データ生成技術の数理」 2.2 スコアベースモデル</a></blockquote><iframe loading="lazy" class="wp-embedded-content" sandbox="allow-scripts" security="restricted"  title="&#8220;[勉強ノート] 「拡散モデル　データ生成技術の数理」 2.2 スコアベースモデル&#8221; &#8212; まったり勉強ノート" src="https://www.mattari-benkyo-note.com/2023/03/08/diffusion_model_book_2_2/embed/#?secret=J4EXuySmX8#?secret=QbbiRh4C0Y" data-secret="QbbiRh4C0Y" width="600" height="338" frameborder="0" marginwidth="0" marginheight="0" scrolling="no"></iframe>
</div></figure>



<p>今後、3章で紹介されている連続時間化拡散モデルのVE-SDEのほうも紹介予定です。コードは昨日できました。ただ、思ったよりも説明が大変そうなので、記事を書くのに時間がかかると思います。</p>



<p>この記事が少しでもみなさんの理解の助けになれば幸いです。</p>



<h2 class="wp-block-heading">参考文献</h2>



<ol class="wp-block-list">
<li>Ho, J., Jain, A., &amp; Abbeel, P. (2020). Denoising Diffusion Probabilistic Models. ArXiv, abs/2006.11239.</li>



<li><a href="https://github.com/hojonathanho/diffusion" target="_blank" rel="noopener" title="">https://github.com/hojonathanho/diffusion</a></li>
</ol><p>The post <a href="https://www.mattari-benkyo-note.com/2023/03/16/diffusion_model_book_2_3/">[勉強ノート] 「拡散モデル　データ生成技術の数理」 2.3 デノイジング拡散確率モデル</a> first appeared on <a href="https://www.mattari-benkyo-note.com">まったり勉強ノート</a>.</p>]]></content:encoded>
					
					<wfw:commentRss>https://www.mattari-benkyo-note.com/2023/03/16/diffusion_model_book_2_3/feed/</wfw:commentRss>
			<slash:comments>0</slash:comments>
		
		
		<post-id xmlns="com-wordpress:feed-additions:1">2116</post-id>	</item>
		<item>
		<title>[勉強ノート] 「拡散モデル　データ生成技術の数理」 2.2 スコアベースモデル</title>
		<link>https://www.mattari-benkyo-note.com/2023/03/08/diffusion_model_book_2_2/</link>
					<comments>https://www.mattari-benkyo-note.com/2023/03/08/diffusion_model_book_2_2/#respond</comments>
		
		<dc:creator><![CDATA[Shuji Suzuki (shu)]]></dc:creator>
		<pubDate>Tue, 07 Mar 2023 21:33:32 +0000</pubDate>
				<category><![CDATA[未分類]]></category>
		<category><![CDATA[python]]></category>
		<category><![CDATA[pytorch]]></category>
		<category><![CDATA[SBM]]></category>
		<category><![CDATA[スコアベースモデル]]></category>
		<category><![CDATA[拡散モデル]]></category>
		<category><![CDATA[書籍]]></category>
		<category><![CDATA[機械学習]]></category>
		<guid isPermaLink="false">https://www.mattari-benkyo-note.com/?p=1932</guid>

					<description><![CDATA[<p>先日紹介した「拡散モデル　データ生成技術の数理」をちゃんと理解するために数式を改めて追ったり、説明されているアルゴリズムを実装したりしています。 その第2弾として「2.2 スコアベースモデル」で説明されているスコアベース [&#8230;]</p>
<p>The post <a href="https://www.mattari-benkyo-note.com/2023/03/08/diffusion_model_book_2_2/">[勉強ノート] 「拡散モデル　データ生成技術の数理」 2.2 スコアベースモデル</a> first appeared on <a href="https://www.mattari-benkyo-note.com">まったり勉強ノート</a>.</p>]]></description>
										<content:encoded><![CDATA[<p>先日紹介した「<a href="https://amzn.to/3SC3LMc" target="_blank" rel="noopener" title="拡散モデル　データ生成技術の数理">拡散モデル　データ生成技術の数理</a>」をちゃんと理解するために数式を改めて追ったり、説明されているアルゴリズムを実装したりしています。</p>



<p>その第2弾として「2.2 スコアベースモデル」で説明されているスコアベースモデルの学習とそれを使ったサンプリングについてPython(深層学習部分はPytorch)でコードを書いて試したのでそのまとめになります。</p>



<p>また、この本を買うか迷っている方は私が読んだ感想をこちらの記事に書いてますので参考にしてみてください。</p>



<figure class="wp-block-embed is-type-wp-embed is-provider-まったり勉強ノート wp-block-embed-まったり勉強ノート"><div class="wp-block-embed__wrapper">
<blockquote class="wp-embedded-content" data-secret="feZfZwZCVZ"><a href="https://www.mattari-benkyo-note.com/2023/02/23/diffusion_model_book_review/">[書評] 拡散モデル データ生成技術の数理 ー 目覚ましい画像生成の発展の裏側を知りたい人へ</a></blockquote><iframe loading="lazy" class="wp-embedded-content" sandbox="allow-scripts" security="restricted"  title="&#8220;[書評] 拡散モデル データ生成技術の数理 ー 目覚ましい画像生成の発展の裏側を知りたい人へ&#8221; &#8212; まったり勉強ノート" src="https://www.mattari-benkyo-note.com/2023/02/23/diffusion_model_book_review/embed/#?secret=rS3Hmaks6Y#?secret=feZfZwZCVZ" data-secret="feZfZwZCVZ" width="600" height="338" frameborder="0" marginwidth="0" marginheight="0" scrolling="no"></iframe>
</div><figcaption class="wp-element-caption">[書評] 拡散モデル データ生成技術の数理 ー 目覚ましい画像生成の発展の裏側を知りたい人へ</figcaption></figure>



<p>また、この記事で紹介したコードは以下にあげてありますので、コード全体を確認したい方はこちらをご覧ください。</p>



<p><a href="https://github.com/shu65/diffusion-model-book/blob/main/diffusion_model_book_2_2_score_based_model.ipynb" target="_blank" rel="noopener" title="">https://github.com/shu65/diffusion-model-book/blob/main/diffusion_model_book_2_2_score_based_model.ipynb</a></p>



<h2 class="wp-block-heading">スコアベースモデルとは</h2>



<p>１章で紹介されているデノイジングスコアマッチングは以下の２つの問題点があると本では紹介されています。</p>



<ol class="wp-block-list">
<li>デノイジングスコアマッチングで推定されたスコア関数はデータ分布の密度が小さい領域で不正確</li>



<li>データ分布が多峰性を持つ場合、あるモード（確率が大きい領域）から他のモードに移る際、確率が小さい領域を通過するために非常に多くのステップを必要とする</li>
</ol>



<p>これらの問題を解決するためにスコアベースモデル（SBM）[1, 2] では複数の異なる強度のノイズによって攪乱した攪乱後分布を用意して、それらの攪乱後分布上のスコアを求めるようにしています。</p>



<h2 class="wp-block-heading">スコアベースモデルの学習</h2>



<p>スコア関数 \(s_{\theta}(\boldsymbol{x}, \sigma_t)\) を学習する際は以下のロス関数を使います。</p>



<p>$$ \begin{align*}<br>L_{\text{SBM}}(\theta) := \sum_{t=1}^T w_t E_{p_{\sigma_t}}(\tilde{\boldsymbol{x}}) \left\{ \left\| \nabla_{\tilde{\boldsymbol{x}}} \log p_{\sigma_t}(\tilde{\boldsymbol{x}}) &#8211; s_{\theta}(\tilde{\boldsymbol{x}}, \sigma_t) \right\|^2 \right\} \tag{2.2.1}<br>\end{align*} $$</p>



<p>ここで\(\sigma_t \) はノイズの強さを表す変数で\( \sigma_{min} = \sigma_1 &lt; \sigma_2 &lt;&#8230; &lt; \sigma_T = \sigma_{max}\)の合計\(T\)個をスコアベースモデルでは利用します。そして、\(p_{\sigma_t}(\tilde{\boldsymbol{x}}) \) \(x\)は\(x\)の分布\(p(x)\)を\(\sigma_t\)の強さで攪乱したあとの分布を表しています。</p>



<p>この式(2.2.1)を本の1.5.5の「デノイジングスコアマッチング」で説明されている通り、デノイジングスコアマッチングを使って式を書き換えると以下のようになります。</p>



<p>$$ \begin{align*}<br>L_{\text{DSM-SBM}}(\theta) := \sum_{t=1}^T w_t E_{\boldsymbol{x} \sim p_{data}(\boldsymbol{x}),\tilde{\boldsymbol{x}} \sim \mathcal{N}(\boldsymbol{x}, \sigma_t^2\boldsymbol{I})} \left\{ \left\| \frac{\boldsymbol{x} &#8211; \tilde{\boldsymbol{x}}}{\sigma_t^2} &#8211; s_{\theta}(\tilde{\boldsymbol{x}}, \sigma_t) \right\|^2 \right\} \tag{2.2.2}<br>\end{align*} $$</p>



<p>詳細は本にわかりやすくかいてあるので本を参照してください。</p>



<p>ここで本の式(1.9)のデノイジングスコアマッチングの式において最初に\(1/2\)があるのに式(2.2.2)ではそれが省略されています。これに関して本にはちゃんと書いてない気がしますが、おそらくこれは\(w_t\)の中に\(1/2\)が含まれているから、もしくは\(1/2\)は定数であり、最適化の際にパラメータが移動する方向は\(1/2\)のありなしで変わらないということで省略しているのではないかと思っています。</p>



<p>ここから２章にはちゃんと書いてないですが、Pythonで実装するためにさらに式変形していきます。\(\tilde{\boldsymbol{x}} \sim \mathcal{N}(\boldsymbol{x}, \sigma_t^2\boldsymbol{I})\)なので、\(\tilde{\boldsymbol{x}}\)を\(\epsilon \sim \mathcal{N}(0, \sigma_t^2 \boldsymbol{I})\)を使って表すと以下のようになります。</p>



<p>$$ \begin{align*}<br>\tilde{\boldsymbol{x}} = \boldsymbol{x} + \epsilon \tag{2.2.3}<br>\end{align*} $$</p>



<p>この式(2.2.3)を使って式(2.2.2)を式変形すると以下の通りです。</p>



<p>$$ \begin{align*}<br>L_{\text{DSM-SBM}}(\theta) :=&amp; \sum_{t=1}^T w_t E_{\boldsymbol{x} \sim p_{data}(\boldsymbol{x}),\tilde{\boldsymbol{x}} \sim \mathcal{N}(\boldsymbol{x}, \sigma_t^2\boldsymbol{I})} \left\{ \left\| \frac{\boldsymbol{x} &#8211; \tilde{\boldsymbol{x}}}{\sigma_t^2} &#8211; s_{\theta}(\tilde{\boldsymbol{x}}, \sigma_t) \right\|^2 \right\} \\<br>=&amp; \sum_{t=1}^T w_t E_{\boldsymbol{x} \sim p_{data}(\boldsymbol{x}),\epsilon \sim \mathcal{N}(\boldsymbol{x}, \sigma_t \boldsymbol{I})} \left\{ \left\| \frac{-\epsilon}{\sigma_t^2} &#8211; s_{\theta}(\tilde{\boldsymbol{x}}, \sigma_t) \right\|^2 \right\} \tag{2.2.4}<br>\end{align*} $$</p>



<p>この式を見たときに\(t=1\)から\(t=T\)までの和をとっている部分、\(T\)のサイズによっては計算量がすごいことにならないか？ということを思いました。このため、何か実装するときに工夫があるのかも？ということで[2]著者実装である[3]を見にいきました。すると2023/03/03時点では\(t=1\)から\(t=T\)ランダムに\(t\)を選び、その平均をとるということをしていました。</p>



<p>Pythonのコードのほうが分かりやすいと思うので、以下にPythonのコードも示しておきます。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="dsm_loss" data-lang="Python"><code>def dsm_loss(score_model, samples, sigmas):
  t = torch.randint(0, len(sigmas), (samples.shape[0],), device=sigmas.device)
  used_sigmas = sigmas[t].view(samples.shape[0], *([1] * len(samples.shape[1:])))
  noise = torch.randn_like(samples) * used_sigmas
  perturbed_samples = samples + noise
  target = - 1 / (used_sigmas ** 2) * noise
  scores = score_model(perturbed_samples, used_sigmas)
  target = target.view(target.shape[0], -1)
  scores = scores.view(scores.shape[0], -1)
  w = used_sigmas.squeeze(-1) ** 2
  loss = ((scores - target) ** 2).sum(dim=-1) * w
  return loss.mean()</code></pre></div>



<p>ここで<code>score_model</code>がスコア関数 \(s_{\theta}(\boldsymbol{x}, \sigma_t )\) 、<code>samples</code>が\(\boldsymbol{x}\)、<code>sigmas</code>が\(\{\sigma_1,&#8230;,\sigma_T\}\)の配列となっています。また、\(w_t\)は本にならって\(w_t=\sigma_t^2\)を使っています。</p>



<p>この関数では最初にランダムに\(t\)を選び、それに従ってノイズを生成し、\(\tilde{\boldsymbol{x}}\)を作ります。その後、スコア関数の<code>score_model</code>を使ってスコアを計算し、式(2.2.4)を使ってロス関数を計算します。</p>



<p>このロス関数を使ってスコア関数のパラメータを学習していきます。</p>



<p>ここで１つ、スコア関数のモデルに関して注意点があります。スコア関数は\(s_{\theta}(\boldsymbol{x}, \sigma_t) \)は\(\boldsymbol{x}\)だけでなく\(\sigma_t\)も引数にとります。このため、モデルの中でどうにかして\(\sigma_t \)と\(\boldsymbol{x}\)の入力を組み合わせる必要があります。これに関して今回のコードでは[3]の実装にならって、以下のようにして\(\boldsymbol{x}\)だけを入力として受け取るスコア関数\(s_{\theta}^{\prime}(\boldsymbol{x})\)の出力を\(\sigma_t\)で割るという形にしています。</p>



<p>$$ \begin{align*}<br>s_{\theta}(\boldsymbol{x}, \sigma_t) = s_{\theta}^{\prime}(\boldsymbol{x}) / \sigma_t \tag{2.2.5}<br>\end{align*} $$</p>



<p>また、後ほど示しますが、今回は２つのガウス分布の混合分布を入力とします。この分布はシンプルな分布なため、今回は簡単なMLPをスコア関数のモデル使用します。コードとては以下のようになります。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="score_model" data-lang="Python"><code>import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

class ScoreModel(nn.Module):
  def __init__(self, n_channels=2):
    super(ScoreModel, self).__init__()

    self.model = nn.Sequential(
        nn.Linear(n_channels, 2*n_channels),
        nn.ELU(),
        nn.Linear(2*n_channels, 16*n_channels),
        nn.ELU(),
        nn.Linear(16*n_channels, 2*n_channels),
        nn.ELU(),
        nn.Linear(2*n_channels, n_channels),
    )

  def forward(self, x, sigma):
    y = self.model(x)
    return y/sigma</code></pre></div>



<h2 class="wp-block-heading">スコアベースモデルを使ったサンプリング</h2>



<p>ここから学習済みのスコア関数\(s_{\theta}(\boldsymbol{x}, \sigma_t) \)を使ったサンプリングについて説明していきます。</p>



<p>スコアベースモデルを使ったサンプリング１章で紹介されたランジュバン・モンテカルロ法をベースにしています。ランジュバン・モンテカルロ法の部分についてはこちらに解説しています。</p>



<figure class="wp-block-embed is-type-wp-embed is-provider-まったり勉強ノート wp-block-embed-まったり勉強ノート"><div class="wp-block-embed__wrapper">
<blockquote class="wp-embedded-content" data-secret="fM7uHxbfRU"><a href="https://www.mattari-benkyo-note.com/2023/03/03/diffusion_model_book_1_5_1/">[勉強ノート] 「拡散モデル　データ生成技術の数理」 1.5.1 ランジュバン・モンテカルロ法</a></blockquote><iframe loading="lazy" class="wp-embedded-content" sandbox="allow-scripts" security="restricted"  title="&#8220;[勉強ノート] 「拡散モデル　データ生成技術の数理」 1.5.1 ランジュバン・モンテカルロ法&#8221; &#8212; まったり勉強ノート" src="https://www.mattari-benkyo-note.com/2023/03/03/diffusion_model_book_1_5_1/embed/#?secret=oaLMCcsv0l#?secret=fM7uHxbfRU" data-secret="fM7uHxbfRU" width="600" height="338" frameborder="0" marginwidth="0" marginheight="0" scrolling="no"></iframe>
</div></figure>



<p>詳細は上の記事にかいてありますが、ランジュバン・モンテカルロ法は最初、ランダムに\(\boldsymbol{x}_0\)を生成後、以下のランジュバン・モンテカルロ法の更新則を\(K\)回繰り返すことで\(p(\boldsymbol{x})\)からサンプリングしたようなデータを作ります。</p>



<p>$$ \begin{align*}<br>\boldsymbol{x}_k := \boldsymbol{x}_{k-1} + \alpha \nabla_\boldsymbol{x} \log p(\boldsymbol{x}_{k-1}) + \sqrt{2\alpha}\boldsymbol{u}_k \tag{2.2.6}<br>\end{align*} $$</p>



<p>スコアベースモデルのサンプリングでは更新則のスコア（\(\nabla_\boldsymbol{x} \log p(\boldsymbol{x}_{k-1})\)）を学習したスコア関数に置き換えた以下の更新則を利用します。</p>



<p>$$ \begin{align*}<br>\boldsymbol{x}_{t, k} := \boldsymbol{x}_{t, k-1} + \alpha_t s_{\theta}(\boldsymbol{x}_{t, k-1}, \sigma_t)+ \sqrt{2\alpha_t}\boldsymbol{u}_k \tag{2.2.7}<br>\end{align*} $$</p>



<p>この更新則を用いたスコアベースモデルのサンプリングの疑似コードは以下の通りです。(「拡散モデル　データ生成技術の数理」Algorithm 2.1の引用)</p>



<ol class="wp-block-list">
<li>\(\boldsymbol{x}_0\)を初期化(\(\boldsymbol{x}_0 \sim \mathcal{N}(0, \sigma_T^2 \boldsymbol{I})\))</li>



<li>for \(t=1,&#8230;,T\) do</li>



<li>\(\quad \alpha_t := \alpha \sigma_t^2\/\sigma_T^2)\</li>



<li>\(\quad\) for \(k=1,&#8230;,K\) do</li>



<li>\(\qquad \boldsymbol{u}_k \sim \mathcal{N}(0, \boldsymbol{I})\)</li>



<li>\(\qquad\) if \(t=1\) and \(k=K\) then \(\boldsymbol{u}_k := 0\)</li>



<li>\(\qquad \boldsymbol{x}_{t, k} := \boldsymbol{x}_{t, k-1} + \alpha_t s_{\theta}(\boldsymbol{x}_{t, k-1}, \sigma_t)+ \sqrt{2\alpha_t}\boldsymbol{u}_k \)</li>



<li>\(\quad\) end for</li>



<li>\(\quad \boldsymbol{x}_{t-1, 0} := \boldsymbol{x}_{t, K}\) </li>



<li>end for</li>



<li>return \(\boldsymbol{x}_{0, 0}\)</li>
</ol>



<p>ここで\(\alpha\)はステップ幅のスケール、\(K\)はステップ回数です。アルゴリズムを見て分かる通り、ノイズの強度を変えながらランジュバン・モンテカルロ法を使って少しずつ\(\boldsymbol{x}_{t, k}\)を変化させています。また、７行目にある通り、各ノイズの強度の最後のステップではデノイジングのみを行うことでサンプリングの品質を向上させています。</p>



<p>この疑似コードをPythonのコードにするとこのようになります。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="sbm_sampling" data-lang="Python"><code>def sbm_sample(n_samples, score_model, sigmas, alpha=0.1):
    sigma_T = sigmas[-1]
    x_0 = torch.randn(n_samples, 2)*sigma_T
    x_tk = x_0
    K = 200
    for t in range(len(sigmas) -1, -1, -1):
      sigma_t = sigmas[t]
      alpha_t = alpha*(sigma_t**2)/(sigma_T**2)
      print(f&quot;t:{t}, sigma_t:{sigma_t}, alpha_t:{alpha_t}&quot;)
      for k in range(K+1):
        u_k = torch.randn(n_samples, 2)
        if (k == K) and t == 0:
          u_k[:, :] = 0.0
        with torch.no_grad():
          score = score_model(x_tk, sigma_t)
          x_tk = x_tk + alpha_t * score + np.sqrt(2 * alpha_t) * u_k
    return x_tk</code></pre></div>



<p><code>n_samples</code>が生成するサンプル数、<code>score_model</code>がスコア関数、<code>sigmas</code>がノイズ強度の配列、<code>alpha</code>がステップ幅のスケールになっています。</p>



<h2 class="wp-block-heading">実行例</h2>



<p>先ほど紹介したPythonコードを実際に動かした例も示しておきます。参考例として入力となる\(\boldsymbol{x}\)のサンプリングする分布の確率密度関数は以下のように平均が違うガウス分布二つの混合分布とし、サンプリングしたデータを正規化して使用します。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="generate_dataset" data-lang="Python"><code>n_samples = int(1e6)
sigma = 0.01

dist0 = torch.distributions.MultivariateNormal(torch.tensor([-2, -2], dtype=torch.float).to(device), sigma*torch.eye(2, dtype=torch.float).to(device))
samples0 = dist0.sample(torch.Size([n_samples//2]))
    
dist1 = torch.distributions.MultivariateNormal(torch.tensor([2, 2], dtype=torch.float).to(device), sigma*torch.eye(2, dtype=torch.float).to(device))
samples1 = dist1.sample(torch.Size([n_samples//2]))
samples = torch.vstack((samples0, samples1))

mean = torch.mean(samples, dim=0)
std = torch.std(samples, dim=0)

normalized_samples = (samples - mean[None, :])/std[None, :]</code></pre></div>



<p>使用する\(\boldsymbol{x}\)を2Dのヒストグラムで可視化すると以下のようになります。</p>


<div class="wp-block-image">
<figure class="aligncenter size-full is-resized"><img loading="lazy" decoding="async" src="https://www.mattari-benkyo-note.com/wp-content/uploads/2023/03/sample_density.png" alt="使用するデータの可視化結果" class="wp-image-2091" width="266" height="264" srcset="https://www.mattari-benkyo-note.com/wp-content/uploads/2023/03/sample_density.png 266w, https://www.mattari-benkyo-note.com/wp-content/uploads/2023/03/sample_density-150x150.png 150w" sizes="auto, (max-width: 266px) 100vw, 266px" /><figcaption class="wp-element-caption">使用するデータの可視化結果</figcaption></figure></div>


<p>このデータを再現できるようにスコア関数を学習します。学習コードは以下の通りです。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="train" data-lang="Python"><code>import torch

batch_size = 512
n_steps = 100000

dataset = torch.utils.data.TensorDataset((normalized_samples))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=512, shuffle=True, num_workers=0)
dataloader_iter = iter(dataloader)

score_model = ScoreModel().to(device)

optimizer = torch.optim.Adam(score_model.parameters())
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.001, total_steps=n_steps)

for i in range(n_steps):
  try:
    x = next(dataloader_iter)[0]
  except StopIteration:
    dataloader_iter = iter(dataloader)
    x = next(dataloader_iter)[0]
  x = x.to(device)

  optimizer.zero_grad()
  loss = dsm_loss(score_model, x, sigmas)
  loss.backward()
  optimizer.step()
  lr_scheduler.step()
  if (i % 1000) == 0:
    print(f&quot;{i} steps loss:{loss}&quot;)</code></pre></div>



<p>学習が終わったら、以下のようにして学習したモデルを利用してサンプリングします。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="sbm_samping" data-lang="Python"><code>samples_pred = sbm_sample(n_samples=100000, score_model=score_model, sigmas=sigmas)</code></pre></div>



<p>サンプリングされたデータの2Dのヒストグラムは以下の通りです。</p>


<div class="wp-block-image">
<figure class="aligncenter size-full"><img loading="lazy" decoding="async" width="266" height="264" src="https://www.mattari-benkyo-note.com/wp-content/uploads/2023/03/predicted_sample_density.png" alt="スコアベースモデルによるサンプリングデータの可視化結果" class="wp-image-2092" srcset="https://www.mattari-benkyo-note.com/wp-content/uploads/2023/03/predicted_sample_density.png 266w, https://www.mattari-benkyo-note.com/wp-content/uploads/2023/03/predicted_sample_density-150x150.png 150w" sizes="auto, (max-width: 266px) 100vw, 266px" /><figcaption class="wp-element-caption">スコアベースモデルによるサンプリングデータの可視化結果</figcaption></figure></div>


<p>可視化結果をみると元の分布の平均の近くにデータが集中しているので、うまくいっていると考えられます。</p>



<p>ただ、やってみるとわかるのですがちゃんとした結果を得るために人手で決めないといけないハイパーパラメータの選択が難しい印象です。この結果もかなり試行錯誤してなんとかこの結果を作ることができたというイメージです。</p>



<h2 class="wp-block-heading">終わりに</h2>



<p>今回は「<a href="https://amzn.to/3SC3LMc" target="_blank" rel="noopener" title="拡散モデル　データ生成技術の数理">拡散モデル　データ生成技術の数理</a>」の2.2のスコアベースモデルの説明の部分のコードを書いたのでそのまとめの記事になります。最初、MNISTのデータでやろうとして、MNISTのデータを学習できるコードを説明するのは結構大変、ということでシンプルな混合ガウス分布にしました。ただ、それでも結構な分量になった印象です。ちなみに次のDDPMも紹介用のコードはできているので、近日中に記事を書いて公開しようと思います。</p>



<p>この記事が少しでもみなさんの理解の助けになれば幸いです。</p>



<h2 class="wp-block-heading">参考文献</h2>



<ol class="wp-block-list">
<li>Song, Y., &amp; Ermon, S. (2019). Generative Modeling by Estimating Gradients of the Data Distribution. ArXiv, abs/1907.05600.</li>



<li>Song, Y., &amp; Ermon, S. (2020). Improved Techniques for Training Score-Based Generative Models. ArXiv, abs/2006.09011.</li>



<li><a href="https://github.com/ermongroup/ncsnv2" target="_blank" rel="noopener" title="">https://github.com/ermongroup/ncsnv2</a></li>
</ol><p>The post <a href="https://www.mattari-benkyo-note.com/2023/03/08/diffusion_model_book_2_2/">[勉強ノート] 「拡散モデル　データ生成技術の数理」 2.2 スコアベースモデル</a> first appeared on <a href="https://www.mattari-benkyo-note.com">まったり勉強ノート</a>.</p>]]></content:encoded>
					
					<wfw:commentRss>https://www.mattari-benkyo-note.com/2023/03/08/diffusion_model_book_2_2/feed/</wfw:commentRss>
			<slash:comments>0</slash:comments>
		
		
		<post-id xmlns="com-wordpress:feed-additions:1">1932</post-id>	</item>
		<item>
		<title>統計的因果推論、統計的因果探索の勉強で読んだ本まとめ(2023/1版)</title>
		<link>https://www.mattari-benkyo-note.com/2023/01/10/causal_book/</link>
					<comments>https://www.mattari-benkyo-note.com/2023/01/10/causal_book/#respond</comments>
		
		<dc:creator><![CDATA[Shuji Suzuki (shu)]]></dc:creator>
		<pubDate>Mon, 09 Jan 2023 22:18:18 +0000</pubDate>
				<category><![CDATA[書評]]></category>
		<category><![CDATA[python]]></category>
		<category><![CDATA[pytorch]]></category>
		<category><![CDATA[R]]></category>
		<category><![CDATA[因果推論]]></category>
		<category><![CDATA[本]]></category>
		<guid isPermaLink="false">https://www.mattari-benkyo-note.com/?p=1534</guid>

					<description><![CDATA[<p>あけましておめでとうございます。年末年始の休みを利用して今年もまとめて本を読んでいたのですが、統計的因果推論に関する本を読んだらすごく面白くて今年はそれ関係の本をまとめて読みました。 今回の記事では私が読んだ本の軽い解説 [&#8230;]</p>
<p>The post <a href="https://www.mattari-benkyo-note.com/2023/01/10/causal_book/">統計的因果推論、統計的因果探索の勉強で読んだ本まとめ(2023/1版)</a> first appeared on <a href="https://www.mattari-benkyo-note.com">まったり勉強ノート</a>.</p>]]></description>
										<content:encoded><![CDATA[<p>あけましておめでとうございます。年末年始の休みを利用して今年もまとめて本を読んでいたのですが、統計的因果推論に関する本を読んだらすごく面白くて今年はそれ関係の本をまとめて読みました。</p>



<p>今回の記事では私が読んだ本の軽い解説とどういう人向けかのまとめです。今後同じように統計的因果推論関係の勉強をしたいという人の参考になれば幸いです。</p>



<h1 class="wp-block-heading">因果推論の科学 「なぜ?」の問いにどう答えるか</h1>


		<div class="pochipp-box"
			data-id="2931"
			data-img="l"
			data-lyt-pc="dflt"
			data-lyt-mb="vrtcl"
			data-btn-style="dflt"
			data-btn-radius="on"
			data-sale-effect="none"
			 data-cvkey="28d3a006" data-auto-update="true"		>
							<div class="pochipp-box__image">
					<a href="https://hb.afl.rakuten.co.jp/hgc/39c5c75e.b3799909.39c5c75f.e5fd5e7d/?pc=https%3A%2F%2Fsearch.rakuten.co.jp%2Fsearch%2Fmall%2F%25E5%259B%25A0%25E6%259E%259C%25E6%258E%25A8%25E8%25AB%2596%25E3%2581%25AE%25E7%25A7%2591%25E5%25AD%25A6&#038;m=https%3A%2F%2Fsearch.rakuten.co.jp%2Fsearch%2Fmall%2F%25E5%259B%25A0%25E6%259E%259C%25E6%258E%25A8%25E8%25AB%2596%25E3%2581%25AE%25E7%25A7%2591%25E5%25AD%25A6" rel="nofollow noopener" target="_blank">
						<img loading="lazy" decoding="async" src="https://thumbnail.image.rakuten.co.jp/@0_mall/book/cabinet/5968/9784163915968_1_6.jpg?_ex=400x400" alt="" width="120" height="120" />					</a>
				</div>
						<div class="pochipp-box__body">
				<div class="pochipp-box__title">
					<a href="https://hb.afl.rakuten.co.jp/hgc/39c5c75e.b3799909.39c5c75f.e5fd5e7d/?pc=https%3A%2F%2Fsearch.rakuten.co.jp%2Fsearch%2Fmall%2F%25E5%259B%25A0%25E6%259E%259C%25E6%258E%25A8%25E8%25AB%2596%25E3%2581%25AE%25E7%25A7%2591%25E5%25AD%25A6&#038;m=https%3A%2F%2Fsearch.rakuten.co.jp%2Fsearch%2Fmall%2F%25E5%259B%25A0%25E6%259E%259C%25E6%258E%25A8%25E8%25AB%2596%25E3%2581%25AE%25E7%25A7%2591%25E5%25AD%25A6" rel="nofollow noopener" target="_blank">
						因果推論の科学 「なぜ？」の問いにどう答えるか					</a>
				</div>

									<div class="pochipp-box__info">ジューディア・パール , ダナ・マッケンジー, 夏目 大</div>
				
				
							</div>
				<div class="pochipp-box__btns"
		data-maxclmn-pc="fit"
		data-maxclmn-mb="1"
	>
					<div class="pochipp-box__btnwrap -amazon">
								<a href="https://www.amazon.co.jp/s?k=%E5%9B%A0%E6%9E%9C%E6%8E%A8%E8%AB%96%E3%81%AE%E7%A7%91%E5%AD%A6&#038;tag=shu65-22" class="pochipp-box__btn" rel="nofollow noopener" target="_blank">
					<span>
						Amazon					</span>
									</a>
			</div>
							<div class="pochipp-box__btnwrap -rakuten">
								<a href="https://hb.afl.rakuten.co.jp/hgc/39c5c75e.b3799909.39c5c75f.e5fd5e7d/?pc=https%3A%2F%2Fsearch.rakuten.co.jp%2Fsearch%2Fmall%2F%25E5%259B%25A0%25E6%259E%259C%25E6%258E%25A8%25E8%25AB%2596%25E3%2581%25AE%25E7%25A7%2591%25E5%25AD%25A6&#038;m=https%3A%2F%2Fsearch.rakuten.co.jp%2Fsearch%2Fmall%2F%25E5%259B%25A0%25E6%259E%259C%25E6%258E%25A8%25E8%25AB%2596%25E3%2581%25AE%25E7%25A7%2591%25E5%25AD%25A6" class="pochipp-box__btn" rel="nofollow noopener" target="_blank">
					<span>
						楽天市場					</span>
									</a>
			</div>
											</div>
								<div class="pochipp-box__logo">
					<img loading="lazy" decoding="async" src="https://www.mattari-benkyo-note.com/wp-content/plugins/pochipp/assets/img/pochipp-logo-t1.png" alt="" width="32" height="32">
					<span>ポチップ</span>
				</div>
					</div>
	


<p>まずは私が統計的因果推論おもしろい！と感じさせてくれた「<a href="https://amzn.to/4aEPNkt" target="_blank" rel="noopener" title="因果推論の科学 「なぜ?」の問いにどう答えるか">因果推論の科学 「なぜ?」の問いにどう答えるか</a>」です。本の内容としては、因果に関する研究が如何に難しいかやどれほど役立つのかはもちろん、因果に関する研究の歴史にも言及し、他の分野、とりわけ統計学とどういう関わりがあってどのように発展してきたか？について書かれています。個人的にはPearsonやFisherなど手法名で名前をしっている人たちがどういうことをしていた人なのかも少ししれて面白かったです。</p>



<p>こちらの本は一般人向けに因果推論がどういうものかを書いた本になっています。書いた人が研究者ということもあってか、一部数式も出てきますが、基本、誰でもわかるような表現で書かれていて、「因果推論ってなに？」ということを背景的なことを重点的に知りたいという方はとっつきやすい１冊ではないかと思います。一方、数式がでてきますが、詳しい説明はそれほどないので、式変形に関してや、統計的因果推論の実際の応用例をたくさん知りたいという方には向かない本です。また、統計的因果推論についてすでにご存じの方向けの説明になってしまいますが、この本は構造的因果モデルとdo演算子を利用した因果推論についての本で、潜在的結果変数の枠組みを使うRubin流の因果推論の話はほとんどでてこないので、後者のほうを知りたいという人にも向かないので注意してください。</p>



<h1 class="wp-block-heading">入門 統計的因果推論</h1>


		<div class="pochipp-box"
			data-id="2932"
			data-img="l"
			data-lyt-pc="dflt"
			data-lyt-mb="vrtcl"
			data-btn-style="dflt"
			data-btn-radius="on"
			data-sale-effect="none"
			 data-cvkey="29acec85" data-auto-update="true"		>
							<div class="pochipp-box__image">
					<a href="https://hb.afl.rakuten.co.jp/hgc/39c5c75e.b3799909.39c5c75f.e5fd5e7d/?pc=https%3A%2F%2Fsearch.rakuten.co.jp%2Fsearch%2Fmall%2F%25E5%2585%25A5%25E9%2596%2580%2520%25E7%25B5%25B1%25E8%25A8%2588%25E7%259A%2584%25E5%259B%25A0%25E6%259E%259C%25E6%258E%25A8%25E8%25AB%2596&#038;m=https%3A%2F%2Fsearch.rakuten.co.jp%2Fsearch%2Fmall%2F%25E5%2585%25A5%25E9%2596%2580%2520%25E7%25B5%25B1%25E8%25A8%2588%25E7%259A%2584%25E5%259B%25A0%25E6%259E%259C%25E6%258E%25A8%25E8%25AB%2596" rel="nofollow noopener" target="_blank">
						<img loading="lazy" decoding="async" src="https://thumbnail.image.rakuten.co.jp/@0_mall/book/cabinet/2411/9784254122411.jpg?_ex=400x400" alt="" width="120" height="120" />					</a>
				</div>
						<div class="pochipp-box__body">
				<div class="pochipp-box__title">
					<a href="https://hb.afl.rakuten.co.jp/hgc/39c5c75e.b3799909.39c5c75f.e5fd5e7d/?pc=https%3A%2F%2Fsearch.rakuten.co.jp%2Fsearch%2Fmall%2F%25E5%2585%25A5%25E9%2596%2580%2520%25E7%25B5%25B1%25E8%25A8%2588%25E7%259A%2584%25E5%259B%25A0%25E6%259E%259C%25E6%258E%25A8%25E8%25AB%2596&#038;m=https%3A%2F%2Fsearch.rakuten.co.jp%2Fsearch%2Fmall%2F%25E5%2585%25A5%25E9%2596%2580%2520%25E7%25B5%25B1%25E8%25A8%2588%25E7%259A%2584%25E5%259B%25A0%25E6%259E%259C%25E6%258E%25A8%25E8%25AB%2596" rel="nofollow noopener" target="_blank">
						入門 統計的因果推論					</a>
				</div>

									<div class="pochipp-box__info">Judea Pearl, Madelyn Glymour, Nicholas P. Jewell, 落海 浩</div>
				
				
							</div>
				<div class="pochipp-box__btns"
		data-maxclmn-pc="fit"
		data-maxclmn-mb="1"
	>
					<div class="pochipp-box__btnwrap -amazon">
								<a href="https://www.amazon.co.jp/s?k=%E5%85%A5%E9%96%80%20%E7%B5%B1%E8%A8%88%E7%9A%84%E5%9B%A0%E6%9E%9C%E6%8E%A8%E8%AB%96&#038;tag=shu65-22" class="pochipp-box__btn" rel="nofollow noopener" target="_blank">
					<span>
						Amazon					</span>
									</a>
			</div>
							<div class="pochipp-box__btnwrap -rakuten">
								<a href="https://hb.afl.rakuten.co.jp/hgc/39c5c75e.b3799909.39c5c75f.e5fd5e7d/?pc=https%3A%2F%2Fsearch.rakuten.co.jp%2Fsearch%2Fmall%2F%25E5%2585%25A5%25E9%2596%2580%2520%25E7%25B5%25B1%25E8%25A8%2588%25E7%259A%2584%25E5%259B%25A0%25E6%259E%259C%25E6%258E%25A8%25E8%25AB%2596&#038;m=https%3A%2F%2Fsearch.rakuten.co.jp%2Fsearch%2Fmall%2F%25E5%2585%25A5%25E9%2596%2580%2520%25E7%25B5%25B1%25E8%25A8%2588%25E7%259A%2584%25E5%259B%25A0%25E6%259E%259C%25E6%258E%25A8%25E8%25AB%2596" class="pochipp-box__btn" rel="nofollow noopener" target="_blank">
					<span>
						楽天市場					</span>
									</a>
			</div>
											</div>
								<div class="pochipp-box__logo">
					<img loading="lazy" decoding="async" src="https://www.mattari-benkyo-note.com/wp-content/plugins/pochipp/assets/img/pochipp-logo-t1.png" alt="" width="32" height="32">
					<span>ポチップ</span>
				</div>
					</div>
	


<p>次は「<a href="https://amzn.to/3Etg8V0" target="_blank" rel="noopener" title="入門 統計的因果推論">入門 統計的因果推論</a>」です。こちらは「因果推論の科学 「なぜ?」の問いにどう答えるか」の方が著者の一人に入っている統計的因果推論に関する教科書です。こちらは「因果推論の科学 「なぜ?」の問いにどう答えるか」とは違い、一般人向けではなく、技術者、研究者向けの本になっています。内容としては統計的因果推論の話をするのに必要な確率と統計、回帰、グラフィカルモデルの基本的な話から始まり、統計的因果推論で重要な介入効果や反実仮想とその応用について解説しています。「因果推論の科学 「なぜ?」の問いにどう答えるか」でも登場した因果グラフの具体例はもちろん、他にもいろいろ出てくるので、因果グラフのイメージはつきやすい印象です。一方、数式の式変形に関する記述が少なくて、この式どうしてでてきたの？という疑問が結構読んでてありました。また、因果グラフの構築の仕方に関する具体的な説明があまりなく、具体的な数値まで使った例というのはすくないので、実務に向けて勉強したい方はこの本だけだと、自分の持っているデータにどう適用すればいいのか？というのがわからないと思っています。</p>



<p>こちらの本、「因果推論の科学 「なぜ?」の問いにどう答えるか」の内容をちゃんと勉強したいという向け方はおすすめの本です。一方、構造的因果モデルとdo演算子を利用する法の因果推論の紹介が主で潜在的結果変数の枠組みを使うRubin流の因果推論の話を期待していると思っているのと違った、ということが起きると思っていますので注意してください。また、先ほども述べたように、この本だけでは統計的因果推論を使いこなすのは難しいと思っていて、実務で統計的因果推論を使いたいという方は後ほど紹介する具体的なコード付きの本を参考を読むとよいかと思っています。</p>



<h1 class="wp-block-heading">統計的因果探索 (機械学習プロフェッショナルシリーズ)</h1>


		<div class="pochipp-box"
			data-id="2935"
			data-img="l"
			data-lyt-pc="dflt"
			data-lyt-mb="vrtcl"
			data-btn-style="dflt"
			data-btn-radius="on"
			data-sale-effect="none"
			 data-cvkey="7e4092a8" data-auto-update="true"		>
							<div class="pochipp-box__image">
					<a href="https://hb.afl.rakuten.co.jp/hgc/39c5c75e.b3799909.39c5c75f.e5fd5e7d/?pc=https%3A%2F%2Fsearch.rakuten.co.jp%2Fsearch%2Fmall%2F%25E7%25B5%25B1%25E8%25A8%2588%25E7%259A%2584%25E5%259B%25A0%25E6%259E%259C%25E6%258E%25A2%25E7%25B4%25A2%2520%2528%25E6%25A9%259F%25E6%25A2%25B0%25E5%25AD%25A6%25E7%25BF%2592%25E3%2583%2597%25E3%2583%25AD%25E3%2583%2595%25E3%2582%25A7%25E3%2583%2583%25E3%2582%25B7%25E3%2583%25A7%25E3%2583%258A%25E3%2583%25AB%25E3%2582%25B7%25E3%2583%25AA%25E3%2583%25BC%25E3%2582%25BA%2529&#038;m=https%3A%2F%2Fsearch.rakuten.co.jp%2Fsearch%2Fmall%2F%25E7%25B5%25B1%25E8%25A8%2588%25E7%259A%2584%25E5%259B%25A0%25E6%259E%259C%25E6%258E%25A2%25E7%25B4%25A2%2520%2528%25E6%25A9%259F%25E6%25A2%25B0%25E5%25AD%25A6%25E7%25BF%2592%25E3%2583%2597%25E3%2583%25AD%25E3%2583%2595%25E3%2582%25A7%25E3%2583%2583%25E3%2582%25B7%25E3%2583%25A7%25E3%2583%258A%25E3%2583%25AB%25E3%2582%25B7%25E3%2583%25AA%25E3%2583%25BC%25E3%2582%25BA%2529" rel="nofollow noopener" target="_blank">
						<img loading="lazy" decoding="async" src="https://thumbnail.image.rakuten.co.jp/@0_mall/book/cabinet/9250/9784061529250.jpg?_ex=400x400" alt="" width="120" height="120" />					</a>
				</div>
						<div class="pochipp-box__body">
				<div class="pochipp-box__title">
					<a href="https://hb.afl.rakuten.co.jp/hgc/39c5c75e.b3799909.39c5c75f.e5fd5e7d/?pc=https%3A%2F%2Fsearch.rakuten.co.jp%2Fsearch%2Fmall%2F%25E7%25B5%25B1%25E8%25A8%2588%25E7%259A%2584%25E5%259B%25A0%25E6%259E%259C%25E6%258E%25A2%25E7%25B4%25A2%2520%2528%25E6%25A9%259F%25E6%25A2%25B0%25E5%25AD%25A6%25E7%25BF%2592%25E3%2583%2597%25E3%2583%25AD%25E3%2583%2595%25E3%2582%25A7%25E3%2583%2583%25E3%2582%25B7%25E3%2583%25A7%25E3%2583%258A%25E3%2583%25AB%25E3%2582%25B7%25E3%2583%25AA%25E3%2583%25BC%25E3%2582%25BA%2529&#038;m=https%3A%2F%2Fsearch.rakuten.co.jp%2Fsearch%2Fmall%2F%25E7%25B5%25B1%25E8%25A8%2588%25E7%259A%2584%25E5%259B%25A0%25E6%259E%259C%25E6%258E%25A2%25E7%25B4%25A2%2520%2528%25E6%25A9%259F%25E6%25A2%25B0%25E5%25AD%25A6%25E7%25BF%2592%25E3%2583%2597%25E3%2583%25AD%25E3%2583%2595%25E3%2582%25A7%25E3%2583%2583%25E3%2582%25B7%25E3%2583%25A7%25E3%2583%258A%25E3%2583%25AB%25E3%2582%25B7%25E3%2583%25AA%25E3%2583%25BC%25E3%2582%25BA%2529" rel="nofollow noopener" target="_blank">
						統計的因果探索 （機械学習プロフェッショナルシリーズ）					</a>
				</div>

									<div class="pochipp-box__info">清水 昌平</div>
				
				
							</div>
				<div class="pochipp-box__btns"
		data-maxclmn-pc="fit"
		data-maxclmn-mb="1"
	>
					<div class="pochipp-box__btnwrap -amazon">
								<a href="https://www.amazon.co.jp/s?k=%E7%B5%B1%E8%A8%88%E7%9A%84%E5%9B%A0%E6%9E%9C%E6%8E%A2%E7%B4%A2%20%28%E6%A9%9F%E6%A2%B0%E5%AD%A6%E7%BF%92%E3%83%97%E3%83%AD%E3%83%95%E3%82%A7%E3%83%83%E3%82%B7%E3%83%A7%E3%83%8A%E3%83%AB%E3%82%B7%E3%83%AA%E3%83%BC%E3%82%BA%29&#038;tag=shu65-22" class="pochipp-box__btn" rel="nofollow noopener" target="_blank">
					<span>
						Amazon					</span>
									</a>
			</div>
							<div class="pochipp-box__btnwrap -rakuten">
								<a href="https://hb.afl.rakuten.co.jp/hgc/39c5c75e.b3799909.39c5c75f.e5fd5e7d/?pc=https%3A%2F%2Fsearch.rakuten.co.jp%2Fsearch%2Fmall%2F%25E7%25B5%25B1%25E8%25A8%2588%25E7%259A%2584%25E5%259B%25A0%25E6%259E%259C%25E6%258E%25A2%25E7%25B4%25A2%2520%2528%25E6%25A9%259F%25E6%25A2%25B0%25E5%25AD%25A6%25E7%25BF%2592%25E3%2583%2597%25E3%2583%25AD%25E3%2583%2595%25E3%2582%25A7%25E3%2583%2583%25E3%2582%25B7%25E3%2583%25A7%25E3%2583%258A%25E3%2583%25AB%25E3%2582%25B7%25E3%2583%25AA%25E3%2583%25BC%25E3%2582%25BA%2529&#038;m=https%3A%2F%2Fsearch.rakuten.co.jp%2Fsearch%2Fmall%2F%25E7%25B5%25B1%25E8%25A8%2588%25E7%259A%2584%25E5%259B%25A0%25E6%259E%259C%25E6%258E%25A2%25E7%25B4%25A2%2520%2528%25E6%25A9%259F%25E6%25A2%25B0%25E5%25AD%25A6%25E7%25BF%2592%25E3%2583%2597%25E3%2583%25AD%25E3%2583%2595%25E3%2582%25A7%25E3%2583%2583%25E3%2582%25B7%25E3%2583%25A7%25E3%2583%258A%25E3%2583%25AB%25E3%2582%25B7%25E3%2583%25AA%25E3%2583%25BC%25E3%2582%25BA%2529" class="pochipp-box__btn" rel="nofollow noopener" target="_blank">
					<span>
						楽天市場					</span>
									</a>
			</div>
											</div>
								<div class="pochipp-box__logo">
					<img loading="lazy" decoding="async" src="https://www.mattari-benkyo-note.com/wp-content/plugins/pochipp/assets/img/pochipp-logo-t1.png" alt="" width="32" height="32">
					<span>ポチップ</span>
				</div>
					</div>
	


<p>次は「<a href="https://amzn.to/3PKVvJv" target="_blank" rel="noopener" title="統計的因果探索">統計的因果探索</a>」です。前の２冊は統計的因果&#8221;推論&#8221;に関する本でしたがこちらは統計的因果&#8221;探索&#8221;に関する本です。前の本では構造的因果グラフは与えられているものと仮定して説明されている部分が多く、どうやって作るか？についての言及がほとんどありません。この本は構造的因果グラフをどうやって作るか？ということに関して、LiNGAMという手法の説明をした本になります。</p>



<p>内容としては簡単な統計的因果推論の話はあるものの、基本は統計的因果探索がメインです。また、統計的因果探索に関してもLiNGAMに関する記述が多く、それ以外に関しては簡単な解説が少しある程度です。ただ、LiNGAMに関しては説明が幅広く、未観測共通原因がある場合のLiNGAMや、LiNGAMで仮定を緩めることに関する記述もあります。なので、LiNGAMを深堀したい人にはちょうどよい１冊かと思います。</p>



<p>一方、他の機械学習プロフェッショナルシリーズの本の例に漏れることなく、コードを使った説明がなかったり、細かい数式変形までは書いてないので、詳細に理解するのは結構大変な印象です。</p>



<p>こちらの本は他の本でLiNGAMに触れて面白そうだからもっと勉強したい！という人向けかなと思っています。</p>



<h1 class="wp-block-heading">つくりながら学ぶ! Pythonによる因果分析 ~因果推論・因果探索の実践入門 (Compass Data Science) </h1>


		<div class="pochipp-box"
			data-id="2937"
			data-img="l"
			data-lyt-pc="dflt"
			data-lyt-mb="vrtcl"
			data-btn-style="dflt"
			data-btn-radius="on"
			data-sale-effect="none"
			 data-cvkey="b9265a27"		>
							<div class="pochipp-box__image">
					<a href="https://hb.afl.rakuten.co.jp/hgc/39c5c75e.b3799909.39c5c75f.e5fd5e7d/?pc=https%3A%2F%2Fsearch.rakuten.co.jp%2Fsearch%2Fmall%2F%25E3%2581%25A4%25E3%2581%258F%25E3%2582%258A%25E3%2581%25AA%25E3%2581%258C%25E3%2582%2589%25E5%25AD%25A6%25E3%2581%25B6%2521%2520Python%25E3%2581%25AB%25E3%2582%2588%25E3%2582%258B%25E5%259B%25A0%25E6%259E%259C%25E5%2588%2586%25E6%259E%2590&#038;m=https%3A%2F%2Fsearch.rakuten.co.jp%2Fsearch%2Fmall%2F%25E3%2581%25A4%25E3%2581%258F%25E3%2582%258A%25E3%2581%25AA%25E3%2581%258C%25E3%2582%2589%25E5%25AD%25A6%25E3%2581%25B6%2521%2520Python%25E3%2581%25AB%25E3%2582%2588%25E3%2582%258B%25E5%259B%25A0%25E6%259E%259C%25E5%2588%2586%25E6%259E%2590" rel="nofollow noopener" target="_blank">
						<img loading="lazy" decoding="async" src="https://thumbnail.image.rakuten.co.jp/@0_mall/book/cabinet/3575/9784839973575.jpg?_ex=400x400" alt="" width="120" height="120" />					</a>
				</div>
						<div class="pochipp-box__body">
				<div class="pochipp-box__title">
					<a href="https://hb.afl.rakuten.co.jp/hgc/39c5c75e.b3799909.39c5c75f.e5fd5e7d/?pc=https%3A%2F%2Fsearch.rakuten.co.jp%2Fsearch%2Fmall%2F%25E3%2581%25A4%25E3%2581%258F%25E3%2582%258A%25E3%2581%25AA%25E3%2581%258C%25E3%2582%2589%25E5%25AD%25A6%25E3%2581%25B6%2521%2520Python%25E3%2581%25AB%25E3%2582%2588%25E3%2582%258B%25E5%259B%25A0%25E6%259E%259C%25E5%2588%2586%25E6%259E%2590&#038;m=https%3A%2F%2Fsearch.rakuten.co.jp%2Fsearch%2Fmall%2F%25E3%2581%25A4%25E3%2581%258F%25E3%2582%258A%25E3%2581%25AA%25E3%2581%258C%25E3%2582%2589%25E5%25AD%25A6%25E3%2581%25B6%2521%2520Python%25E3%2581%25AB%25E3%2582%2588%25E3%2582%258B%25E5%259B%25A0%25E6%259E%259C%25E5%2588%2586%25E6%259E%2590" rel="nofollow noopener" target="_blank">
						つくりながら学ぶ! Pythonによる因果分析 因果推論・因果探索の実践入門					</a>
				</div>

									<div class="pochipp-box__info">小川雄太郎</div>
				
				
							</div>
				<div class="pochipp-box__btns"
		data-maxclmn-pc="fit"
		data-maxclmn-mb="1"
	>
					<div class="pochipp-box__btnwrap -amazon">
								<a href="https://www.amazon.co.jp/s?k=%E3%81%A4%E3%81%8F%E3%82%8A%E3%81%AA%E3%81%8C%E3%82%89%E5%AD%A6%E3%81%B6%21%20Python%E3%81%AB%E3%82%88%E3%82%8B%E5%9B%A0%E6%9E%9C%E5%88%86%E6%9E%90&#038;tag=shu65-22" class="pochipp-box__btn" rel="nofollow noopener" target="_blank">
					<span>
						Amazon					</span>
									</a>
			</div>
							<div class="pochipp-box__btnwrap -rakuten">
								<a href="https://hb.afl.rakuten.co.jp/hgc/39c5c75e.b3799909.39c5c75f.e5fd5e7d/?pc=https%3A%2F%2Fsearch.rakuten.co.jp%2Fsearch%2Fmall%2F%25E3%2581%25A4%25E3%2581%258F%25E3%2582%258A%25E3%2581%25AA%25E3%2581%258C%25E3%2582%2589%25E5%25AD%25A6%25E3%2581%25B6%2521%2520Python%25E3%2581%25AB%25E3%2582%2588%25E3%2582%258B%25E5%259B%25A0%25E6%259E%259C%25E5%2588%2586%25E6%259E%2590&#038;m=https%3A%2F%2Fsearch.rakuten.co.jp%2Fsearch%2Fmall%2F%25E3%2581%25A4%25E3%2581%258F%25E3%2582%258A%25E3%2581%25AA%25E3%2581%258C%25E3%2582%2589%25E5%25AD%25A6%25E3%2581%25B6%2521%2520Python%25E3%2581%25AB%25E3%2582%2588%25E3%2582%258B%25E5%259B%25A0%25E6%259E%259C%25E5%2588%2586%25E6%259E%2590" class="pochipp-box__btn" rel="nofollow noopener" target="_blank">
					<span>
						楽天市場					</span>
									</a>
			</div>
											</div>
								<div class="pochipp-box__logo">
					<img loading="lazy" decoding="async" src="https://www.mattari-benkyo-note.com/wp-content/plugins/pochipp/assets/img/pochipp-logo-t1.png" alt="" width="32" height="32">
					<span>ポチップ</span>
				</div>
					</div>
	


<p>次は「<a href="https://amzn.to/3KLfhmz" target="_blank" rel="noopener" title="つくりながら学ぶ! Pythonによる因果分析 ~因果推論・因果探索の実践入門">つくりながら学ぶ! Pythonによる因果分析 ~因果推論・因果探索の実践入門</a>」です。</p>



<p>これまで紹介した本の内容をPythonで実際に実装するか？を勉強したい人には現状ベストな本です。内容としては構造的因果グラフとdo演算子を使った統計的因果推論の基本的な説明に加えて、LiNGAMとベイジアンネットワーク、最新の深層学習を使った因果探索の話に関する説明＋Pythonを使ったコードになっています。特にPythonコードはJupyter Notebookで書かれていて、Google Colabを使えばすぐに実行できるように工夫されています。このため、これまで紹介した本のどれよりも具体的でかつ分かりやすい印象です。コードも一通り読みつつ実行もしてみましたが、きれいなコードで分かりやすかった印象です。</p>



<p>このため、今すぐ統計的因果推論や統計的因果探索を実務で使いたい人に対して一番最初に読むといい本はどれですか？と聞かれたらこれが良いと答えると思います。</p>



<h1 class="wp-block-heading">統計的因果推論の理論と実装 (Wonderful R)</h1>


		<div class="pochipp-box"
			data-id="2939"
			data-img="l"
			data-lyt-pc="dflt"
			data-lyt-mb="vrtcl"
			data-btn-style="dflt"
			data-btn-radius="on"
			data-sale-effect="none"
			 data-cvkey="8e12cc90"		>
							<div class="pochipp-box__image">
					<a href="https://hb.afl.rakuten.co.jp/hgc/39c5c75e.b3799909.39c5c75f.e5fd5e7d/?pc=https%3A%2F%2Fsearch.rakuten.co.jp%2Fsearch%2Fmall%2F%25E7%25B5%25B1%25E8%25A8%2588%25E7%259A%2584%25E5%259B%25A0%25E6%259E%259C%25E6%258E%25A8%25E8%25AB%2596%25E3%2581%25AE%25E7%2590%2586%25E8%25AB%2596%25E3%2581%25A8%25E5%25AE%259F%25E8%25A3%2585&#038;m=https%3A%2F%2Fsearch.rakuten.co.jp%2Fsearch%2Fmall%2F%25E7%25B5%25B1%25E8%25A8%2588%25E7%259A%2584%25E5%259B%25A0%25E6%259E%259C%25E6%258E%25A8%25E8%25AB%2596%25E3%2581%25AE%25E7%2590%2586%25E8%25AB%2596%25E3%2581%25A8%25E5%25AE%259F%25E8%25A3%2585" rel="nofollow noopener" target="_blank">
						<img loading="lazy" decoding="async" src="https://thumbnail.image.rakuten.co.jp/@0_mall/book/cabinet/2452/9784320112452_1_4.jpg?_ex=400x400" alt="" width="120" height="120" />					</a>
				</div>
						<div class="pochipp-box__body">
				<div class="pochipp-box__title">
					<a href="https://hb.afl.rakuten.co.jp/hgc/39c5c75e.b3799909.39c5c75f.e5fd5e7d/?pc=https%3A%2F%2Fsearch.rakuten.co.jp%2Fsearch%2Fmall%2F%25E7%25B5%25B1%25E8%25A8%2588%25E7%259A%2584%25E5%259B%25A0%25E6%259E%259C%25E6%258E%25A8%25E8%25AB%2596%25E3%2581%25AE%25E7%2590%2586%25E8%25AB%2596%25E3%2581%25A8%25E5%25AE%259F%25E8%25A3%2585&#038;m=https%3A%2F%2Fsearch.rakuten.co.jp%2Fsearch%2Fmall%2F%25E7%25B5%25B1%25E8%25A8%2588%25E7%259A%2584%25E5%259B%25A0%25E6%259E%259C%25E6%258E%25A8%25E8%25AB%2596%25E3%2581%25AE%25E7%2590%2586%25E8%25AB%2596%25E3%2581%25A8%25E5%25AE%259F%25E8%25A3%2585" rel="nofollow noopener" target="_blank">
						統計的因果推論の理論と実装 潜在的結果変数と欠測データ （Wonderful R　5）					</a>
				</div>

									<div class="pochipp-box__info">高橋 将宜</div>
				
				
							</div>
				<div class="pochipp-box__btns"
		data-maxclmn-pc="fit"
		data-maxclmn-mb="1"
	>
					<div class="pochipp-box__btnwrap -amazon">
								<a href="https://www.amazon.co.jp/s?k=%E7%B5%B1%E8%A8%88%E7%9A%84%E5%9B%A0%E6%9E%9C%E6%8E%A8%E8%AB%96%E3%81%AE%E7%90%86%E8%AB%96%E3%81%A8%E5%AE%9F%E8%A3%85&#038;tag=shu65-22" class="pochipp-box__btn" rel="nofollow noopener" target="_blank">
					<span>
						Amazon					</span>
									</a>
			</div>
							<div class="pochipp-box__btnwrap -rakuten">
								<a href="https://hb.afl.rakuten.co.jp/hgc/39c5c75e.b3799909.39c5c75f.e5fd5e7d/?pc=https%3A%2F%2Fsearch.rakuten.co.jp%2Fsearch%2Fmall%2F%25E7%25B5%25B1%25E8%25A8%2588%25E7%259A%2584%25E5%259B%25A0%25E6%259E%259C%25E6%258E%25A8%25E8%25AB%2596%25E3%2581%25AE%25E7%2590%2586%25E8%25AB%2596%25E3%2581%25A8%25E5%25AE%259F%25E8%25A3%2585&#038;m=https%3A%2F%2Fsearch.rakuten.co.jp%2Fsearch%2Fmall%2F%25E7%25B5%25B1%25E8%25A8%2588%25E7%259A%2584%25E5%259B%25A0%25E6%259E%259C%25E6%258E%25A8%25E8%25AB%2596%25E3%2581%25AE%25E7%2590%2586%25E8%25AB%2596%25E3%2581%25A8%25E5%25AE%259F%25E8%25A3%2585" class="pochipp-box__btn" rel="nofollow noopener" target="_blank">
					<span>
						楽天市場					</span>
									</a>
			</div>
											</div>
								<div class="pochipp-box__logo">
					<img loading="lazy" decoding="async" src="https://www.mattari-benkyo-note.com/wp-content/plugins/pochipp/assets/img/pochipp-logo-t1.png" alt="" width="32" height="32">
					<span>ポチップ</span>
				</div>
					</div>
	


<p>最後に「<a href="https://amzn.to/3khMGKu" target="_blank" rel="noopener" title="統計的因果推論の理論と実装 (Wonderful R)">統計的因果推論の理論と実装 (Wonderful R)</a>」です。</p>



<p>今までの本と同じく「統計的因果推論」と書かれていますが、こっちは潜在的結果変数の枠組みを使うRubin流の因果推論の本です。このため、因果グラフは一部出てきますが、do演算子についてはでてきません。ただ、かなり幅広い内容の本となっているので、Rubin流の因果推論について詳しく知りたいという人です。説明にはRによる具体的なコードが付いていて、言葉だけではわかりづらいところもコードで補えるような形になっています。</p>



<p>PythonはわからないけどRを使える人や、Rubin流の因果推論についても詳しく知りたいという人にはお勧めな本かと思います。</p>



<h1 class="wp-block-heading">終わりに</h1>



<p>年末年始のお休みで読んだ本のうち、統計的因果推論、統計的因果探索に関する本をまとめて紹介しました。最初に紹介した「因果推論の科学 「なぜ?」の問いにどう答えるか」が思いのほか面白くてその勢いのまま読んだという感じがしてますが、因果推論はいつかまとめてちゃんと勉強したいと思っていたのでちょうどよかったです。</p>



<p>今後、今日紹介した一部の本に関しては本に書かれていた説明だけだと意味がわからないところがあったので、そういうところに関してはまたメモとして記事にしておこうと思います。</p><p>The post <a href="https://www.mattari-benkyo-note.com/2023/01/10/causal_book/">統計的因果推論、統計的因果探索の勉強で読んだ本まとめ(2023/1版)</a> first appeared on <a href="https://www.mattari-benkyo-note.com">まったり勉強ノート</a>.</p>]]></content:encoded>
					
					<wfw:commentRss>https://www.mattari-benkyo-note.com/2023/01/10/causal_book/feed/</wfw:commentRss>
			<slash:comments>0</slash:comments>
		
		
		<post-id xmlns="com-wordpress:feed-additions:1">1534</post-id>	</item>
		<item>
		<title>Kaggleの「Open Problems &#8211; Multimodal Single-Cell Integration」の振り返り</title>
		<link>https://www.mattari-benkyo-note.com/2022/11/19/open-problems-multimodal/</link>
					<comments>https://www.mattari-benkyo-note.com/2022/11/19/open-problems-multimodal/#respond</comments>
		
		<dc:creator><![CDATA[Shuji Suzuki (shu)]]></dc:creator>
		<pubDate>Sat, 19 Nov 2022 00:05:55 +0000</pubDate>
				<category><![CDATA[プログラミング]]></category>
		<category><![CDATA[Bioinformatics]]></category>
		<category><![CDATA[kaggle]]></category>
		<category><![CDATA[python]]></category>
		<category><![CDATA[pytorch]]></category>
		<guid isPermaLink="false">https://www.mattari-benkyo-note.com/?p=1447</guid>

					<description><![CDATA[<p>今回は2022/11/15 (日本時間の2022/11/16の朝)まで行われていた「Open Problems &#8211; Multimodal Single-Cell Integration」に参加した際、どうして [&#8230;]</p>
<p>The post <a href="https://www.mattari-benkyo-note.com/2022/11/19/open-problems-multimodal/">Kaggleの「Open Problems – Multimodal Single-Cell Integration」の振り返り</a> first appeared on <a href="https://www.mattari-benkyo-note.com">まったり勉強ノート</a>.</p>]]></description>
										<content:encoded><![CDATA[<p>今回は2022/11/15 (日本時間の2022/11/16の朝)まで行われていた「<a href="https://www.kaggle.com/competitions/open-problems-multimodal/overview" target="_blank" rel="noopener" title="Open Problems - Multimodal Single-Cell Integration">Open Problems &#8211; Multimodal Single-Cell Integration</a>」に参加した際、どうして参加したのかや、参加中にやってよかったこと、課題などを忘れないようにまとめたので、せっかくなので記事にしました。</p>



<p>ちなみに、自分の手法に関してはこちらをご覧ください。</p>



<p><a href="https://www.kaggle.com/competitions/open-problems-multimodal/discussion/366961">https://www.kaggle.com/competitions/open-problems-multimodal/discussion/366961</a></p>



<h1 class="wp-block-heading">Open Problems &#8211; Multimodal Single-Cell Integration とは？</h1>



<p>コンペのサイトのoverviewに書かれていることをざっくり日本語に訳すと「骨髄幹細胞が血液細胞になるにつれて単一の細胞のDNAとRNA、タンパク質がどのように変化するかを予測するコンペ」ということになります。もう少し専門家の人がわかりやすいように説明すると以下の二つのsingle-cellのデータを使ったモデルを構築するコンペです。</p>



<ol class="wp-block-list">
<li>Chromatin accessibility(peak countをTF-IDFで変換したデータ)からRNAの発現量 (library-sizeでノーマライズされたcountデータ)の予測</li>



<li>RNAの発現量 (library-sizeでノーマライズされたcountデータ)からsurface protein levels (dsb でノーマライズされたデータ)の予測</li>
</ol>



<h1 class="wp-block-heading">そもそもなぜ参加したのか？</h1>



<p>参加してがんばろうと思ったきっかけをせっかくなので書いておくと</p>



<ol class="wp-block-list">
<li>Bioチームの同僚がこういうコンペがあるよとslackで教えてくれた (参加したきっかけ)</li>



<li>育休に入る前にやっておくべきことは？ということを会社の先輩パパさんに相談したら、「solo gold medalは取っておけ」と言われた (がんばったきっかけ)</li>
</ol>



<p>ということがあります。特に2つ目は「確かに！」と思いました。なので、今回が初ソロ参加にして、solo gold medalを取る最後のチャンスということで頑張りました。アドバイスをくれた先輩パパさんには感謝しかありません。</p>



<h1 class="wp-block-heading">やってよかったこと</h1>



<p>さて、ここからやってよかったことについて忘れないように書いておきます。</p>



<h2 class="wp-block-heading">まずはとにかく簡単な方法でいいのでsubmitする</h2>



<p>社内のkaggle強い人に前に言われた気がするので、まず意識したことがこれです。やってみて思ったのですが以下のような効果があることを実感しました。</p>



<ol class="wp-block-list">
<li>submitして順位やスコアがでるようになるとモチベーションが上がる。</li>



<li>何かベースラインがあると手法開発がしやすい</li>
</ol>



<p>特に1の効果がすごかった気がします。ちなみに私の場合、KaggleのCodeで公開されてたシンプルな手法をそのままコピペしてsubmitして一番最初のスコアを出しました。シンプルな手法だったため、最初は200位にも入れなかったと記憶していますが、それでもモチベーションはそれまでと比べてすごく上がりました。</p>



<h2 class="wp-block-heading">毎日決めた数の改良を試す</h2>



<p>これはコンペに限らず重要なことだと思いますが、とにかくコンスタントに改良を続けていき、最低限local CV scoreを出すということを意識してました。</p>



<p>今回のコンペでは3，4個の改良を毎日試すことを目標にしてやってました。私の場合、gitの1 branchが1改良になっていて、9/1からがんばり始めて11/15までに最終的には351個branchができていました。なので、通算4から5個くらいの改良を毎日試していたことになります。</p>



<p>このとき改良としてうまくいきそうなものはもちろんですが、思いついたタイミングではあまり筋が良くなさそうだけど他にやることがないというときは、筋が良くないアイディアもダメ元で試すようにしてました。</p>



<p>結果として、やってよかったと感じた理由としては、仮説を立てて試すとうまくいかなかったときの問題への理解度がすごい上がるため、とにかくいろいろ試すことで、そこからいいアイディアをひらめくということが多くあったからです。</p>



<p>今回のコンペで何度か大きくスコアを上げたタイミングがありましたが、総じて何かの失敗から気が付いたアイディアをもとにしたことが多かった印象です。ちなみにどれくらい失敗し続けたのかの体感ですが、1週間スコアが上がらないということがよくあったので、うまくいった改良というのはこれだけ試しても数えるくらいしかなかった印象です。</p>



<h2 class="wp-block-heading">ensembleはすぐに試さない</h2>



<p>これも社内のkaggle強い人に前に言われたことなのですが、ensembleをすぐに試さないというのはやってよかったなと思います。<br>今回はensembleを試さなくてよかったなと感じた理由としては、ensembleはやれば簡単にスコアが上がるのでいいように感じますが、ensembleは始めるとすごいいろいろ試せることがある一方、おそらくそこまで高いスコア上昇をしていなかったような印象があります。</p>



<p>このためensembleを頑張ることに時間使うよりも1つのモデルのスコアが上がるように頑張るという作戦でいたのですが、結果としてそれが良かった印象です。</p>



<h2 class="wp-block-heading">最後まで諦めない</h2>



<p>最後はこれです。特に今回はsolo gold medalが欲しかったので、一人で出ていたのですが、途中全然順位が上がらず何度ももうやめようかなぁと思うタイミングがありました。結果としてそのときダメもとで試した改良やそこから思いついたアイディアがうまくいってスコアを伸ばすということが何度もあったので諦めなくてよかったです。</p>



<h1 class="wp-block-heading">課題</h1>



<p>次にコンペの参加中、もしくは終わったあと振り返って感じた課題的な部分も列挙しておきます。</p>



<h2 class="wp-block-heading">モチベーションの維持が大変（特に一人のとき）</h2>



<p>「最後まで諦めない」のところに似たようなことを書きましたが、とにかくモチベーション維持が大変でした。今まではチームででていたのと、すごい応援してくれる上司がいてくれたりとこの部分はそこまで問題にならなかった印象でした。</p>



<p>今後一人ででるならこの部分はまだまだどうにかしないと最後まで戦うのは難しい気がしています。</p>



<h2 class="wp-block-heading">public scoreがどうしても気になってしまう</h2>



<p>今回のコンペはデータセットの説明を読んだ段階でpublic scoreとprivate scoreのギャップが激しそうだなぁということを思ってたので、それほどpublic scoreを気にしないほうがいいかも？ということを最初思ってました。ただ、それでも最後はpublic scoreを気にしてモデルの改良をしてしまっていました。</p>



<p>コンペのサイトのデータセットの説明のところに詳しくかかれていますが、今回public scoreは4人中一人の最終日一つ前までのデータを、private scoreは4人全員の最後の日のデータを予測するというものでした。この説明を見ると予測する対象が全然違うことがわかります。実際、public scoreの上位陣が軒並みprivate scoreでは順位を落としていました。なので、後から振り返るとやっぱり最初に思った通りpublic scoreをそれほど気にせずモデルの改良をしていて正解だったと思います。ただ、そうはいっても、public scoreの順位が気になってしまって、結局最後はpublic scoreを気にして最終submittionを決めていました。ただ、ふたを開けたらlocal CVがベストなものがprivate scoreも一番よかったので、public scoreを気にしすぎたなぁと思っていました。</p>



<p>ただ、この部分はふたを開けてみないとわからないところなので、難しいポイントな気がしています。</p>



<h2 class="wp-block-heading">論文を読んで最新研究をコンペの期間中に試すことが心理的に難しい</h2>



<p>期間の前半ならまだましですが、どうしても後半になればなるほど、心理的に追い込まれていきます。なので、試すアイディアがないと思いつつも、自分の全然読んだことない論文の手法を試すということが難しかったです。この結果、特に後半はアイディアが前に試したもののちょっと変更したものばかりになり、結果としてそれほど精度が向上しないということが起きました。ただ、この部分一人でやっていると意識しても難しい部分な気がしています。</p>



<p>このため、日ごろから論文を読んでいろいろ手法を勉強し、できれば実装を動かして感触を確かめておくことが重要かなと思いました。実際、今回のコンペではReactomeのpathwayデータを使ったのですが、これは論文の追試のついでにいろいろ試したことがあったからこそできたことだったと思っています。</p>



<p>ちなみにその時の記事はこちらにあります。</p>



<figure class="wp-block-embed is-type-wp-embed is-provider-まったり勉強ノート wp-block-embed-まったり勉強ノート"><div class="wp-block-embed__wrapper">
<blockquote class="wp-embedded-content" data-secret="2XpReJuJFD"><a href="https://www.mattari-benkyo-note.com/2022/05/06/reactome-gene-pathway-data/">ReactomeからPathwayの階層構造とPathwayに関連するGeneのデータを取得する</a></blockquote><iframe loading="lazy" class="wp-embedded-content" sandbox="allow-scripts" security="restricted"  title="&#8220;ReactomeからPathwayの階層構造とPathwayに関連するGeneのデータを取得する&#8221; &#8212; まったり勉強ノート" src="https://www.mattari-benkyo-note.com/2022/05/06/reactome-gene-pathway-data/embed/#?secret=1JxmvC3IrQ#?secret=2XpReJuJFD" data-secret="2XpReJuJFD" width="600" height="338" frameborder="0" marginwidth="0" marginheight="0" scrolling="no"></iframe>
</div></figure>



<h2 class="wp-block-heading">ensembleの準備をせずにいると最後の土壇場で困る</h2>



<p>ensembleをすぐに試さないことを意識していた結果、実は今回期間最後のほうでensembleをしようとしたときにいろいろ準備できてないことに気が付いてすごい慌てました。具体的には以下の二つが締め切り1週間前の段階でできていませんでした。</p>



<ol class="wp-block-list">
<li>予測結果をどのように集計して最終的なファイルを作るか？</li>



<li>ensembleした結果をどのように評価するか？</li>
</ol>



<p>結局、1はギリギリできたのですが、2の評価方法はとても間に合いそうになかったのでpublic scoreを見てうまくいっているかいっていないか？を判断していました。ただ、これは評価としては微妙なので、次からはそんなに頑張らないにしても中盤くらいにはensembleの準備はしておこうと思います。</p>



<h1 class="wp-block-heading">最後に</h1>



<p>折角なので個人的な振り返りをblog記事にしました。仕事ではない状態でのkaggle参加は初めてだったのですが、思ったよりも疲れたのと、無事子供がコンペ締め切りの前日に生まれて、これから子育てもあるのででしばらくはコンペにでない気がします。ただ、次出たときに今回のコンペで何を思ったのか忘れないようにしておければと思っていたので、記事にできてよかったです。</p>



<p>これが他の人の参考になれば幸いです。</p><p>The post <a href="https://www.mattari-benkyo-note.com/2022/11/19/open-problems-multimodal/">Kaggleの「Open Problems – Multimodal Single-Cell Integration」の振り返り</a> first appeared on <a href="https://www.mattari-benkyo-note.com">まったり勉強ノート</a>.</p>]]></content:encoded>
					
					<wfw:commentRss>https://www.mattari-benkyo-note.com/2022/11/19/open-problems-multimodal/feed/</wfw:commentRss>
			<slash:comments>0</slash:comments>
		
		
		<post-id xmlns="com-wordpress:feed-additions:1">1447</post-id>	</item>
		<item>
		<title>PyTorch Geometricを使ってVariational Graph Auto-Encodersを作って学習してみる</title>
		<link>https://www.mattari-benkyo-note.com/2022/05/23/pytorch-geometric-vgae/</link>
					<comments>https://www.mattari-benkyo-note.com/2022/05/23/pytorch-geometric-vgae/#respond</comments>
		
		<dc:creator><![CDATA[Shuji Suzuki (shu)]]></dc:creator>
		<pubDate>Sun, 22 May 2022 23:08:14 +0000</pubDate>
				<category><![CDATA[プログラミング]]></category>
		<category><![CDATA[GNN]]></category>
		<category><![CDATA[pytorch]]></category>
		<category><![CDATA[pytorch-geometric]]></category>
		<guid isPermaLink="false">https://www.mattari-benkyo-note.com/?p=1436</guid>

					<description><![CDATA[<p>はじめに 最近読んだ論文にVariational Graph Auto-Encoders (VGAE) を使ったモデルがあったので、自分でもやってみようと思い、作ってみました。本日はそのまとめになります。 本日紹介する使 [&#8230;]</p>
<p>The post <a href="https://www.mattari-benkyo-note.com/2022/05/23/pytorch-geometric-vgae/">PyTorch Geometricを使ってVariational Graph Auto-Encodersを作って学習してみる</a> first appeared on <a href="https://www.mattari-benkyo-note.com">まったり勉強ノート</a>.</p>]]></description>
										<content:encoded><![CDATA[<h2 class="wp-block-heading">はじめに</h2>



<p>最近読んだ論文に<a href="https://arxiv.org/abs/1611.07308" target="_blank" rel="noreferrer noopener">Variational Graph Auto-Encoders</a> (VGAE) を使ったモデルがあったので、自分でもやってみようと思い、作ってみました。本日はそのまとめになります。</p>



<p>本日紹介する使うコードは以下のものです。</p>



<p><a href="https://github.com/shu65/pytorch_geometric_examples/blob/main/PyTorch_Geometric_Variational_Graph_AutoEncoder.ipynb">https://github.com/shu65/pytorch_geometric_examples/blob/main/PyTorch_Geometric_Variational_Graph_AutoEncoder.ipynb</a></p>



<p>また、このコード自体、以下のPyTorch Geometricのexampleのコードとほぼ同じです。</p>



<p><a href="https://github.com/pyg-team/pytorch_geometric/blob/ee509ad65aefa679047356bb00bc498f35ce7e20/examples/autoencoder.py">https://github.com/pyg-team/pytorch_geometric/blob/ee509ad65aefa679047356bb00bc498f35ce7e20/examples/autoencoder.py</a></p>



<p>このblog記事ではVGAEで必要な機能がPyTorch Geometricでどう実装されているのかわからなかった部分がいくつかあるのでその部分を解説していく記事になります。</p>



<h2 class="wp-block-heading">PyTorch Geometricとは</h2>



<p>PyTorch GeometricはPyTorchを使って構築されたGraph Neural Network向けのライブラリになります。</p>



<p>GitHubのURLは以下の通りです。</p>



<p><a href="https://github.com/pyg-team/pytorch_geometric">https://github.com/pyg-team/pytorch_geometric</a></p>



<p>最新のPyTorchやCUDAにもちゃんと対応しており、Graph Neural Networkで必要な基本的な機能はそろっている印象です。</p>



<h2 class="wp-block-heading">Variational Graph Auto-Encoders (VGAE)とは</h2>



<p>VGAEは<a href="https://arxiv.org/abs/1312.6114" target="_blank" rel="noreferrer noopener" title="Variational Auto-Encoder">Variational Auto-Encoder</a> (VAE) というモデルをGraphデータ向けに拡張したモデルです。VAEの説明を始めるとそれだけですごく長くなりますので、今回はVGAEを実装するうえで必要なところだけ紹介します。</p>



<p>VAEは以下のようにEncoderとDecoderという二つのモデルを組み合わせたモデルになります。</p>



<div class="wp-block-image"><figure class="aligncenter size-full is-resized"><img loading="lazy" decoding="async" src="https://www.mattari-benkyo-note.com/wp-content/uploads/2022/05/VAE.png" alt="" class="wp-image-1441" width="227" height="338" srcset="https://www.mattari-benkyo-note.com/wp-content/uploads/2022/05/VAE.png 480w, https://www.mattari-benkyo-note.com/wp-content/uploads/2022/05/VAE-201x300.png 201w" sizes="auto, (max-width: 227px) 100vw, 227px" /><figcaption>VAEの概要図</figcaption></figure></div>



<p>このうち、EncoderとDecoderは以下のようなモデルになります。</p>



<ul class="wp-block-list"><li>Encoder: 入力Xを受け取って潜在変数Zの分布のパラメータを出力する</li><li>Decoder: 潜在変数Zを受け取って入力Xを再構成する</li></ul>



<p>VAEで重要なのがEncoderの部分と潜在変数Zのサンプリングの部分です。この潜在変数Zの分布が標準正規分布という仮定のもと学習させながら、Encoderで潜在変数Zの分布のパラメータを出力し、その分布のパラメータを使って潜在変数ZをサンプリングしてDecoderに渡すということを行います。</p>



<p>このVAEをGraph データに拡張するためにVGAEはEncoderとDecoderを以下のようなモデルにしています。</p>



<ul class="wp-block-list"><li>Encoder: ノードの特徴ベクトルXと隣接行列Aを入力として受け取り、潜在変数Zの分布のパラメータを出力する</li><li>Decoder: 潜在変数Zを受け取り隣接行列Aを再構築する</li></ul>



<p>図にすると以下のようなイメージです。</p>



<div class="wp-block-image"><figure class="aligncenter size-full is-resized"><img loading="lazy" decoding="async" src="https://www.mattari-benkyo-note.com/wp-content/uploads/2022/05/VGAE.png" alt="" class="wp-image-1442" width="288" height="356" srcset="https://www.mattari-benkyo-note.com/wp-content/uploads/2022/05/VGAE.png 578w, https://www.mattari-benkyo-note.com/wp-content/uploads/2022/05/VGAE-243x300.png 243w" sizes="auto, (max-width: 288px) 100vw, 288px" /><figcaption>VGAEの概要図</figcaption></figure></div>



<p>VGAEとVAEとの違いはEncoderでグラフの情報であるノード情報と隣接行列を受け取れるようにしたことと、Decoderが出力するものが隣接行列になることです。</p>



<h2 class="wp-block-heading">VGAEをPyTorch Geometricを使って実装する</h2>



<p>VGAEの概略を説明したので次は実際に実装を紹介していきます。まずはEncoderである<code>VariationalGCNEncoder</code>から見ていきます。EncoderではPyTorch Geometricに実装されている <code>GCNConv</code> を使って実装します。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="vgae_encoder" data-lang="Python"><code>from torch_geometric.nn import GCNConv

class VariationalGCNEncoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, 2 * out_channels)
        self.conv_mu = GCNConv(2 * out_channels, out_channels)
        self.conv_logstd = GCNConv(2 * out_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)</code></pre></div>



<p> <code>GCNConv</code> はノードのインプットのチャンネル数、アウトプットのチャネル数を引数にとってインスタンスを作ります。そしてforwardではノードのtensor <code>x</code> と隣接行列のかわりにどのノード同士がつながっているか？を示す<code>edge_index</code>を渡します。<code>GCNConv</code> の中身についてはドキュメントに詳しく書かれているのでそちらをご覧ください。</p>



<p><a href="https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.GCNConv">https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.GCNConv</a></p>



<p>このEncoderではVGAEの概要でも説明した通り、潜在変数の分布のパラメータを返します。ここではガウス分布の平均を表すmuと標準偏差にlogを適用したlogstdを返しています。</p>



<p>モデルの実装としてはあとはPyTorch Geometricで実装されている<code>VGAE</code>というクラスに渡せば終わりになります。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="vgae" data-lang="Python"><code>from torch_geometric.nn import VGAE

model = VGAE(VariationalGCNEncoder(in_channels, out_channels))</code></pre></div>



<p>ただ、これだとさすがに初見だと何が何だかわからなかったので、少し説明します。</p>



<p>まず、Decoderについてです。Decoderは<code>VGAE</code> のデフォルトでは<code>InnerProductDecoder</code>というものが使われます。これはVGAEの元論文でも使われていたDecoderの実装で、エッジの両端のノードに対応する潜在変数の各要素の積を取って総和を取り、sigmoidを適用して0-1の値にして出力します。出力値が0-1の値になっているのでDecoderの出力値は計算に使った二つのノードの間にエッジがある確率とみることができます。</p>



<p>詳しくは以下のドキュメントをご覧ください。</p>



<p><a href="https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.models.InnerProductDecoder">https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.models.InnerProductDecoder</a></p>



<p>また、ロス関数についてですが、<code>VGAE</code> の中にVGAEで必要な以下の二つが実装されています。</p>



<ul class="wp-block-list"><li><code>recon_loss</code>: 潜在変数zとノード同士のつながりを示すpos_edge_indexを入力にとり、Decoderを利用して各エッジのある確率を計算、その確率に対してbinary cross entropyを計算してlossとして返す関数<br><a href="https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/models/autoencoder.html#GAE.recon_loss" target="_blank" rel="noreferrer noopener">https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/models/autoencoder.html#GAE.recon_loss</a></li><li><code>kl_loss</code>: Encoderの出力したmuとlogstdを使って標準正規分布とのKLダイバージェンスを計算しlossとして返す関数<br><a href="https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.models.VGAE.kl_loss" target="_blank" rel="noreferrer noopener">https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.models.VGAE.kl_loss</a></li></ul>



<p>これを以下のように学習ループで利用して学習をおこないます。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="vgae_train" data-lang="Python"><code>for epoch in range(0, 400):
    model.train()
    optimizer.zero_grad()
    z = model.encode(train_data.x, train_data.edge_index)
    recon_loss = model.recon_loss(z, train_data.pos_edge_label_index)
    kl_loss = (1 / train_data.num_nodes) * model.kl_loss()
    loss = recon_loss + kl_loss
    loss.backward()
    optimizer.step()</code></pre></div>



<p>最後に上のコードではノード間にエッジがあるところの情報は<code>train_data.pos_edge_label_index</code>で渡しているのですが、ノード間にエッジがないという情報はどこで渡しているか？ということについて説明します。</p>



<p>コードを読むと実はrecon_lossの中で自動的にエッジがないという情報を生成してそれを込みでロスが計算されています。具体的には以下の部分です。</p>



<p><a href="https://github.com/pyg-team/pytorch_geometric/blob/d2b2e662488eae07d153de6d4b8c56c24bf413d9/torch_geometric/nn/models/autoencoder.py#L101">https://github.com/pyg-team/pytorch_geometric/blob/d2b2e662488eae07d153de6d4b8c56c24bf413d9/torch_geometric/nn/models/autoencoder.py#L101</a></p>



<p>ここで引数で<code>neg_edge_index</code>が<code>None</code>のときは自動でエッジが存在しないノードのペアをサンプリングするという処理になっています。</p>



<p>以下です。その他の部分で気になるところがある場合は全体のコードを以下のところに置いてありますのでご覧ください。</p>



<p><a href="https://github.com/shu65/pytorch_geometric_examples/blob/main/PyTorch_Geometric_Variational_Graph_AutoEncoder.ipynb">https://github.com/shu65/pytorch_geometric_examples/blob/main/PyTorch_Geometric_Variational_Graph_AutoEncoder.ipynb</a></p>



<h2 class="wp-block-heading">終わりに</h2>



<p>今回はPyTorch Geometricの練習として、VGAEを実装してみたのでまとめの記事を書きました。PyTorch Geometricを今回初めて使ったのですが、Graph Neural Networkに必要な基本的な機能はそろっていそうなので、今後もGraph Neural Networkを使う機会があれば使ってみようと思います。</p><p>The post <a href="https://www.mattari-benkyo-note.com/2022/05/23/pytorch-geometric-vgae/">PyTorch Geometricを使ってVariational Graph Auto-Encodersを作って学習してみる</a> first appeared on <a href="https://www.mattari-benkyo-note.com">まったり勉強ノート</a>.</p>]]></content:encoded>
					
					<wfw:commentRss>https://www.mattari-benkyo-note.com/2022/05/23/pytorch-geometric-vgae/feed/</wfw:commentRss>
			<slash:comments>0</slash:comments>
		
		
		<post-id xmlns="com-wordpress:feed-additions:1">1436</post-id>	</item>
		<item>
		<title>[勉強ノート]「ベイズ推論による機械学習入門」5.7 ニューラルネットワークで紹介されたものをPyTorchで実装してみた</title>
		<link>https://www.mattari-benkyo-note.com/2022/02/28/bayes-book-5-7-pytorch/</link>
					<comments>https://www.mattari-benkyo-note.com/2022/02/28/bayes-book-5-7-pytorch/#respond</comments>
		
		<dc:creator><![CDATA[Shuji Suzuki (shu)]]></dc:creator>
		<pubDate>Sun, 27 Feb 2022 23:17:33 +0000</pubDate>
				<category><![CDATA[未分類]]></category>
		<category><![CDATA[pytorch]]></category>
		<category><![CDATA[ベイズ推論]]></category>
		<category><![CDATA[ベイズ推論による機械学習入門]]></category>
		<category><![CDATA[勉強ノート]]></category>
		<guid isPermaLink="false">https://www.mattari-benkyo-note.com/?p=849</guid>

					<description><![CDATA[<p>はじめに 最近ベイズ推論の勉強をしていて機械学習スタートアップシリーズの「ベイズ推論による機械学習入門」を読んでいます。今回はこの本の5.7 のニューラルネットワークの章で紹介されていたモデルをPyTorchで実装したの [&#8230;]</p>
<p>The post <a href="https://www.mattari-benkyo-note.com/2022/02/28/bayes-book-5-7-pytorch/">[勉強ノート]「ベイズ推論による機械学習入門」5.7 ニューラルネットワークで紹介されたものをPyTorchで実装してみた</a> first appeared on <a href="https://www.mattari-benkyo-note.com">まったり勉強ノート</a>.</p>]]></description>
										<content:encoded><![CDATA[<h1 class="wp-block-heading">はじめに</h1>



<p>最近ベイズ推論の勉強をしていて機械学習スタートアップシリーズの「ベイズ推論による機械学習入門」を読んでいます。今回はこの本の5.7 のニューラルネットワークの章で紹介されていたモデルをPyTorchで実装したので、実装と苦労した点を紹介していきます。</p>



<p>参考にしたJuliaのサンプルコードはこちらです。</p>



<p><a href="https://github.com/sammy-suyama/BayesBook/blob/master/src/demo_BayesNeuralNet.jl">https://github.com/sammy-suyama/BayesBook/blob/master/src/demo_BayesNeuralNet.jl</a></p>



<p>今回はこのサンプルコードをもとにして、PyTorchで実装したものを作り、以下に公開しました。</p>



<p><a href="https://github.com/shu65/blog-bayes-book/blob/main/%E3%83%99%E3%82%A4%E3%82%BA%E6%8E%A8%E5%AE%9A%E3%81%AB%E3%82%88%E3%82%8B%E6%A9%9F%E6%A2%B0%E5%AD%A6%E7%BF%92%E5%85%A5%E9%96%80%E3%80%805_7_%E3%83%8B%E3%83%A5%E3%83%BC%E3%83%A9%E3%83%AB%E3%83%8D%E3%83%83%E3%83%88%E3%83%AF%E3%83%BC%E3%82%AF.ipynb" target="_blank" rel="noreferrer noopener" title="ベイズ推定による機械学習入門　5_7_ニューラルネットワーク.ipynb">https://github.com/shu65/blog-bayes-book/blob/main/%E3%83%99%E3%82%A4%E3%82%BA%E6%8E%A8%E5%AE%9A%E3%81%AB%E3%82%88%E3%82%8B%E6%A9%9F%E6%A2%B0%E5%AD%A6%E7%BF%92%E5%85%A5%E9%96%80%E3%80%805_7_%E3%83%8B%E3%83%A5%E3%83%BC%E3%83%A9%E3%83%AB%E3%83%8D%E3%83%83%E3%83%88%E3%83%AF%E3%83%BC%E3%82%AF.ipynb</a></p>



<p>今回はこのPyTorch実装についての紹介になります。パラメータは少しいじっていますが大体Juliaのサンプルと合わせています。</p>



<p>PyTorchで実装することで、より複雑なモデルを作ることが簡単になるのと、単純作業だけどやるのは大変な偏微分の計算のところをPyTorchの自動微分に任せることができるため、より本質的なことろが理解しやすくなると思ったためです。</p>



<p>それではまず本に書かれているモデルについて簡単に説明して、その後実際のPyTorch実装の説明という流れで説明していきたいと思います。</p>



<h1 class="wp-block-heading">本で紹介されているモデルと変分推論</h1>



<h2 class="wp-block-heading">モデルの定義</h2>



<p>まずは本で紹介されているモデルについて説明していきます。訓練データの入力値と出力値の集合をそれぞれ \( \boldsymbol{X} \)、 \( \boldsymbol{Y} \) と置きます。この集合の要素数を \( N \) とし \( n \in N \) 番目のデータの入力値を \( \boldsymbol{x}_n \in \mathbb{R}^M \) 、出力値を \( \boldsymbol{y}_n \in \mathbb{R}^D \) とするとき、ガウス分布によってモデル化すると以下のようになります。</p>



<p>$$ \begin{align*} <br>p(\boldsymbol{Y}|\boldsymbol{X}, \boldsymbol{W}) <br>=&amp; \prod_{n=1}^N p(\boldsymbol{y}_n|\boldsymbol{x}_n, \boldsymbol{W}) \\<br>=&amp; \prod_{n=1}^N \mathcal{N}(\boldsymbol{y}_n|f(\boldsymbol{W}, \boldsymbol{x}_n), \lambda_y^{-1}\boldsymbol{I}_D) \tag{5.257} <br>\end{align*} $$</p>



<p>ここで、\( \boldsymbol{W} \) は モデルのパラメータの集合、\( \lambda_y^{-1} \) 固定の精度パラメータです。また、非線形関数\(f\)は次のように定義することにします。</p>



<p>$$ \begin{align*}<br>f(\boldsymbol{W}, \boldsymbol{x}_n) = {\boldsymbol{W}^{(2)}}^\mathrm{T} \text{Tanh}({\boldsymbol{W}^{(1)}}^\mathrm{T}\boldsymbol{x}_n) \tag{5.258}<br>\end{align*} $$</p>



<p>ここでモデルパラメータ \( \boldsymbol{W} \) の要素を\( \boldsymbol{W}^{(1)} \) と \( \boldsymbol{W}^{(2)} \)として、\( \boldsymbol{W}^{(1)} \in \mathbb{R}^{M \times K}\) 及び\( \boldsymbol{W}^{(2)} \in \mathbb{R}^{K \times D}\) という行列としています。このモデルパラメータの各要素は次のようなシンプルなガウス事前分布を仮定することにします。</p>



<p>$$ \begin{align*}<br>p (w_{m,k}^{(1)}) = \mathcal{N} (w_{m,k}^{(1)}|0, \lambda_w^{-1}) \tag{5.259} \\<br>p (w_{m,k}^{(2)}) = \mathcal{N} (w_{m,k}^{(2)}|0, \lambda_w^{-1}) \tag{5.259}<br>\end{align*} $$</p>



<p>また \( \text{Tanh}(\cdot) \) は以下のように定義されます。</p>



<p>$$<br>\text{Tanh}(a) = \frac{\text{exp}(a) &#8211; \text{exp}(-a)}{\text{exp}(a) + \text{exp}(-a)} \tag{5.261}<br>$$</p>



<p>以上が本に書かれたモデルの定義になります。</p>



<h2 class="wp-block-heading">変分推論</h2>



<p>ここではニューラルネットワークモデルのパラメータ \( \boldsymbol{W} = \{ \boldsymbol{W}^{(1)}, \boldsymbol{W}^{(2)} \} \) の事後分布を推論する問題を考えます。こちらは本で一つ前に紹介されていたロジスティック回帰のときとほぼ同じように行っていきます。</p>



<p>今回は細かい式の説明は省いていますが、ニューラルネットワークの勾配の計算は5.6ロジスティック回帰と似たような式変形になります。5.6ロジスティック回帰の式の導出はこちらの記事に詳しく書いてあるので参考にしてみてください。</p>



<figure class="wp-block-embed is-type-wp-embed is-provider-まったり勉強ノート wp-block-embed-まったり勉強ノート"><div class="wp-block-embed__wrapper">
<blockquote class="wp-embedded-content" data-secret="jbuQ8qITUF"><a href="https://www.mattari-benkyo-note.com/2022/02/24/bayes-book-5-6/">[勉強ノート]「ベイズ推論による機械学習入門」5.6ロジスティック回帰</a></blockquote><iframe loading="lazy" class="wp-embedded-content" sandbox="allow-scripts" security="restricted"  title="&#8220;[勉強ノート]「ベイズ推論による機械学習入門」5.6ロジスティック回帰&#8221; &#8212; まったり勉強ノート" src="https://www.mattari-benkyo-note.com/2022/02/24/bayes-book-5-6/embed/#?secret=qxmIOB0tZv#?secret=jbuQ8qITUF" data-secret="jbuQ8qITUF" width="600" height="338" frameborder="0" marginwidth="0" marginheight="0" scrolling="no"></iframe>
</div></figure>



<p>事後分布を推定するために、対角ガウス分布を使った以下のような近似事後分布を利用します。</p>



<p>$$ \begin{align*}<br>q(\boldsymbol{W}^{(1)}; \boldsymbol{\eta}^{(1)}) =&amp; \prod_{m=1}^M \prod_{k=1}^K \mathcal{N}(w_{m,k}|\mu_{m,k}^{(1)},{\sigma_{m,k}^{(1)}}^2) \tag{5.262.1} \\<br>q(\boldsymbol{W}^{(2)}; \boldsymbol{\eta}^{(2)}) =&amp; \prod_{m=1}^M \prod_{k=1}^K \mathcal{N}(w_{m,k}|\mu_{m,k}^{(2)},{\sigma_{m,k}^{(2)}}^2) \tag{5.262.2} \\<br>\end{align*} $$</p>



<p>ここで \( \boldsymbol{\eta} = \{ \boldsymbol{\eta}^{(1)}, \boldsymbol{\eta}^{(2)} \} \) が変分パラメータの集合となります。今回は以下のような近似事後分布と真の事後分布のKLダイバージェンスを最小化するような変分パラメータ&nbsp;\( \boldsymbol{\eta} \)を見つけることを目指します。</p>



<p>$$ \begin{align*}<br>&amp;\text{KL} [q(\boldsymbol{W};\boldsymbol{\eta})||p(\boldsymbol{W}|\boldsymbol{Y},\boldsymbol{X})] \\<br>&amp; \ = \langle \ln q(\boldsymbol{W};\boldsymbol{\eta}) \rangle _{q(\boldsymbol{W};\boldsymbol{\eta})} <br>&#8211; \langle \ln p(\boldsymbol{W}) \rangle _{q(\boldsymbol{W};\boldsymbol{\eta})} \\<br>&amp; \qquad &#8211; \sum_{n=1}^N \langle \ln p(\boldsymbol{y}_n | \boldsymbol{x}_n, \boldsymbol{W}) \rangle _{q(\boldsymbol{W};\boldsymbol{\eta})} <br>+ \text{const} \tag{5.236}<br>\end{align*} $$</p>



<p>この式 (5.236)の最小化するにあたり、以下のような再パラメータ化トリックを利用して、 \( \boldsymbol{W} \) の各要素 \( w \) (インデックスは省略してます。) を以下のように置きます。</p>



<p>$$ \begin{align*}<br>\tilde{w} = \mu + \sigma \tilde{\epsilon} \tag{5.237} \\<br>\text{ただし} \tilde{\epsilon} \sim \mathcal{N} (\epsilon|0,1) \tag{5.238}<br>\end{align*} $$</p>



<p>これを利用すると以下のような \( g(\boldsymbol{\tilde{W}}, \boldsymbol{\eta})  \) を最小化することになります。</p>



<p><br>$$ \begin{align*}<br>&amp; \text{KL} [q(\boldsymbol{W};\boldsymbol{\eta})||p(\boldsymbol{W}|\boldsymbol{Y},\boldsymbol{X})] \\<br>&amp; \ \approx \ln q(\boldsymbol{\tilde{W}};\boldsymbol{\eta})<br>&#8211; \ln p(\boldsymbol{\tilde{W}}) \\<br>&amp; \qquad &#8211; \sum_{n=1}^N \ln p(\boldsymbol{y}_n | \boldsymbol{x}_n, \boldsymbol{\tilde{W}}) + \text{const} \\<br>&amp; \ = g(\boldsymbol{\tilde{W}}, \boldsymbol{\eta}) \tag{5.239}<br>\end{align*} $$</p>



<p>ただし、本ではすべてのデータで尤度の勾配を計算する方法ではなく、確率的勾配降下法（stochastic gradient descent, SGD）にも触れられているので、この記事ではSGDを使って最適化します。ただ、後ほどまた説明しますが、本で書かれているデータを1つ1つ逐次的に与えて勾配を計算するオンラインのSGDではうまく収束してくれなかったので、ミニバッチを用いるSGDを使います。これは基本的に本で書かれているように式(5.239) の事前分布と近似事後分布の項の影響をデータ数に応じて抑えます。今回はミニバッチを利用するので、ミニバッチ内の訓練データを\( \{\boldsymbol{X}_B, \boldsymbol{Y}_B\}\)とし、 \( b \in B \) 番目のデータの入力値を \( \boldsymbol{x}_b\) 、出力値を \( \boldsymbol{y}_b \)として式 (5.239) を変形した以下の式の勾配を利用します。</p>



<p>$$ \begin{align*}<br>&amp; \text{KL} [q(\boldsymbol{W};\boldsymbol{\eta})||p(\boldsymbol{W}|\boldsymbol{Y}_B,\boldsymbol{X}_B)] \\<br>&amp; \ \approx \frac{B}{N} \lbrace \ln q(\boldsymbol{\tilde{W}};\boldsymbol{\eta})<br>&#8211; \ln p(\boldsymbol{\tilde{W}}) \rbrace \\<br>&amp; \qquad &#8211; \sum_{b=1}^B \ln p(\boldsymbol{y}_b | \boldsymbol{x}_b, \boldsymbol{\tilde{W}}) + \text{const} \\<br>\tag{5.269.1}<br>\end{align*} $$</p>



<p>本ではこれ以外に勾配を計算するための式変形が細かく書いてありますが、今回はPyTorchの自動微分の機能を使うため、式 (5.269.1)  の値を計算し、この値をlossとして<code>backward()</code> を実行するため、説明はここまでにします。<br></p>



<h1 class="wp-block-heading">実装について</h1>



<p>今回はPyTorchの自動微分の機能を使って変分パラメータの最適化に必要な勾配を計算します。このため、処理の基本的な流れとしては式 (5.269.1) の値を計算し、この値をlossとして<code>backward()</code> を実行するというのを指定した回数繰り返して最適化します。</p>



<h2 class="wp-block-heading">モデル部分の実装</h2>



<p>式 (5.258) をPyTorchで実装します。具体的なものは最初に示したJupyter Notebookの<code> BayeNNModel </code>クラスの実装をご覧ください。ここでは重要な部分だけ示します。まず、式 (5.258)を <code>forward()</code> に実装します。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="BayeNNModel_forward" data-lang="Python"><code>    def forward(self, X):
      W1 = self.sample_W(self.mu1, self.rho1)
      h1 = torch.nn.functional.linear(X, W1, bias=None)
      h2 = torch.tanh(h1)
      W2 = self.sample_W(self.mu2, self.rho2)
      Y_est = torch.nn.functional.linear(h2, W2, bias=None)
      return Y_est, W1, W2</code></pre></div>



<p>今回は \( \boldsymbol{\tilde{W}} \) はサンプリングしてくる必要があるので、<code>sample_W() </code>という関数でサンプリングしてそれを<code>torch.nn.functional.linear()</code>に入れるということをしています。</p>



<p>\( \boldsymbol{\tilde{W}} \) はサンプリングしてくる部分は以下のようにします。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="BayeNNModel_sample_W" data-lang="Python"><code>    def sample_W(self, mu, rho):
      eps = torch.randn(mu.shape)
      sigma = self.rho2sigma(rho) + self.approximate_posterior_sigma_eps
      W = mu + sigma * eps
      return W</code></pre></div>



<p>基本的には式(5.237) の実装になります。ただ、後ほど説明する \( \ln q(\boldsymbol{\tilde{W}};\boldsymbol{\eta}) \) の計算のところで、ガウス分布に入れる \( \sigma \) が0になってしまいエラーになるケースが発生してしまうため、0にならないようにするための補正 (<code>self.approximate_posterior_sigma_eps</code>) を加算しています。</p>



<h2 class="wp-block-heading">変分推論部分の実装</h2>



<p>勾配を計算するための式(5.269.1) の各項を計算します。これらの項はすべてガウス分布になっているのでPyTorchの <code>torch.distributions.normal.Normal()</code> を使えば簡単に実装できます。</p>



<p>\( \ln q(\boldsymbol{\tilde{W}};\boldsymbol{\eta}) \) と\( \ln p(\boldsymbol{\tilde{W}}) \) 、\( \sum_{b=1}^B \ln p(\boldsymbol{y}_b | \boldsymbol{x}_b, \boldsymbol{\tilde{W}}) \)のそれぞれの項は以下の関数で計算するようにしています。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="BayeNNModel_log_prob" data-lang="Python"><code>    def _compute_approximate_posterior_log_prob_core(self, W, mu, rho):
      sigma = self.rho2sigma(rho) + self.approximate_posterior_sigma_eps
      W_dist = torch.distributions.normal.Normal(mu, sigma)
      log_prob_W = W_dist.log_prob(W)
      sum_log_prob_W = torch.sum(log_prob_W)
      return sum_log_prob_W

    def _compute_prior_log_prob_core(self, W, sigma_w):
      W_dist = torch.distributions.normal.Normal(0, sigma_w)
      log_prob_W = W_dist.log_prob(W)
      sum_log_prob_W = torch.sum(log_prob_W)
      return sum_log_prob_W

    def compute_approximate_posterior_log_prob(self, W1, W2):
      log_prob_W1 = self._compute_approximate_posterior_log_prob_core(W1, self.mu1, self.rho1)
      log_prob_W2 = self._compute_approximate_posterior_log_prob_core(W2, self.mu2, self.rho2)
      return log_prob_W1 + log_prob_W2

    def compute_prior_log_prob(self, W1, W2):
      log_prob_W1 = self._compute_prior_log_prob_core(W1, self.sigma_w)
      log_prob_W2 = self._compute_prior_log_prob_core(W2, self.sigma_w)
      return log_prob_W1 + log_prob_W2

    def compute_log_prob_p(self, Y, Y_est):
      Y_dist = torch.distributions.normal.Normal(Y_est, self.sigma_y)
      log_p = Y_dist.log_prob(Y)
      return torch.sum(log_p)</code></pre></div>



<p>これらを使ってKLダイバージェンスの勾配に関係する部分だけ計算して、<code>backward()</code>とOptimizerの<code>step()</code>を以下のように呼びます。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="vi_step" data-lang="Python"><code>def vi_step(X, Y, model, optimizer, N, max_grad_norm=1e2):
  model.zero_grad()

  batch_size = X.shape[0]
  Y_est, W1, W2 = model(X)
  prior_log_prob_W = model.compute_prior_log_prob(W1, W2)
  posterior_log_prob_W = model.compute_approximate_posterior_log_prob(W1, W2)
  log_prob_p = model.compute_log_prob_p(Y, Y_est)
  kl_divergence = batch_size/N * (posterior_log_prob_W - prior_log_prob_W) - log_prob_p

  kl_divergence.backward()
  torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
  optimizer.step()
  return kl_divergence.item()</code></pre></div>



<p>これらを使ってJuliaのサンプルと同様の訓練データで事後分布の推定を行って本の図5.18と似たような図を以下のように作ってみました。</p>



<div class="wp-block-image"><figure class="aligncenter size-full is-resized"><img loading="lazy" decoding="async" src="https://www.mattari-benkyo-note.com/wp-content/uploads/2022/02/BayeNNModel_plot.png" alt="学習後のニューラルネットワークによる予測" class="wp-image-899" width="380" height="248" srcset="https://www.mattari-benkyo-note.com/wp-content/uploads/2022/02/BayeNNModel_plot.png 380w, https://www.mattari-benkyo-note.com/wp-content/uploads/2022/02/BayeNNModel_plot-300x196.png 300w" sizes="auto, (max-width: 380px) 100vw, 380px" /><figcaption>学習後のニューラルネットワークによる予測</figcaption></figure></div>



<p>図を見ると十分な精度がでていると思っています。</p>



<h1 class="wp-block-heading">実装で苦労した点について</h1>



<p>いざ実装してみると以下の点で工夫が必要だったので簡単に紹介します。</p>



<h2 class="wp-block-heading">最初の数イテレーションの勾配が大きすぎてモデルのパラメータがnanになる</h2>



<p>最初の数イテレーションはどうしても勾配が大きくなりがちです。今回は特にバッチサイズが大きいとモデルのパラメータが途中でnanになってしまうという問題が発生しました。この手の問題は深層学習ではよくあるためいくつか対処する手段はありますが、今回はシンプルなgradient clippingを利用しています。具体的には<code> vi_step()</code> で<code>torch.nn.utils.clip_grad_norm_()</code>を呼んでいる部分がそれにあたります。</p>



<h2 class="wp-block-heading">近似事後分布のσが0になる</h2>



<p>バッチサイズや他のパラメータによっても発生したりしなかったりしますが、時々\( \ln q(\boldsymbol{\tilde{W}};\boldsymbol{\eta}) \)を計算する部分で \( \sigma \) が計算誤差で0に丸められてしまうとうケースが発生しました。このため、以下のようにして小さい値を加算して0になるのを防ぐ必要がありました。</p>



<div class="hcb_wrap"><pre class="prism line-numbers lang-python" data-file="W_sampling" data-lang="Python"><code>      sigma = self.rho2sigma(rho) + self.approximate_posterior_sigma_eps
      W_dist = torch.distributions.normal.Normal(mu, sigma)</code></pre></div>



<h1 class="wp-block-heading">終わりに</h1>



<p>「ベイズ推論による機械学習入門」の5.7 ニューラルネットワークのモデルをPyTorchで実装したので、実装についてと苦労した点についての記事を書きました。</p>



<p>実装する前はすぐできるだろうと思っていましたが、苦労した点に書いたような問題が出てきて思ったより時間がかかった印象です。ただ、実際に実装してみてベイズ推論の理解が深まったのでやってよかったです。実はこの本の他の実装もいつくかしてあるので機会があればまたblogの記事にしようと思います。</p><p>The post <a href="https://www.mattari-benkyo-note.com/2022/02/28/bayes-book-5-7-pytorch/">[勉強ノート]「ベイズ推論による機械学習入門」5.7 ニューラルネットワークで紹介されたものをPyTorchで実装してみた</a> first appeared on <a href="https://www.mattari-benkyo-note.com">まったり勉強ノート</a>.</p>]]></content:encoded>
					
					<wfw:commentRss>https://www.mattari-benkyo-note.com/2022/02/28/bayes-book-5-7-pytorch/feed/</wfw:commentRss>
			<slash:comments>0</slash:comments>
		
		
		<post-id xmlns="com-wordpress:feed-additions:1">849</post-id>	</item>
	</channel>
</rss>
