论文链接:InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets

Disentangled Representation

infoGAN的核心是通过无监督模型学得数据实例的Disentangled Representation(解耦表征)

对于什么是Disentangled Representation,目前并没有比较正式的定义
这个概念最早出现于2013年Bengio Representation Learning: A Review and New Perspectives 一文中

Single latent units are sensitive to changes in single generative factors, while being relatively invariant to changes in other factors.

即“单个隐单元仅对单个生成因子的变化敏感,而对其他因子的变化不敏感”

直观理解为,将某实例的representation分解为若干个相互独立变化因子(variation factors)
当单个因子变化时,仅使生成数据中单个因素发生变化,且每个因子都有一定的语义含义

若能使用无监督模型学得一组好的Disentangled Representation,那么对于未知的下游任务将十分有意义,这就是InfoGAN的目的

InfoGAN

互信息Mutual Information

InfoGAN的理论基础是信息论中互信息Mutual Information的概念

互信息是衡量随机变量之间相互依赖程度的度量
对于随机变量$X,Y$,它们之间的互信息为

其直观解释就是,在已知$Y$的情况下,$X$不确定度的减少量,显然减少越多两者相关性越强
互信息和信息增益本质是一样的,只是不同语境中的不同解释

互信息最大化

假设数据的语义特征由一组隐变量$c1,c_2,…,c_L$控制,且有$P(c_1,c_2,…,c_L)=\prod{i=1}^LP(c_i)$,即他们相互独立,并令$c=concat(c1, c2, …, c_L)$,称为隐编码latent code

则InfoGAN生成器的输入为噪声向量$z$和latent code $c$

此时若以标准GAN的方式训练,则生成器可能学得$P_G(x|c)=P_G(x)$导致latent code直接被忽略
因此必须保证latent code和生成器分布G(z,c)之间有较大的互信息,即$I(c,G(z,c))$要大

然而由于后验概率$P(c|x)$未知,$I(c,G(z,c))$很难直接计算,因此考虑引入辅助分布$Q(c|x)$来近似$P(c|x)$

此处引入一个引理(证明见论文附录)

对于随机变量$X,Y$和函数$f(x,y)$,在一定条件下有:$E{x\sim X,y\sim Y|x}[f(x,y)]=E{x\sim X,y\sim Y|x,x’\sim X|y}[f(x’,y)]$

则有

这样我们就求出了$I(c,G(z,c))$与后验概率$P(c|x)$无关的下界$L_I(G,Q)$,只要最大化$L_I(G,Q)$即可

综上InfoGAN的目标函数为

代码实现

对于辅助分布$Q(c|x)$,我们可以直接用神经网络表示(除输出层外其余与D共用),此时有

则其代码实现为

1
2
3
4
5
6
7
8
9
class InfoGAN:
def mutual_info_loss(self, c, c_given_x):
# LI(G, Q) given by the paper
eps = 1e-9
conditional_entropy = K.mean(- K.sum(K.log(c_given_x + eps) * c, axis=1))
entropy = K.mean(- K.sum(K.log(c + eps) * c, axis=1))

loss = conditional_entropy + entropy
return loss

对于表示类别的特殊latent code,由于one-hot处理,其分布类似于c=[0,0,..,1,…,0,0]
此时$H(c)=0$,若Q输出经过softmax处理,则$L_I$等价于交叉熵

进一步的,由如下推导

可知最大化$L_I(G,Q)$即最小化c与Q的KL散度
因此对于更一般的离散型latent code,也可以使用交叉熵作为loss

此外对于连续的latent code,可以直接假设其服从均匀/正态分布
以正态分布为例,此时Q只需两个输出,分别表示分布的均值和标准差(标准差需要exp激活保证为正数)
其loss可根据正态分布间的KL散度公式计算,也可以简单的用mse替代

Mnist示例

