From 68cf686046065e3a84ccb903fbfe36f62f5c80f2 Mon Sep 17 00:00:00 2001 From: Yuhao Chen Date: Sun, 2 Feb 2025 21:38:39 -0800 Subject: [PATCH 1/2] succefully build our new dataset --- v2/red_gym_env_v2.py | 1 - v2/run_pretrained_interactive.py | 28 +++++++++++++++++++++++++++- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/v2/red_gym_env_v2.py b/v2/red_gym_env_v2.py index be9971d2..797ef6c7 100644 --- a/v2/red_gym_env_v2.py +++ b/v2/red_gym_env_v2.py @@ -202,7 +202,6 @@ def step(self, action): if self.save_video and self.step_count == 0: self.start_video() - self.run_action_on_emulator(action) self.append_agent_stats(action) diff --git a/v2/run_pretrained_interactive.py b/v2/run_pretrained_interactive.py index a4db9758..5fdb3f25 100644 --- a/v2/run_pretrained_interactive.py +++ b/v2/run_pretrained_interactive.py @@ -10,6 +10,13 @@ from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv from stable_baselines3.common.utils import set_random_seed from stable_baselines3.common.callbacks import CheckpointCallback +import numpy as np +import matplotlib.pyplot as plt + +""" +Our dataset +""" +X = [] def make_env(rank, env_conf, seed=0): """ @@ -71,7 +78,16 @@ def get_most_recent_zip_with_age(folder_path): #keyboard.on_press_key("M", toggle_agent) obs, info = env.reset() + counter = 0 + x = np.zeros((3, 73, 80)) + while True: + if counter > 2: + counter = 0 + # Combine the last 3 frames + combined = np.hstack((x[0], x[1], x[2])) + X.append(combined) + action = 7 # pass action try: with open("agent_enabled.txt", "r") as f: @@ -81,7 +97,17 @@ def get_most_recent_zip_with_age(folder_path): if agent_enabled: action, _states = model.predict(obs, deterministic=False) obs, rewards, terminated, truncated, info = env.step(action) - env.render() + game_pixels = env.render() + # Combine game_pxiels with action + action = np.full((1, 80), action) + pixel_image = game_pixels.squeeze() + action_image = action.squeeze() + combined = np.vstack((pixel_image, action_image)) + x[counter] = combined + + counter += 1 + + if truncated: break env.close() From c31784248e02b8eed2400a71e98d841372229535 Mon Sep 17 00:00:00 2001 From: Yuhao Chen Date: Thu, 6 Feb 2025 13:09:55 -0800 Subject: [PATCH 2/2] commented out visualization --- v2/run_pretrained_interactive.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/v2/run_pretrained_interactive.py b/v2/run_pretrained_interactive.py index 5fdb3f25..56931b14 100644 --- a/v2/run_pretrained_interactive.py +++ b/v2/run_pretrained_interactive.py @@ -86,8 +86,13 @@ def get_most_recent_zip_with_age(folder_path): counter = 0 # Combine the last 3 frames combined = np.hstack((x[0], x[1], x[2])) + # plt.imshow(combined) + # plt.axis('off') + # plt.show() + X.append(combined) + action = 7 # pass action try: with open("agent_enabled.txt", "r") as f: @@ -107,7 +112,6 @@ def get_most_recent_zip_with_age(folder_path): counter += 1 - if truncated: break env.close()