Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -245,3 +245,4 @@ c++/mcts
backup/
old/
saves/
.idea/
67 changes: 67 additions & 0 deletions alphatsp/experiments/supervised.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import alphatsp.tsp
import alphatsp.util

import alphatsp.solvers.policy_solvers
from alphatsp.solvers.example_generators import NNExampleGenerator
from alphatsp.solvers.policy_networks import SupervisedPolicyNetworkTrainer

import torch
import numpy as np

import copy

from torch.multiprocessing import Process, Manager


def run(args):

# setup
N, D = args.N, args.D
n_examples = args.n_train_examples
n_threads = args.n_threads
n_examples_per_thread = n_examples//n_threads

# create policy network
policy_network = alphatsp.util.get_policy_network(args.policy_network)

# generate examples
print("Generating examples and training...")

manager = Manager()
train_queue = manager.Queue()
shared_dict = manager.dict()

shared_dict["success"] = False

producers = []
for _ in range(n_threads):
producers.append(Process(target=generate_examples, args=(n_examples_per_thread, train_queue, args)))

for p in producers:
p.start()

c = Process(target=train, args=(policy_network, train_queue, shared_dict, args))
c.start()

for p in producers:
p.join()
train_queue.put(None)

c.join()

status = shared_dict["success"]
if not status:
print("Experiment failed.")
return -1

def generate_examples(n_examples, train_queue, args):
generator = NNExampleGenerator(train_queue, args)
generator.generate_examples(n_examples)
return

def train(policy_network, train_queue, shared_dict, args):
trainer = SupervisedPolicyNetworkTrainer(policy_network, train_queue)
trainer.train_all()
shared_dict["model"] = copy.deepcopy(trainer.model.cpu())
shared_dict["success"] = True
return
80 changes: 80 additions & 0 deletions alphatsp/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import os
import datetime
import csv
import shutil
import torch
import numpy as np
import pandas as pd

import matplotlib
matplotlib.use("agg")
import matplotlib.pyplot as plt
plt.style.use("seaborn")

class Logger:

def __init__(self, args, enabled=True):
self.logging = enabled
if not self.logging:
return

self.dt = datetime.datetime.now().strftime("%m%d_%H%M")
self.path = f"./saves/{self.dt}"

if not os.path.exists(self.path):
os.makedirs(self.path)

self.losses = []
self.eval = []

self.main_log_fn = os.path.join(self.path, "log.txt")
shutil.copy2("args.py", self.path)

def save_model(self, model, iterations):
if not self.logging: return
if isinstance(iterations, int):
fn = os.path.join(self.path, f"policynet_{epoch:07d}.pth")
else:
fn = os.path.join(self.path, f"policynet_{epoch}.pth")
torch.save(model.state_dict(), fn)
self.print(f"Saved model to: {fn}\n")

def print(self, *x):
print(*x)
self.log(*x)

def log(self, *x):
if not self.logging: return
with open(self.main_log_fn, "a") as f:
print(*x, file=f, flush=True)

def log_loss(self, l):
self.losses.append(l)

def log_eval(self, data):
if not self.logging: return
self.eval.append(data)

def save(self):
if not self.logging: return

with open(os.path.join(self.path, "loss.csv"), "w") as f:
csvwriter = csv.DictWriter(f, ["it", "loss"])
csvwriter.writeheader()
for it, loss in enumerate(self.losses):
row = {"it": it, "loss": loss}
csvwriter.writerow(row)

with open(os.path.join(self.path, "eval.csv"), "w") as f:
cols = ["it"] + sorted(list(set(self.eval[0].keys()) - set(["it"])))
csvwriter = csv.DictWriter(f, cols)
csvwriter.writeheader()
for row in self.eval_scores:
csvwriter.writerow(row)

plt.clf()

plt.plot(self.losses)
plt.xlabel("iterations")
plt.ylabel("training loss")
plt.savefig(os.path.join(self.path, "losses.png"))
39 changes: 39 additions & 0 deletions alphatsp/solvers/example_generators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import torch
import copy
import random
from alphatsp.tsp import TSP
from alphatsp.solvers.mcts import MCTSNode, MCTSTree
from alphatsp.solvers import heuristics
from alphatsp.util import get_graph_constructor

