diff --git a/docs/tutorials/5a_train_preference_comparisons_with_cnn.ipynb b/docs/tutorials/5a_train_preference_comparisons_with_cnn.ipynb index 570e6154f..53571ab4d 100644 --- a/docs/tutorials/5a_train_preference_comparisons_with_cnn.ipynb +++ b/docs/tutorials/5a_train_preference_comparisons_with_cnn.ipynb @@ -151,7 +151,8 @@ "metadata": {}, "outputs": [], "source": [ - "from imitation.rewards.reward_nets import ShapedRewardNet, cnn_transpose\n", + "from imitation.util.networks import cnn_transpose\n", + "from imitation.rewards.reward_nets import ShapedRewardNet\n", "from imitation.rewards.reward_wrapper import RewardVecEnvWrapper\n", "\n", "\n", diff --git a/setup.py b/setup.py index 151f17ea9..5d9270996 100644 --- a/setup.py +++ b/setup.py @@ -14,11 +14,13 @@ IS_NOT_WINDOWS = os.name != "nt" PARALLEL_REQUIRE = ["ray[debug,tune]~=2.0.0"] -ATARI_REQUIRE = [ +IMAGE_ENV_REQUIRE = [ "opencv-python", "ale-py==0.7.4", "pillow", "autorom[accept-rom-license]~=0.4.2", + "procgen==0.10.7", + "gym3@git+https://github.com/openai/gym3.git#4c3824680eaf9dd04dce224ee3d4856429878226", # noqa: E501 ] PYTYPE = ["pytype==2022.7.26"] if IS_NOT_WINDOWS else [] STABLE_BASELINES3 = "stable-baselines3>=1.6.1" @@ -61,7 +63,7 @@ "pre-commit>=2.20.0", ] + PARALLEL_REQUIRE - + ATARI_REQUIRE + + IMAGE_ENV_REQUIRE + PYTYPE ) DOCS_REQUIRE = [ @@ -74,7 +76,7 @@ "sphinx-github-changelog~=1.2.0", "myst-nb==0.16.0", "ipykernel~=6.15.2", -] + ATARI_REQUIRE +] + IMAGE_ENV_REQUIRE def get_readme() -> str: @@ -231,7 +233,7 @@ def get_local_version(version: "ScmVersion", time_format="%Y%m%d") -> str: "mujoco": [ "gym[classic_control,mujoco]" + GYM_VERSION_SPECIFIER, ], - "atari": ATARI_REQUIRE, + "image_envs": IMAGE_ENV_REQUIRE, }, entry_points={ "console_scripts": [ diff --git a/src/imitation/algorithms/adversarial/common.py b/src/imitation/algorithms/adversarial/common.py index 8729d3c30..667aa1ba7 100644 --- a/src/imitation/algorithms/adversarial/common.py +++ b/src/imitation/algorithms/adversarial/common.py @@ -118,6 +118,7 @@ def __init__( disc_opt_kwargs: Optional[Mapping] = None, gen_train_timesteps: Optional[int] = None, gen_replay_buffer_capacity: Optional[int] = None, + transpose_obs: bool = False, custom_logger: Optional[logger.HierarchicalLogger] = None, init_tensorboard: bool = False, init_tensorboard_graph: bool = False, @@ -161,6 +162,9 @@ def __init__( the generator that can be stored). By default this is equal to `gen_train_timesteps`, meaning that we sample only from the most recent batch of generator samples. + transpose_obs: Whether observations will need to be transposed from (h,w,c) + format to be manually fed into the policy. Should usually be True for + image environments, and usually be False otherwise. custom_logger: Where to log to; if None (default), creates a new logger. init_tensorboard: If True, makes various discriminator TensorBoard summaries. @@ -221,6 +225,7 @@ def __init__( self._summary_writer = thboard.SummaryWriter(str(summary_dir)) self.venv_buffering = wrappers.BufferingWrapper(self.venv) + self.transpose_obs = transpose_obs if debug_use_ground_truth: # Would use an identity reward fn here, but RewardFns can't see rewards. @@ -481,18 +486,19 @@ def _get_log_policy_act_prob( Returns: A batch of log policy action probabilities. """ + obs_th_ = networks.cnn_transpose(obs_th) if self.transpose_obs else obs_th if isinstance(self.policy, policies.ActorCriticPolicy): # policies.ActorCriticPolicy has a concrete implementation of # evaluate_actions to generate log_policy_act_prob given obs and actions. _, log_policy_act_prob_th, _ = self.policy.evaluate_actions( - obs_th, + obs_th_, acts_th, ) elif isinstance(self.policy, sac_policies.SACPolicy): gen_algo_actor = self.policy.actor assert gen_algo_actor is not None # generate log_policy_act_prob from SAC actor. - mean_actions, log_std, _ = gen_algo_actor.get_action_dist_params(obs_th) + mean_actions, log_std, _ = gen_algo_actor.get_action_dist_params(obs_th_) distribution = gen_algo_actor.action_dist.proba_distribution( mean_actions, log_std, diff --git a/src/imitation/policies/base.py b/src/imitation/policies/base.py index 3101cf2c7..0b6b3fd4f 100644 --- a/src/imitation/policies/base.py +++ b/src/imitation/policies/base.py @@ -1,12 +1,13 @@ """Custom policy classes and convenience methods.""" import abc -from typing import Type +from typing import Optional, Tuple, Type import gym import numpy as np import torch as th from stable_baselines3.common import policies, torch_layers +from stable_baselines3.common.distributions import Distribution from stable_baselines3.sac import policies as sac_policies from torch import nn @@ -88,6 +89,103 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs, net_arch=[1024, 1024]) +class CnnPolicy(policies.ActorCriticCnnPolicy): + """A CNN Actor-Critic policy. + + This policy optionally transposes its observation inputs. Note that if this is done, + the policy expects the observation space to be a Box with values ranging from 0 to + 255. Methods are copy-pasted from StableBaselines 3's ActorCriticPolicy, with an + initial check whether or not to transpose an observation input. + """ + + def __init__(self, *args, transpose_input: bool = False, **kwargs): + """Builds CnnPolicy; arguments passed to `CnnActorCriticPolicy`.""" + self.transpose_input = transpose_input + if self.transpose_input: + kwargs.update( + { + "observation_space": self.transpose_space( + kwargs["observation_space"], + ), + }, + ) + super().__init__(*args, **kwargs) + + def transpose_space(self, observation_space: gym.spaces.Box) -> gym.spaces.Box: + if not isinstance(observation_space, gym.spaces.Box): + raise TypeError("This code assumes that observation spaces are gym Boxes.") + if not ( + np.all(observation_space.low == 0) and np.all(observation_space.high == 255) + ): + error_msg = ( + "This code assumes the observation space values range from " + + "0 to 255." + ) + raise ValueError(error_msg) + h, w, c = observation_space.shape + new_shape = (c, h, w) + return gym.spaces.Box( + low=0, + high=255, + shape=new_shape, + dtype=observation_space.dtype, + ) + + def forward( + self, + obs: th.Tensor, + deterministic: bool = False, + ) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: + if self.transpose_input: + obs_ = networks.cnn_transpose(obs) + else: + obs_ = obs + # Preprocess the observation if needed + features = self.extract_features(obs_) + latent_pi, latent_vf = self.mlp_extractor(features) + # Evaluate the values for the given observations + values = self.value_net(latent_vf) + distribution = self._get_action_dist_from_latent(latent_pi) + actions = distribution.get_actions(deterministic=deterministic) + log_prob = distribution.log_prob(actions) + return actions, values, log_prob + + def evaluate_actions( + self, + obs: th.Tensor, + actions: th.Tensor, + ) -> Tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]: + if self.transpose_input: + obs_ = networks.cnn_transpose(obs) + else: + obs_ = obs + # Preprocess the observation if needed + features = self.extract_features(obs_) + latent_pi, latent_vf = self.mlp_extractor(features) + distribution = self._get_action_dist_from_latent(latent_pi) + log_prob = distribution.log_prob(actions) + values = self.value_net(latent_vf) + return values, log_prob, distribution.entropy() + + def get_distribution(self, obs: th.Tensor) -> Distribution: + if self.transpose_input: + obs_ = networks.cnn_transpose(obs) + else: + obs_ = obs + features = self.extract_features(obs_) + latent_pi = self.mlp_extractor.forward_actor(features) + return self._get_action_dist_from_latent(latent_pi) + + def predict_values(self, obs: th.Tensor) -> th.Tensor: + if self.transpose_input: + obs_ = networks.cnn_transpose(obs) + else: + obs_ = obs + features = self.extract_features(obs_) + latent_vf = self.mlp_extractor.forward_critic(features) + return self.value_net(latent_vf) + + class NormalizeFeaturesExtractor(torch_layers.FlattenExtractor): """Feature extractor that flattens then normalizes input.""" diff --git a/src/imitation/rewards/reward_nets.py b/src/imitation/rewards/reward_nets.py index 4e6c747e3..5134cbbc7 100644 --- a/src/imitation/rewards/reward_nets.py +++ b/src/imitation/rewards/reward_nets.py @@ -568,10 +568,12 @@ def forward( """ inputs = [] if self.use_state: - state_ = cnn_transpose(state) if self.hwc_format else state + state_ = networks.cnn_transpose(state) if self.hwc_format else state inputs.append(state_) if self.use_next_state: - next_state_ = cnn_transpose(next_state) if self.hwc_format else next_state + next_state_ = ( + networks.cnn_transpose(next_state) if self.hwc_format else next_state + ) inputs.append(next_state_) inputs_concat = th.cat(inputs, dim=1) @@ -597,16 +599,6 @@ def forward( return rewards -def cnn_transpose(tens: th.Tensor) -> th.Tensor: - """Transpose a (b,h,w,c)-formatted tensor to (b,c,h,w) format.""" - if len(tens.shape) == 4: - return th.permute(tens, (0, 3, 1, 2)) - else: - raise ValueError( - f"Invalid input: len(tens.shape) = {len(tens.shape)} != 4.", - ) - - class NormalizedRewardNet(PredictProcessedWrapper): """A reward net that normalizes the output of its base network.""" @@ -872,7 +864,7 @@ def __init__( ) def forward(self, state: th.Tensor) -> th.Tensor: - state_ = cnn_transpose(state) if self.hwc_format else state + state_ = networks.cnn_transpose(state) if self.hwc_format else state return self._potential_net(state_) diff --git a/src/imitation/scripts/common/common.py b/src/imitation/scripts/common/common.py index 72d44f2f4..6cbf0d510 100644 --- a/src/imitation/scripts/common/common.py +++ b/src/imitation/scripts/common/common.py @@ -3,10 +3,11 @@ import contextlib import logging import pathlib -from typing import Any, Generator, Mapping, Sequence, Tuple, Union +from typing import Any, Callable, Generator, Mapping, Optional, Sequence, Tuple, Union import numpy as np import sacred +from gym import Env from stable_baselines3.common import vec_env from imitation.data import types @@ -35,6 +36,7 @@ def config(): num_vec = 8 # number of environments in VecEnv parallel = True # Use SubprocVecEnv rather than DummyVecEnv max_episode_steps = None # Set to positive int to limit episode horizons + post_wrappers = [] # Sequence of wrappers to apply to each env in the VecEnv env_make_kwargs = {} # The kwargs passed to `spec.make`. locals() # quieten flake8 @@ -141,6 +143,7 @@ def make_venv( parallel: bool, log_dir: str, max_episode_steps: int, + post_wrappers: Optional[Sequence[Callable[[Env, int], Env]]], env_make_kwargs: Mapping[str, Any], **kwargs, ) -> Generator[vec_env.VecEnv, None, None]: @@ -154,6 +157,8 @@ def make_venv( max_episode_steps: If not None, then a TimeLimit wrapper is applied to each environment to artificially limit the maximum number of timesteps in an episode. + post_wrappers: If specified, iteratively wraps each environment with each + of the wrappers specified in the sequence. log_dir: Logs episode return statistics to a subdirectory 'monitor`. env_make_kwargs: The kwargs passed to `spec.make` of a gym environment. kwargs: Passed through to `util.make_vec_env`. @@ -169,8 +174,9 @@ def make_venv( rng=rng, n_envs=num_vec, parallel=parallel, - max_episode_steps=max_episode_steps, log_dir=log_dir, + max_episode_steps=max_episode_steps, + post_wrappers=post_wrappers, env_make_kwargs=env_make_kwargs, **kwargs, ) diff --git a/src/imitation/scripts/common/reward.py b/src/imitation/scripts/common/reward.py index c40d3751f..7930c1526 100644 --- a/src/imitation/scripts/common/reward.py +++ b/src/imitation/scripts/common/reward.py @@ -60,6 +60,11 @@ def reward_ensemble(): locals() +@reward_ingredient.named_config +def cnn_reward(): + net_cls = reward_nets.CnnRewardNet # noqa: F841 + + @reward_ingredient.config_hook def config_hook(config, command_name, logger): """Sets default values for `net_cls` and `net_kwargs`.""" @@ -71,7 +76,10 @@ def config_hook(config, command_name, logger): default_net = reward_nets.BasicShapedRewardNet res["net_cls"] = default_net - if "normalize_input_layer" not in config["reward"]["net_kwargs"]: + if ( + "normalize_input_layer" not in config["reward"]["net_kwargs"] + and config["reward"]["net_cls"] != reward_nets.CnnRewardNet + ): res["net_kwargs"] = {"normalize_input_layer": networks.RunningNorm} if "net_cls" in res and issubclass(res["net_cls"], reward_nets.RewardEnsemble): diff --git a/src/imitation/scripts/config/eval_policy.py b/src/imitation/scripts/config/eval_policy.py index 9bc8e29a6..907e2fa63 100644 --- a/src/imitation/scripts/config/eval_policy.py +++ b/src/imitation/scripts/config/eval_policy.py @@ -116,6 +116,89 @@ def seals_walker(): common = dict(env_name="seals/Walker2d-v0") +# Procgen configs + + +@eval_policy_ex.named_config +def coinrun(): + common = dict(env_name="procgen:procgen-coinrun-v0") + + +@eval_policy_ex.named_config +def maze(): + common = dict(env_name="procgen:procgen-maze-v0") + + +@eval_policy_ex.named_config +def bigfish(): + common = dict(env_name="procgen:procgen-bigfish-v0") + + +@eval_policy_ex.named_config +def bossfight(): + common = dict(env_name="procgen:procgen-bossfight-v0") + + +@eval_policy_ex.named_config +def caveflyer(): + common = dict(env_name="procgen:procgen-caveflyer-v0") + + +@eval_policy_ex.named_config +def chaser(): + common = dict(env_name="procgen:procgen-chaser-v0") + + +@eval_policy_ex.named_config +def climber(): + common = dict(env_name="procgen:procgen-climber-v0") + + +@eval_policy_ex.named_config +def dodgeball(): + common = dict(env_name="procgen:procgen-dodgeball-v0") + + +@eval_policy_ex.named_config +def fruitbot(): + common = dict(env_name="procgen:procgen-fruitbot-v0") + + +@eval_policy_ex.named_config +def heist(): + common = dict(env_name="procgen:procgen-heist-v0") + + +@eval_policy_ex.named_config +def jumper(): + common = dict(env_name="procgen:procgen-jumper-v0") + + +@eval_policy_ex.named_config +def leaper(): + common = dict(env_name="procgen:procgen-leaper-v0") + + +@eval_policy_ex.named_config +def miner(): + common = dict(env_name="procgen:procgen-miner-v0") + + +@eval_policy_ex.named_config +def ninja(): + common = dict(env_name="procgen:procgen-ninja-v0") + + +@eval_policy_ex.named_config +def plunder(): + common = dict(env_name="procgen:procgen-plunder-v0") + + +@eval_policy_ex.named_config +def starpilot(): + common = dict(env_name="procgen:procgen-starpilot-v0") + + @eval_policy_ex.named_config def fast(): common = dict(env_name="seals/CartPole-v0", num_vec=1, parallel=False) diff --git a/src/imitation/scripts/config/train_adversarial.py b/src/imitation/scripts/config/train_adversarial.py index 3183ac9f6..596b8b041 100644 --- a/src/imitation/scripts/config/train_adversarial.py +++ b/src/imitation/scripts/config/train_adversarial.py @@ -1,6 +1,9 @@ """Configuration for imitation.scripts.train_adversarial.""" import sacred +from gym.wrappers import TimeLimit +from seals.util import AutoResetWrapper +from stable_baselines3.common.atari_wrappers import AtariWrapper from imitation.rewards import reward_nets from imitation.scripts.common import common, demonstrations, expert, reward, rl, train @@ -172,6 +175,35 @@ def seals_walker(): common = dict(env_name="seals/Walker2d-v0") +# Atari configs + + +@train_adversarial_ex.named_config +def asteroids(): + common = dict( + env_name="AsteroidsNoFrameskip-v4", + post_wrappers=[ + lambda env, _: AutoResetWrapper(env), + lambda env, _: AtariWrapper(env, terminal_on_life_loss=False), + lambda env, _: TimeLimit(env, max_episode_steps=100_000), + ], + ) + algorithm_kwargs = dict(transpose_obs=True) + + +@train_adversarial_ex.named_config +def asteroids_short_episodes(): + common = dict( + env_name="AsteroidsNoFrameskip-v4", + post_wrappers=[ + lambda env, _: AutoResetWrapper(env), + lambda env, _: AtariWrapper(env, terminal_on_life_loss=False), + lambda env, _: TimeLimit(env, max_episode_steps=100), + ], + ) + algorithm_kwargs = dict(transpose_obs=True) + + # Debug configs diff --git a/src/imitation/scripts/config/train_imitation.py b/src/imitation/scripts/config/train_imitation.py index c2466a936..43a9c099e 100644 --- a/src/imitation/scripts/config/train_imitation.py +++ b/src/imitation/scripts/config/train_imitation.py @@ -2,7 +2,12 @@ import sacred import torch as th +from gym.wrappers import TimeLimit +from seals.util import AutoResetWrapper +from stable_baselines3.common import torch_layers +from stable_baselines3.common.atari_wrappers import AtariWrapper +from imitation.policies import base from imitation.scripts.common import common from imitation.scripts.common import demonstrations as demos_common from imitation.scripts.common import expert, train @@ -105,6 +110,46 @@ def seals_humanoid(): common = dict(env_name="seals/Humanoid-v0") +@train_imitation_ex.named_config +def asteroids(): + common = dict( + env_name="AsteroidsNoFrameskip-v4", + post_wrappers=[ + lambda env, _: AutoResetWrapper(env), + lambda env, _: AtariWrapper(env, terminal_on_life_loss=False), + lambda env, _: TimeLimit(env, max_episode_steps=100_000), + ], + ) + train = dict( + policy_kwargs=dict(transpose_input=True), + ) + transpose_obs = True + + +@train_imitation_ex.named_config +def asteroids_short_episodes(): + common = dict( + env_name="AsteroidsNoFrameskip-v4", + post_wrappers=[ + lambda env, _: AutoResetWrapper(env), + lambda env, _: AtariWrapper(env, terminal_on_life_loss=False), + lambda env, _: TimeLimit(env, max_episode_steps=100), + ], + ) + train = dict( + policy_kwargs=dict(transpose_input=True), + ) + transpose_obs = True + + +@train_imitation_ex.named_config +def cnn(): + train = dict( + policy_cls=base.CnnPolicy, + policy_kwargs=dict(features_extractor_class=torch_layers.NatureCNN), + ) + + @train_imitation_ex.named_config def fast(): dagger = dict(total_timesteps=50) diff --git a/src/imitation/scripts/config/train_preference_comparisons.py b/src/imitation/scripts/config/train_preference_comparisons.py index ba4e9483c..735541fc3 100644 --- a/src/imitation/scripts/config/train_preference_comparisons.py +++ b/src/imitation/scripts/config/train_preference_comparisons.py @@ -1,6 +1,9 @@ """Configuration for imitation.scripts.train_preference_comparisons.""" import sacred +from gym.wrappers import TimeLimit +from seals.util import AutoResetWrapper +from stable_baselines3.common.atari_wrappers import AtariWrapper from imitation.algorithms import preference_comparisons from imitation.scripts.common import common, reward, rl, train @@ -115,6 +118,30 @@ def seals_mountain_car(): common = dict(env_name="seals/MountainCar-v0") +@train_preference_comparisons_ex.named_config +def asteroids(): + common = dict( + env_name="AsteroidsNoFrameskip-v4", + post_wrappers=[ + lambda env, _: AutoResetWrapper(env), + lambda env, _: AtariWrapper(env, terminal_on_life_loss=False), + lambda env, _: TimeLimit(env, max_episode_steps=100_000), + ], + ) + + +@train_preference_comparisons_ex.named_config +def asteroids_short_episodes(): + common = dict( + env_name="AsteroidsNoFrameskip-v4", + post_wrappers=[ + lambda env, _: AutoResetWrapper(env), + lambda env, _: AtariWrapper(env, terminal_on_life_loss=False), + lambda env, _: TimeLimit(env, max_episode_steps=100), + ], + ) + + @train_preference_comparisons_ex.named_config def fast(): # Minimize the amount of computation. Useful for test cases. diff --git a/src/imitation/scripts/config/train_rl.py b/src/imitation/scripts/config/train_rl.py index b9ede3165..87e32dd1c 100644 --- a/src/imitation/scripts/config/train_rl.py +++ b/src/imitation/scripts/config/train_rl.py @@ -1,6 +1,9 @@ """Configuration settings for train_rl, training a policy with RL.""" import sacred +from gym.wrappers import TimeLimit +from seals.util import AutoResetWrapper +from stable_baselines3.common.atari_wrappers import AtariWrapper from imitation.scripts.common import common, rl, train @@ -52,6 +55,30 @@ def acrobot(): common = dict(env_name="Acrobot-v1") +@train_rl_ex.named_config +def asteroids(): + common = dict( + env_name="AsteroidsNoFrameskip-v4", + post_wrappers=[ + lambda env, _: AutoResetWrapper(env), + lambda env, _: AtariWrapper(env, terminal_on_life_loss=False), + lambda env, _: TimeLimit(env, max_episode_steps=100_000), + ], + ) + + +@train_rl_ex.named_config +def asteroids_short_episodes(): + common = dict( + env_name="AsteroidsNoFrameskip-v4", + post_wrappers=[ + lambda env, _: AutoResetWrapper(env), + lambda env, _: AtariWrapper(env, terminal_on_life_loss=False), + lambda env, _: TimeLimit(env, max_episode_steps=100), + ], + ) + + @train_rl_ex.named_config def ant(): common = dict(env_name="Ant-v2") @@ -132,6 +159,89 @@ def seals_walker(): common = dict(env_name="seals/Walker2d-v0") +# Procgen configs + + +@train_rl_ex.named_config +def coinrun(): + common = dict(env_name="procgen:procgen-coinrun-v0") + + +@train_rl_ex.named_config +def maze(): + common = dict(env_name="procgen:procgen-maze-v0") + + +@train_rl_ex.named_config +def bigfish(): + common = dict(env_name="procgen:procgen-bigfish-v0") + + +@train_rl_ex.named_config +def bossfight(): + common = dict(env_name="procgen:procgen-bossfight-v0") + + +@train_rl_ex.named_config +def caveflyer(): + common = dict(env_name="procgen:procgen-caveflyer-v0") + + +@train_rl_ex.named_config +def chaser(): + common = dict(env_name="procgen:procgen-chaser-v0") + + +@train_rl_ex.named_config +def climber(): + common = dict(env_name="procgen:procgen-climber-v0") + + +@train_rl_ex.named_config +def dodgeball(): + common = dict(env_name="procgen:procgen-dodgeball-v0") + + +@train_rl_ex.named_config +def fruitbot(): + common = dict(env_name="procgen:procgen-fruitbot-v0") + + +@train_rl_ex.named_config +def heist(): + common = dict(env_name="procgen:procgen-heist-v0") + + +@train_rl_ex.named_config +def jumper(): + common = dict(env_name="procgen:procgen-jumper-v0") + + +@train_rl_ex.named_config +def leaper(): + common = dict(env_name="procgen:procgen-leaper-v0") + + +@train_rl_ex.named_config +def miner(): + common = dict(env_name="procgen:procgen-miner-v0") + + +@train_rl_ex.named_config +def ninja(): + common = dict(env_name="procgen:procgen-ninja-v0") + + +@train_rl_ex.named_config +def plunder(): + common = dict(env_name="procgen:procgen-plunder-v0") + + +@train_rl_ex.named_config +def starpilot(): + common = dict(env_name="procgen:procgen-starpilot-v0") + + # Debug configs diff --git a/src/imitation/scripts/train_imitation.py b/src/imitation/scripts/train_imitation.py index 09393366e..7eed14664 100644 --- a/src/imitation/scripts/train_imitation.py +++ b/src/imitation/scripts/train_imitation.py @@ -71,6 +71,7 @@ def train_imitation( dagger: Mapping[str, Any], use_dagger: bool, agent_path: Optional[str], + transpose_obs: bool = False, ) -> Mapping[str, Mapping[str, float]]: """Runs DAgger (if `use_dagger`) or BC (otherwise) training. @@ -82,6 +83,9 @@ def train_imitation( agent_path: Path to serialized policy. If provided, then load the policy from this path. Otherwise, make a new policy. Specify only if policy_cls and policy_kwargs are not specified. + transpose_obs: Whether observations will need to be transposed from (h,w,c) + format to be fed into the policy. Should usually be True for image + environments, and usually be False otherwise. Returns: Statistics for rollouts from the trained policy and demonstration data. @@ -96,9 +100,15 @@ def train_imitation( if not use_dagger or dagger["use_offline_rollouts"]: expert_trajs = demonstrations.get_expert_trajectories() + if transpose_obs: + # this modification only affects the observation space the BC trainer + # expects to deal with + bc_trainer_venv = vec_env.vec_transpose.VecTransposeImage(venv) + else: + bc_trainer_venv = venv bc_trainer = bc_algorithm.BC( - observation_space=venv.observation_space, - action_space=venv.action_space, + observation_space=bc_trainer_venv.observation_space, + action_space=bc_trainer_venv.action_space, policy=imit_policy, demonstrations=expert_trajs, custom_logger=custom_logger, diff --git a/src/imitation/scripts/train_rl.py b/src/imitation/scripts/train_rl.py index fb7959592..35d9bcd61 100644 --- a/src/imitation/scripts/train_rl.py +++ b/src/imitation/scripts/train_rl.py @@ -13,6 +13,7 @@ import warnings from typing import Any, Mapping, Optional +from sacred.config.custom_containers import ReadOnlyDict from sacred.observers import FileStorageObserver from stable_baselines3.common import callbacks from stable_baselines3.common.vec_env import VecNormalize @@ -21,7 +22,8 @@ from imitation.policies import serialize from imitation.rewards.reward_wrapper import RewardVecEnvWrapper from imitation.rewards.serialize import load_reward -from imitation.scripts.common import common, rl, train +from imitation.scripts.common import common as scripts_common +from imitation.scripts.common import rl, train from imitation.scripts.config.train_rl import train_rl_ex @@ -40,6 +42,7 @@ def train_rl( policy_save_interval: int, policy_save_final: bool, agent_path: Optional[str], + common: ReadOnlyDict, ) -> Mapping[str, float]: """Trains an expert policy from scratch and saves the rollouts and policy. @@ -82,19 +85,22 @@ def train_rl( policy_save_final: If True, then save the policy right after training is finished. agent_path: Path to load warm-started agent. + common: Dummy argument for the `common` ingredient configuration. Returns: The return value of `rollout_stats()` using the final policy. """ - rng = common.make_rng() - custom_logger, log_dir = common.setup_logging() + rng = scripts_common.make_rng() + custom_logger, log_dir = scripts_common.setup_logging() rollout_dir = log_dir / "rollouts" policy_dir = log_dir / "policies" rollout_dir.mkdir(parents=True, exist_ok=True) policy_dir.mkdir(parents=True, exist_ok=True) - post_wrappers = [lambda env, idx: wrappers.RolloutInfoWrapper(env)] - with common.make_venv(post_wrappers=post_wrappers) as venv: + all_post_wrappers = common["post_wrappers"] + [ + lambda env, idx: wrappers.RolloutInfoWrapper(env), + ] + with scripts_common.make_venv(post_wrappers=all_post_wrappers) as venv: callback_objs = [] if reward_type is not None: reward_fn = load_reward( diff --git a/src/imitation/util/networks.py b/src/imitation/util/networks.py index c27aea2cd..7ead022b3 100644 --- a/src/imitation/util/networks.py +++ b/src/imitation/util/networks.py @@ -201,6 +201,16 @@ def update_stats(self, batch: th.Tensor) -> None: self.num_batches += 1 # type: ignore[misc] +def cnn_transpose(tens: th.Tensor) -> th.Tensor: + """Transpose a (b,h,w,c)-formatted tensor to (b,c,h,w) format.""" + if len(tens.shape) == 4: + return th.permute(tens, (0, 3, 1, 2)) + else: + raise ValueError( + f"Invalid input: len(tens.shape) = {len(tens.shape)} != 4.", + ) + + def build_mlp( in_size: int, hid_sizes: Iterable[int], diff --git a/tests/rewards/test_reward_nets.py b/tests/rewards/test_reward_nets.py index c65332d0c..fd42b272b 100644 --- a/tests/rewards/test_reward_nets.py +++ b/tests/rewards/test_reward_nets.py @@ -214,10 +214,10 @@ def test_cnn_transpose_input_validation(dimensions: int): tens = th.zeros(shape) if dimensions == 4: # should succeed - reward_nets.cnn_transpose(tens) + networks.cnn_transpose(tens) else: # should fail with pytest.raises(ValueError, match="Invalid input: "): - reward_nets.cnn_transpose(tens) + networks.cnn_transpose(tens) def _sample(space, n): diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index 226b6b3c2..5622a3e0c 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -68,6 +68,10 @@ PENDULUM_TEST_DATA_PATH = TEST_DATA_PATH / "expert_models/pendulum_0/" PENDULUM_TEST_ROLLOUT_PATH = PENDULUM_TEST_DATA_PATH / "rollouts/final.npz" +ASTEROIDS_TEST_DATA_PATH = TEST_DATA_PATH / "expert_models/asteroids_short_episodes_0/" +ASTEROIDS_TEST_ROLLOUT_PATH = ASTEROIDS_TEST_DATA_PATH / "rollouts/final.pkl" +ASTEROIDS_TEST_POLICY_PATH = ASTEROIDS_TEST_DATA_PATH / "policies/final" + OLD_FMT_ROLLOUT_TEST_DATA_PATH = TEST_DATA_PATH / "old_format_rollout.pkl" @@ -111,6 +115,8 @@ def test_main_console(script_mod): RL_SAC_NAMED_CONFIGS = ["rl.sac", "train.sac"] +ASTEROIDS_CNN_POLICY_CONFIG = ["asteroids_short_episodes", "train.cnn"] + @pytest.fixture( params=[ @@ -254,6 +260,20 @@ def test_train_preference_comparisons_reward_named_config(tmpdir, named_configs) assert isinstance(run.result, dict) +def test_train_preference_comparisons_image_env(tmpdir): + config_updates = dict(common=dict(log_root=tmpdir)) + run = train_preference_comparisons.train_preference_comparisons_ex.run( + named_configs=( + ["reward.cnn_reward"] + + ASTEROIDS_CNN_POLICY_CONFIG + + ALGO_FAST_CONFIGS["preference_comparison"] + ), + config_updates=config_updates, + ) + assert run.status == "COMPLETED" + assert isinstance(run.result, dict) + + def test_train_dagger_main(tmpdir): with pytest.warns(None) as record: run = train_imitation.train_imitation_ex.run( @@ -301,6 +321,21 @@ def test_train_dagger_warmstart(tmpdir): assert isinstance(run_warmstart.result, dict) +def test_train_dagger_with_image_env(tmpdir): + run = train_imitation.train_imitation_ex.run( + command_name="dagger", + named_configs=( + ["asteroids_short_episodes", "cnn"] + ALGO_FAST_CONFIGS["imitation"] + ), + config_updates=dict( + common=dict(log_root=tmpdir), + demonstrations=dict(rollout_path=ASTEROIDS_TEST_ROLLOUT_PATH), + ), + ) + assert run.status == "COMPLETED" + assert isinstance(run.result, dict) + + def test_train_bc_main_with_none_demonstrations_raises_value_error(tmpdir): with pytest.raises(ValueError, match=".*n_expert_demos.*rollout_path.*"): train_imitation.train_imitation_ex.run( @@ -427,6 +462,16 @@ def test_train_rl_sac(tmpdir): assert isinstance(run.result, dict) +def test_train_rl_image_env(tmpdir): + config_updates = dict(common=dict(log_root=tmpdir)) + run = train_rl.train_rl_ex.run( + named_configs=ASTEROIDS_CNN_POLICY_CONFIG + ALGO_FAST_CONFIGS["rl"], + config_updates=config_updates, + ) + assert run.status == "COMPLETED" + assert isinstance(run.result, dict) + + # check if platform is macos EVAL_POLICY_CONFIGS: List[Dict] = [ @@ -601,6 +646,27 @@ def test_train_adversarial_algorithm_value_error(tmpdir): ) +@pytest.mark.parametrize("command", ("airl", "gail")) +def test_train_adversarial_image_env(tmpdir, command): + """Smoke test for imitation.scripts.train_adversarial on atari.""" + named_configs = ( + ASTEROIDS_CNN_POLICY_CONFIG + + ALGO_FAST_CONFIGS["adversarial"] + + ["reward.cnn_reward"] + ) + config_updates = { + "common": dict(log_root=tmpdir), + "demonstrations": dict(rollout_path=ASTEROIDS_TEST_ROLLOUT_PATH), + } + run = train_adversarial.train_adversarial_ex.run( + command_name=command, + named_configs=named_configs, + config_updates=config_updates, + ) + assert run.status == "COMPLETED" + _check_train_ex_result(run.result) + + def test_transfer_learning(tmpdir: str) -> None: """Transfer learning smoke test.