论文链接:Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks

cycleGAN的提出

之前提到的GAN实现通用图像间转换 ——Pix2Pix详解
pix2pix实现了通用的Image to Image Translation模型

然而pix2pix的训练需要成对的图像数据集
很多时候我们无法获得这样的数据集,例如如图的转换

unpaired_image

针对这种问题,作者基于pix2pix的结构,提出了cycleGAN

cycleGAN的理论

cycleGAN

图(a)表示cycleGAN的对抗训练部分

cycleGAN含有两个生成器和两个判别器
两个生成器分别学习映射$G:X\rightarrow Y$和$F: Y\rightarrow X$,其中X与Y分别表示两个不成对的图像分布
两个判别器则是分别判断生成的两个假图像分布$\hat{X},\hat{Y}$​真实性的

cycle consistence loss

首先给出部分论文原文

Adversarial training can, in theory, learn mappings G and F that produce outputs identically distributed as target domains Y and X respectively (strictly speaking, this requires G and F to be stochastic functions).

However, with large enough capacity, a network can map the same set of input images to any random permutation of images in the target domain, where any of the learned mappings can induce an output distribution that matches the target distribution. Thus, adversarial losses alone cannot guarantee that the learned function can map an individual input xi to yi.

即如果只使用对抗训练,在网络容量和数据量足够大的情况下
网络可以将同一分布映射到目标分布的不同排列上
此时对于单个输入x,我们并不一定能获得想要的输出y

因此作者引入了cycle consistence loss(循环一致性损失)

上图中(b)(c)即表示cycleGAN进行cycle consistence训练的过程
即$\forall x\in X$​训练目标为$x\rightarrow G(x) \rightarrow F(G(x))\approx x$​,反之$\forall y\in Y$​同理

这部分的损失函数使用了L1损失,如下所示
(论文中作者尝试将L1替换成对抗损失,但没有明显的效果提升)

cycleGAN的最终损失函数为

论文中取$\lambda=10$​,且对抗训练采用least square loss而不是对数损失

生成器的训练目标为

identity loss

上述损失已经是cycleGAN针对大部分问题的损失了
但是对于art to photo和photo enhancement问题作者还引入了identity loss

这个损失以$0.5\lambda$的权重加入原损失

加入identity loss的因为是原损失可能出现如图的问题
即虽然达到了风格转换的目的,但也改变了一些不应该改变的输入信息

cycleGAN_identity

Keras代码实现

生成器和判别器结构依照论文附录构建

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
from keras.models import Model
from keras.layers import Dropout, Conv2D, LeakyReLU, Input, Add, Activation, Conv2DTranspose
from keras.optimizers import Adam
from keras.initializers import RandomNormal
from NormalizationLayer import InstanceNormalization
import os
import cv2
import numpy as np
import random

class DataLoader:
def __init__(self, dir_A, dir_B, batch_size, img_shape):
self.dir_A = dir_A
self.dir_B = dir_B

self.flist = os.listdir(dir_A)
self.fnum = len(self.flist)

self.batch_size = batch_size
self.img_shape = img_shape

self.idx_cur = 0

def getNumberOfBatch(self):
num = self.fnum / self.batch_size
if self.fnum % self.batch_size != 0:
num += 1
return int(num)

def reset(self):
self.idx_cur = 0
random.shuffle(self.flist)

def __iter__(self):
return self

def __next__(self):
if self.idx_cur >= self.fnum:
self.reset()
raise StopIteration

if self.idx_cur+self.batch_size-1 < self.fnum:
length = self.batch_size
idx_nxt = self.idx_cur+self.batch_size
else:
length = self.fnum-self.idx_cur
idx_nxt = self.fnum

imgA = np.zeros((length, *self.img_shape))
imgB = np.zeros((length, *self.img_shape))

for k in range(length):
fpath_A = os.path.join(self.dir_A, self.flist[self.idx_cur+k])
fpath_B = os.path.join(self.dir_B, self.flist[self.idx_cur+k])

