Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 41 additions & 35 deletions utils/create_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,11 @@
from __future__ import division
from __future__ import print_function

from io import BytesIO
from io import BytesIO, StringIO
import os
import pickle
import StringIO
import urllib.request as urllib2
import tarfile
import urllib2

import keras.backend as K
from keras.datasets import cifar10
Expand All @@ -37,16 +36,16 @@
import numpy as np
import pandas as pd
from sklearn.datasets import fetch_20newsgroups_vectorized
from sklearn.datasets import fetch_mldata
from sklearn.datasets import fetch_openml
from sklearn.datasets import load_breast_cancer
from sklearn.datasets import load_iris
import sklearn.datasets.rcv1
from sklearn.datasets import fetch_rcv1
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfTransformer

from absl import app
from absl import flags
from tensorflow import gfile
from tensorflow.io import gfile

flags.DEFINE_string('save_dir', '/tmp/data',
'Where to save outputs')
Expand Down Expand Up @@ -189,7 +188,7 @@ def download_file(url):
for member in tar_dir.getnames():
if '_batch' in member:
filestream = tar_dir.extractfile(member).read()
batch = pickle.load(StringIO.StringIO(filestream))
batch = pickle.load(BytesIO(filestream), encoding='latin1')
if X is None:
X = np.array(batch['data'], dtype=np.uint8)
y = np.array(batch['labels'])
Expand All @@ -205,9 +204,9 @@ def get_mldata(dataset):
save_dir = FLAGS.save_dir
filename = os.path.join(save_dir, dataset[1]+'.pkl')

if not gfile.Exists(save_dir):
gfile.MkDir(save_dir)
if not gfile.Exists(filename):
if not gfile.exists(save_dir):
gfile.makedirs(save_dir)
if not gfile.exists(filename):
if dataset[0][-3:] == 'csv':
data = get_csv_data(dataset[0])
elif dataset[0] == 'breast_cancer':
Expand All @@ -217,19 +216,12 @@ def get_mldata(dataset):
elif dataset[0] == 'newsgroup':
# Removing header information to make sure that no newsgroup identifying
# information is included in data
data = fetch_20newsgroups_vectorized(subset='all', remove=('headers'))
data = fetch_20newsgroups_vectorized(subset='all', remove=('headers',))
tfidf = TfidfTransformer(norm='l2')
X = tfidf.fit_transform(data.data)
data.data = X
elif dataset[0] == 'rcv1':
sklearn.datasets.rcv1.URL = (
'http://www.ai.mit.edu/projects/jmlr/papers/'
'volume5/lewis04a/a13-vector-files/lyrl2004_vectors')
sklearn.datasets.rcv1.URL_topics = (
'http://www.ai.mit.edu/projects/jmlr/papers/'
'volume5/lewis04a/a08-topic-qrels/rcv1-v2.topics.qrels.gz')
data = sklearn.datasets.fetch_rcv1(
data_home='/tmp')
data = fetch_rcv1(data_home='/tmp')
elif dataset[0] == 'wikipedia_attack':
data = get_wikipedia_talk_data()
elif dataset[0] == 'cifar10':
Expand All @@ -238,7 +230,7 @@ def get_mldata(dataset):
data = get_keras_data(dataset[0])
else:
try:
data = fetch_mldata(dataset[0])
data = fetch_openml(dataset[0])
except:
raise Exception('ERROR: failed to fetch data from mldata.org')
X = data.data
Expand All @@ -255,21 +247,35 @@ def main(argv):
del argv # Unused.
# First entry of tuple is mldata.org name, second is the name that we'll use
# to reference the data.
datasets = [('mnist (original)', 'mnist'), ('australian', 'australian'),
('heart', 'heart'), ('breast_cancer', 'breast_cancer'),
('iris', 'iris'), ('vehicle', 'vehicle'), ('wine', 'wine'),
('waveform ida', 'waveform'), ('german ida', 'german'),
('splice ida', 'splice'), ('ringnorm ida', 'ringnorm'),
('twonorm ida', 'twonorm'), ('diabetes_scale', 'diabetes'),
('mushrooms', 'mushrooms'), ('letter', 'letter'), ('dna', 'dna'),
('banana-ida', 'banana'), ('letter', 'letter'), ('dna', 'dna'),
('newsgroup', 'newsgroup'), ('cifar10', 'cifar10'),
('cifar10_keras', 'cifar10_keras'),
('cifar100_keras', 'cifar100_keras'),
('cifar100_coarse_keras', 'cifar100_coarse_keras'),
('mnist_keras', 'mnist_keras'),
('wikipedia_attack', 'wikipedia_attack'),
('rcv1', 'rcv1')]
datasets = [
('mnist_784', 'mnist'),
('australian', 'australian'),
('heart', 'heart'),
('breast-cancer', 'breast_cancer'),
('iris', 'iris'),
('vehicle', 'vehicle'),
('wine', 'wine'),
('waveform-5000', 'waveform'),
# ('german ida', 'german'),
('splice', 'splice'),
('ringnorm', 'ringnorm'),
('twonorm', 'twonorm'),
('diabetes', 'diabetes'),
('mushroom', 'mushrooms'),
('letter', 'letter'),
('dna', 'dna'),
('banana', 'banana'),
('letter', 'letter'),
('dna', 'dna'),
('newsgroup', 'newsgroup'),
('cifar10', 'cifar10'),
('cifar10_keras', 'cifar10_keras'),
('cifar100_keras', 'cifar100_keras'),
('cifar100_coarse_keras', 'cifar100_coarse_keras'),
('mnist_keras', 'mnist_keras'),
('wikipedia_attack', 'wikipedia_attack'),
('rcv1', 'rcv1')
]

if FLAGS.datasets:
subset = FLAGS.datasets.split(',')
Expand Down