class MCTSExampleGenerator:

Expand Down Expand Up @@ -82,3 +86,38 @@ def solve(self):
mcts_payoff = self.tsp.tour_length(mcts_tour)

return mcts_tour, mcts_payoff

class NNExampleGenerator:

def __init__(self, example_queue, args):
self.args = args
self.graph_constructor = get_graph_constructor(args.graph_construction)
self.example_queue = example_queue
self.n_samples = max(args.N//10, 1)

def generate_examples(self, n_examples):

for _ in range(n_examples//self.n_samples):

# generate tsp
tsp = TSP(self.args.N, self.args.D)

# solve
tour, tour_len = heuristics.nearest_greedy(tsp)

# generate examples
remaining = set(range(self.args.N))
for i in sorted(random.sample(range(self.args.N-1), self.n_samples)):

partial_tour = tour[:i]
remaining = remaining - set(partial_tour)
r = sorted(list(remaining))

graph = self.graph_constructor(tsp, partial_tour, r)

example = {
"graph": graph,
"choice": r.index(tour[i+1]),
"pred_value": tour_len
}
self.example_queue.put(example)
8 changes: 4 additions & 4 deletions alphatsp/solvers/graph_construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def construct_graph_grow(tsp, tour, remaining):

x = torch.cat([points, choices.unsqueeze(-1).to(dtype=torch.float)], dim=-1)

graph = Data(x=x, pos=points, edge_index=edges, y=choices)
graph = {"x": x, "pos": points, "edge_index": edges, "edge_attr": edge_lengths, "y": choices}
return graph

def construct_graph_prune(tsp, tour, remaining):
Expand All @@ -40,7 +40,7 @@ def construct_graph_prune(tsp, tour, remaining):

x = torch.cat([points, choices.unsqueeze(-1).to(dtype=torch.float)], dim=-1)

graph = Data(x=x, pos=points, edge_index=edges, y=choices)
graph = {"x": x, "pos": points, "edge_index": edges, "edge_attr": edge_lengths, "y": choices}
return graph

def construct_graph_prune_weighted(tsp, tour, remaining):
Expand Down Expand Up @@ -71,5 +71,5 @@ def construct_graph_prune_weighted(tsp, tour, remaining):

x = torch.cat([points, choices.unsqueeze(-1).to(dtype=torch.float)], dim=-1)

graph = Data(x=x, pos=points, edge_index=edges, edge_attr=edge_lengths, y=choices)
return graph
graph = {"x": x, "pos": points, "edge_index": edges, "edge_attr": edge_lengths, "y": choices}
return graph
77 changes: 68 additions & 9 deletions alphatsp/solvers/policy_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@
from torch_geometric.nn import GCNConv, global_mean_pool, ARMAConv, XConv, SAGEConv
from torch_geometric.data import Data, DataLoader

from alphatsp.logger import Logger

if torch.cuda.is_available(): device = torch.device("cuda:0")
else: device = torch.device("cpu")


class GCNPolicyNetwork(nn.Module):
def __init__(self, d=3):
super(GCNPolicyNetwork, self).__init__()
Expand All @@ -26,7 +32,7 @@ def forward(self, graph):
choice = torch.masked_select(c.squeeze(), choices)
choice = F.softmax(choice, dim=0)

v = global_mean_pool(x, torch.zeros(graph.num_nodes, dtype=torch.long))
v = global_mean_pool(x, torch.zeros(graph.num_nodes, dtype=torch.long, device=x.device))
value = self.fc(v)

return choice, value
Expand Down Expand Up @@ -59,12 +65,12 @@ def __init__(self, d=3):
num_layers=2,
shared_weights=True,
dropout=0.1,
act=None)
act=None).to(device)

self.fc = nn.Linear(16, 1)
self.fc = nn.Linear(16, 1).to(device)

def forward(self, graph):
x, edges, choices = graph.x, graph.edge_index, graph.y
x, edges, choices = graph['x'], graph['edge_index'], graph['y']

