Skip to content
Open
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@ __pycache__
*.tfevents.*
logs
*.png
*.prof
*.prof
*.sh
.pytest_cache
1 change: 1 addition & 0 deletions Architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def preprocess_obs(obs, device):
if isinstance(obs, np.ndarray):
obs = torch.from_numpy(obs)

# Pixel observations
if obs.dtype == torch.uint8:
if len(obs.shape) == 3:
obs = obs.unsqueeze(0).to(device=device)
Expand Down
50 changes: 45 additions & 5 deletions BaseAgent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import threading
import time
import numpy as np
import torch
Expand Down Expand Up @@ -40,6 +41,7 @@ def __init__(self,
save_checkpoints: bool = False,
seed: Optional[int] = None,
eval_callbacks: List[callable] = [],
use_threaded_eval: bool = True,
) -> None:

self.LOG_PARAMS = {
Expand Down Expand Up @@ -77,6 +79,7 @@ def __init__(self,
self.batch_size = batch_size

self.loggers = loggers
self.use_threaded_eval = use_threaded_eval

self.gradient_steps = gradient_steps
self.device = auto_device(device)
Expand Down Expand Up @@ -153,6 +156,16 @@ def learn(self, total_timesteps: int):
"""
Train the agent for total_timesteps
"""
if self.use_threaded_eval:
stop_event = threading.Event()
def evaluation_worker():
while not stop_event.is_set():
self.evaluate(n_episodes=10)
# stop_event.wait(10)

worker = threading.Thread(target=evaluation_worker)
worker.start()

# Start a timer to log fps:
init_train_time = time.thread_time_ns()
self.learn_env_steps = 0
Expand All @@ -161,6 +174,17 @@ def learn(self, total_timesteps: int):
with tqdm.tqdm(total=total_timesteps, desc="Training") as pbar:

while self.learn_env_steps < total_timesteps:
# Check if worker is alive:
if self.use_threaded_eval:
# stop_event = threading.Event()

def evaluation_worker():
# while not stop_event.is_set():
self.evaluate(n_episodes=10)

worker = threading.Thread(target=evaluation_worker)
worker.start()

state, _ = self.env.reset()

done = False
Expand Down Expand Up @@ -191,14 +215,32 @@ def learn(self, total_timesteps: int):
train_time = (time.thread_time_ns() - init_train_time) / 1e9
train_fps = self.log_interval / train_time
self.log_history('time/train_fps', train_fps, self.learn_env_steps)
self.avg_eval_rwd = self.evaluate()
if self.use_threaded_eval:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like we are waiting for the finish of the old worker but not starting the new worker thread when self.use_threaded_eval is true

# Restart the worker:
# stop_event.set()
worker.join() # Wait for the worker to finish
# stop_event.clear() # Clear the stop event before restarting the worker
# self.log_history('eval/avg_episode_length', n_steps / n_episodes, self.learn_env_steps)
# self.log_history('eval/time', eval_time, self.learn_env_steps)
# self.log_history('eval/fps', eval_fps, self.learn_env_steps)
worker = threading.Thread(target=evaluation_worker)

worker.start()
else:
self.avg_eval_rwd = self.evaluate(n_episodes=10)

self.log_history('eval/avg_reward', self.avg_eval_rwd, self.learn_env_steps)

init_train_time = time.thread_time_ns()
pbar.update(self.log_interval)
if self.use_threaded_eval:
stop_event.set()

if done:
self.log_history("rollout/ep_reward", self.rollout_reward, self.learn_env_steps)
self.log_history("rollout/avg_episode_length", avg_ep_len, self.learn_env_steps)


def _on_step(self) -> None:
"""
This method is called after every step in the environment
Expand Down Expand Up @@ -237,10 +279,8 @@ def evaluate(self, n_episodes=10) -> float:
avg_reward /= n_episodes
eval_fps = n_steps / eval_time
self.eval_time = eval_time
self.log_history('eval/avg_reward', avg_reward, self.learn_env_steps)
self.log_history('eval/avg_episode_length', n_steps / n_episodes, self.learn_env_steps)
self.log_history('eval/time', eval_time, self.learn_env_steps)
self.log_history('eval/fps', eval_fps, self.learn_env_steps)
self.avg_eval_rwd = avg_reward

for callback in self.eval_callbacks:
callback(self, end=True)
return avg_reward
Expand Down
32 changes: 18 additions & 14 deletions DQN.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from BaseAgent import BaseAgent, get_new_params
from utils import polyak


class DQN(BaseAgent):
def __init__(self,
*args,
Expand All @@ -23,7 +24,7 @@ def __init__(self,
super().__init__(*args, **kwargs)
self.kwargs = get_new_params(self, locals())

self.algo_name = 'SQL'
self.algo_name = 'DQN'
self.gamma = gamma
self.minimum_epsilon = minimum_epsilon
self.exploration_fraction = exploration_fraction
Expand Down Expand Up @@ -66,15 +67,17 @@ def _on_step(self) -> None:
super()._on_step()

# Update epsilon:
self.epsilon = max(self.minimum_epsilon, (self.initial_epsilon - self.learn_env_steps / self.total_timesteps / self.exploration_fraction))
self.epsilon = max(self.minimum_epsilon,
(self.initial_epsilon -
self.learn_env_steps / self.total_timesteps / self.exploration_fraction))

if self.learn_env_steps % self.log_interval == 0:
self.log_history("train/epsilon", self.epsilon, self.learn_env_steps)

# Periodically update the target network:
if self.use_target_network and self.learn_env_steps % self.target_update_interval == 0:
# Use Polyak averaging as specified:
polyak(self.online_qs, self.target_qs, self.polyak_tau)
polyak(self.online_qs, self.target_qs, self.polyak_tau, self.device)


def exploration_policy(self, state: np.ndarray) -> int:
Expand All @@ -97,6 +100,7 @@ def calculate_loss(self, batch):
dones = dones.float()
curr_q = self.online_qs(states).squeeze().gather(1, actions.long())
with torch.no_grad():
# TODO: push this into pre-processing to clean up the agent code:
if isinstance(self.env.observation_space, gymnasium.spaces.Discrete):
states = states.squeeze()
next_states = next_states.squeeze()
Expand All @@ -114,7 +118,7 @@ def calculate_loss(self, batch):

self.log_history("train/online_q_mean", curr_q.mean().item(), self.learn_env_steps)
# log the loss:
logger.log_history("train/loss", loss.item(), self.learn_env_steps)
self.log_history("train/loss", loss.item(), self.learn_env_steps)

return loss

Expand All @@ -125,25 +129,25 @@ def calculate_loss(self, batch):

from Logger import WandBLogger, TensorboardLogger
logger = TensorboardLogger('logs/atari')
#logger = WandBLogger(entity='jacobhadamczyk', project='test')
# mlp = make_mlp(env.unwrapped.observation_space.shape[0], env.unwrapped.action_space.n, hidden_dims=[32, 32])#, activation=torch.nn.Mish)
# cnn = make_atari_nature_cnn(gym.make(env).action_space.n)

env = 'CartPole-v1'
agent = DQN(env,
architecture=make_mlp,
architecture_kwargs={'input_dim': gym.make(env).observation_space.shape[0],
'output_dim': gym.make(env).action_space.n,
'hidden_dims': [64, 64]},
# architecture=make_atari_nature_cnn,
# architecture_kwargs={'output_dim': gym.make(env).action_space.n},
loggers=(logger,),
learning_rate=0.001,
train_interval=1,
gradient_steps=1,
batch_size=64,
use_target_network=True,
target_update_interval=10,
polyak_tau=1.0,
learning_starts=1000,
log_interval=500,

use_target_network=False,
# target_update_interval=10,
# polyak_tau=0.0,
learning_starts=100,
log_interval=200,
use_threaded_eval=False,
)
agent.learn(total_timesteps=60_000)
agent.learn(total_timesteps=100_000)
18 changes: 16 additions & 2 deletions Logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,24 @@ def __init__(self, log_dir):
self.writer = SummaryWriter(log_dir)
def log_hparams(self, hparam_dict):
for param, value in hparam_dict.items():
self.writer.add_text(param, str(value), global_step=0)
if isinstance(value, str):
self.writer.add_text(param, value, global_step=0)
else:
self.writer.add_text(param, str(value), global_step=0)
def log_history(self, param, value, step):
self.writer.add_scalar(param, value, global_step=step)
def log_video(self, video_path, name="video"):
self.writer.add_video(name, video_path)
def log_image(self, image_path, name="image"):
self.writer.add_image(name, image_path)
self.writer.add_image(name, image_path)


class NullLogger(BaseLogger):
def log_hparams(self, hparam_dict):
pass
def log_history(self, param, value, step):
pass
def log_video(self, video_path):
pass
def log_image(self, image_path):
pass
12 changes: 7 additions & 5 deletions SoftQAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import torch

from Architectures import make_mlp
from BaseAgent import BaseAgent, get_new_params, AUCCallback
from BaseAgent import BaseAgent, get_new_params
from callbacks import AUCCallback
from utils import polyak
from Logger import WandBLogger, TensorboardLogger

Expand All @@ -13,7 +14,7 @@ class SoftQAgent(BaseAgent):
def __init__(self,
*args,
gamma: float = 0.99,
beta: float = 5.0,
beta: float = 1.0,
use_target_network: bool = False,
target_update_interval: Optional[int] = None,
polyak_tau: Optional[float] = None,
Expand Down Expand Up @@ -115,15 +116,15 @@ def _on_step(self) -> None:
# Periodically update the target network:
if self.use_target_network and self.learn_env_steps % self.target_update_interval == 0:
# Use Polyak averaging as specified:
polyak(self.online_softqs, self.target_softqs, self.polyak_tau)
polyak(self.online_softqs, self.target_softqs, self.polyak_tau, self.device)

super()._on_step()


if __name__ == '__main__':
import gymnasium as gym
env = gym.make('Acrobot-v1')
logger = TensorboardLogger('logs/acro')
env = gym.make('CartPole-v1')
logger = TensorboardLogger('logs/cartpole')
#logger = WandBLogger(entity='jacobhadamczyk', project='test')
mlp = make_mlp(env.unwrapped.observation_space.shape[0], env.unwrapped.action_space.n, hidden_dims=[32, 32])
agent = SoftQAgent(env,
Expand All @@ -138,5 +139,6 @@ def _on_step(self) -> None:
target_update_interval=10,
polyak_tau=1.0,
eval_callbacks=[AUCCallback],
use_threaded_eval=False,
)
agent.learn(total_timesteps=50000)
87 changes: 87 additions & 0 deletions compare_eval_threaded.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import time
import numpy as np
from Architectures import make_atari_nature_cnn, make_mlp
from DQN import DQN
import gymnasium as gym

from Logger import NullLogger

# env = gym.make('CartPole-v1')
n_steps = 10000
# Note: we eliminate all logging and evaluation for this comparison

# sb3_agent = sb3_DQN('MlpPolicy',
# env,
# learning_rate=0.001,
# buffer_size=n_steps,
# learning_starts=0,
# target_update_interval=10,
# )
# our_agent = our_DQN(env,
# loggers=(NullLogger(),),
# architecture = make_mlp,
# architecture_kwargs = {'input_dim': env.observation_space.shape[0],
# 'output_dim': env.action_space.n,
# 'hidden_dims': [64, 64],
# 'device': 'cpu'},
# learning_rate=0.001,
# train_interval=4,
# gradient_steps=1,
# batch_size=32,
# use_target_network=True,
# target_update_interval=10,
# buffer_size=n_steps,
# exploration_fraction=0.1,
# log_interval=n_steps,
# polyak_tau=1.0,
# device='cpu')

def time_learning(agent, n_steps):
start = time.time()
agent.learn(n_steps)
end = time.time()
return end - start



env = 'ALE/Pong-v5'

threaded_agent = DQN(env,
loggers=(NullLogger(),),
architecture=make_atari_nature_cnn,
architecture_kwargs={'output_dim': gym.make(env).action_space.n},
learning_rate=0.001,
train_interval=4,
gradient_steps=1,
batch_size=32,
use_target_network=True,
target_update_interval=100,
buffer_size=n_steps,
exploration_fraction=0.1,
log_interval=n_steps // 10,
polyak_tau=1.0,
use_threaded_eval=True)


unthreaded_agent = DQN(env,
loggers=(NullLogger(),),
architecture=make_atari_nature_cnn,
architecture_kwargs={'output_dim': gym.make(env).action_space.n},
learning_rate=0.001,
train_interval=4,
gradient_steps=1,
batch_size=32,
use_target_network=True,
target_update_interval=100,
buffer_size=n_steps,
exploration_fraction=0.1,
log_interval=n_steps // 10,
polyak_tau=1.0,
use_threaded_eval=False)


unthreaded_time = np.mean([time_learning(unthreaded_agent, n_steps) for _ in range(3)])
threaded_time = np.mean([time_learning(threaded_agent, n_steps) for _ in range(3)])

print(f"Un-threaded (standard) evaluation training took {unthreaded_time:.2f} seconds")
print(f"Threaded (new) evaluation training agent took {threaded_time:.2f} seconds")
Loading