CGAN

论文链接:Conditional Generative Adversarial Nets

原始的GAN在图像生成上取得了很大成功
但是原始GAN仅以某个先验的随机噪声作为G的输入
使得G并不总能生成我们想要的图像

因此便有了CGAN(Conditional GAN),即条件式GAN

CGAN

如图所示,除了随机噪声z,CGAN还在G和D的输入端都加入了条件输入c
条件输入c的选择以及c与z的结合方式可以有很多选择,没有特别限制

以mnist手写数字数据集为例
可以令c为数字的类别,让后用keras的embedding层将c转换为和z尺寸相同的向量然后相乘

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from keras.datasets import mnist
from keras.models import Model, Sequential
from keras.layers import Dense, Flatten, BatchNormalization, Reshape, Input, \
LeakyReLU, Embedding, multiply, Activation
from keras.optimizers import Adam
import numpy as np
import matplotlib.pyplot as plt

class CGAN:
def __init__(self):
self.img_row = 28
self.img_col = 28
self.channel = 1
self.img_shape = (self.img_row, self.img_col, self.channel)
self.latent_dim = 100
self.num_class = 10

self.buildGAN()

def buildGenerator(self):
model = Sequential()

model.add(Dense(input_dim=self.latent_dim, units=256))
model.add(BatchNormalization(momentum=0.9))
model.add(Activation('relu'))

model.add(Dense(512))
model.add(BatchNormalization(momentum=0.9))
model.add(Activation('relu'))

model.add(Dense(1024))
model.add(BatchNormalization(momentum=0.9))
model.add(Activation('relu'))

model.add(Dense(np.prod(self.img_shape), activation='tanh'))
model.add(Reshape(self.img_shape))

noise = Input(shape=(self.latent_dim,))

label = Input(shape=(1,))
labelEmbeded = Flatten()(Embedding(self.num_class, self.latent_dim)(label))

input = multiply([noise, labelEmbeded])
img = model(input)

return Model([noise, label], img)

def buildDiscriminator(self):
model = Sequential()

model.add(Dense(512))
model.add(LeakyReLU(0.2))

model.add(Dense(256))
model.add(LeakyReLU(0.2))

model.add(Dense(64))
model.add(LeakyReLU(0.2))

model.add(Dense(1, activation='sigmoid'))

img = Input(shape=self.img_shape)
imgFlat = Flatten()(img)

label = Input(shape=(1, ))
labelEmbeded = Flatten()(Embedding(self.num_class, np.prod(self.img_shape))(label))

input = multiply([imgFlat, labelEmbeded])
validity = model(input)

discriminator = Model([img, label], validity)
discriminator.compile(optimizer=Adam(2e-4), loss='binary_crossentropy')
return discriminator

def buildGAN(self):
self.generator = self.buildGenerator()
self.discriminator = self.buildDiscriminator()

noise = Input(shape=(self.latent_dim,))
label = Input(shape=(1,))
img = self.generator([noise, label])

self.discriminator.trainable = False

validity = self.discriminator([img, label])

self.combined = Model([noise, label], validity)
self.combined.compile(optimizer=Adam(2e-4), loss='binary_crossentropy')

def trainModel(self, epochs, batch_size=64):
(X_train, Y_train), (X_test, Y_test) = mnist.load_data()

# Normalize [-1, 1]
X_train = X_train / 127.5 - 1.
X_train = np.expand_dims(X_train, axis=3)

valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))

for epoch in range(epochs):
epoch += 1

# 随机选取一个batch的图片
idx = np.random.randint(0, X_train.shape[0], batch_size)
orgImgs = X_train[idx]
label = Y_train[idx]

# 生成一个标准正版态分布(mu=0,sigma=1)作为输入噪声
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
sampleLabel = np.random.randint(0, 10, batch_size).reshape(-1, 1)

genImgs = self.generator.predict([noise, sampleLabel])

