From b116b3d870342f9e0cd716c0c006be30e7d257e3 Mon Sep 17 00:00:00 2001 From: Suresh Kumar Maddala Date: Sun, 15 Jun 2025 22:45:51 -0500 Subject: [PATCH 1/2] update code to downlaod datasets in the python3 --- utils/create_data.py | 76 ++++++++++++++++++++++++-------------------- 1 file changed, 41 insertions(+), 35 deletions(-) diff --git a/utils/create_data.py b/utils/create_data.py index 07e23ce..c8ff3e6 100644 --- a/utils/create_data.py +++ b/utils/create_data.py @@ -22,12 +22,11 @@ from __future__ import division from __future__ import print_function -from io import BytesIO +from io import BytesIO, StringIO import os import pickle -import StringIO +import urllib.request as urllib2 import tarfile -import urllib2 import keras.backend as K from keras.datasets import cifar10 @@ -37,16 +36,16 @@ import numpy as np import pandas as pd from sklearn.datasets import fetch_20newsgroups_vectorized -from sklearn.datasets import fetch_mldata +from sklearn.datasets import fetch_openml from sklearn.datasets import load_breast_cancer from sklearn.datasets import load_iris -import sklearn.datasets.rcv1 +from sklearn.datasets import fetch_rcv1 from sklearn.feature_extraction.text import CountVectorizer from sklearn.feature_extraction.text import TfidfTransformer from absl import app from absl import flags -from tensorflow import gfile +from tensorflow.io import gfile flags.DEFINE_string('save_dir', '/tmp/data', 'Where to save outputs') @@ -189,7 +188,7 @@ def download_file(url): for member in tar_dir.getnames(): if '_batch' in member: filestream = tar_dir.extractfile(member).read() - batch = pickle.load(StringIO.StringIO(filestream)) + batch = pickle.load(BytesIO(filestream), encoding='latin1') if X is None: X = np.array(batch['data'], dtype=np.uint8) y = np.array(batch['labels']) @@ -205,9 +204,9 @@ def get_mldata(dataset): save_dir = FLAGS.save_dir filename = os.path.join(save_dir, dataset[1]+'.pkl') - if not gfile.Exists(save_dir): - gfile.MkDir(save_dir) - if not gfile.Exists(filename): + if not gfile.exists(save_dir): + gfile.makedirs(save_dir) + if not gfile.exists(filename): if dataset[0][-3:] == 'csv': data = get_csv_data(dataset[0]) elif dataset[0] == 'breast_cancer': @@ -217,19 +216,12 @@ def get_mldata(dataset): elif dataset[0] == 'newsgroup': # Removing header information to make sure that no newsgroup identifying # information is included in data - data = fetch_20newsgroups_vectorized(subset='all', remove=('headers')) + data = fetch_20newsgroups_vectorized(subset='all', remove=('headers',)) tfidf = TfidfTransformer(norm='l2') X = tfidf.fit_transform(data.data) data.data = X elif dataset[0] == 'rcv1': - sklearn.datasets.rcv1.URL = ( - 'http://www.ai.mit.edu/projects/jmlr/papers/' - 'volume5/lewis04a/a13-vector-files/lyrl2004_vectors') - sklearn.datasets.rcv1.URL_topics = ( - 'http://www.ai.mit.edu/projects/jmlr/papers/' - 'volume5/lewis04a/a08-topic-qrels/rcv1-v2.topics.qrels.gz') - data = sklearn.datasets.fetch_rcv1( - data_home='/tmp') + data = fetch_rcv1(data_home='/tmp') elif dataset[0] == 'wikipedia_attack': data = get_wikipedia_talk_data() elif dataset[0] == 'cifar10': @@ -238,7 +230,7 @@ def get_mldata(dataset): data = get_keras_data(dataset[0]) else: try: - data = fetch_mldata(dataset[0]) + data = fetch_openml(dataset[0]) except: raise Exception('ERROR: failed to fetch data from mldata.org') X = data.data @@ -255,21 +247,35 @@ def main(argv): del argv # Unused. # First entry of tuple is mldata.org name, second is the name that we'll use # to reference the data. - datasets = [('mnist (original)', 'mnist'), ('australian', 'australian'), - ('heart', 'heart'), ('breast_cancer', 'breast_cancer'), - ('iris', 'iris'), ('vehicle', 'vehicle'), ('wine', 'wine'), - ('waveform ida', 'waveform'), ('german ida', 'german'), - ('splice ida', 'splice'), ('ringnorm ida', 'ringnorm'), - ('twonorm ida', 'twonorm'), ('diabetes_scale', 'diabetes'), - ('mushrooms', 'mushrooms'), ('letter', 'letter'), ('dna', 'dna'), - ('banana-ida', 'banana'), ('letter', 'letter'), ('dna', 'dna'), - ('newsgroup', 'newsgroup'), ('cifar10', 'cifar10'), - ('cifar10_keras', 'cifar10_keras'), - ('cifar100_keras', 'cifar100_keras'), - ('cifar100_coarse_keras', 'cifar100_coarse_keras'), - ('mnist_keras', 'mnist_keras'), - ('wikipedia_attack', 'wikipedia_attack'), - ('rcv1', 'rcv1')] + datasets = [ + ('mnist_784', 'mnist'), + ('australian', 'australian'), + ('heart', 'heart'), + ('breast-cancer', 'breast_cancer'), + ('iris', 'iris'), + ('vehicle', 'vehicle'), + ('wine', 'wine'), + ('waveform-5000', 'waveform'), + # ('31', 'german'), + ('splice', 'splice'), + ('ringnorm', 'ringnorm'), + ('twonorm', 'twonorm'), + ('diabetes', 'diabetes'), + ('mushroom', 'mushrooms'), + ('letter', 'letter'), + ('dna', 'dna'), + ('banana', 'banana'), + ('letter', 'letter'), + ('dna', 'dna'), + ('newsgroup', 'newsgroup'), + ('cifar10', 'cifar10'), + ('cifar10_keras', 'cifar10_keras'), + ('cifar100_keras', 'cifar100_keras'), + ('cifar100_coarse_keras', 'cifar100_coarse_keras'), + ('mnist_keras', 'mnist_keras'), + ('wikipedia_attack', 'wikipedia_attack'), + ('rcv1', 'rcv1') + ] if FLAGS.datasets: subset = FLAGS.datasets.split(',') From 030f45410d90585aeab2174e34f79c0e5b592832 Mon Sep 17 00:00:00 2001 From: Suresh Kumar Maddala Date: Sun, 15 Jun 2025 23:08:50 -0500 Subject: [PATCH 2/2] chnages to download data in python3 --- utils/create_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/create_data.py b/utils/create_data.py index c8ff3e6..3ce7d8a 100644 --- a/utils/create_data.py +++ b/utils/create_data.py @@ -256,7 +256,7 @@ def main(argv): ('vehicle', 'vehicle'), ('wine', 'wine'), ('waveform-5000', 'waveform'), - # ('31', 'german'), + # ('german ida', 'german'), ('splice', 'splice'), ('ringnorm', 'ringnorm'), ('twonorm', 'twonorm'),