img_a = cv2.imread(fpath_A, 1)
img_b = cv2.imread(fpath_B, 1)

imgA[k] = (img_a.astype(np.float32) - 127.5) / 127.5
imgB[k] = (img_b.astype(np.float32) - 127.5) / 127.5

self.idx_cur = idx_nxt

return imgA, imgB

class CycleGAN:
def __init__(self, L_id=False):
self.img_row = 256
self.img_col = 256
self.img_channels = 3
self.img_shape = (self.img_row, self.img_col, self.img_channels)

patch = int(self.img_row / 2 ** 4)
self.discPatch = (patch, patch, 1)

self.L_id = L_id
self.buildGAN(L_id)

def buildGenerator(self, num_resNet):
initWeight = RandomNormal(stddev=0.02)

def resBlock(inputs, filters):
x = Conv2D(filters, kernel_size=3, padding='same', kernel_initializer=initWeight)(inputs)
x = InstanceNormalization()(x)
x = Activation('relu')(x)

x = Conv2D(filters, kernel_size=3, padding='same', kernel_initializer=initWeight)(x)
x = InstanceNormalization()(x)
x = Activation('relu')(x)

x = Add()([x, inputs])
outputs = Activation('relu')(x)
return outputs

def convLayer(inputs, filters, k_size=3, stride=1, act='relu'):
x = Conv2D(filters, kernel_size=k_size, strides=stride, padding='same', kernel_initializer=initWeight)(inputs)
x = InstanceNormalization()(x)
outputs = Activation(act)(x)
return outputs

def deConvLayer(inputs, filters, k_size=3, stride=2):
x = Conv2DTranspose(filters, kernel_size=k_size, strides=stride, padding='same', kernel_initializer=initWeight)(inputs)
x = InstanceNormalization()(x)
outputs = Activation('relu')(x)
return outputs

img_input = Input(shape=self.img_shape)

x = convLayer(img_input, 64, k_size=7)
x = convLayer(x, 128, stride=2)
x = convLayer(x, 256, stride=2)

for _ in range(num_resNet):
x = resBlock(x, 256)

x = deConvLayer(x, 128)
x = deConvLayer(x, 64)
img_output = convLayer(x, 3, k_size=7, act='tanh')

return Model(img_input, img_output)

def buildDiscriminator(self):
initWeight = RandomNormal(stddev=0.02)

def discLayer(inputs, filters, k_size=4, norm=True):
x = Conv2D(filters, kernel_size=k_size, strides=2, padding='same', kernel_initializer=initWeight)(inputs)
if norm:
x = InstanceNormalization()(x)
outputs = LeakyReLU(alpha=0.2)(x)
return outputs

inputImg = Input(shape=self.img_shape)

disc1 = discLayer(inputImg, 64, norm=False)
disc2 = discLayer(disc1, 128)
disc3 = discLayer(disc2, 256)
disc4 = discLayer(disc3, 512)

validity = Conv2D(filters=1, kernel_size=4, padding='same', kernel_initializer=initWeight)(disc4)

return Model(inputImg, validity)

def buildGAN(self, L_id=False):
lambda_cycle = 10.0
lambda_id = 0.5 * lambda_cycle # used for arts -> photo
optimizer = Adam(2e-4, 0.5)

# build discriminator
self.disc_A = self.buildDiscriminator()
self.disc_B = self.buildDiscriminator()

self.disc_A.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])
self.disc_B.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])

# build generator, 6 ResBlock for 128x128, 9 for 256x256
self.gen_AB = self.buildGenerator(num_resNet=9)
self.gen_BA = self.buildGenerator(num_resNet=9)

imgA = Input(shape=self.img_shape)
imgB = Input(shape=self.img_shape)

fake_B = self.gen_AB(imgA)
fake_A = self.gen_BA(imgB)

reconstr_A = self.gen_BA(fake_B)
reconstr_B = self.gen_AB(fake_A)

self.disc_A.trainable = False
self.disc_B.trainable = False

