diff --git a/.gitignore b/.gitignore index 9006844..8292d48 100644 --- a/.gitignore +++ b/.gitignore @@ -245,3 +245,4 @@ c++/mcts backup/ old/ saves/ +.idea/ diff --git a/alphatsp/experiments/supervised.py b/alphatsp/experiments/supervised.py new file mode 100644 index 0000000..a83c168 --- /dev/null +++ b/alphatsp/experiments/supervised.py @@ -0,0 +1,67 @@ +import alphatsp.tsp +import alphatsp.util + +import alphatsp.solvers.policy_solvers +from alphatsp.solvers.example_generators import NNExampleGenerator +from alphatsp.solvers.policy_networks import SupervisedPolicyNetworkTrainer + +import torch +import numpy as np + +import copy + +from torch.multiprocessing import Process, Manager + + +def run(args): + + # setup + N, D = args.N, args.D + n_examples = args.n_train_examples + n_threads = args.n_threads + n_examples_per_thread = n_examples//n_threads + + # create policy network + policy_network = alphatsp.util.get_policy_network(args.policy_network) + + # generate examples + print("Generating examples and training...") + + manager = Manager() + train_queue = manager.Queue() + shared_dict = manager.dict() + + shared_dict["success"] = False + + producers = [] + for _ in range(n_threads): + producers.append(Process(target=generate_examples, args=(n_examples_per_thread, train_queue, args))) + + for p in producers: + p.start() + + c = Process(target=train, args=(policy_network, train_queue, shared_dict, args)) + c.start() + + for p in producers: + p.join() + train_queue.put(None) + + c.join() + + status = shared_dict["success"] + if not status: + print("Experiment failed.") + return -1 + +def generate_examples(n_examples, train_queue, args): + generator = NNExampleGenerator(train_queue, args) + generator.generate_examples(n_examples) + return + +def train(policy_network, train_queue, shared_dict, args): + trainer = SupervisedPolicyNetworkTrainer(policy_network, train_queue) + trainer.train_all() + shared_dict["model"] = copy.deepcopy(trainer.model.cpu()) + shared_dict["success"] = True + return diff --git a/alphatsp/logger.py b/alphatsp/logger.py new file mode 100644 index 0000000..f3a3251 --- /dev/null +++ b/alphatsp/logger.py @@ -0,0 +1,80 @@ +import os +import datetime +import csv +import shutil +import torch +import numpy as np +import pandas as pd + +import matplotlib +matplotlib.use("agg") +import matplotlib.pyplot as plt +plt.style.use("seaborn") + +class Logger: + + def __init__(self, args, enabled=True): + self.logging = enabled + if not self.logging: + return + + self.dt = datetime.datetime.now().strftime("%m%d_%H%M") + self.path = f"./saves/{self.dt}" + + if not os.path.exists(self.path): + os.makedirs(self.path) + + self.losses = [] + self.eval = [] + + self.main_log_fn = os.path.join(self.path, "log.txt") + shutil.copy2("args.py", self.path) + + def save_model(self, model, iterations): + if not self.logging: return + if isinstance(iterations, int): + fn = os.path.join(self.path, f"policynet_{epoch:07d}.pth") + else: + fn = os.path.join(self.path, f"policynet_{epoch}.pth") + torch.save(model.state_dict(), fn) + self.print(f"Saved model to: {fn}\n") + + def print(self, *x): + print(*x) + self.log(*x) + + def log(self, *x): + if not self.logging: return + with open(self.main_log_fn, "a") as f: + print(*x, file=f, flush=True) + + def log_loss(self, l): + self.losses.append(l) + + def log_eval(self, data): + if not self.logging: return + self.eval.append(data) + + def save(self): + if not self.logging: return + + with open(os.path.join(self.path, "loss.csv"), "w") as f: + csvwriter = csv.DictWriter(f, ["it", "loss"]) + csvwriter.writeheader() + for it, loss in enumerate(self.losses): + row = {"it": it, "loss": loss} + csvwriter.writerow(row) + + with open(os.path.join(self.path, "eval.csv"), "w") as f: + cols = ["it"] + sorted(list(set(self.eval[0].keys()) - set(["it"]))) + csvwriter = csv.DictWriter(f, cols) + csvwriter.writeheader() + for row in self.eval_scores: + csvwriter.writerow(row) + + plt.clf() + + plt.plot(self.losses) + plt.xlabel("iterations") + plt.ylabel("training loss") + plt.savefig(os.path.join(self.path, "losses.png")) diff --git a/alphatsp/solvers/example_generators.py b/alphatsp/solvers/example_generators.py index d28529f..0359628 100644 --- a/alphatsp/solvers/example_generators.py +++ b/alphatsp/solvers/example_generators.py @@ -1,6 +1,10 @@ import torch import copy +import random +from alphatsp.tsp import TSP from alphatsp.solvers.mcts import MCTSNode, MCTSTree +from alphatsp.solvers import heuristics +from alphatsp.util import get_graph_constructor class MCTSExampleGenerator: @@ -82,3 +86,38 @@ def solve(self): mcts_payoff = self.tsp.tour_length(mcts_tour) return mcts_tour, mcts_payoff + +class NNExampleGenerator: + + def __init__(self, example_queue, args): + self.args = args + self.graph_constructor = get_graph_constructor(args.graph_construction) + self.example_queue = example_queue + self.n_samples = max(args.N//10, 1) + + def generate_examples(self, n_examples): + + for _ in range(n_examples//self.n_samples): + + # generate tsp + tsp = TSP(self.args.N, self.args.D) + + # solve + tour, tour_len = heuristics.nearest_greedy(tsp) + + # generate examples + remaining = set(range(self.args.N)) + for i in sorted(random.sample(range(self.args.N-1), self.n_samples)): + + partial_tour = tour[:i] + remaining = remaining - set(partial_tour) + r = sorted(list(remaining)) + + graph = self.graph_constructor(tsp, partial_tour, r) + + example = { + "graph": graph, + "choice": r.index(tour[i+1]), + "pred_value": tour_len + } + self.example_queue.put(example) diff --git a/alphatsp/solvers/graph_construction.py b/alphatsp/solvers/graph_construction.py index ce5fba6..82217b0 100644 --- a/alphatsp/solvers/graph_construction.py +++ b/alphatsp/solvers/graph_construction.py @@ -17,7 +17,7 @@ def construct_graph_grow(tsp, tour, remaining): x = torch.cat([points, choices.unsqueeze(-1).to(dtype=torch.float)], dim=-1) - graph = Data(x=x, pos=points, edge_index=edges, y=choices) + graph = {"x": x, "pos": points, "edge_index": edges, "edge_attr": edge_lengths, "y": choices} return graph def construct_graph_prune(tsp, tour, remaining): @@ -40,7 +40,7 @@ def construct_graph_prune(tsp, tour, remaining): x = torch.cat([points, choices.unsqueeze(-1).to(dtype=torch.float)], dim=-1) - graph = Data(x=x, pos=points, edge_index=edges, y=choices) + graph = {"x": x, "pos": points, "edge_index": edges, "edge_attr": edge_lengths, "y": choices} return graph def construct_graph_prune_weighted(tsp, tour, remaining): @@ -71,5 +71,5 @@ def construct_graph_prune_weighted(tsp, tour, remaining): x = torch.cat([points, choices.unsqueeze(-1).to(dtype=torch.float)], dim=-1) - graph = Data(x=x, pos=points, edge_index=edges, edge_attr=edge_lengths, y=choices) - return graph \ No newline at end of file + graph = {"x": x, "pos": points, "edge_index": edges, "edge_attr": edge_lengths, "y": choices} + return graph diff --git a/alphatsp/solvers/policy_networks.py b/alphatsp/solvers/policy_networks.py index 969bfb4..5d0dcd9 100644 --- a/alphatsp/solvers/policy_networks.py +++ b/alphatsp/solvers/policy_networks.py @@ -6,6 +6,12 @@ from torch_geometric.nn import GCNConv, global_mean_pool, ARMAConv, XConv, SAGEConv from torch_geometric.data import Data, DataLoader +from alphatsp.logger import Logger + +if torch.cuda.is_available(): device = torch.device("cuda:0") +else: device = torch.device("cpu") + + class GCNPolicyNetwork(nn.Module): def __init__(self, d=3): super(GCNPolicyNetwork, self).__init__() @@ -26,7 +32,7 @@ def forward(self, graph): choice = torch.masked_select(c.squeeze(), choices) choice = F.softmax(choice, dim=0) - v = global_mean_pool(x, torch.zeros(graph.num_nodes, dtype=torch.long)) + v = global_mean_pool(x, torch.zeros(graph.num_nodes, dtype=torch.long, device=x.device)) value = self.fc(v) return choice, value @@ -59,12 +65,12 @@ def __init__(self, d=3): num_layers=2, shared_weights=True, dropout=0.1, - act=None) + act=None).to(device) - self.fc = nn.Linear(16, 1) + self.fc = nn.Linear(16, 1).to(device) def forward(self, graph): - x, edges, choices = graph.x, graph.edge_index, graph.y + x, edges, choices = graph['x'], graph['edge_index'], graph['y'] x = self.conv1(x, edges) x = F.relu(x) @@ -75,7 +81,7 @@ def forward(self, graph): choice = torch.masked_select(c.squeeze(), choices) choice = F.softmax(choice, dim=0) - v = global_mean_pool(x, torch.zeros(graph.num_nodes, dtype=torch.long)) + v = global_mean_pool(x, torch.zeros(x.size(0), dtype=torch.long, device=x.device)) value = self.fc(v) return choice, value @@ -100,7 +106,7 @@ def forward(self, graph): choice = torch.masked_select(c.squeeze(), choices) choice = F.softmax(choice, dim=0) - v = global_mean_pool(x, torch.zeros(graph.num_nodes, dtype=torch.long)) + v = global_mean_pool(x, torch.zeros(graph.num_nodes, dtype=torch.long, device=x.device)) value = self.fc(v) return choice, value @@ -125,7 +131,7 @@ def forward(self, graph): choice = torch.masked_select(c.squeeze(), choices) choice = F.softmax(choice, dim=0) - v = global_mean_pool(x, torch.zeros(graph.num_nodes, dtype=torch.long)) + v = global_mean_pool(x, torch.zeros(graph.num_nodes, dtype=torch.long, device=x.device)) value = self.fc(v) return choice, value @@ -150,7 +156,7 @@ def forward(self, graph): choice = torch.masked_select(c.squeeze(), choices) choice = F.softmax(choice, dim=0) - v = global_mean_pool(x, torch.zeros(graph.num_nodes, dtype=torch.long)) + v = global_mean_pool(x, torch.zeros(graph.num_nodes, dtype=torch.long, device=x.device)) value = self.fc(v) return choice, value @@ -174,6 +180,9 @@ def train_example(self): if example is None: return -1 graph, choice_probs, value = example["graph"], example["choice_probs"], example["pred_value"] + graph = Data(**graph) + graph = graph.to(device) + pred_choices, pred_value = self.model(graph) loss = self.loss_fn(pred_choices, choice_probs) + (0.2 * self.loss_fn(pred_value, value)) @@ -194,4 +203,54 @@ def train_all(self): return 0 def save_model(self): - torch.save(self.model.state_dict(), f"saves/policynet_{self.n_examples_used:06d}.pth") \ No newline at end of file + torch.save(self.model.state_dict(), f"saves/policynet_{self.n_examples_used:06d}.pth") + +class SupervisedPolicyNetworkTrainer: + + def __init__(self, model, example_queue): + + self.model = model.to(device) + self.value_loss_fn = nn.MSELoss() + self.choice_loss_fn = nn.CrossEntropyLoss() + self.optimizer = torch.optim.Adam(params=self.model.parameters(), lr=1e-5) + + self.example_queue = example_queue + self.n_examples_used = 0 + + self.logger = Logger() + + def train_all(self): + while True: + if not self.example_queue.empty(): + return_code = self.train_example() + if self.n_examples_used%1000 == 0 and self.n_examples_used!=0: + self.logger.print(f"iter={self.n_examples_used}, avg_loss={sum(self.logger.losses[-100:])/100:.4f}") + if self.n_examples_used%10000 == 0 and self.n_examples_used!=0: + self.logger.save_model(self.model, self.n_examples_used) + if return_code == -1: + self.logger.save() + self.logger.save_model(self.model, "final") + return + + def train_example(self): + self.model.train() + + example = self.example_queue.get() + if example is None: return -1 + graph, choice, value = example["graph"], example["choice"], example["pred_value"] + + graph = Data(**graph) + graph = graph.to(device) + + pred_choices, pred_value = self.model(graph) + choice, value = torch.tensor([choice], device=device), torch.tensor([value], device=device) + pred_choices, pred_value = pred_choices.unsqueeze(0).to(device), pred_value.squeeze(0).to(device) + loss = self.choice_loss_fn(pred_choices, choice) + 0.2 * self.value_loss_fn(pred_value, value) + + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + self.logger.log_loss(loss.item()) + self.n_examples_used += 1 + return 0 diff --git a/c++/MCTSNode.cpp b/c++/MCTSNode.cpp deleted file mode 100644 index 239dd3c..0000000 --- a/c++/MCTSNode.cpp +++ /dev/null @@ -1,149 +0,0 @@ -#include "MCTSNode.h" -#include -#include -#include -#include - -float tour_len(std::vector tour, std::vector> points); - -MCTSNode::MCTSNode(int n) { - this->parent = nullptr; - this->visits = 0; - this->total_score = 0.0; - this->avg_score = 0.0; - this->n = n; - this->tour = std::vector(); - this->tour.push_back(0); - this->remaining = std::set(); - for (int i = 1; i < n; i++) - this->remaining.insert(i); -} - -MCTSNode::MCTSNode(MCTSNode* p, std::vector tour, std::set remaining, int n) { - this->parent = p; - this->visits = 0; - this->total_score = 0.0; - this->avg_score = 0.0; - this->n = n; - this->tour = tour; - this->remaining = remaining; -} - -std::random_device MCTSNode::rd = std::random_device(); -std::mt19937 MCTSNode::g = std::mt19937(MCTSNode::rd()); - -bool MCTSNode::has_children() { - return this->children.size() > 0; -} - -bool MCTSNode::is_leaf() { - return this->tour.size() == this->n; -} - -bool MCTSNode::is_expanded() { - return this->children.size() == this->remaining.size(); -} - -std::vector MCTSNode::get_tour() { - std::vector t(this->tour); - t.push_back(t[0]); - return t; -} - -std::shared_ptr MCTSNode::best_child_score() { - float best_score = -1; - std::shared_ptr best_node(nullptr); - for (std::shared_ptr n : this->children) { - if (n->avg_score > best_score) { - best_score = n->avg_score; - best_node = n; - } - } - return best_node; -} - -std::shared_ptr MCTSNode::best_child_visits() { - float best_score = -1; - std::shared_ptr best_node(nullptr); - for (std::shared_ptr n : this->children) { - if (n->visits > best_score) { - best_score = n->visits; - best_node = n; - } - } - return best_node; -} - -std::shared_ptr MCTSNode::best_child_uct() { - float best_score = -1; - std::shared_ptr best_node(nullptr); - for (std::shared_ptr n : this->children) { - float score = n->avg_score + std::sqrt(2 * std::log(this->visits) / n->visits); - if (score > best_score) { - best_score = score; - best_node = n; - } - } - return best_node; -} - -std::shared_ptr MCTSNode::expand() { - std::uniform_int_distribution<> dis(0, this->remaining.size()-1); - auto it(this->remaining.begin()); - advance(it, dis(g)); - int k = *it; - - std::vector next_tour(this->tour); - next_tour.push_back(k); - - std::set next_remaining(this->remaining); - next_remaining.erase(k); - - std::shared_ptr m = std::make_shared(this, next_tour, next_remaining, this->n); - this->children.push_back(m); - - return m; -} - -void MCTSNode::backprop(float reward) { - this->visits += 1; - this->total_score += reward; - this->avg_score = this->total_score / (float)(this->visits); - if (this->parent != nullptr) { - this->parent->backprop(reward); - } -} - -float MCTSNode::simulate(std::vector> points) { - - // 1. randomly permute remaining nodes - std::vector r(this->remaining.begin(), this->remaining.end()); - std::shuffle(r.begin(), r.end(), this->g); - - // 2. merge current tour with permuted remaining nodes - std::vector sim_tour(this->tour); - sim_tour.insert(sim_tour.end(), r.begin(), r.end()); - sim_tour.push_back(sim_tour[0]); - - // 3. compute the length of the new tour and return - float len = tour_len(sim_tour, points); - return len; - -} - -float tour_len(std::vector tour, std::vector> points) { - float len = 0; - int d = points[0].size(); - int n = points.size(); - for (int i = 1; i < n+1; i++) { - float edge_len = 0; - for (int j = 0; j < d; j++) { - float diff = points[tour[i]][j] - points[tour[i-1]][j]; - diff = diff * diff; - edge_len += diff; - } - edge_len = std::sqrt(edge_len); - len += edge_len; - } - return len; -} diff --git a/c++/MCTSNode.h b/c++/MCTSNode.h deleted file mode 100644 index e655488..0000000 --- a/c++/MCTSNode.h +++ /dev/null @@ -1,35 +0,0 @@ -#ifndef MCTSNODEH -#define MCTSNODEH - -#include -#include -#include - -class MCTSNode { -private: - MCTSNode* parent; - std::vector tour; - std::set remaining; - int visits; - float total_score; - float avg_score; - int n; - static std::random_device rd; - static std::mt19937 g; -public: - MCTSNode(int n); - MCTSNode(MCTSNode* p, std::vector tour, std::set remaining, int n); - std::shared_ptr expand(); - void backprop(float reward); - float simulate(std::vector> points); - std::vector get_tour(); - bool has_children(); - bool is_leaf(); - bool is_expanded(); - std::shared_ptr best_child_score(); - std::shared_ptr best_child_visits(); - std::shared_ptr best_child_uct(); - std::vector> children; -}; - -#endif \ No newline at end of file diff --git a/c++/Makefile b/c++/Makefile deleted file mode 100644 index 0c60cd6..0000000 --- a/c++/Makefile +++ /dev/null @@ -1,13 +0,0 @@ -all: mcts - -mcts: mcts.o mctsnode.o - g++ --std=c++17 -O3 -o mcts mcts.o mctsnode.o - -mcts.o: mcts.cpp - g++ --std=c++17 -O3 -c -o mcts.o mcts.cpp - -mctsnode.o: MCTSNode.cpp MCTSNode.h - g++ --std=c++17 -O3 -c -o mctsnode.o MCTSNode.cpp - -clean: - rm *.o mcts \ No newline at end of file diff --git a/c++/mcts.cpp b/c++/mcts.cpp deleted file mode 100644 index 939b214..0000000 --- a/c++/mcts.cpp +++ /dev/null @@ -1,217 +0,0 @@ -#include -#include -#include -#include -#include "MCTSNode.h" - -std::shared_ptr mcts(std::shared_ptr rootnode, std::vector> points, int iterations); -float compute_tour_length(std::vector tour, std::vector> points); -std::vector greedy(std::vector> points); -void random_tours(std::vector> points); - -int main() { - - // 1. Create TSP instance - - int n = 60; - int d = 2; - int iterations = 1000; - - std::vector> points; - - std::random_device rd; - std::mt19937 gen(rd()); - std::uniform_real_distribution<> dis(0.0, 1.0); - - for (int i = 0; i < n; i++) { - std::vector p = {(float)dis(gen), (float)dis(gen)}; - points.push_back(p); - } - - // 2. Construct MCTS tree - - std::shared_ptr rootnode = std::make_shared(n); - std::shared_ptr node(rootnode); - - // 3. Run MCTS at each level of the tree - - while (!node->is_leaf()) { - node = mcts(node, points, iterations); - } - - // 4. Display result - - std::vector optimal_tour(node->get_tour()); - float optimal_tour_length = compute_tour_length(optimal_tour, points); - - for (int i = 0; i < optimal_tour.size(); i++) { - std::cout << optimal_tour[i]; - if (i != optimal_tour.size() - 1) { - std::cout << " -> "; - } - } - std::cout << std::endl; - std::cout << "Tour length: " << optimal_tour_length << std::endl; - - // 5. Run greedy - - std::vector greedy_tour = greedy(points); - float greedy_tour_length = compute_tour_length(greedy_tour, points); - - for (int i = 0; i < greedy_tour.size(); i++) { - std::cout << greedy_tour[i]; - if (i != greedy_tour.size() - 1) { - std::cout << " -> "; - } - } - std::cout << std::endl; - std::cout << "Greedy tour length: " << greedy_tour_length << std::endl; - - // 6. Random tours - random_tours(points); - - // 7. Return - return 0; -} - -std::shared_ptr mcts(std::shared_ptr rootnode, std::vector> points, int iterations) { - - int n = points.size(); - - // 1. Begin search - for (int it=0; it < iterations; it++) { - - std::shared_ptr node(rootnode); - - // 2. Descend - while (!node->is_leaf()) { - if (!node->is_expanded()) { - node = node->expand(); - break; - } else { - node = node->best_child_uct(); - } - } - - // 3. Simulate - float tour_len = node->simulate(points); - float reward = ((2.0 * n) - tour_len) / (2.0 * n); - - // 4. Backprop - node->backprop(reward); - - } - - // 5. Select and return best child node - return rootnode->best_child_score(); - -} - -float compute_tour_length(std::vector tour, std::vector> points) { - float len = 0; - int d = points[0].size(); - int n = points.size(); - for (int i = 1; i < n+1; i++) { - float edge_len = 0; - for (int j = 0; j < d; j++) { - float diff = points[tour[i]][j] - points[tour[i-1]][j]; - diff = diff * diff; - edge_len += diff; - } - edge_len = std::sqrt(edge_len); - len += edge_len; - } - return len; -} - -std::vector greedy(std::vector> points) { - - // 1. Get start node - std::random_device rd; - std::mt19937 gen(rd()); - std::uniform_int_distribution<> dis(0, points.size()-1); - int start = dis(gen); - - // 2. Start tour - std::vector tour = {start}; - - // 3. Compute remaining - std::set remaining; - for (int i = 0; i < points.size(); i++) { - if (i != start) { - remaining.insert(i); - } - } - - // 4. Build tour - std::vector pt1 = points[start]; - while (!remaining.empty()) { - - // 4.1 Compute min distance - int next_node = -1; - float min_dist = std::numeric_limits::max(); - for (int ind2 : remaining) { - - std::vector pt2 = points[ind2]; - - float edge_len = 0; - for (int j = 0; j < pt1.size(); j++) { - float diff = pt1[j] - pt2[j]; - diff = diff * diff; - edge_len += diff; - } - edge_len = std::sqrt(edge_len); - - if (edge_len < min_dist) { - min_dist = edge_len; - next_node = ind2; - } - - } - - // 4.2 Add to tour, remove from remaining - pt1 = points[next_node]; - tour.push_back(next_node); - remaining.erase(next_node); - - } - - // 5. Complete tour - tour.push_back(tour[0]); - - // 6. Return tour - return tour; - -} - -void random_tours(std::vector> points) { - - int iterations = 100000; - - std::random_device rd; - std::mt19937 gen(rd()); - - std::vector tour; - for (int i = 0; i < points.size(); i++) { - tour.push_back(i); - } - - float total_len = 0; - float best_len = std::numeric_limits::max(); - - for (int i = 0; i < iterations; i++) { - std::vector t(tour); - std::shuffle(t.begin(), t.end(), gen); - t.push_back(t[0]); - float l = compute_tour_length(t, points); - total_len += l; - if (l < best_len) - best_len = l; - } - - float avg_len = total_len / (float)iterations; - - std::cout << "Random avg length: " << avg_len << std::endl; - std::cout << "Random best length: " << best_len << std::endl; - -} diff --git a/main.py b/main.py index bdc09dd..4ab9b18 100644 --- a/main.py +++ b/main.py @@ -1,40 +1,17 @@ import argparse -import multiprocessing as mp from args import Args -from alphatsp.experiments import ( - nearestneighbor, - mcts, - exact, - gurobi, - insertion, - policy, - parallel, - selfplay -) +import importlib + +import torch.multiprocessing as mp +mp.set_sharing_strategy("file_system") def main(args): - a = Args() - if args.experiment == "nearestneighbor": - nearestneighbor.run(a) - elif args.experiment == "mcts": - mcts.run(a) - elif args.experiment == "exact": - exact.run(a) - elif args.experiment == "gurobi": - gurobi.run(a) - elif args.experiment == "insertion": - insertion.run(a) - elif args.experiment == "policy": - policy.run(a) - elif args.experiment == "parallel": - parallel.run(a) - elif args.experiment == "selfplay": - selfplay.run(a) - else: - raise ValueError("Invalid experiment selection.") + config = Args() + experiment = importlib.import_module(f"alphatsp.experiments.{args.experiment}") + experiment.run(config) if __name__ == "__main__": - mp.set_start_method('spawn', force=True) + mp.set_start_method("spawn", force=True) parser = argparse.ArgumentParser() parser.add_argument("--experiment", type=str, required=True, help="experiment name") args = parser.parse_args() diff --git a/requirements.txt b/requirements.txt index 5a68f7e..983afdd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,10 @@ torch numpy matplotlib torch_geometric +torch_sparse +torch_scatter +torch_cluster scipy cython tqdm -pyconcorde \ No newline at end of file +pyconcorde