Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added datasets/Ciao/dataset.pkl
Binary file not shown.
Binary file added datasets/Ciao/list.pkl
Binary file not shown.
Binary file added datasets/Epinions/dataset.pkl
Binary file not shown.
Binary file added datasets/Epinions/list.pkl
Binary file not shown.
107 changes: 52 additions & 55 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,18 @@
#!/usr/bin/env python37
# -*- coding: utf-8 -*-
"""
Created on 30 Sep, 2019

@author: wangshuo
"""

import os
import time
import sys
import argparse
import pickle
import numpy as np
import random
from tqdm import tqdm
from os.path import join

import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.autograd import Variable
from torch.backends import cudnn

from tqdm import tqdm
from utils import collate_fn
from model import GraphRec
from dataloader import GRDataset

parser = argparse.ArgumentParser()
parser.add_argument('--dataset_path', default='dataset/Ciao/', help='dataset directory path: datasets/Ciao/Epinions')
parser.add_argument('--dataset_path', default='./datasets/Ciao/', help='dataset directory path: ./datasets/Ciao or ./datasets/Epinions')
parser.add_argument('--batch_size', type=int, default=256, help='input batch size')
parser.add_argument('--embed_dim', type=int, default=64, help='the dimension of embedding')
parser.add_argument('--epoch', type=int, default=30, help='the number of epochs to train for')
Expand All @@ -44,48 +27,70 @@
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def main():
print('Loading data...')
with open(args.dataset_path + 'dataset.pkl', 'rb') as f:
train_set = pickle.load(f)
valid_set = pickle.load(f)
test_set = pickle.load(f)

with open(args.dataset_path + 'list.pkl', 'rb') as f:
u_items_list = pickle.load(f)
u_users_list = pickle.load(f)
u_users_items_list = pickle.load(f)
i_users_list = pickle.load(f)
(user_count, item_count, rate_count) = pickle.load(f)
dataset_path = os.path.abspath(args.dataset_path)
print(f"Loading data from {dataset_path}")
if not os.path.exists(dataset_path):
print(f"Error: Dataset path {dataset_path} does not exist.")
sys.exit(1)

try:
dataset_file = os.path.join(dataset_path, 'dataset.pkl')
print(f"Trying to open {dataset_file}")
with open(dataset_file, 'rb') as f:
train_set = pickle.load(f)
valid_set = pickle.load(f)
test_set = pickle.load(f)
print("Successfully loaded dataset.pkl")
except FileNotFoundError as e:
print(f"Error: {e}")
return
except Exception as e:
print(f"Unexpected error while loading dataset.pkl: {e}")
return

try:
list_file = os.path.join(dataset_path, 'list.pkl')
print(f"Trying to open {list_file}")
with open(list_file, 'rb') as f:
u_items_list = pickle.load(f)
u_users_list = pickle.load(f)
u_users_items_list = pickle.load(f)
i_users_list = pickle.load(f)
(user_count, item_count, rate_count) = pickle.load(f)
print("Successfully loaded list.pkl")
except FileNotFoundError as e:
print(f"Error: {e}")
return
except Exception as e:
print(f"Unexpected error while loading list.pkl: {e}")
return

train_data = GRDataset(train_set, u_items_list, u_users_list, u_users_items_list, i_users_list)
valid_data = GRDataset(valid_set, u_items_list, u_users_list, u_users_items_list, i_users_list)
test_data = GRDataset(test_set, u_items_list, u_users_list, u_users_items_list, i_users_list)
train_loader = DataLoader(train_data, batch_size = args.batch_size, shuffle = True, collate_fn = collate_fn)
valid_loader = DataLoader(valid_data, batch_size = args.batch_size, shuffle = False, collate_fn = collate_fn)
test_loader = DataLoader(test_data, batch_size = args.batch_size, shuffle = False, collate_fn = collate_fn)
train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn)
valid_loader = DataLoader(valid_data, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_data, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn)

model = GraphRec(user_count+1, item_count+1, rate_count+1, args.embed_dim).to(device)

if args.test:
print('Load checkpoint and testing...')
ckpt = torch.load('best_checkpoint.pth.tar')
ckpt = torch.load('best_checkpoint.pth.tar', map_location=device)
model.load_state_dict(ckpt['state_dict'])
mae, rmse = validate(test_loader, model)
print("Test: MAE: {:.4f}, RMSE: {:.4f}".format(mae, rmse))
return

optimizer = optim.RMSprop(model.parameters(), args.lr)
criterion = nn.MSELoss()
scheduler = StepLR(optimizer, step_size = args.lr_dc_step, gamma = args.lr_dc)
optimizer = torch.optim.RMSprop(model.parameters(), args.lr)
criterion = torch.nn.MSELoss()
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_dc_step, gamma=args.lr_dc)

for epoch in tqdm(range(args.epoch)):
# train for one epoch
scheduler.step(epoch = epoch)
trainForEpoch(train_loader, model, optimizer, epoch, args.epoch, criterion, log_aggr = 100)

scheduler.step(epoch=epoch)
trainForEpoch(train_loader, model, optimizer, epoch, args.epoch, criterion, log_aggr=100)
mae, rmse = validate(valid_loader, model)

# store best loss and save a model checkpoint