valid_A = self.disc_A(fake_A)
valid_B = self.disc_B(fake_B)

outputs = [valid_A, valid_B, reconstr_A, reconstr_B]
loss = ['mse', 'mse', 'mae', 'mae']
loss_weights = [1, 1, lambda_cycle, lambda_cycle]

if L_id:
imgA_id = self.gen_BA(imgA)
imgB_id = self.gen_AB(imgB)

outputs.append(imgA_id)
loss.append('mae')
loss_weights.append(lambda_id)

outputs.append(imgB_id)
loss.append('mae')
loss_weights.append(lambda_id)

self.combined = Model(inputs=[imgA, imgB], outputs=outputs)
self.combined.compile(optimizer=optimizer, loss=loss, loss_weights=loss_weights)

def trainModel(self, epochs, batch_size=1):
self.dataLoader = DataLoader(
'F:/wallpaper/datas/test/trainA', 'F:/wallpaper/datas/test/trainB',
batch_size, self.img_shape
)

totalStep = self.dataLoader.getNumberOfBatch()
for epoch in range(epochs):
for step, (imgA, imgB) in enumerate(self.dataLoader):
valid = np.ones((imgA.shape[0],) + self.discPatch)
fake = np.zeros((imgA.shape[0],) + self.discPatch)

fake_B = self.gen_AB.predict(imgA)
fake_A = self.gen_BA.predict(imgB)

discA_loss_real = self.disc_A.train_on_batch(imgA, valid)
discA_loss_fake = self.disc_A.train_on_batch(fake_A, fake)
D_A_loss = 0.5 * np.add(discA_loss_real, discA_loss_fake)

discB_loss_real = self.disc_B.train_on_batch(imgB, valid)
discB_loss_fake = self.disc_B.train_on_batch(fake_B, fake)
D_B_loss = 0.5 * np.add(discB_loss_real, discB_loss_fake)

D_loss = 0.5 * np.add(D_A_loss, D_B_loss)

if self.L_id:
G_loss = self.combined.train_on_batch([imgA, imgB], [valid, valid, imgA, imgB, imgA, imgB])
print("Epoch {}/{} : Batch {}/{} -- D loss: {:.6f}, acc: {:.2f} , "
"G loss: {:.6f}, adv:{:.6f}, recon: {:.6f}, id: {:.6f}".format(
epoch+1, epochs, step+1, totalStep, D_loss[0], D_loss[1] * 100,
G_loss[0], np.mean(G_loss[1:3]), np.mean(G_loss[3:5]), np.mean(G_loss[5:6])
))
else:
G_loss = self.combined.train_on_batch([imgA, imgB], [valid, valid, imgA, imgB])
print("Epoch {}/{} : Batch {}/{} -- D loss: {:.6f}, acc: {:.2f} , "
"G loss: {:.6f}, adv:{:.6f}, recon: {:.6f}".format(
epoch+1, epochs, step+1, totalStep, D_loss[0], D_loss[1] * 100,
G_loss[0], np.mean(G_loss[1:3]), np.mean(G_loss[3:5])
))

if step % 200 == 0:
fpath = 'F:/wallpaper/datas/sketch/testB/1047028.png'
fname = 'output{}.png'.format(epoch)
self.colorizeImage(fpath=fpath, outputDir='output', fname=fname)

def colorizeImage(self, fpath, outputDir, fname):
img_input = cv2.imread(fpath, 1)
img_input = cv2.resize(img_input, (256, 256))

img_input = np.expand_dims(img_input, 0)
img_input = (img_input.astype(np.float32) - 127.5) / 127.5

img_output = self.gen_AB.predict(img_input)[0]
img_output = img_output * 127.5 + 127.5
img_output = img_output.astype(np.uint8)

outputPath = os.path.join(outputDir, fname)
if not os.path.exists(outputDir):
os.mkdir(outputDir)
cv2.imwrite(outputPath, img_output)

model = CycleGAN()
model.trainModel(epochs=5, batch_size=2)