x = self.conv1(x, edges)
x = F.relu(x)
Expand All @@ -75,7 +81,7 @@ def forward(self, graph):
choice = torch.masked_select(c.squeeze(), choices)
choice = F.softmax(choice, dim=0)

v = global_mean_pool(x, torch.zeros(graph.num_nodes, dtype=torch.long))
v = global_mean_pool(x, torch.zeros(x.size(0), dtype=torch.long, device=x.device))
value = self.fc(v)

return choice, value
Expand All @@ -100,7 +106,7 @@ def forward(self, graph):
choice = torch.masked_select(c.squeeze(), choices)
choice = F.softmax(choice, dim=0)

v = global_mean_pool(x, torch.zeros(graph.num_nodes, dtype=torch.long))
v = global_mean_pool(x, torch.zeros(graph.num_nodes, dtype=torch.long, device=x.device))
value = self.fc(v)

return choice, value
Expand All @@ -125,7 +131,7 @@ def forward(self, graph):
choice = torch.masked_select(c.squeeze(), choices)
choice = F.softmax(choice, dim=0)

v = global_mean_pool(x, torch.zeros(graph.num_nodes, dtype=torch.long))
v = global_mean_pool(x, torch.zeros(graph.num_nodes, dtype=torch.long, device=x.device))
value = self.fc(v)

return choice, value
Expand All @@ -150,7 +156,7 @@ def forward(self, graph):
choice = torch.masked_select(c.squeeze(), choices)
choice = F.softmax(choice, dim=0)

v = global_mean_pool(x, torch.zeros(graph.num_nodes, dtype=torch.long))
v = global_mean_pool(x, torch.zeros(graph.num_nodes, dtype=torch.long, device=x.device))
value = self.fc(v)

return choice, value
Expand All @@ -174,6 +180,9 @@ def train_example(self):
if example is None: return -1
graph, choice_probs, value = example["graph"], example["choice_probs"], example["pred_value"]

graph = Data(**graph)
graph = graph.to(device)

pred_choices, pred_value = self.model(graph)
loss = self.loss_fn(pred_choices, choice_probs) + (0.2 * self.loss_fn(pred_value, value))

Expand All @@ -194,4 +203,54 @@ def train_all(self):
return 0

def save_model(self):
torch.save(self.model.state_dict(), f"saves/policynet_{self.n_examples_used:06d}.pth")
torch.save(self.model.state_dict(), f"saves/policynet_{self.n_examples_used:06d}.pth")

class SupervisedPolicyNetworkTrainer:

def __init__(self, model, example_queue):

self.model = model.to(device)
self.value_loss_fn = nn.MSELoss()
self.choice_loss_fn = nn.CrossEntropyLoss()
self.optimizer = torch.optim.Adam(params=self.model.parameters(), lr=1e-5)

self.example_queue = example_queue
self.n_examples_used = 0

self.logger = Logger()

def train_all(self):
while True:
if not self.example_queue.empty():
return_code = self.train_example()
if self.n_examples_used%1000 == 0 and self.n_examples_used!=0:
self.logger.print(f"iter={self.n_examples_used}, avg_loss={sum(self.logger.losses[-100:])/100:.4f}")
if self.n_examples_used%10000 == 0 and self.n_examples_used!=0:
self.logger.save_model(self.model, self.n_examples_used)
if return_code == -1:
self.logger.save()
self.logger.save_model(self.model, "final")
return

def train_example(self):
self.model.train()

example = self.example_queue.get()
if example is None: return -1
graph, choice, value = example["graph"], example["choice"], example["pred_value"]

graph = Data(**graph)
graph = graph.to(device)

pred_choices, pred_value = self.model(graph)
choice, value = torch.tensor([choice], device=device), torch.tensor([value], device=device)
pred_choices, pred_value = pred_choices.unsqueeze(0).to(device), pred_value.squeeze(0).to(device)
loss = self.choice_loss_fn(pred_choices, choice) + 0.2 * self.value_loss_fn(pred_value, value)

self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()

self.logger.log_loss(loss.item())
self.n_examples_used += 1
return 0
Loading