はじめに
日本語の記事でKeras + CycleGAN(pix2pix) + colab環境 + 自前データセットでの実装例があまりに少なかったので少しでも参考になればと思い、書きました。
Kerasでの実装について
まず、今回GANで一般的に使用されているPytorchではなくKerasで実装しようと思った理由として可読性にあります。一般に機械学習moduleとしてpytorch,tensorflow,Kerasが挙げられると思います。論文や最新の研究においてPytorchが使われている理由としてはそのカスタイマイズ性や計算速度、デバッグのしやすさにあると考えていますが[1]、自分みたいな初心者には0からの実装はハードルが高かったのでKerasを用いて実装いたしました。
また、自由なモデルが使えるように自前のデータセットを用いて学習できるようにカスタマイズしました。さらに自分にはスペックの高いPCを持っていなかったのでcolabで実装しました。また、参考文献として[2][3]を参考に一部コードを参考にしました。
CycleGANについての基本的な理解は参考文献やGANの数学的な話を参考にしてください。
データセットはfacadesデータセットをDLしてしようしており、左半分が生成画像、右半分が元画像となっています。
ネットワークの構成とディレクト&データセットについて
識別機(discriminator)について
Model: "model_44"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_53 (InputLayer) [(None, 256, 256, 3)] 0
conv2d_234 (Conv2D) (None, 128, 128, 64) 3136
leaky_re_lu_144 (LeakyReLU) (None, 128, 128, 64) 0
conv2d_235 (Conv2D) (None, 64, 64, 128) 131200
leaky_re_lu_145 (LeakyReLU) (None, 64, 64, 128) 0
instance_normalization_180 (None, 64, 64, 128) 2
(InstanceNormalization)
conv2d_236 (Conv2D) (None, 32, 32, 256) 524544
leaky_re_lu_146 (LeakyReLU) (None, 32, 32, 256) 0
instance_normalization_181 (None, 32, 32, 256) 2
(InstanceNormalization)
conv2d_237 (Conv2D) (None, 16, 16, 512) 2097664
leaky_re_lu_147 (LeakyReLU) (None, 16, 16, 512) 0
instance_normalization_182 (None, 16, 16, 512) 2
(InstanceNormalization)
conv2d_238 (Conv2D) (None, 16, 16, 1) 8193
=================================================================
Total params: 2,764,743
Trainable params: 2,764,743
Non-trainable params: 0
_________________________________________________________________
生成器(Generator)について
Model: "model_45"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_54 (InputLayer) [(None, 256, 256, 3 0 []
)]
conv2d_239 (Conv2D) (None, 128, 128, 32 1568 ['input_54[0][0]']
)
leaky_re_lu_148 (LeakyReLU) (None, 128, 128, 32 0 ['conv2d_239[0][0]']
)
instance_normalization_183 (In (None, 128, 128, 32 2 ['leaky_re_lu_148[0][0]']
stanceNormalization) )
conv2d_240 (Conv2D) (None, 64, 64, 64) 32832 ['instance_normalization_183[0][0
]']
leaky_re_lu_149 (LeakyReLU) (None, 64, 64, 64) 0 ['conv2d_240[0][0]']
instance_normalization_184 (In (None, 64, 64, 64) 2 ['leaky_re_lu_149[0][0]']
stanceNormalization)
conv2d_241 (Conv2D) (None, 32, 32, 128) 131200 ['instance_normalization_184[0][0
]']
leaky_re_lu_150 (LeakyReLU) (None, 32, 32, 128) 0 ['conv2d_241[0][0]']
instance_normalization_185 (In (None, 32, 32, 128) 2 ['leaky_re_lu_150[0][0]']
stanceNormalization)
conv2d_242 (Conv2D) (None, 16, 16, 256) 524544 ['instance_normalization_185[0][0
]']
leaky_re_lu_151 (LeakyReLU) (None, 16, 16, 256) 0 ['conv2d_242[0][0]']
instance_normalization_186 (In (None, 16, 16, 256) 2 ['leaky_re_lu_151[0][0]']
stanceNormalization)
up_sampling2d_72 (UpSampling2D (None, 32, 32, 256) 0 ['instance_normalization_186[0][0
) ]']
conv2d_243 (Conv2D) (None, 32, 32, 128) 524416 ['up_sampling2d_72[0][0]']
instance_normalization_187 (In (None, 32, 32, 128) 2 ['conv2d_243[0][0]']
stanceNormalization)
concatenate_54 (Concatenate) (None, 32, 32, 256) 0 ['instance_normalization_187[0][0
]',
'instance_normalization_185[0][0
]']
up_sampling2d_73 (UpSampling2D (None, 64, 64, 256) 0 ['concatenate_54[0][0]']
)
conv2d_244 (Conv2D) (None, 64, 64, 64) 262208 ['up_sampling2d_73[0][0]']
instance_normalization_188 (In (None, 64, 64, 64) 2 ['conv2d_244[0][0]']
stanceNormalization)
concatenate_55 (Concatenate) (None, 64, 64, 128) 0 ['instance_normalization_188[0][0
]',
'instance_normalization_184[0][0
]']
up_sampling2d_74 (UpSampling2D (None, 128, 128, 12 0 ['concatenate_55[0][0]']
) 8)
conv2d_245 (Conv2D) (None, 128, 128, 32 65568 ['up_sampling2d_74[0][0]']
)
instance_normalization_189 (In (None, 128, 128, 32 2 ['conv2d_245[0][0]']
stanceNormalization) )
concatenate_56 (Concatenate) (None, 128, 128, 64 0 ['instance_normalization_189[0][0
) ]',
'instance_normalization_183[0][0
]']
up_sampling2d_75 (UpSampling2D (None, 256, 256, 64 0 ['concatenate_56[0][0]']
) )
conv2d_246 (Conv2D) (None, 256, 256, 3) 3075 ['up_sampling2d_75[0][0]']
==================================================================================================
Total params: 1,545,425
Trainable params: 1,545,425
Non-trainable params: 0
__________________________________________________________________________________________________
また、ファイルについては以下のようになっています。
├── datasets
│ └── facades
│ ├── test
│ ├── train
│ └── val
└── Keras_cycleGAN.ipynb
また、画像は
このように右側の元画像から左の画像を生成するような場合を考えています。
コード
参考文献
[1]TensorflowとKeras、PyTorchの比較
[3]/Keras-GAN
コメント