GANの数学的な話

Python

はじめに

今回東大松尾研の深層生成モデルの講義を受けたり、書籍を読んでGAN(敵対生成ネットワーク)を学んだため、ノート替わりにここにメモしていこうと考えました。

生成モデルとは

そもそも、最近話題になっている画像生成モデルについて定性的にだが考えてみようと思います。生成モデルは広義において次のように定義されています[1]。

生成モデルは、データセットがどのように生成されるかを確率モデルの観点から記述する。このモデルからサンプリングすることで、新しいデータを生成できる。

これを自分なりの理解でいうと以下のようになります。

生成モデルの考え方として対象生成物は確率分布にしたがって、生成されているとしているとしています。

具体的な話をすると上記画像では画像のうち、このピクセルは髪の黒色である可能性が高いな、ここに目が来る可能性が高いな…などなど画像を構成するには確率分布がありそうに見えます。この実際のデータの確率分布を再現するといったことが生成モデルの目的となっています。この再現に向けた学習の仕方は様々な手法が提案されています。

今回は、その学習の中でもGAN(敵対生成ネットワーク)について勉強しました。

GAN

生成モデルでGANの前にVAE(Variational Auto-Encoder)といった生成モデルについて触れてみたいと思います。このVAEではどのように確率分布を学習していたかを説明していきます。

「生成モデルとは」で触れた通り生成モデルでは確率分布の学習を目的としています。ここでデータセットが持つ本来の確率分布を\(p_{data}(x)\)、学習された確率分布を\(p_{\theta}(x)\)とするとこれらが一致するように学習を行います(下イメージ図)。これを最尤学習といいます。

VAEではこの最尤学習に対して”尤もらしさ”の評価としてKullback–Leiblerダイバージェンスといった指標を用いました。KLダイバージェンス\(D_{KL}\)は次のように定義されます。

$$D_{KL}[p_{data}(x)||p_{\theta}(x)]=\int p_{data} \log \frac{p_{data}}{p_{\theta}} dx $$

このKLダイバージェンスは距離の公理である、対称性を満たしていませんが確率分布同士の距離のようなものだと理解するとわかりやすいです。VAEの学習ではこの距離が近づくように学習を行っていきます。しかし、VAEでは上図で示している通り確率分布を何らかの分布(ガウス分布やベルヌーイ分布)などの決まった形状として仮定し、確率分布を学習しています。そのため、訓練データにない部分はぼやけたような画像となってしまいます。

そこで生成モデルとして決まった形状で仮定しないで学習をすることを考えます。しかし、決まった形状を仮定しないと”尤もらしさ”の評価をする尤度計算ができません。そこで生成される確率分布の違いだけをひとまず次のように定義しました。

$$r_{\phi}(x)=\frac{p_{data}(x)}{p_{\theta}(x)} $$

ここで\(r_{\phi}(x)\)を密度比と呼びます。
この密度比の解釈としてデータ集合\({x_1,x_2,…x_N}\)のうち半分がデータセットからの画像(本物:1)、もう半分が生成された画像(偽物:0)としてラベル付け(y)すると以下のような条件付き分布となります[2]。

$$p_{data}(x)=p(x|y=1)\\
p_{\theta}(x)=p(x|y=0)$$

したがって、密度比は以下のように書き直せます。

$$
\begin{eqnarray}
r_{\phi}&=&\frac{p_{data}(x)}{p_{\theta}(x)} \\
&=& \frac{p(x|y=1)}{p(x|y=0)} \\
&=& \frac{\frac{p(y=1|x)p(x)}{p(y=1)}}{\frac{p(y=0|x)p(x)}{p(y=0)}} \\
&=& \frac{p(y=1|x)}{p(y=0|x)}・\frac{1-\pi}{\pi}
\end{eqnarray}
$$

※ベイズの定理
$$
\begin{eqnarray}
p(y|x)&=&\frac{p(y)p(x|y)}{p(x)} \\
&=& p(y)p(x|y)
\end{eqnarray}
$$

ここで\(\pi=p(y=1)\)であり、画像が本物であるかの確率\(p(y=1|x)\)を判定できればよくなることがわかります。この\(p(y=1|x)\)を判定する分布を\(D(x;\phi)\)とおき、これを識別機(discriminator)と呼びます。

