PyTorch Geometricを使ってVariational Graph Auto-Encodersを作って学習してみる

はじめに

最近読んだ論文にVariational Graph Auto-Encoders (VGAE) を使ったモデルがあったので、自分でもやってみようと思い、作ってみました。本日はそのまとめになります。

本日紹介する使うコードは以下のものです。

https://github.com/shu65/pytorch_geometric_examples/blob/main/PyTorch_Geometric_Variational_Graph_AutoEncoder.ipynb

また、このコード自体、以下のPyTorch Geometricのexampleのコードとほぼ同じです。

https://github.com/pyg-team/pytorch_geometric/blob/ee509ad65aefa679047356bb00bc498f35ce7e20/examples/autoencoder.py

このblog記事ではVGAEで必要な機能がPyTorch Geometricでどう実装されているのかわからなかった部分がいくつかあるのでその部分を解説していく記事になります。

PyTorch Geometricとは

PyTorch GeometricはPyTorchを使って構築されたGraph Neural Network向けのライブラリになります。

GitHubのURLは以下の通りです。

https://github.com/pyg-team/pytorch_geometric

最新のPyTorchやCUDAにもちゃんと対応しており、Graph Neural Networkで必要な基本的な機能はそろっている印象です。

Variational Graph Auto-Encoders (VGAE)とは

VGAEはVariational Auto-Encoder (VAE) というモデルをGraphデータ向けに拡張したモデルです。VAEの説明を始めるとそれだけですごく長くなりますので、今回はVGAEを実装するうえで必要なところだけ紹介します。

VAEは以下のようにEncoderとDecoderという二つのモデルを組み合わせたモデルになります。

VAEの概要図

このうち、EncoderとDecoderは以下のようなモデルになります。

  • Encoder: 入力Xを受け取って潜在変数Zの分布のパラメータを出力する
  • Decoder: 潜在変数Zを受け取って入力Xを再構成する

VAEで重要なのがEncoderの部分と潜在変数Zのサンプリングの部分です。この潜在変数Zの分布が標準正規分布という仮定のもと学習させながら、Encoderで潜在変数Zの分布のパラメータを出力し、その分布のパラメータを使って潜在変数ZをサンプリングしてDecoderに渡すということを行います。

このVAEをGraph データに拡張するためにVGAEはEncoderとDecoderを以下のようなモデルにしています。

  • Encoder: ノードの特徴ベクトルXと隣接行列Aを入力として受け取り、潜在変数Zの分布のパラメータを出力する
  • Decoder: 潜在変数Zを受け取り隣接行列Aを再構築する

図にすると以下のようなイメージです。

VGAEの概要図

VGAEとVAEとの違いはEncoderでグラフの情報であるノード情報と隣接行列を受け取れるようにしたことと、Decoderが出力するものが隣接行列になることです。

VGAEをPyTorch Geometricを使って実装する

VGAEの概略を説明したので次は実際に実装を紹介していきます。まずはEncoderであるVariationalGCNEncoderから見ていきます。EncoderではPyTorch Geometricに実装されている GCNConv を使って実装します。

from torch_geometric.nn import GCNConv

class VariationalGCNEncoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, 2 * out_channels)
        self.conv_mu = GCNConv(2 * out_channels, out_channels)
        self.conv_logstd = GCNConv(2 * out_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)

GCNConv はノードのインプットのチャンネル数、アウトプットのチャネル数を引数にとってインスタンスを作ります。そしてforwardではノードのtensor x と隣接行列のかわりにどのノード同士がつながっているか?を示すedge_indexを渡します。GCNConv の中身についてはドキュメントに詳しく書かれているのでそちらをご覧ください。

https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.GCNConv

このEncoderではVGAEの概要でも説明した通り、潜在変数の分布のパラメータを返します。ここではガウス分布の平均を表すmuと標準偏差にlogを適用したlogstdを返しています。

モデルの実装としてはあとはPyTorch Geometricで実装されているVGAEというクラスに渡せば終わりになります。

from torch_geometric.nn import VGAE

model = VGAE(VariationalGCNEncoder(in_channels, out_channels))

ただ、これだとさすがに初見だと何が何だかわからなかったので、少し説明します。

まず、Decoderについてです。DecoderはVGAE のデフォルトではInnerProductDecoderというものが使われます。これはVGAEの元論文でも使われていたDecoderの実装で、エッジの両端のノードに対応する潜在変数の各要素の積を取って総和を取り、sigmoidを適用して0-1の値にして出力します。出力値が0-1の値になっているのでDecoderの出力値は計算に使った二つのノードの間にエッジがある確率とみることができます。

詳しくは以下のドキュメントをご覧ください。

https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.models.InnerProductDecoder

また、ロス関数についてですが、VGAE の中にVGAEで必要な以下の二つが実装されています。

これを以下のように学習ループで利用して学習をおこないます。

for epoch in range(0, 400):
    model.train()
    optimizer.zero_grad()
    z = model.encode(train_data.x, train_data.edge_index)
    recon_loss = model.recon_loss(z, train_data.pos_edge_label_index)
    kl_loss = (1 / train_data.num_nodes) * model.kl_loss()
    loss = recon_loss + kl_loss
    loss.backward()
    optimizer.step()

最後に上のコードではノード間にエッジがあるところの情報はtrain_data.pos_edge_label_indexで渡しているのですが、ノード間にエッジがないという情報はどこで渡しているか?ということについて説明します。

コードを読むと実はrecon_lossの中で自動的にエッジがないという情報を生成してそれを込みでロスが計算されています。具体的には以下の部分です。

https://github.com/pyg-team/pytorch_geometric/blob/d2b2e662488eae07d153de6d4b8c56c24bf413d9/torch_geometric/nn/models/autoencoder.py#L101

ここで引数でneg_edge_indexNoneのときは自動でエッジが存在しないノードのペアをサンプリングするという処理になっています。

以下です。その他の部分で気になるところがある場合は全体のコードを以下のところに置いてありますのでご覧ください。

https://github.com/shu65/pytorch_geometric_examples/blob/main/PyTorch_Geometric_Variational_Graph_AutoEncoder.ipynb

終わりに

今回はPyTorch Geometricの練習として、VGAEを実装してみたのでまとめの記事を書きました。PyTorch Geometricを今回初めて使ったのですが、Graph Neural Networkに必要な基本的な機能はそろっていそうなので、今後もGraph Neural Networkを使う機会があれば使ってみようと思います。

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です