需注意的是,InfoGAN最终学得的类别标签与数字图像的对应关系可能是“不正确”的
因为与CGAN等不同,我们在给InfoGAN传递类别标签时并不是将已知的对应关系告诉它,而是让他自行发现标签与图像间的关系,因此其学得的对应关系会看起来是“不正确”的

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
""" Information Maximizing GAN
paper: InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets
see: https://proceedings.neurips.cc/paper/2016/hash/7c9d0b1f96aebd7b5eca8c3edaa19ebb-Abstract.html
"""

from keras.utils import np_utils
from keras.datasets import mnist
from keras.models import Model
from keras.layers import Dense, Flatten, BatchNormalization, Reshape, Input, \
LeakyReLU, Conv2D, Conv2DTranspose, Activation, Dropout
from keras.optimizers import Adam
# from keras.optimizers import adam_v2
import keras.backend as K
import numpy as np
import matplotlib.pyplot as plt


class InfoGAN:
def __init__(self):
self.img_row = 28
self.img_col = 28
self.channel = 1
self.img_shape = (self.img_row, self.img_col, self.channel)

self.noise_dim = 62
self.num_class = 10
self.num_latent_code = 1
self.latent_dim = self.noise_dim + self.num_class + self.num_latent_code

self.buildGAN()

def buildGenerator(self):
inputs = Input(shape=(self.latent_dim,))

x = Dense(1024)(inputs)
x = BatchNormalization(momentum=0.8)(x)
x = Activation('relu')(x)

x = Dense(7 * 7 * 256)(x)
x = BatchNormalization(momentum=0.8)(x)
x = Activation('relu')(x)

x = Reshape((7, 7, 256))(x)

x = Conv2DTranspose(256, kernel_size=3, strides=2, padding='same')(x)
x = BatchNormalization(momentum=0.8)(x)
x = Activation('relu')(x)

x = Conv2DTranspose(128, kernel_size=3, strides=2, padding='same')(x)
x = BatchNormalization(momentum=0.8)(x)
x = Activation('relu')(x)

x = Conv2DTranspose(64, kernel_size=3, padding='same')(x)
x = BatchNormalization(momentum=0.8)(x)
x = Activation('relu')(x)

img = Conv2DTranspose(self.channel, kernel_size=3, padding='same', activation='tanh')(x)

return Model(inputs, img)

def buildDiscriminator(self):
input_img = Input(shape=self.img_shape)

x = Conv2D(64, kernel_size=3, strides=2, padding='same')(input_img)
x = LeakyReLU(alpha=0.2)(x)
x = Dropout(0.4)(x)

x = Conv2D(128, kernel_size=3, strides=2, padding='same')(x)
x = LeakyReLU(alpha=0.2)(x)
x = Dropout(0.4)(x)

x = Conv2D(256, kernel_size=3, strides=2, padding='same')(x)
x = LeakyReLU(alpha=0.2)(x)
x = Dropout(0.4)(x)

x = Flatten()(x)

validity = Dense(1, activation='sigmoid')(x)
label = Dense(self.num_class, activation='softmax')(x)

c2 = Dense(2)(x)
c2 = Activation(self.dev_positive)(c2)

return Model(input_img, validity), Model(input_img, [label, c2])

def dev_positive(self, x):
# the standard deviation is parameterized through an exponential transformation to ensure positivity
x = K.concatenate([x[:, 0:1], K.exp(x[:, 1:2])])
return x

def gauss_kl_divergence(self, p, q):
# KL divergence between two gaussian distributions
mu1, sigma1 = p[:, 0], p[:, 1]
mu2, sigma2 = q[:, 0], q[:, 1]
t1 = K.log(sigma2 / sigma1)
t2 = (K.square(sigma1) + K.square(mu1 - mu2)) / K.square(sigma2) / 2
kl_div = K.mean(t1 + t2 + 0.5)
return kl_div

def mutual_info_loss(self, c, c_given_x):
# LI(G, Q) given by the paper
eps = 1e-9
conditional_entropy = K.mean(- K.sum(K.log(c_given_x + eps) * c, axis=1))
entropy = K.mean(- K.sum(K.log(c + eps) * c, axis=1))

