diff --git a/v2/red_gym_env_v2.py b/v2/red_gym_env_v2.py index be9971d2..797ef6c7 100644 --- a/v2/red_gym_env_v2.py +++ b/v2/red_gym_env_v2.py @@ -202,7 +202,6 @@ def step(self, action): if self.save_video and self.step_count == 0: self.start_video() - self.run_action_on_emulator(action) self.append_agent_stats(action) diff --git a/v2/run_pretrained_interactive.py b/v2/run_pretrained_interactive.py index a4db9758..56931b14 100644 --- a/v2/run_pretrained_interactive.py +++ b/v2/run_pretrained_interactive.py @@ -10,6 +10,13 @@ from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv from stable_baselines3.common.utils import set_random_seed from stable_baselines3.common.callbacks import CheckpointCallback +import numpy as np +import matplotlib.pyplot as plt + +""" +Our dataset +""" +X = [] def make_env(rank, env_conf, seed=0): """ @@ -71,7 +78,21 @@ def get_most_recent_zip_with_age(folder_path): #keyboard.on_press_key("M", toggle_agent) obs, info = env.reset() + counter = 0 + x = np.zeros((3, 73, 80)) + while True: + if counter > 2: + counter = 0 + # Combine the last 3 frames + combined = np.hstack((x[0], x[1], x[2])) + # plt.imshow(combined) + # plt.axis('off') + # plt.show() + + X.append(combined) + + action = 7 # pass action try: with open("agent_enabled.txt", "r") as f: @@ -81,7 +102,16 @@ def get_most_recent_zip_with_age(folder_path): if agent_enabled: action, _states = model.predict(obs, deterministic=False) obs, rewards, terminated, truncated, info = env.step(action) - env.render() + game_pixels = env.render() + # Combine game_pxiels with action + action = np.full((1, 80), action) + pixel_image = game_pixels.squeeze() + action_image = action.squeeze() + combined = np.vstack((pixel_image, action_image)) + x[counter] = combined + + counter += 1 + if truncated: break env.close()