Skip to content

DQN wrapped Agents always use uniform replay buffer #230

@Davidb8

Description

@Davidb8

In the initialization code of the DQNAgent, line 386 has:

self._replay_scheme = 'uniform'

This initialization runs after the RainbowAgent initializes its replay_scheme

    vmax = float(vmax)
    self._num_atoms = num_atoms
    # If vmin is not specified, set it to -vmax similar to C51.
    vmin = vmin if vmin else -vmax
    self._support = jnp.linspace(vmin, vmax, num_atoms)
    self._replay_scheme = replay_scheme

    super(JaxRainbowAgent, self).__init__(
        num_actions=num_actions,
        observation_shape=observation_shape,
        observation_dtype=observation_dtype,
        stack_size=stack_size,
        network=functools.partial(network, num_atoms=num_atoms),
        gamma=gamma,
        update_horizon=update_horizon,
        min_replay_history=min_replay_history,
        update_period=update_period,
        target_update_period=target_update_period,
        epsilon_fn=epsilon_fn,
        epsilon_train=epsilon_train,
        epsilon_eval=epsilon_eval,
        epsilon_decay_period=epsilon_decay_period,
        optimizer=optimizer,
        seed=seed,
        summary_writer=summary_writer,
        summary_writing_frequency=summary_writing_frequency,
        allow_partial_reload=allow_partial_reload,
    )

Thus, even when gin configuring with a prioritized replay buffer, it overrides to be a uniform replay buffer.

This can be fixed by adjusted the DQNAgent code

to

if not hasattr(self, '_replay_scheme'):

      self._replay_scheme = 'uniform'

If this is indeed an issue and not an implementation error, I can make a PR to fix using this solution

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions