From 5541f921c79273d6d5dc428a4b64e755056ed020 Mon Sep 17 00:00:00 2001 From: Jacob Adamczyk <42879357+JacobHA@users.noreply.github.com> Date: Tue, 2 Jul 2024 11:12:33 -0400 Subject: [PATCH 1/7] aadd evaluation episodes in a new thread --- .gitignore | 3 ++- Architectures.py | 1 + BaseAgent.py | 18 +++++++++++++++++- DQN.py | 1 + Logger.py | 5 ++++- README.md | 14 ++++++++++++++ SoftQAgent.py | 5 +++-- tests/test_nature_cnn.py | 2 +- 8 files changed, 43 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index 24b1455..a064cc9 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,5 @@ __pycache__ *.tfevents.* logs *.png -*.prof \ No newline at end of file +*.prof +.pytest_cache \ No newline at end of file diff --git a/Architectures.py b/Architectures.py index 12ddd9d..6334ea5 100644 --- a/Architectures.py +++ b/Architectures.py @@ -69,6 +69,7 @@ def preprocess_obs(obs, device): if isinstance(obs, np.ndarray): obs = torch.from_numpy(obs) + # Pixel observations if obs.dtype == torch.uint8: if len(obs.shape) == 3: obs = obs.unsqueeze(0).to(device=device) diff --git a/BaseAgent.py b/BaseAgent.py index 9972065..91f515c 100644 --- a/BaseAgent.py +++ b/BaseAgent.py @@ -1,4 +1,5 @@ import os +import threading import time import numpy as np import torch @@ -153,6 +154,15 @@ def learn(self, total_timesteps: int): """ Train the agent for total_timesteps """ + stop_event = threading.Event() + def evaluation_worker(): + while not stop_event.is_set(): + self.evaluate(n_episodes=10) + # stop_event.wait(10) + + worker = threading.Thread(target=evaluation_worker) + worker.start() + # Start a timer to log fps: init_train_time = time.thread_time_ns() self.learn_env_steps = 0 @@ -191,14 +201,20 @@ def learn(self, total_timesteps: int): train_time = (time.thread_time_ns() - init_train_time) / 1e9 train_fps = self.log_interval / train_time self.log_history('time/train_fps', train_fps, self.learn_env_steps) - self.avg_eval_rwd = self.evaluate() + # Restart the worker: + self.avg_eval_rwd = worker.join() + # worker = threading.Thread(target=self.evaluate, args=(10,)) + init_train_time = time.thread_time_ns() pbar.update(self.log_interval) + stop_event.set() + # worker.join() if done: self.log_history("rollout/ep_reward", self.rollout_reward, self.learn_env_steps) self.log_history("rollout/avg_episode_length", avg_ep_len, self.learn_env_steps) + def _on_step(self) -> None: """ This method is called after every step in the environment diff --git a/DQN.py b/DQN.py index e8834e9..7e8d6b1 100644 --- a/DQN.py +++ b/DQN.py @@ -6,6 +6,7 @@ from BaseAgent import BaseAgent, get_new_params from utils import polyak + class DQN(BaseAgent): def __init__(self, *args, diff --git a/Logger.py b/Logger.py index 879688a..2adfc7e 100644 --- a/Logger.py +++ b/Logger.py @@ -75,7 +75,10 @@ def __init__(self, log_dir): self.writer = SummaryWriter(log_dir) def log_hparams(self, hparam_dict): for param, value in hparam_dict.items(): - self.writer.add_text(param, str(value), global_step=0) + if isinstance(value, str): + self.writer.add_text(param, value, global_step=0) + else: + self.writer.add_text(param, str(value), global_step=0) def log_history(self, param, value, step): self.writer.add_scalar(param, value, global_step=step) def log_video(self, video_path, name="video"): diff --git a/README.md b/README.md index 8ea8cad..e3c6355 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,21 @@ BARL: Base Agents for Reinforcement Learning This codebase provides implementations of many RL algorithms with the goal of being flexible for new algorithm ideas and fast experimentation. + + + ## TODO: +- [ ] Architecture kwargs (rename to model?) +- [ ] reduce branching in preprocess obs, maybe cache which preprocess for each env (map of env to function) +- [ ] add more tests +- [ ] dueling architecture +- [ ] hyperparameters +- [ ] random seeds +- [ ] double check Polyak averaging tau, timing +- [ ] Clean up eval thread + + + In the future we would like to have implemented: ### Algorithms: diff --git a/SoftQAgent.py b/SoftQAgent.py index 550de86..6a0b70b 100644 --- a/SoftQAgent.py +++ b/SoftQAgent.py @@ -4,7 +4,8 @@ import torch from Architectures import make_mlp -from BaseAgent import BaseAgent, get_new_params, AUCCallback +from BaseAgent import BaseAgent, get_new_params +from callbacks import AUCCallback from utils import polyak from Logger import WandBLogger, TensorboardLogger @@ -123,7 +124,7 @@ def _on_step(self) -> None: if __name__ == '__main__': import gymnasium as gym env = gym.make('Acrobot-v1') - logger = TensorboardLogger('logs/acro') + logger = TensorboardLogger('logs/acrb') #logger = WandBLogger(entity='jacobhadamczyk', project='test') mlp = make_mlp(env.unwrapped.observation_space.shape[0], env.unwrapped.action_space.n, hidden_dims=[32, 32]) agent = SoftQAgent(env, diff --git a/tests/test_nature_cnn.py b/tests/test_nature_cnn.py index 81b5a27..3621ef4 100644 --- a/tests/test_nature_cnn.py +++ b/tests/test_nature_cnn.py @@ -8,7 +8,7 @@ class TestCNN(unittest.TestCase): def test_cnn_with_stacked_states(self): # Initialize the environment env_id = "ALE/Pong-v5" - env, eval_env = env_id_to_envs(env_id, render=True, is_atari=True, permute_dims=True) + env, eval_env = env_id_to_envs(env_id, render=False, is_atari=True, permute_dims=True) # Create the CNN cnn = make_atari_nature_cnn(output_dim=env.action_space.n, input_dim=(84, 84, 4)) From d38cada4cc8444a21cb62527e9e6aa69900cfa3f Mon Sep 17 00:00:00 2001 From: JacobHA Date: Tue, 2 Jul 2024 11:55:56 -0400 Subject: [PATCH 2/7] comparison profiling, polyak needs device, new NullLogger, new polyak --- DQN.py | 32 +++++++++--------- Logger.py | 13 +++++++- compare_to_sb3.py | 84 +++++++++++++++++++++++++++++++++++++++++++++++ utils.py | 67 ++++++++++++++++++++++++------------- 4 files changed, 156 insertions(+), 40 deletions(-) create mode 100644 compare_to_sb3.py diff --git a/DQN.py b/DQN.py index e8834e9..841846f 100644 --- a/DQN.py +++ b/DQN.py @@ -74,7 +74,7 @@ def _on_step(self) -> None: # Periodically update the target network: if self.use_target_network and self.learn_env_steps % self.target_update_interval == 0: # Use Polyak averaging as specified: - polyak(self.online_qs, self.target_qs, self.polyak_tau) + polyak(self.online_qs, self.target_qs, self.polyak_tau, self.device) def exploration_policy(self, state: np.ndarray) -> int: @@ -114,7 +114,7 @@ def calculate_loss(self, batch): self.log_history("train/online_q_mean", curr_q.mean().item(), self.learn_env_steps) # log the loss: - logger.log_history("train/loss", loss.item(), self.learn_env_steps) + self.log_history("train/loss", loss.item(), self.learn_env_steps) return loss @@ -125,25 +125,25 @@ def calculate_loss(self, batch): from Logger import WandBLogger, TensorboardLogger logger = TensorboardLogger('logs/atari') - #logger = WandBLogger(entity='jacobhadamczyk', project='test') - # mlp = make_mlp(env.unwrapped.observation_space.shape[0], env.unwrapped.action_space.n, hidden_dims=[32, 32])#, activation=torch.nn.Mish) - # cnn = make_atari_nature_cnn(gym.make(env).action_space.n) - env = 'CartPole-v1' + + # env = 'CartPole-v1' agent = DQN(env, - architecture=make_mlp, - architecture_kwargs={'input_dim': gym.make(env).observation_space.shape[0], - 'output_dim': gym.make(env).action_space.n, - 'hidden_dims': [64, 64]}, + # architecture=make_mlp, + # architecture_kwargs={'input_dim': gym.make(env).observation_space.shape[0], + # 'output_dim': gym.make(env).action_space.n, + # 'hidden_dims': [64, 64]}, + architecture=make_atari_nature_cnn, + architecture_kwargs={'output_dim': gym.make(env).action_space.n}, loggers=(logger,), learning_rate=0.001, - train_interval=1, - gradient_steps=1, + train_interval=4, + gradient_steps=4, batch_size=64, use_target_network=True, - target_update_interval=10, + target_update_interval=1_000, polyak_tau=1.0, - learning_starts=1000, - log_interval=500, + learning_starts=5_000, + log_interval=2000, ) - agent.learn(total_timesteps=60_000) + agent.learn(total_timesteps=100_000) diff --git a/Logger.py b/Logger.py index 879688a..39a49b0 100644 --- a/Logger.py +++ b/Logger.py @@ -81,4 +81,15 @@ def log_history(self, param, value, step): def log_video(self, video_path, name="video"): self.writer.add_video(name, video_path) def log_image(self, image_path, name="image"): - self.writer.add_image(name, image_path) \ No newline at end of file + self.writer.add_image(name, image_path) + + +class NullLogger(BaseLogger): + def log_hparams(self, hparam_dict): + pass + def log_history(self, param, value, step): + pass + def log_video(self, video_path): + pass + def log_image(self, image_path): + pass \ No newline at end of file diff --git a/compare_to_sb3.py b/compare_to_sb3.py new file mode 100644 index 0000000..e3aa5e5 --- /dev/null +++ b/compare_to_sb3.py @@ -0,0 +1,84 @@ +import time +import numpy as np +from stable_baselines3 import DQN as sb3_DQN +from Architectures import make_atari_nature_cnn, make_mlp +from DQN import DQN as our_DQN +import gymnasium as gym + +from Logger import NullLogger + +env = gym.make('CartPole-v1') +n_steps = 10000 +# Note: we eliminate all logging and evaluation for this comparison + +sb3_agent = sb3_DQN('MlpPolicy', + env, + learning_rate=0.001, + buffer_size=n_steps, + learning_starts=0, + target_update_interval=10, + ) +our_agent = our_DQN(env, + loggers=(NullLogger(),), + architecture = make_mlp, + architecture_kwargs = {'input_dim': env.observation_space.shape[0], + 'output_dim': env.action_space.n, + 'hidden_dims': [64, 64], + 'device': 'cpu'}, + learning_rate=0.001, + train_interval=4, + gradient_steps=1, + batch_size=32, + use_target_network=True, + target_update_interval=10, + buffer_size=n_steps, + exploration_fraction=0.1, + log_interval=n_steps, + polyak_tau=1.0, + device='cpu') + +def time_learning(agent, n_steps): + start = time.time() + agent.learn(n_steps) + end = time.time() + return end - start + + + +# sb3_time = np.mean([time_learning(sb3_agent, n_steps) for _ in range(5)]) +# our_time = np.mean([time_learning(our_agent, n_steps) for _ in range(5)]) + +# print(f"SB3 took {sb3_time:.2f} seconds") +# print(f"Our agent took {our_time:.2f} seconds") + + +# Now do the same on Atari: +env = 'ALE/Pong-v5' + +sb3_agent = sb3_DQN('CnnPolicy', + env, + learning_rate=0.001, + buffer_size=n_steps, + learning_starts=0, + target_update_interval=100 + ) +our_agent = our_DQN(env, + loggers=(NullLogger(),), + architecture=make_atari_nature_cnn, + architecture_kwargs={'output_dim': gym.make(env).action_space.n}, + learning_rate=0.001, + train_interval=4, + gradient_steps=1, + batch_size=32, + use_target_network=True, + target_update_interval=100, + buffer_size=n_steps, + exploration_fraction=0.1, + log_interval=n_steps, + polyak_tau=1.0) + +sb3_time = np.mean([time_learning(sb3_agent, n_steps) for _ in range(3)]) +our_time = np.mean([time_learning(our_agent, n_steps) for _ in range(3)]) + +print(f"SB3 took {sb3_time:.2f} seconds") +print(f"Our agent took {our_time:.2f} seconds") \ No newline at end of file diff --git a/utils.py b/utils.py index 5efe4f2..2006885 100644 --- a/utils.py +++ b/utils.py @@ -96,29 +96,6 @@ def find_torch_modules(module, modules=None, prefix=None): return modules -def polyak(target_nets, online_nets, tau): - tau = 1 - tau - """ - Perform a Polyak (exponential moving average) update for target networks. - - Args: - online_nets (list): A list of online networks whose parameters will be used for the update. - tau (float): The update rate, typically between 0 and 1. - - Raises: - ValueError: If the number of online networks does not match the number of target networks. - """ - with torch.no_grad(): - # zip does not raise an exception if length of parameters does not match. - for new_params, target_params in zip(online_nets.parameters(), target_nets.parameters()): - # for new_param, target_param in zip_strict(new_params, target_params): - # target_param.data.mul_(tau).add_(new_param.data, alpha=1.0-tau) - #TODO: Remove dependency on stable_baselines3 by using in-place ops as above. - # zip does not raise an exception if length of parameters does not match. - for param, target_param in zip_strict(new_params, target_params): - target_param.data.mul_(1 - tau) - torch.add(target_param.data, param.data, alpha=tau, out=target_param.data) - def auto_device(device: Union[torch.device, str] = 'auto'): if device == 'auto': @@ -199,3 +176,47 @@ def atari_env_id_to_envs(env_id, render, n_envs, frameskip=1, framestack_k=None, eval_env = copy.deepcopy(env_id) return env, eval_env + + +def polyak(target_nets, online_nets, tau, device): + """ + Perform a Polyak (exponential moving average) update for target networks. + + Args: + online_nets (list): A list of online networks whose parameters will be used for the update. + tau (float): The update rate, typically between 0 and 1. + Returns: + None: operations are performed in-place. + """ + # Thanks to m-rph at + # https://github.com/DLR-RM/stable-baselines3/issues/93 + # for the fix to this function. Looks like correct device and addcmul are quite helpful. + # The only addition is the strict kwarg + one = torch.ones(1, requires_grad=False).to(device) + for param, target_param in zip(online_nets.parameters(), target_nets.parameters(), strict=True): + target_param.data.mul_(1 - tau) + target_param.data.addcmul_(param.data, one, value=tau) + + +# def polyak(target_nets, online_nets, tau): +# tau = 1 - tau +# """ +# Perform a Polyak (exponential moving average) update for target networks. + +# Args: +# online_nets (list): A list of online networks whose parameters will be used for the update. +# tau (float): The update rate, typically between 0 and 1. + +# Raises: +# ValueError: If the number of online networks does not match the number of target networks. +# """ +# with torch.no_grad(): +# # zip does not raise an exception if length of parameters does not match. +# for new_params, target_params in zip(online_nets.parameters(), target_nets.parameters()): +# # for new_param, target_param in zip_strict(new_params, target_params): +# # target_param.data.mul_(tau).add_(new_param.data, alpha=1.0-tau) +# #TODO: Remove dependency on stable_baselines3 by using in-place ops as above. +# # zip does not raise an exception if length of parameters does not match. +# for param, target_param in zip_strict(new_params, target_params): +# target_param.data.mul_(1 - tau) +# torch.add(target_param.data, param.data, alpha=tau, out=target_param.data) From d826bf638259c7192e165fd466b1aaefc87dff61 Mon Sep 17 00:00:00 2001 From: JacobHA Date: Tue, 2 Jul 2024 11:56:44 -0400 Subject: [PATCH 3/7] polyak --- .gitignore | 3 ++- SoftQAgent.py | 8 ++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index 24b1455..adb6448 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,5 @@ __pycache__ *.tfevents.* logs *.png -*.prof \ No newline at end of file +*.prof +*.sh \ No newline at end of file diff --git a/SoftQAgent.py b/SoftQAgent.py index 550de86..a70da1f 100644 --- a/SoftQAgent.py +++ b/SoftQAgent.py @@ -13,7 +13,7 @@ class SoftQAgent(BaseAgent): def __init__(self, *args, gamma: float = 0.99, - beta: float = 5.0, + beta: float = 1.0, use_target_network: bool = False, target_update_interval: Optional[int] = None, polyak_tau: Optional[float] = None, @@ -115,15 +115,15 @@ def _on_step(self) -> None: # Periodically update the target network: if self.use_target_network and self.learn_env_steps % self.target_update_interval == 0: # Use Polyak averaging as specified: - polyak(self.online_softqs, self.target_softqs, self.polyak_tau) + polyak(self.online_softqs, self.target_softqs, self.polyak_tau, self.device) super()._on_step() if __name__ == '__main__': import gymnasium as gym - env = gym.make('Acrobot-v1') - logger = TensorboardLogger('logs/acro') + env = gym.make('CartPole-v1') + logger = TensorboardLogger('logs/cartpole') #logger = WandBLogger(entity='jacobhadamczyk', project='test') mlp = make_mlp(env.unwrapped.observation_space.shape[0], env.unwrapped.action_space.n, hidden_dims=[32, 32]) agent = SoftQAgent(env, From 57c2b95e90a68c6e2a04e7c90e5324178a24e240 Mon Sep 17 00:00:00 2001 From: JacobHA Date: Tue, 2 Jul 2024 12:00:49 -0400 Subject: [PATCH 4/7] fix test, don't tear down prematurely --- tests/test_loggers.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_loggers.py b/tests/test_loggers.py index 9310be9..09effb9 100644 --- a/tests/test_loggers.py +++ b/tests/test_loggers.py @@ -52,8 +52,6 @@ class TestTensorboardLogger(unittest.TestCase): def setUp(self): self.log_dir = 'test_log_dir' - if os.path.exists(self.log_dir): - shutil.rmtree(self.log_dir) self.tensorboard_logger = TensorboardLogger(self.log_dir) def tearDown(self): From fcd0acc3e3b5b73b36f2472e52a290227bb63469 Mon Sep 17 00:00:00 2001 From: JacobHA Date: Tue, 2 Jul 2024 12:21:27 -0400 Subject: [PATCH 5/7] adding kwarg for threading, typo --- BaseAgent.py | 2 + DQN.py | 2 +- compare_eval_threaded.py | 87 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 90 insertions(+), 1 deletion(-) create mode 100644 compare_eval_threaded.py diff --git a/BaseAgent.py b/BaseAgent.py index 9972065..d85d744 100644 --- a/BaseAgent.py +++ b/BaseAgent.py @@ -40,6 +40,7 @@ def __init__(self, save_checkpoints: bool = False, seed: Optional[int] = None, eval_callbacks: List[callable] = [], + use_threaded_eval: bool = True, ) -> None: self.LOG_PARAMS = { @@ -77,6 +78,7 @@ def __init__(self, self.batch_size = batch_size self.loggers = loggers + self.use_threaded_eval = use_threaded_eval self.gradient_steps = gradient_steps self.device = auto_device(device) diff --git a/DQN.py b/DQN.py index 841846f..10ac46c 100644 --- a/DQN.py +++ b/DQN.py @@ -23,7 +23,7 @@ def __init__(self, super().__init__(*args, **kwargs) self.kwargs = get_new_params(self, locals()) - self.algo_name = 'SQL' + self.algo_name = 'DQN' self.gamma = gamma self.minimum_epsilon = minimum_epsilon self.exploration_fraction = exploration_fraction diff --git a/compare_eval_threaded.py b/compare_eval_threaded.py new file mode 100644 index 0000000..9932865 --- /dev/null +++ b/compare_eval_threaded.py @@ -0,0 +1,87 @@ +import time +import numpy as np +from Architectures import make_atari_nature_cnn, make_mlp +from DQN import DQN +import gymnasium as gym + +from Logger import NullLogger + +# env = gym.make('CartPole-v1') +n_steps = 10000 +# Note: we eliminate all logging and evaluation for this comparison + +# sb3_agent = sb3_DQN('MlpPolicy', +# env, +# learning_rate=0.001, +# buffer_size=n_steps, +# learning_starts=0, +# target_update_interval=10, +# ) +# our_agent = our_DQN(env, +# loggers=(NullLogger(),), +# architecture = make_mlp, +# architecture_kwargs = {'input_dim': env.observation_space.shape[0], +# 'output_dim': env.action_space.n, +# 'hidden_dims': [64, 64], +# 'device': 'cpu'}, +# learning_rate=0.001, +# train_interval=4, +# gradient_steps=1, +# batch_size=32, +# use_target_network=True, +# target_update_interval=10, +# buffer_size=n_steps, +# exploration_fraction=0.1, +# log_interval=n_steps, +# polyak_tau=1.0, +# device='cpu') + +def time_learning(agent, n_steps): + start = time.time() + agent.learn(n_steps) + end = time.time() + return end - start + + + +env = 'ALE/Pong-v5' + +threaded_agent = DQN(env, + loggers=(NullLogger(),), + architecture=make_atari_nature_cnn, + architecture_kwargs={'output_dim': gym.make(env).action_space.n}, + learning_rate=0.001, + train_interval=4, + gradient_steps=1, + batch_size=32, + use_target_network=True, + target_update_interval=100, + buffer_size=n_steps, + exploration_fraction=0.1, + log_interval=n_steps // 10, + polyak_tau=1.0, + use_threaded_eval=True) + + +unthreaded_agent = DQN(env, + loggers=(NullLogger(),), + architecture=make_atari_nature_cnn, + architecture_kwargs={'output_dim': gym.make(env).action_space.n}, + learning_rate=0.001, + train_interval=4, + gradient_steps=1, + batch_size=32, + use_target_network=True, + target_update_interval=100, + buffer_size=n_steps, + exploration_fraction=0.1, + log_interval=n_steps // 10, + polyak_tau=1.0, + use_threaded_eval=False) + + +unthreaded_time = np.mean([time_learning(unthreaded_agent, n_steps) for _ in range(3)]) +threaded_time = np.mean([time_learning(threaded_agent, n_steps) for _ in range(3)]) + +print(f"Un-threaded (standard) evaluation training took {unthreaded_time:.2f} seconds") +print(f"Threaded (new) evaluation training agent took {threaded_time:.2f} seconds") \ No newline at end of file From a6a79d2ba05623165ce83588fd77a16dcf82d5cb Mon Sep 17 00:00:00 2001 From: JacobHA Date: Tue, 2 Jul 2024 12:24:44 -0400 Subject: [PATCH 6/7] added conditionals for if the agent's eval should be threaded --- BaseAgent.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/BaseAgent.py b/BaseAgent.py index d006f56..6fc00b8 100644 --- a/BaseAgent.py +++ b/BaseAgent.py @@ -156,14 +156,15 @@ def learn(self, total_timesteps: int): """ Train the agent for total_timesteps """ - stop_event = threading.Event() - def evaluation_worker(): - while not stop_event.is_set(): - self.evaluate(n_episodes=10) - # stop_event.wait(10) + if self.use_threaded_eval: + stop_event = threading.Event() + def evaluation_worker(): + while not stop_event.is_set(): + self.evaluate(n_episodes=10) + # stop_event.wait(10) - worker = threading.Thread(target=evaluation_worker) - worker.start() + worker = threading.Thread(target=evaluation_worker) + worker.start() # Start a timer to log fps: init_train_time = time.thread_time_ns() @@ -203,14 +204,17 @@ def evaluation_worker(): train_time = (time.thread_time_ns() - init_train_time) / 1e9 train_fps = self.log_interval / train_time self.log_history('time/train_fps', train_fps, self.learn_env_steps) - # Restart the worker: - self.avg_eval_rwd = worker.join() - # worker = threading.Thread(target=self.evaluate, args=(10,)) + if self.use_threaded_eval: + # Restart the worker: + self.avg_eval_rwd = worker.join() + # worker = threading.Thread(target=self.evaluate, args=(10,)) + else: + self.avg_eval_rwd = self.evaluate(n_episodes=10) init_train_time = time.thread_time_ns() pbar.update(self.log_interval) - - stop_event.set() + if self.use_threaded_eval: + stop_event.set() # worker.join() if done: self.log_history("rollout/ep_reward", self.rollout_reward, self.learn_env_steps) From 63c3ebbe9e2590fa7bc5204a3df1a0fe9840ccab Mon Sep 17 00:00:00 2001 From: Jacob Adamczyk <42879357+JacobHA@users.noreply.github.com> Date: Tue, 2 Jul 2024 23:18:48 -0400 Subject: [PATCH 7/7] trying to rewrite threading for proper logging --- BaseAgent.py | 32 +++++++++++++++++++++++++------- DQN.py | 35 +++++++++++++++++++---------------- SoftQAgent.py | 1 + 3 files changed, 45 insertions(+), 23 deletions(-) diff --git a/BaseAgent.py b/BaseAgent.py index 6fc00b8..44a09c4 100644 --- a/BaseAgent.py +++ b/BaseAgent.py @@ -174,6 +174,17 @@ def evaluation_worker(): with tqdm.tqdm(total=total_timesteps, desc="Training") as pbar: while self.learn_env_steps < total_timesteps: + # Check if worker is alive: + if self.use_threaded_eval: + # stop_event = threading.Event() + + def evaluation_worker(): + # while not stop_event.is_set(): + self.evaluate(n_episodes=10) + + worker = threading.Thread(target=evaluation_worker) + worker.start() + state, _ = self.env.reset() done = False @@ -206,16 +217,25 @@ def evaluation_worker(): self.log_history('time/train_fps', train_fps, self.learn_env_steps) if self.use_threaded_eval: # Restart the worker: - self.avg_eval_rwd = worker.join() - # worker = threading.Thread(target=self.evaluate, args=(10,)) + # stop_event.set() + worker.join() # Wait for the worker to finish + # stop_event.clear() # Clear the stop event before restarting the worker + # self.log_history('eval/avg_episode_length', n_steps / n_episodes, self.learn_env_steps) + # self.log_history('eval/time', eval_time, self.learn_env_steps) + # self.log_history('eval/fps', eval_fps, self.learn_env_steps) + worker = threading.Thread(target=evaluation_worker) + + worker.start() else: self.avg_eval_rwd = self.evaluate(n_episodes=10) + + self.log_history('eval/avg_reward', self.avg_eval_rwd, self.learn_env_steps) init_train_time = time.thread_time_ns() pbar.update(self.log_interval) if self.use_threaded_eval: stop_event.set() - # worker.join() + if done: self.log_history("rollout/ep_reward", self.rollout_reward, self.learn_env_steps) self.log_history("rollout/avg_episode_length", avg_ep_len, self.learn_env_steps) @@ -259,10 +279,8 @@ def evaluate(self, n_episodes=10) -> float: avg_reward /= n_episodes eval_fps = n_steps / eval_time self.eval_time = eval_time - self.log_history('eval/avg_reward', avg_reward, self.learn_env_steps) - self.log_history('eval/avg_episode_length', n_steps / n_episodes, self.learn_env_steps) - self.log_history('eval/time', eval_time, self.learn_env_steps) - self.log_history('eval/fps', eval_fps, self.learn_env_steps) + self.avg_eval_rwd = avg_reward + for callback in self.eval_callbacks: callback(self, end=True) return avg_reward diff --git a/DQN.py b/DQN.py index ba99ac7..2987903 100644 --- a/DQN.py +++ b/DQN.py @@ -67,7 +67,9 @@ def _on_step(self) -> None: super()._on_step() # Update epsilon: - self.epsilon = max(self.minimum_epsilon, (self.initial_epsilon - self.learn_env_steps / self.total_timesteps / self.exploration_fraction)) + self.epsilon = max(self.minimum_epsilon, + (self.initial_epsilon - + self.learn_env_steps / self.total_timesteps / self.exploration_fraction)) if self.learn_env_steps % self.log_interval == 0: self.log_history("train/epsilon", self.epsilon, self.learn_env_steps) @@ -98,6 +100,7 @@ def calculate_loss(self, batch): dones = dones.float() curr_q = self.online_qs(states).squeeze().gather(1, actions.long()) with torch.no_grad(): + # TODO: push this into pre-processing to clean up the agent code: if isinstance(self.env.observation_space, gymnasium.spaces.Discrete): states = states.squeeze() next_states = next_states.squeeze() @@ -127,24 +130,24 @@ def calculate_loss(self, batch): from Logger import WandBLogger, TensorboardLogger logger = TensorboardLogger('logs/atari') - # env = 'CartPole-v1' + env = 'CartPole-v1' agent = DQN(env, - # architecture=make_mlp, - # architecture_kwargs={'input_dim': gym.make(env).observation_space.shape[0], - # 'output_dim': gym.make(env).action_space.n, - # 'hidden_dims': [64, 64]}, - architecture=make_atari_nature_cnn, - architecture_kwargs={'output_dim': gym.make(env).action_space.n}, + architecture=make_mlp, + architecture_kwargs={'input_dim': gym.make(env).observation_space.shape[0], + 'output_dim': gym.make(env).action_space.n, + 'hidden_dims': [64, 64]}, + # architecture=make_atari_nature_cnn, + # architecture_kwargs={'output_dim': gym.make(env).action_space.n}, loggers=(logger,), learning_rate=0.001, - train_interval=4, - gradient_steps=4, + train_interval=1, + gradient_steps=1, batch_size=64, - use_target_network=True, - target_update_interval=1_000, - polyak_tau=1.0, - learning_starts=5_000, - log_interval=2000, - + use_target_network=False, + # target_update_interval=10, + # polyak_tau=0.0, + learning_starts=100, + log_interval=200, + use_threaded_eval=False, ) agent.learn(total_timesteps=100_000) diff --git a/SoftQAgent.py b/SoftQAgent.py index def0bcf..573ec76 100644 --- a/SoftQAgent.py +++ b/SoftQAgent.py @@ -139,5 +139,6 @@ def _on_step(self) -> None: target_update_interval=10, polyak_tau=1.0, eval_callbacks=[AUCCallback], + use_threaded_eval=False, ) agent.learn(total_timesteps=50000)