つまり、深層ニューラルネットワークが強い分類問題として置き換えることができるようになります。

この分類問題の解き方を考えます。分類問題で目的関数としてよく置かれるものとして負の交差エントロピー損失といったものがあります。これは次のように表せます。

$$\mathcal{L}(\phi;x)=\mathbb{E}_{p(x,y)}[ylogD(x;\phi)+(1-y)log(1-D(x;\phi))]$$

これを式変形すると以下のように変形できます。

$$\mathcal{L}(\phi;x)=\mathbb{E}_{p_{data}(x)}[logD(x;\phi)]+\mathbb{E}_{p_{\theta}(x)}[log(1-D(x;\phi)) ]$$

もし、識別機が適切に学習ができ、推定出来たら

$$D(x;\phi^*)=\frac{r_{\phi^*}(x)}{r_{\phi^*}(x)+1}=\frac{p_{data}(x)}{p_{data}(x)+p_{\theta}(x)}$$

に収束します。これを負の交差エントロピー損失に代入すると

$$\mathcal{L}(\phi;x)=2D_{JS} (p_{data}||p_{\theta})-2log2$$

ここで\(D_{JS}\)Jensen-Shannonダイバージェンスと呼びます。JSダイバージェンスは以下のように定義されています。

$$D_{JS} (p_{data}||p_{\theta})=\frac{1}{2} D_{KL} (p_{data}|| \frac{p_{data}+p_{\theta}}{2})+\frac{1}{2} D_{KL} (p_{\theta}|| \frac{p_{data}+p_{\theta}}{2})$$

式変形によって識別機の目的関数にKLダイバージェンスを含んだJSダイバージェンスが出てくるのは面白いですね。

次に学習した確率分布\(p_{\theta}(x)\)を用いて

$$x \sim p_{\theta}(x|z)\Leftrightarrow x \sim G$$

ここでGを生成器(Generator)と呼びます。ここで出てきたzは任意の入力で、画像生成するためのきっかけのようなものになる。また、Gを学習するための目的関数は

$$\mathcal{L}(\phi^*,\theta;x)=\mathbb{E}_{p_{data}(x)}[logD(x;\phi^*)]+\mathbb{E}_{p(z)}[log(1-D(G(z;\theta);\phi^*)) ]$$

となり、これを最小化するように学習する。ここまでの識別機と生成器の学習をまとめると識別機は

$$max_\phi \mathbb{E}_{p_{data}(x)}[logD(x;\phi)]+\mathbb{E}_{p(z)}[log(1-D(G(z;\theta);\phi)) ]$$

生成器は

$$min_\theta \mathbb{E}_{p(z)}[log(1-D(G(z;\theta);\phi^*)) ]$$

となる。このような枠組みで生成モデルを学習する枠組みをgenarative adversarial netwoks(GAN)と呼びます[3]。

このように識別機と生成器がお互いに敵対しながら学習を進めていくのが特徴になっています。このアルゴリズムは以下のようになっています[3]。

GANの特徴

GANは2014年の登場後、2015年後半に有効な学習手法が確立されてから急速に注目を集めました。書籍として実践GAN ~敵対的生成ネットワークによる深層学習~GANディープラーニング実装ハンドブックなど日本語でも実装や理論的にわかりやすい本が発売されているのでぜひ読んでほしいです。

また、GANは高次元データに強い、背景がきれいに生成されるなどの特徴から成功したモデルと呼ばれています[4]。その一方mode collapseや勾配消失問題などが発生しやすく、学習が不安定になりがちであるといった問題点も指摘されています。このため、今後の研究が期待されているモデルでもあります。

 

参考文献

[1]David Faster,”生成Deep Learning“,OREILLY (2020).

[2]杉山 将,”密度比推定に基づく統計的機械学習 “,https://www.scat.or.jp/cms/wp-content/uploads/2020/06/sugiyama.pdf (2022.10.08閲覧)

[3]Ian J. Goodfellow. et. al.,”Generative Adversarial Networks” arXiv (2014).

[4]岡野原大輔”AI技術の最前線“日経BP (2022)

[5]GAN (Generative Adversarial Networks):敵対的生成ネットワーク

コメント

タイトルとURLをコピーしました