diff --git a/datasets/Ciao/dataset.pkl b/datasets/Ciao/dataset.pkl new file mode 100644 index 0000000..6e663d2 Binary files /dev/null and b/datasets/Ciao/dataset.pkl differ diff --git a/datasets/Ciao/list.pkl b/datasets/Ciao/list.pkl new file mode 100644 index 0000000..60813d4 Binary files /dev/null and b/datasets/Ciao/list.pkl differ diff --git a/datasets/Epinions/dataset.pkl b/datasets/Epinions/dataset.pkl new file mode 100644 index 0000000..f0cac1e Binary files /dev/null and b/datasets/Epinions/dataset.pkl differ diff --git a/datasets/Epinions/list.pkl b/datasets/Epinions/list.pkl new file mode 100644 index 0000000..a6b3b61 Binary files /dev/null and b/datasets/Epinions/list.pkl differ diff --git a/main.py b/main.py index 54cf6dd..7e12a2d 100644 --- a/main.py +++ b/main.py @@ -1,35 +1,18 @@ -#!/usr/bin/env python37 -# -*- coding: utf-8 -*- -""" -Created on 30 Sep, 2019 - -@author: wangshuo -""" - import os import time +import sys import argparse import pickle import numpy as np -import random -from tqdm import tqdm -from os.path import join - import torch -from torch import nn from torch.utils.data import DataLoader -import torch.nn.functional as F -import torch.optim as optim -from torch.optim.lr_scheduler import StepLR -from torch.autograd import Variable -from torch.backends import cudnn - +from tqdm import tqdm from utils import collate_fn from model import GraphRec from dataloader import GRDataset parser = argparse.ArgumentParser() -parser.add_argument('--dataset_path', default='dataset/Ciao/', help='dataset directory path: datasets/Ciao/Epinions') +parser.add_argument('--dataset_path', default='./datasets/Ciao/', help='dataset directory path: ./datasets/Ciao or ./datasets/Epinions') parser.add_argument('--batch_size', type=int, default=256, help='input batch size') parser.add_argument('--embed_dim', type=int, default=64, help='the dimension of embedding') parser.add_argument('--epoch', type=int, default=30, help='the number of epochs to train for') @@ -44,48 +27,70 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def main(): - print('Loading data...') - with open(args.dataset_path + 'dataset.pkl', 'rb') as f: - train_set = pickle.load(f) - valid_set = pickle.load(f) - test_set = pickle.load(f) - - with open(args.dataset_path + 'list.pkl', 'rb') as f: - u_items_list = pickle.load(f) - u_users_list = pickle.load(f) - u_users_items_list = pickle.load(f) - i_users_list = pickle.load(f) - (user_count, item_count, rate_count) = pickle.load(f) + dataset_path = os.path.abspath(args.dataset_path) + print(f"Loading data from {dataset_path}") + if not os.path.exists(dataset_path): + print(f"Error: Dataset path {dataset_path} does not exist.") + sys.exit(1) + try: + dataset_file = os.path.join(dataset_path, 'dataset.pkl') + print(f"Trying to open {dataset_file}") + with open(dataset_file, 'rb') as f: + train_set = pickle.load(f) + valid_set = pickle.load(f) + test_set = pickle.load(f) + print("Successfully loaded dataset.pkl") + except FileNotFoundError as e: + print(f"Error: {e}") + return + except Exception as e: + print(f"Unexpected error while loading dataset.pkl: {e}") + return + + try: + list_file = os.path.join(dataset_path, 'list.pkl') + print(f"Trying to open {list_file}") + with open(list_file, 'rb') as f: + u_items_list = pickle.load(f) + u_users_list = pickle.load(f) + u_users_items_list = pickle.load(f) + i_users_list = pickle.load(f) + (user_count, item_count, rate_count) = pickle.load(f) + print("Successfully loaded list.pkl") + except FileNotFoundError as e: + print(f"Error: {e}") + return + except Exception as e: + print(f"Unexpected error while loading list.pkl: {e}") + return + train_data = GRDataset(train_set, u_items_list, u_users_list, u_users_items_list, i_users_list) valid_data = GRDataset(valid_set, u_items_list, u_users_list, u_users_items_list, i_users_list) test_data = GRDataset(test_set, u_items_list, u_users_list, u_users_items_list, i_users_list) - train_loader = DataLoader(train_data, batch_size = args.batch_size, shuffle = True, collate_fn = collate_fn) - valid_loader = DataLoader(valid_data, batch_size = args.batch_size, shuffle = False, collate_fn = collate_fn) - test_loader = DataLoader(test_data, batch_size = args.batch_size, shuffle = False, collate_fn = collate_fn) + train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn) + valid_loader = DataLoader(valid_data, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn) + test_loader = DataLoader(test_data, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn) model = GraphRec(user_count+1, item_count+1, rate_count+1, args.embed_dim).to(device) if args.test: print('Load checkpoint and testing...') - ckpt = torch.load('best_checkpoint.pth.tar') + ckpt = torch.load('best_checkpoint.pth.tar', map_location=device) model.load_state_dict(ckpt['state_dict']) mae, rmse = validate(test_loader, model) print("Test: MAE: {:.4f}, RMSE: {:.4f}".format(mae, rmse)) return - optimizer = optim.RMSprop(model.parameters(), args.lr) - criterion = nn.MSELoss() - scheduler = StepLR(optimizer, step_size = args.lr_dc_step, gamma = args.lr_dc) + optimizer = torch.optim.RMSprop(model.parameters(), args.lr) + criterion = torch.nn.MSELoss() + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_dc_step, gamma=args.lr_dc) for epoch in tqdm(range(args.epoch)): - # train for one epoch - scheduler.step(epoch = epoch) - trainForEpoch(train_loader, model, optimizer, epoch, args.epoch, criterion, log_aggr = 100) - + scheduler.step(epoch=epoch) + trainForEpoch(train_loader, model, optimizer, epoch, args.epoch, criterion, log_aggr=100) mae, rmse = validate(valid_loader, model) - - # store best loss and save a model checkpoint + ckpt_dict = { 'epoch': epoch + 1, 'state_dict': model.state_dict(), @@ -102,12 +107,9 @@ def main(): print('Epoch {} validation: MAE: {:.4f}, RMSE: {:.4f}, Best MAE: {:.4f}'.format(epoch, mae, rmse, best_mae)) - def trainForEpoch(train_loader, model, optimizer, epoch, num_epochs, criterion, log_aggr=1): model.train() - sum_epoch_loss = 0 - start = time.time() for i, (uids, iids, labels, u_items, u_users, u_users_items, i_users) in tqdm(enumerate(train_loader), total=len(train_loader)): uids = uids.to(device) @@ -120,11 +122,9 @@ def trainForEpoch(train_loader, model, optimizer, epoch, num_epochs, criterion, optimizer.zero_grad() outputs = model(uids, iids, u_items, u_users, u_users_items, i_users) - loss = criterion(outputs, labels.unsqueeze(1)) loss.backward() - optimizer.step() - + optimizer.step() loss_val = loss.item() sum_epoch_loss += loss_val @@ -132,12 +132,10 @@ def trainForEpoch(train_loader, model, optimizer, epoch, num_epochs, criterion, if i % log_aggr == 0: print('[TRAIN] epoch %d/%d batch loss: %.4f (avg %.4f) (%.2f im/s)' - % (epoch + 1, num_epochs, loss_val, sum_epoch_loss / (i + 1), - len(uids) / (time.time() - start))) + % (epoch + 1, num_epochs, loss_val, sum_epoch_loss / (i + 1), len(uids) / (time.time() - start))) start = time.time() - def validate(valid_loader, model): model.eval() errors = [] @@ -158,6 +156,5 @@ def validate(valid_loader, model): rmse = np.sqrt(np.mean(np.power(errors, 2))) return mae, rmse - if __name__ == '__main__': main() diff --git a/requirements.txt b/requirements.txt index d587495..6590a67 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,99 +1,17 @@ -alabaster==0.7.12 -asn1crypto==0.24.0 -astroid==2.3.1 -attrs==19.2.0 -Babel==2.7.0 -backcall==0.1.0 -bleach==3.1.0 -certifi==2019.9.11 -cffi==1.12.3 -chardet==3.0.4 -cloudpickle==1.2.2 -cryptography==2.7 -decorator==4.4.0 -defusedxml==0.6.0 -docutils==0.15.2 -entrypoints==0.3 -idna==2.8 -imagesize==1.1.0 -ipykernel==5.1.2 -ipython==7.8.0 -ipython-genutils==0.2.0 -ipywidgets==7.5.1 -isort==4.3.21 -jedi==0.15.1 -jeepney==0.4.1 -Jinja2==2.10.1 -json5==0.8.5 -jsonschema==3.0.2 -jupyter==1.0.0 -jupyter-client==5.3.3 -jupyter-console==6.0.0 -jupyter-core==4.5.0 -jupyterlab==1.1.4 -jupyterlab-server==1.0.6 -keyring==18.0.0 -lazy-object-proxy==1.4.2 -MarkupSafe==1.1.1 -mccabe==0.6.1 -mistune==0.8.4 -nbconvert==5.6.0 -nbformat==4.4.0 -notebook==6.0.1 -numpy==1.17.2 -numpydoc==0.9.1 -packaging==19.2 -pandas==0.25.1 -pandocfilters==1.4.2 -parso==0.5.1 -pexpect==4.7.0 -pickleshare==0.7.5 -Pillow==6.2.0 -prometheus-client==0.7.1 -prompt-toolkit==2.0.10 -psutil==5.6.3 -ptyprocess==0.6.0 -pycodestyle==2.5.0 -pycparser==2.19 -pyflakes==2.1.1 -Pygments==2.4.2 -pylint==2.4.2 -pyOpenSSL==19.0.0 -pyparsing==2.4.2 -pyrsistent==0.15.4 -PySocks==1.7.1 -python-dateutil==2.8.0 -pytz==2019.2 -pyzmq==18.1.0 -QtAwesome==0.6.0 -qtconsole==4.5.5 -QtPy==1.9.0 -requests==2.22.0 -rope==0.14.0 -scipy==1.3.1 -SecretStorage==3.1.1 -Send2Trash==1.5.0 -six==1.12.0 -snowballstemmer==1.9.1 -Sphinx==2.2.0 -sphinxcontrib-applehelp==1.0.1 -sphinxcontrib-devhelp==1.0.1 -sphinxcontrib-htmlhelp==1.0.2 -sphinxcontrib-jsmath==1.0.1 -sphinxcontrib-qthelp==1.0.2 -sphinxcontrib-serializinghtml==1.1.3 -spyder==3.3.6 -spyder-kernels==0.5.2 -terminado==0.8.2 -testpath==0.4.2 -torch==1.1.0 -torchvision==0.3.0 -tornado==6.0.3 -tqdm==4.36.1 -traitlets==4.3.3 -urllib3==1.24.2 -wcwidth==0.1.7 -webencodings==0.5.1 -widgetsnbextension==3.5.1 -wrapt==1.11.2 -wurlitzer==1.0.3 +filelock==3.15.4 +fsspec==2024.6.1 +Jinja2==3.1.4 +MarkupSafe==2.1.5 +mpmath==1.3.0 +networkx==3.3 +numpy==1.26.4 +pandas==2.2.2 +python-dateutil==2.9.0.post0 +pytz==2024.1 +scipy==1.14.0 +six==1.16.0 +sympy==1.12.1 +torch==2.2.2 +tqdm==4.66.4 +typing_extensions==4.12.2 +tzdata==2024.1