Kerasで自前データセットを使ってCycleGANを実装by colab

Python

はじめに

日本語の記事でKeras + CycleGAN(pix2pix) + colab環境 + 自前データセットでの実装例があまりに少なかったので少しでも参考になればと思い、書きました。

 

Kerasでの実装について

まず、今回GANで一般的に使用されているPytorchではなくKerasで実装しようと思った理由として可読性にあります。一般に機械学習moduleとしてpytorch,tensorflow,Kerasが挙げられると思います。論文や最新の研究においてPytorchが使われている理由としてはそのカスタイマイズ性や計算速度、デバッグのしやすさにあると考えていますが[1]、自分みたいな初心者には0からの実装はハードルが高かったのでKerasを用いて実装いたしました。
また、自由なモデルが使えるように自前のデータセットを用いて学習できるようにカスタマイズしました。さらに自分にはスペックの高いPCを持っていなかったのでcolabで実装しました。また、参考文献として[2][3]を参考に一部コードを参考にしました。

CycleGANについての基本的な理解は参考文献やGANの数学的な話を参考にしてください。

 

データセットはfacadesデータセットをDLしてしようしており、左半分が生成画像、右半分が元画像となっています。

 

ネットワークの構成とディレクト&データセットについて

識別機(discriminator)について

Model: "model_44"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_53 (InputLayer)       [(None, 256, 256, 3)]     0         
                                                                 
 conv2d_234 (Conv2D)         (None, 128, 128, 64)      3136      
                                                                 
 leaky_re_lu_144 (LeakyReLU)  (None, 128, 128, 64)     0         
                                                                 
 conv2d_235 (Conv2D)         (None, 64, 64, 128)       131200    
                                                                 
 leaky_re_lu_145 (LeakyReLU)  (None, 64, 64, 128)      0         
                                                                 
 instance_normalization_180   (None, 64, 64, 128)      2         
 (InstanceNormalization)                                         
                                                                 
 conv2d_236 (Conv2D)         (None, 32, 32, 256)       524544    
                                                                 
 leaky_re_lu_146 (LeakyReLU)  (None, 32, 32, 256)      0         
                                                                 
 instance_normalization_181   (None, 32, 32, 256)      2         
 (InstanceNormalization)                                         
                                                                 
 conv2d_237 (Conv2D)         (None, 16, 16, 512)       2097664   
                                                                 
 leaky_re_lu_147 (LeakyReLU)  (None, 16, 16, 512)      0         
                                                                 
 instance_normalization_182   (None, 16, 16, 512)      2         
 (InstanceNormalization)                                         
                                                                 
 conv2d_238 (Conv2D)         (None, 16, 16, 1)         8193      
                                                                 
=================================================================
Total params: 2,764,743
Trainable params: 2,764,743
Non-trainable params: 0
_________________________________________________________________

生成器(Generator)について

