Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
# Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks
## Alec Radford, Luke Metz, Soumith Chintala
## [Alec Radford]((https://github.com/newmu)), [Luke Metz](https://github.com/lukemetz), [Soumith Chintala](https://github.com/soumith)

All images in this paper are generated by a neural network. They are NOT REAL.

Full paper here: [http://arxiv.org/abs/1511.06434](http://arxiv.org/abs/1511.06434)

###Other implementations of DCGAN
* [Torch](https://github.com/soumith/dcgan.torch)
* [Chainer](https://github.com/mattya/chainer-DCGAN)
* [TensorFlow](https://github.com/carpedm20/DCGAN-tensorflow)

##Summary of DCGAN
We
- stabilize Generative Adversarial networks with some architectural constraints
Expand Down
84 changes: 84 additions & 0 deletions imagenet/load_pretrained.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import sys
sys.path.append('..')

import numpy as np
import theano
import theano.tensor as T
from theano.sandbox.cuda.dnn import dnn_conv

from lib import costs
from lib import inits
from lib import updates
from lib import activations
from lib.vis import color_grid_vis
from lib.rng import py_rng, np_rng
from lib.ops import batchnorm, conv_cond_concat, deconv, dropout, l2normalize
from lib.metrics import nnc_score, nnd_score
from lib.theano_utils import floatX, sharedX, intX
from lib.data_utils import OneHot, shuffle, iter_data, center_crop, patch

from sklearn.externals import joblib

"""
This example loads the 32x32 imagenet model used in the paper,
generates 400 random samples, and sorts them according to the
discriminator's probability of being real and renders them to
the file samples.png
"""

nz = 256
nc = 3
npx = 32
ngf = 128
ndf = 128

relu = activations.Rectify()
sigmoid = activations.Sigmoid()
lrelu = activations.LeakyRectify()
tanh = activations.Tanh()

model_path = '../models/imagenet_gan_pretrain_128f_relu_lrelu_7l_3x3_256z/'
gen_params = [sharedX(p) for p in joblib.load(model_path+'30_gen_params.jl')]
discrim_params = [sharedX(p) for p in joblib.load(model_path+'30_discrim_params.jl')]

def gen(Z, w, g, b, w2, g2, b2, w3, g3, b3, w4, g4, b4, w5, g5, b5, w6, g6, b6, wx):
h = relu(batchnorm(T.dot(Z, w), g=g, b=b))
h = h.reshape((h.shape[0], ngf*4, 4, 4))
h2 = relu(batchnorm(deconv(h, w2, subsample=(2, 2), border_mode=(1, 1)), g=g2, b=b2))
h3 = relu(batchnorm(deconv(h2, w3, subsample=(1, 1), border_mode=(1, 1)), g=g3, b=b3))
h4 = relu(batchnorm(deconv(h3, w4, subsample=(2, 2), border_mode=(1, 1)), g=g4, b=b4))
h5 = relu(batchnorm(deconv(h4, w5, subsample=(1, 1), border_mode=(1, 1)), g=g5, b=b5))
h6 = relu(batchnorm(deconv(h5, w6, subsample=(2, 2), border_mode=(1, 1)), g=g6, b=b6))
x = tanh(deconv(h6, wx, subsample=(1, 1), border_mode=(1, 1)))
return x

def discrim(X, w, w2, g2, b2, w3, g3, b3, w4, g4, b4, w5, g5, b5, w6, g6, b6, wy):
h = lrelu(dnn_conv(X, w, subsample=(1, 1), border_mode=(1, 1)))
h2 = lrelu(batchnorm(dnn_conv(h, w2, subsample=(2, 2), border_mode=(1, 1)), g=g2, b=b2))
h3 = lrelu(batchnorm(dnn_conv(h2, w3, subsample=(1, 1), border_mode=(1, 1)), g=g3, b=b3))
h4 = lrelu(batchnorm(dnn_conv(h3, w4, subsample=(2, 2), border_mode=(1, 1)), g=g4, b=b4))
h5 = lrelu(batchnorm(dnn_conv(h4, w5, subsample=(1, 1), border_mode=(1, 1)), g=g5, b=b5))
h6 = lrelu(batchnorm(dnn_conv(h5, w6, subsample=(2, 2), border_mode=(1, 1)), g=g6, b=b6))
h6 = T.flatten(h6, 2)
y = sigmoid(T.dot(h6, wy))
return y

def inverse_transform(X):
X = (X.reshape(-1, nc, npx, npx).transpose(0, 2, 3, 1)+1.)/2.
return X

Z = T.matrix()
X = T.tensor4()

gX = gen(Z, *gen_params)
dX = discrim(X, *discrim_params)

_gen = theano.function([Z], gX)
_discrim = theano.function([X], dX)

sample_zmb = floatX(np_rng.uniform(-1., 1., size=(400, 256)))
samples = _gen(sample_zmb)
scores = _discrim(samples)
sort = np.argsort(scores.flatten())[::-1]
samples = samples[sort]
color_grid_vis(inverse_transform(samples), (20, 20), 'samples.png')
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
43 changes: 43 additions & 0 deletions svhn/load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import sys
sys.path.append('..')

import os
import numpy as np
from scipy.io import loadmat

from lib.data_utils import shuffle
from lib.config import data_dir

def svhn(extra=False):
data = loadmat(os.path.join(data_dir, 'train_32x32.mat'))
trX = data['X'].transpose(3, 2, 0, 1)
trY = data['y'].flatten()-1
data = loadmat(os.path.join(data_dir, 'test_32x32.mat'))
teX = data['X'].transpose(3, 2, 0, 1)
teY = data['y'].flatten()-1
if extra:
data = loadmat(os.path.join(data_dir, 'extra_32x32.mat'))
exX = data['X'].transpose(3, 2, 0, 1)
exY = data['y'].flatten()-1
return trX, exX, teX, trY, exY, teY
return trX, teX, trY, teY

def svhn_with_valid_set(extra=False):
if extra:
trX, exX, teX, trY, exY, teY = svhn(extra=extra)
else:
trX, teX, trY, teY = svhn(extra=extra)
trX, trY = shuffle(trX, trY)
vaX = trX[:10000]
vaY = trY[:10000]
trX = trX[10000:]
trY = trY[10000:]
if extra:
trS = np.asarray([1 for _ in range(len(trY))] + [0 for _ in range(len(exY))])
trX = np.concatenate([trX, exX], axis=0)
trY = np.concatenate([trY, exY], axis=0)
trX, trY, trS = shuffle(trX, trY, trS)
if extra:
return trX, vaX, teX, trY, vaY, teY, trS
else:
return trX, vaX, teX, trY, vaY, teY
176 changes: 176 additions & 0 deletions svhn/svhn_semisup_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import sys
sys.path.append('..')

import os
import json
from time import time
import numpy as np
from tqdm import tqdm
from sklearn.externals import joblib

from sklearn import metrics
from sklearn.linear_model import LogisticRegression as LR
from sklearn.svm import LinearSVC as LSVC

import theano
import theano.tensor as T
from theano.sandbox.cuda.dnn import dnn_conv, dnn_pool

from lib import activations
from lib import updates
from lib import inits
from lib.vis import color_grid_vis
from lib.rng import py_rng, np_rng
from lib.ops import batchnorm, conv_cond_concat, deconv, dropout
from lib.theano_utils import floatX, sharedX
from lib.data_utils import OneHot, shuffle, iter_data
from lib.metrics import nnc_score, nnd_score
from lib.costs import MSE,CCE

from load import svhn_with_valid_set

relu = activations.Rectify()
lrelu = activations.LeakyRectify(leak=0.2)
sigmoid = activations.Sigmoid()

trX, vaX, teX, trY, vaY, teY = svhn_with_valid_set(extra=False)

vaX = floatX(vaX)/127.5-1.
trX = floatX(trX)/127.5-1.
teX = floatX(teX)/127.5-1.

X = T.tensor4()

desc = 'svhn_unsup_all_conv_dcgan_100z_gaussian_lr_0.0005_64mb'
epoch = 200
params = [sharedX(p) for p in joblib.load('../models/%s/%d_discrim_params.jl'%(desc, epoch))]
print desc.upper()
print 'epoch %d'%epoch

def mean_and_var(X):
u = T.mean(X, axis=[0, 2, 3])
s = T.mean(T.sqr(X - u.dimshuffle('x', 0, 'x', 'x')), axis=[0, 2, 3])
return u, s

def bnorm_statistics(X, w, w2, g2, b2, w3, g3, b3, wy):
h = lrelu(dnn_conv(X, w, subsample=(2, 2), border_mode=(2, 2)))

h2 = dnn_conv(h, w2, subsample=(2, 2), border_mode=(2, 2))
h2_u, h2_s = mean_and_var(h2)
h2 = lrelu(batchnorm(h2, g=g2, b=b2))

h3 = dnn_conv(h2, w3, subsample=(2, 2), border_mode=(2, 2))
h3_u, h3_s = mean_and_var(h3)
h3 = lrelu(batchnorm(h3, g=g3, b=b3))

h_us = [h2_u, h3_u]
h_ss = [h2_s, h3_s]
return h_us, h_ss

def infer_bnorm_stats(X, nbatch=128):
U = [np.zeros(128, dtype=theano.config.floatX), np.zeros(256, dtype=theano.config.floatX)]
S = [np.zeros(128, dtype=theano.config.floatX), np.zeros(256, dtype=theano.config.floatX)]
n = 0
for xmb in iter_data(X, size=nbatch):
stats = _bnorm_stats(floatX(xmb))
umb = stats[:2]
smb = stats[2:]
for i, u in enumerate(umb):
U[i] += u
for i, s in enumerate(smb):
S[i] += s
n += 1
U = [u/n for u in U]
S = [s/n for s in S]
return U, S

def model(X,
h2_u, h3_u,
h2_s, h3_s,
w, w2, g2, b2, w3, g3, b3, wy
):
h = lrelu(dnn_conv(X, w, subsample=(2, 2), border_mode=(2, 2)))
h2 = lrelu(batchnorm(dnn_conv(h, w2, subsample=(2, 2), border_mode=(2, 2)), g=g2, b=b2, u=h2_u, s=h2_s))
h3 = lrelu(batchnorm(dnn_conv(h2, w3, subsample=(2, 2), border_mode=(2, 2)), g=g3, b=b3, u=h3_u, s=h3_s))
h = T.flatten(dnn_pool(h, (4, 4), (4, 4), mode='max'), 2)
h2 = T.flatten(dnn_pool(h2, (2, 2), (2, 2), mode='max'), 2)
h3 = T.flatten(dnn_pool(h3, (1, 1), (1, 1), mode='max'), 2)
f = T.concatenate([h, h2, h3], axis=1)
return [f]

X = T.tensor4()

h_us, h_ss = bnorm_statistics(X, *params)
_bnorm_stats = theano.function([X], h_us + h_ss)

trU, trS = infer_bnorm_stats(trX)

HUs = [sharedX(u) for u in trU]
HSs = [sharedX(s) for s in trS]

targs = [X]+HUs+HSs+params
f = model(*targs)
_features = theano.function([X], f)

def features(X, nbatch=128):
Xfs = []
for xmb in iter_data(X, size=nbatch):
fmbs = _features(floatX(xmb))
for i, fmb in enumerate(fmbs):
Xfs.append(fmb)
return np.concatenate(Xfs, axis=0)

cs = [0.0001, 0.0002, 0.0005, 0.001, 0.002, 0.005, 0.01]
vaXt = features(vaX)
mean_va_accs = []
for c in cs:
tr_accs = []
va_accs = []
te_accs = []
for _ in tqdm(range(10), leave=False, ncols=80):
idxs = np.arange(len(trX))
classes_idxs = [idxs[trY==y] for y in range(10)]
sampled_idxs = [py_rng.sample(class_idxs, 100) for class_idxs in classes_idxs]
sampled_idxs = np.asarray(sampled_idxs).flatten()

trXt = features(trX[sampled_idxs])

model = LSVC(C=c)
model.fit(trXt[:1000], trY[sampled_idxs])
tr_pred = model.predict(trXt)
va_pred = model.predict(vaXt)
tr_acc = metrics.accuracy_score(trY[sampled_idxs], tr_pred[:1000])
va_acc = metrics.accuracy_score(vaY, va_pred)
tr_accs.append(100*(1-tr_acc))
va_accs.append(100*(1-va_acc))
mean_va_accs.append(np.mean(va_accs))
print 'c: %.4f train: %.4f %.4f valid: %.4f %.4f'%(c, np.mean(tr_accs), np.std(tr_accs)*1.96, np.mean(va_accs), np.std(va_accs)*1.96)
best_va_idx = np.argmin(mean_va_accs)
best_va_c = cs[best_va_idx]
print 'best c: %.4f'%best_va_c
teXt = features(teX)

tr_accs = []
va_accs = []
te_accs = []
for _ in tqdm(range(100), leave=False, ncols=80):
idxs = np.arange(len(trX))
classes_idxs = [idxs[trY==y] for y in range(10)]
sampled_idxs = [py_rng.sample(class_idxs, 100) for class_idxs in classes_idxs]
sampled_idxs = np.asarray(sampled_idxs).flatten()

trXt = features(trX[sampled_idxs])

model = LSVC(C=best_va_c)
model.fit(trXt[:1000], trY[sampled_idxs])
tr_pred = model.predict(trXt)
va_pred = model.predict(vaXt)
te_pred = model.predict(teXt)
tr_acc = metrics.accuracy_score(trY[sampled_idxs], tr_pred[:1000])
va_acc = metrics.accuracy_score(vaY, va_pred)
te_acc = metrics.accuracy_score(teY, te_pred)
# print '%.4f %.4f %.4f %.4f'%(c, 100*(1-tr_acc), 100*(1-va_acc), 100*(1-te_acc))
tr_accs.append(100*(1-tr_acc))
va_accs.append(100*(1-va_acc))
te_accs.append(100*(1-te_acc))
print 'train: %.4f %.4f valid: %.4f %.4f test: %.4f %.4f'%(np.mean(tr_accs), np.std(tr_accs)*1.96, np.mean(va_accs), np.std(va_accs)*1.96, np.mean(te_accs), np.std(te_accs)*1.96)