loss = conditional_entropy + entropy
return loss

def buildGAN(self):
self.generator = self.buildGenerator()
self.discriminator, self.auxiliary = self.buildDiscriminator()
self.discriminator.compile(optimizer=Adam(2e-4), loss='binary_crossentropy', metrics=['acc'])

inputs = Input(shape=(self.latent_dim,))
img = self.generator(inputs)

# setting trainable to false after compiling
# the discriminator will be frozen only when training the combined
self.discriminator.trainable = False

validity = self.discriminator(img)
label, c2 = self.auxiliary(img)

self.combined = Model(inputs, [validity, label, c2])

# for latent code of labels, LI(G, Q) is equivalent to cross entropy
self.combined.compile(
optimizer=Adam(1e-3),
loss=['binary_crossentropy', 'categorical_crossentropy', self.gauss_kl_divergence]
)

def getRandomInput(self, batch_size):
# noise with standard normal distribution,
noise = np.random.normal(0, 1, (batch_size, self.noise_dim))

# randomly generated category label (discrete latent code)
labels = np.random.randint(0, self.num_class, (batch_size,))
labels = np_utils.to_categorical(labels, num_classes=self.num_class)

# continuous latent code c2
latent_c2 = np.random.normal(0, 1 / 3, (batch_size, 1))

# concatenate noise and latent code
inputs = np.hstack((noise, labels, latent_c2))
assert inputs.shape == (batch_size, self.latent_dim)

return inputs, labels

def trainModel(self, epochs, batch_size=64):
(X_train, Y_train), (_, _) = mnist.load_data()

# Normalize to [-1, 1]
X_train = X_train / 127.5 - 1.
X_train = np.expand_dims(X_train, axis=3)

valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))

for epoch in range(epochs):
epoch += 1

# randomly select a batch of real images
idx = np.random.randint(0, X_train.shape[0], batch_size)
org_imgs = X_train[idx]

# generate fake images
inputs, labels = self.getRandomInput(batch_size)
genImgs = self.generator.predict(inputs)

# train the discriminator
D_loss_real = self.discriminator.train_on_batch(org_imgs, valid)
D_loss_fake = self.discriminator.train_on_batch(genImgs, fake)
D_loss = 0.5 * np.add(D_loss_real, D_loss_fake)

# train the generator
c2 = np.zeros((batch_size, 2))
c2[:, 1] = 1 / 3
G_loss = self.combined.train_on_batch(inputs, [valid, labels, c2])

print("epoch {} --- D loss: {:.4f}, acc: {:.2f}, G loss: {:.4f}".format(
epoch, D_loss[0], 100 * D_loss[1], G_loss[0]))

if epoch % 1000 == 0:
self.saveImage(epoch)

def saveImage(self, epoch):
r, c = 3, 10

noise = np.random.normal(0, 1, (r * c, self.noise_dim))

# [[0, 1, 2, ..., 9], [0, 1, 2, ..., 9], [0, 1, 2, ..., 9]]
labels = np.tile(np.arange(0, c), r)
labels = np_utils.to_categorical(labels, num_classes=self.num_class)

# [[-1, -1, ..., -1], [0, 0, ..., 0], [1, 1, ..., 1]]
c2 = np.zeros((r * c, 1))
c2[:c, :], c2[r * c - c + 1:, :] = -1, 1

inputs = np.hstack((noise, labels, c2))
assert inputs.shape == (r * c, self.latent_dim)
genImgs = self.generator.predict(inputs)

# Rescale to [0, 1]
genImgs = 0.5 * genImgs + 0.5

fig, axs = plt.subplots(r, c)
cnt = 0
for i in range(r):
for j in range(c):
axs[i, j].imshow(genImgs[cnt, :, :, 0], cmap='gray')
axs[i, j].axis('off')
cnt += 1
fig.savefig('output/%d.png' % epoch)
plt.close()


def main():
infoGAN = InfoGAN()
infoGAN.trainModel(epochs=20000, batch_size=64)


if __name__ == '__main__':
main()