Model: "model_45"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_54 (InputLayer)          [(None, 256, 256, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d_239 (Conv2D)            (None, 128, 128, 32  1568        ['input_54[0][0]']               
                                )                                                                 
                                                                                                  
 leaky_re_lu_148 (LeakyReLU)    (None, 128, 128, 32  0           ['conv2d_239[0][0]']             
                                )                                                                 
                                                                                                  
 instance_normalization_183 (In  (None, 128, 128, 32  2          ['leaky_re_lu_148[0][0]']        
 stanceNormalization)           )                                                                 
                                                                                                  
 conv2d_240 (Conv2D)            (None, 64, 64, 64)   32832       ['instance_normalization_183[0][0
                                                                 ]']                              
                                                                                                  
 leaky_re_lu_149 (LeakyReLU)    (None, 64, 64, 64)   0           ['conv2d_240[0][0]']             
                                                                                                  
 instance_normalization_184 (In  (None, 64, 64, 64)  2           ['leaky_re_lu_149[0][0]']        
 stanceNormalization)                                                                             
                                                                                                  
 conv2d_241 (Conv2D)            (None, 32, 32, 128)  131200      ['instance_normalization_184[0][0
                                                                 ]']                              
                                                                                                  
 leaky_re_lu_150 (LeakyReLU)    (None, 32, 32, 128)  0           ['conv2d_241[0][0]']             
                                                                                                  
 instance_normalization_185 (In  (None, 32, 32, 128)  2          ['leaky_re_lu_150[0][0]']        
 stanceNormalization)                                                                             
                                                                                                  
 conv2d_242 (Conv2D)            (None, 16, 16, 256)  524544      ['instance_normalization_185[0][0
                                                                 ]']                              
                                                                                                  
 leaky_re_lu_151 (LeakyReLU)    (None, 16, 16, 256)  0           ['conv2d_242[0][0]']             
                                                                                                  
 instance_normalization_186 (In  (None, 16, 16, 256)  2          ['leaky_re_lu_151[0][0]']        
 stanceNormalization)                                                                             
                                                                                                  
 up_sampling2d_72 (UpSampling2D  (None, 32, 32, 256)  0          ['instance_normalization_186[0][0
 )                                                               ]']                              
                                                                                                  
 conv2d_243 (Conv2D)            (None, 32, 32, 128)  524416      ['up_sampling2d_72[0][0]']       
                                                                                                  
 instance_normalization_187 (In  (None, 32, 32, 128)  2          ['conv2d_243[0][0]']             
 stanceNormalization)                                                                             
                                                                                                  
 concatenate_54 (Concatenate)   (None, 32, 32, 256)  0           ['instance_normalization_187[0][0
                                                                 ]',                              
                                                                  'instance_normalization_185[0][0
                                                                 ]']                              
                                                                                                  
 up_sampling2d_73 (UpSampling2D  (None, 64, 64, 256)  0          ['concatenate_54[0][0]']         
 )                                                                                                
                                                                                                  
 conv2d_244 (Conv2D)            (None, 64, 64, 64)   262208      ['up_sampling2d_73[0][0]']       
                                                                                                  
 instance_normalization_188 (In  (None, 64, 64, 64)  2           ['conv2d_244[0][0]']             
 stanceNormalization)                                                                             
                                                                                                  
 concatenate_55 (Concatenate)   (None, 64, 64, 128)  0           ['instance_normalization_188[0][0
                                                                 ]',                              
                                                                  'instance_normalization_184[0][0
                                                                 ]']                              
                                                                                                  
 up_sampling2d_74 (UpSampling2D  (None, 128, 128, 12  0          ['concatenate_55[0][0]']         
 )                              8)                                                                
                                                                                                  
 conv2d_245 (Conv2D)            (None, 128, 128, 32  65568       ['up_sampling2d_74[0][0]']       
                                )                                                                 
                                                                                                  
 instance_normalization_189 (In  (None, 128, 128, 32  2          ['conv2d_245[0][0]']             
 stanceNormalization)           )                                                                 
                                                                                                  
 concatenate_56 (Concatenate)   (None, 128, 128, 64  0           ['instance_normalization_189[0][0
                                )                                ]',                              
                                                                  'instance_normalization_183[0][0
                                                                 ]']                              
                                                                                                  
 up_sampling2d_75 (UpSampling2D  (None, 256, 256, 64  0          ['concatenate_56[0][0]']         
 )                              )                                                                 
                                                                                                  
 conv2d_246 (Conv2D)            (None, 256, 256, 3)  3075        ['up_sampling2d_75[0][0]']       
                                                                                                  
==================================================================================================
Total params: 1,545,425
Trainable params: 1,545,425
Non-trainable params: 0
__________________________________________________________________________________________________

また、ファイルについては以下のようになっています。

├── datasets
│   └── facades
│       ├── test
│       ├── train
│       └── val
└── Keras_cycleGAN.ipynb

また、画像は

このように右側の元画像から左の画像を生成するような場合を考えています。

コード

参考文献

[1]TensorflowとKeras、PyTorchの比較

[2]実践GAN ~敵対的生成ネットワークによる深層学習~

[3]/Keras-GAN

コメント

タイトルとURLをコピーしました