StarGAN v2: Diverse Image Synthesis for Multiple Domains
abstract
generate diverse images (여러개의 target domain으로 매핑 가능)
scalable over multiple domains
이 논문에서 domain은 특정 사람의 얼굴, 성별과 같은 내용이고 source image에 대응하고, style은 헤어/메이크업/수염과 같은 것들을 가리키며 reference image에 대응한다
contribution
domain label -> domain specific style code
mapping network: Gaussian noise z -> style code s, style encoder: reference image x -> style code s
method/architecture: 총 네개의 모듈, 네 개의 loss function
1) generator G(x, s), s는 2) or 3)에서.
s를 넣는 방법은 adaIN, s는 y의 style을 나타내도록 디자인이 됨.
4 downsampling, 4 intermiediate block, 4 upsampling (IN, adaIN), style code는 adaIN으로 들어감. 모두 다 residual block임
2) mapping network s = Fy(z)
: latent code to style, z는 Gaussian noise
FC4개가 있고, 또다른 FC4개 쌓은게 K개만큼 있음 (K=도메인 수)
3) style encoder s = Ey(x)
: image to style, 다양한 레퍼런스 이미지를 이용해서 다양한 스타일 코드를 생성함
CNN stack이고...마지막에 K개의 linear가 있음
4) discriminator
: 구조는 3)과 똑같고,, 대신에 D의 dimension은 1임 (0 또는 1)
loss function
1) adv.
: eq1, 실제 이미지로 만든거랑 target domain의 style code로 만든 fake image를 구분 못하도록 하는게 목적.
2) style reconstruction
: eq2, 이전에는 encoder가 스타일마다 여러개였지만 여기선 인코더가 한개임...output은 multiple domain이지만...
스타일 코드 s에 대한 L1 loss
3) style diversification
: eq3, Generator의 output에 대한 regularizer같은 역할 (L1)이고 이걸 최대로 만들면...스타일 종류가 다양해지게 만들 수 있음
4) cycle consistency loss
: 생성된 이미지가 포즈와 같이 domain-invariant한 특징을 보존한다는 것을 보장하기 위해서,,
total loss = 1+ 2 - 3 + 4
코드 봤더니 K값은 보통 2로 두는거같음 남자/여자인감...
결국 contribution은 latent를 이용할 때 이전처럼 바로 z = N(0, I)에서 이용하는게 아니고.... z를 s라는 다른 subspace로 보내는데에 있는거같음. 이걸 생각하면 styleGAN하고도 좀 비슷하고 그렇다. starGAN자체가 attribute를 그냥 class의 combination으로 가져갔으니까 거기서 latent부분을 심화?시켰다고 보면 되겠군...
*** 비교 starGAN
input: x: real image, c_ori: original label, c_smpl: sampled categorical...
Discriminator D: input 1개(image) output 2개 (real/fake, class) // Generator G: input 2개 (c, x), output 1개 (image)
1a) x -> [ D ] -> 0/1, cls 에서 d_loss_real, d_loss_class 구하고
1b) x -> [ G ] -> [ D ] -> 0/1, _ 에서 d_loss_fake 구하고
1c) d_loss_gp를 구함 (from WGAN-GP)
-> 여기 나오는 loss를 다 더해서 D에 대해 BP
2a) x_fake = G(x, c_smpl)일 때 D(x_fake)로 g_loss_fake, g_loss_cls 구하고
2b) G(x_fake, c)를 이용하여 x_recover 구해서 g_cyclic loss 구하고
-> 여기 나오는 loss를 다 더해서 G에 대해 BP
*** GAN update code들
* vanilla GAN
input: x: real images // z: latent variable sampling from normal distribution
1) z -> [ G ] -> x_fake, g_loss = BCE( D(x_fake), 1 ) # to fool G and update G w/ g_loss
2) d_loss = BCE( D(x_real), 1 ) + BCE( D(x_fake), 0 ) # and update D w/ d_loss
*conditional GAN
input: x: real_images y: real_labels// z: latent sampling from N(0, 1), y_input: generated lables from categorical
1) x_fake = G(z, y_input), g_loss =
* infoGAN
input: x: real images, y: real label // z: latent sampling from N(0, 1), y_smpl: sampling from categorical, c: code sampling from U(-1, 1)
- code라고 부르는 latent c가 추가되고, 이 c와 x_fake 사이의 MI가 최대가 되도록 트레이닝. 여기서 c는 class외의 attribute라고 볼 수 있겠음. c는 uniform, y는 cat
1) z, c, y_smpl -> [ G ] -> x_fake, g_loss = MSE( D(x_fake), 1) # fool and update G
2) d_loss = MSE( D(x_real), 1 ) + MSE( D(x_fake), 0 ) # and update D
3) info_loss = CE( D(x_fake), y ) + MSE( D(x_fake), c ) # and update all parameters G and D
여기서 Discriminator는 마지막에 FC가 3개 달려있는데...각각 0/1, class, code이렇게 연결됨
** multi-scale discriminator/generator (pix2pixhd이고, melGAN에서 나오는 내용인데 mel을 reconstruction 할때는 multi-scale Generator가 좋았다는 얘기가 있음)
** mel-spectrum 자체를 disentangling하는 다양한 방법들과 이미지에서 찾아볼 수 있는 차이점은 무엇???