# Train the discriminator
D_loss_real = self.discriminator.train_on_batch([orgImgs, label], valid)
D_loss_fake = self.discriminator.train_on_batch([genImgs, sampleLabel], fake)
D_loss = 0.5 * np.add(D_loss_real, D_loss_fake)

# Train Generator
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
sampleLabel = np.random.randint(0, 10, batch_size).reshape(-1, 1)
G_loss = self.combined.train_on_batch([noise, sampleLabel], valid)

print("{} --- D loss: {:.4f} , G loss: {:.4f}".format(epoch, D_loss, G_loss))

if epoch % 400 == 0:
self.saveImage(epoch)


def saveImage(self, epoch):
r, c = 3, 10
noise = np.random.normal(0, 1, (r * c, self.latent_dim))
label = np.arange(0, 10).reshape(-1, 1)
for i in range(2):
label = np.vstack((label, np.arange(0, 10).reshape(-1, 1)))

genImgs = self.generator.predict([noise, label])

# Rescale images 0 - 1
genImgs = 0.5 * genImgs + 0.5

fig, axs = plt.subplots(r, c)
cnt = 0
for i in range(r):
for j in range(c):
axs[i, j].imshow(genImgs[cnt, :, :, 0], cmap='gray')
axs[i, j].axis('off')
cnt += 1
fig.savefig('generated\\%d.png' % epoch)
plt.close()

def main():
cgan = CGAN()
cgan.trainModel(epochs=10000)

if __name__ == '__main__':
main()

ACGAN

论文链接:Conditional Image Synthesis with Auxiliary Classifier GANs

ACGAN全称为Auxiliary Classifier GAN,即辅助分类器GAN
ACGAN也是一种条件式GAN,其结构如图所示

ACGAN

它与原始CGAN不同在于,条件信息c不输入D,而是尝试让D重建条件信息
以mnist手写数字数据集为例,ACGAN的判别器不需要输入标签条件,而是需要预测图像的标签

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
from keras.datasets import mnist
from keras.models import Model, Sequential
from keras.layers import Dense, Flatten, BatchNormalization, Reshape, Input, \
LeakyReLU, Embedding, multiply, Activation, Conv2DTranspose, Conv2D, Dropout
from keras.optimizers import Adam
from keras.utils import np_utils
import numpy as np
import matplotlib.pyplot as plt

class CGAN:
def __init__(self):
self.img_row = 28
self.img_col = 28
self.channel = 1
self.img_shape = (self.img_row, self.img_col, self.channel)
self.latent_dim = 100
self.num_class = 10

self.buildGAN()

def buildGenerator(self):
model = Sequential()

model.add(Dense(input_dim=self.latent_dim, units=7 * 7 * 256))
model.add(BatchNormalization(momentum=0.8))
model.add(Activation('relu'))

model.add(Reshape((7, 7, 256)))

model.add(Conv2DTranspose(filters=128, kernel_size=3, strides=2, padding='same'))
model.add(BatchNormalization(momentum=0.8))
model.add(Activation('relu'))

model.add(Conv2DTranspose(filters=64, kernel_size=3, strides=2, padding='same'))
model.add(BatchNormalization(momentum=0.8))
model.add(Activation('relu'))

model.add(Conv2DTranspose(filters=32, kernel_size=3, strides=1, padding='same'))
model.add(BatchNormalization(momentum=0.8))
model.add(Activation('relu'))

model.add(Conv2DTranspose(filters=1, kernel_size=3, strides=1, padding='same'))
model.add(Activation('tanh'))

noise = Input(shape=(self.latent_dim,))

label = Input(shape=(1,))
labelEmbeded = Flatten()(Embedding(self.num_class, self.latent_dim)(label))

input = multiply([noise, labelEmbeded])
img = model(input)

return Model([noise, label], img)

def buildDiscriminator(self):
model = Sequential()

model.add(Conv2D(input_shape=self.img_shape, filters=64, kernel_size=3, strides=2, padding='same'))
model.add(BatchNormalization(momentum=0.8))
model.add(LeakyReLU(0.2))
model.add(Dropout(0.4))