ckpt_dict = {
'epoch': epoch + 1,
'state_dict': model.state_dict(),
Expand All @@ -102,12 +107,9 @@ def main():

print('Epoch {} validation: MAE: {:.4f}, RMSE: {:.4f}, Best MAE: {:.4f}'.format(epoch, mae, rmse, best_mae))


def trainForEpoch(train_loader, model, optimizer, epoch, num_epochs, criterion, log_aggr=1):
model.train()

sum_epoch_loss = 0

start = time.time()
for i, (uids, iids, labels, u_items, u_users, u_users_items, i_users) in tqdm(enumerate(train_loader), total=len(train_loader)):
uids = uids.to(device)
Expand All @@ -120,24 +122,20 @@ def trainForEpoch(train_loader, model, optimizer, epoch, num_epochs, criterion,

optimizer.zero_grad()
outputs = model(uids, iids, u_items, u_users, u_users_items, i_users)

loss = criterion(outputs, labels.unsqueeze(1))
loss.backward()
optimizer.step()

optimizer.step()
loss_val = loss.item()
sum_epoch_loss += loss_val

iter_num = epoch * len(train_loader) + i + 1

if i % log_aggr == 0:
print('[TRAIN] epoch %d/%d batch loss: %.4f (avg %.4f) (%.2f im/s)'
% (epoch + 1, num_epochs, loss_val, sum_epoch_loss / (i + 1),
len(uids) / (time.time() - start)))
% (epoch + 1, num_epochs, loss_val, sum_epoch_loss / (i + 1), len(uids) / (time.time() - start)))

start = time.time()


def validate(valid_loader, model):
model.eval()
errors = []
Expand All @@ -158,6 +156,5 @@ def validate(valid_loader, model):
rmse = np.sqrt(np.mean(np.power(errors, 2)))
return mae, rmse


if __name__ == '__main__':
main()
116 changes: 17 additions & 99 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,99 +1,17 @@
alabaster==0.7.12
asn1crypto==0.24.0
astroid==2.3.1
attrs==19.2.0
Babel==2.7.0
backcall==0.1.0
bleach==3.1.0
certifi==2019.9.11
cffi==1.12.3
chardet==3.0.4
cloudpickle==1.2.2
cryptography==2.7
decorator==4.4.0
defusedxml==0.6.0
docutils==0.15.2
entrypoints==0.3
idna==2.8
imagesize==1.1.0
ipykernel==5.1.2
ipython==7.8.0
ipython-genutils==0.2.0
ipywidgets==7.5.1
isort==4.3.21
jedi==0.15.1
jeepney==0.4.1
Jinja2==2.10.1
json5==0.8.5
jsonschema==3.0.2
jupyter==1.0.0
jupyter-client==5.3.3
jupyter-console==6.0.0
jupyter-core==4.5.0
jupyterlab==1.1.4
jupyterlab-server==1.0.6
keyring==18.0.0
lazy-object-proxy==1.4.2
MarkupSafe==1.1.1
mccabe==0.6.1
mistune==0.8.4
nbconvert==5.6.0
nbformat==4.4.0
notebook==6.0.1
numpy==1.17.2
numpydoc==0.9.1
packaging==19.2
pandas==0.25.1
pandocfilters==1.4.2
parso==0.5.1
pexpect==4.7.0
pickleshare==0.7.5
Pillow==6.2.0
prometheus-client==0.7.1
prompt-toolkit==2.0.10
psutil==5.6.3
ptyprocess==0.6.0
pycodestyle==2.5.0
pycparser==2.19
pyflakes==2.1.1
Pygments==2.4.2
pylint==2.4.2
pyOpenSSL==19.0.0
pyparsing==2.4.2
pyrsistent==0.15.4
PySocks==1.7.1
python-dateutil==2.8.0
pytz==2019.2
pyzmq==18.1.0
QtAwesome==0.6.0
qtconsole==4.5.5
QtPy==1.9.0
requests==2.22.0
rope==0.14.0
scipy==1.3.1
SecretStorage==3.1.1
Send2Trash==1.5.0
six==1.12.0
snowballstemmer==1.9.1
Sphinx==2.2.0
sphinxcontrib-applehelp==1.0.1
sphinxcontrib-devhelp==1.0.1
sphinxcontrib-htmlhelp==1.0.2
sphinxcontrib-jsmath==1.0.1
sphinxcontrib-qthelp==1.0.2
sphinxcontrib-serializinghtml==1.1.3
spyder==3.3.6
spyder-kernels==0.5.2
terminado==0.8.2
testpath==0.4.2
torch==1.1.0
torchvision==0.3.0
tornado==6.0.3
tqdm==4.36.1
traitlets==4.3.3
urllib3==1.24.2
wcwidth==0.1.7
webencodings==0.5.1
widgetsnbextension==3.5.1
wrapt==1.11.2
wurlitzer==1.0.3
filelock==3.15.4
fsspec==2024.6.1
Jinja2==3.1.4
MarkupSafe==2.1.5
mpmath==1.3.0
networkx==3.3
numpy==1.26.4
pandas==2.2.2
python-dateutil==2.9.0.post0
pytz==2024.1
scipy==1.14.0
six==1.16.0
sympy==1.12.1
torch==2.2.2
tqdm==4.66.4
typing_extensions==4.12.2
tzdata==2024.1