-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Open
Description
In the initialization code of the DQNAgent, line 386 has:
self._replay_scheme = 'uniform'
This initialization runs after the RainbowAgent initializes its replay_scheme
vmax = float(vmax)
self._num_atoms = num_atoms
# If vmin is not specified, set it to -vmax similar to C51.
vmin = vmin if vmin else -vmax
self._support = jnp.linspace(vmin, vmax, num_atoms)
self._replay_scheme = replay_scheme
super(JaxRainbowAgent, self).__init__(
num_actions=num_actions,
observation_shape=observation_shape,
observation_dtype=observation_dtype,
stack_size=stack_size,
network=functools.partial(network, num_atoms=num_atoms),
gamma=gamma,
update_horizon=update_horizon,
min_replay_history=min_replay_history,
update_period=update_period,
target_update_period=target_update_period,
epsilon_fn=epsilon_fn,
epsilon_train=epsilon_train,
epsilon_eval=epsilon_eval,
epsilon_decay_period=epsilon_decay_period,
optimizer=optimizer,
seed=seed,
summary_writer=summary_writer,
summary_writing_frequency=summary_writing_frequency,
allow_partial_reload=allow_partial_reload,
)
Thus, even when gin configuring with a prioritized replay buffer, it overrides to be a uniform replay buffer.
This can be fixed by adjusted the DQNAgent code
to
if not hasattr(self, '_replay_scheme'):
self._replay_scheme = 'uniform'
If this is indeed an issue and not an implementation error, I can make a PR to fix using this solution
Metadata
Metadata
Assignees
Labels
No labels