From db50520b221a66849320f2893c05e3c079c685ff Mon Sep 17 00:00:00 2001 From: KimHyeon-Ji <72793869+KimHyeon-Ji@users.noreply.github.com> Date: Mon, 28 Nov 2022 22:05:55 +0900 Subject: [PATCH] Update main_regression.py --- main_regression.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/main_regression.py b/main_regression.py index 728e51a..cbcda56 100644 --- a/main_regression.py +++ b/main_regression.py @@ -1,6 +1,7 @@ import torch import pandas as pd import numpy as np +import os from sklearn.metrics import mean_squared_error, mean_absolute_error from models.train_model import Train_Test @@ -139,6 +140,9 @@ def save_model(self, best_model, best_model_path): :param best_model_path: path for saving model :type best_model_path: str """ + + # make folder to save model + os.makedirs('./ckpt', exist_ok=True) # save model torch.save(best_model.state_dict(), best_model_path) @@ -224,4 +228,4 @@ def get_dataloader(self, x_data, y_data, batch_size, shuffle): # DataLoader 구축 data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle) - return data_loader \ No newline at end of file + return data_loader