model.add(Conv2D(filters=128, kernel_size=3, strides=2, padding='same'))
model.add(BatchNormalization(momentum=0.8))
model.add(LeakyReLU(0.2))
model.add(Dropout(0.4))


model.add(Conv2D(filters=256, kernel_size=3, strides=2, padding='same'))
model.add(BatchNormalization(momentum=0.8))
model.add(LeakyReLU(0.2))
model.add(Dropout(0.4))

model.add(Flatten())

img = Input(shape=self.img_shape)
x = model(img)

label_pred = Dense(10, activation='softmax')(x)
validity = Dense(1, activation='sigmoid')(x)

discriminator = Model(img, [validity, label_pred])
discriminator.compile(optimizer=Adam(2e-4),
loss=['binary_crossentropy', 'categorical_crossentropy'])
return discriminator

def buildGAN(self):
self.generator = self.buildGenerator()
self.discriminator = self.buildDiscriminator()

noise = Input(shape=(self.latent_dim,))
label = Input(shape=(1,))
img = self.generator([noise, label])

self.discriminator.trainable = False

validity, label_pred = self.discriminator(img)

self.combined = Model([noise, label], [validity, label_pred])
self.combined.compile(optimizer=Adam(2e-4),
loss=['binary_crossentropy', 'categorical_crossentropy'])

def trainModel(self, epochs, batch_size=64):
(X_train, Y_train), (X_test, Y_test) = mnist.load_data()

# Normalize [-1, 1]
X_train = X_train / 127.5 - 1.
X_train = np.expand_dims(X_train, axis=3)

valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))

for epoch in range(epochs):
epoch += 1

# 随机选取一个batch的图片
idx = np.random.randint(0, X_train.shape[0], batch_size)
orgImgs = X_train[idx]
label = Y_train[idx]
label_onehot = np_utils.to_categorical(label, num_classes=10)

# 生成一个标准正版态分布(mu=0,sigma=1)作为输入噪声
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
sampleLabel = np.random.randint(0, 10, batch_size).reshape(-1, 1)
sampleLabel_onehot = np_utils.to_categorical(sampleLabel, num_classes=10)

genImgs = self.generator.predict([noise, sampleLabel])

# Train the discriminator
D_loss_real = self.discriminator.train_on_batch(orgImgs, [valid, label_onehot])
D_loss_fake = self.discriminator.train_on_batch(genImgs, [fake, sampleLabel_onehot])
D_loss = 0.5 * np.add(D_loss_real, D_loss_fake)

# Train Generator
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
sampleLabel = np.random.randint(0, 10, batch_size).reshape(-1, 1)
sampleLabel_onehot = np_utils.to_categorical(sampleLabel, num_classes=10)
G_loss = self.combined.train_on_batch([noise, sampleLabel], [valid, sampleLabel_onehot])

print("epoch {} --- D_loss: {:.4f}, G loss: {:.4f}".format(
epoch, np.mean(D_loss), np.mean(G_loss)))

if epoch % 400 == 0:
self.saveImage(epoch)


def saveImage(self, epoch):
r, c = 3, 10
noise = np.random.normal(0, 1, (r * c, self.latent_dim))
label = np.arange(0, 10).reshape(-1, 1)
for i in range(2):
label = np.vstack((label, np.arange(0, 10).reshape(-1, 1)))

genImgs = self.generator.predict([noise, label])

# Rescale images 0 - 1
genImgs = 0.5 * genImgs + 0.5

fig, axs = plt.subplots(r, c)
cnt = 0
for i in range(r):
for j in range(c):
axs[i, j].imshow(genImgs[cnt, :, :, 0], cmap='gray')
axs[i, j].axis('off')
cnt += 1
fig.savefig('generated\\%d.png' % epoch)
plt.close()

def main():
cgan = CGAN()
cgan.trainModel(epochs=10000)

if __name__ == '__main__':
main()