# Demonstration of RLlib's ReplayBuffer workflow from typing import Optional import random import numpy as np from ray import tune from ray.rllib.utils.replay_buffers import ReplayBuffer, StorageUnit from ray.rllib.utils.annotations import override from ray.rllib.utils.typing import SampleBatchType from ray.rllib.utils.replay_buffers.utils import validate_buffer_config from ray.rllib.examples.env.random_env import RandomEnv from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.algorithms.dqn.dqn import DQNConfig # __sphinx_doc_replay_buffer_type_specification__begin__ config = DQNConfig().training(replay_buffer_config={"type": ReplayBuffer}).to_dict() another_config = ( DQNConfig().training(replay_buffer_config={"type": "ReplayBuffer"}).to_dict() ) yet_another_config = ( DQNConfig() .training( replay_buffer_config={"type": "ray.rllib.utils.replay_buffers.ReplayBuffer"} ) .to_dict() ) validate_buffer_config(config) validate_buffer_config(another_config) validate_buffer_config(yet_another_config) # After validation, all three configs yield the same effective config assert config == another_config == yet_another_config # __sphinx_doc_replay_buffer_type_specification__end__ # __sphinx_doc_replay_buffer_basic_interaction__begin__ # We choose fragments because it does not impose restrictions on our batch to be added buffer = ReplayBuffer(capacity=2, storage_unit=StorageUnit.FRAGMENTS) dummy_batch = SampleBatch({"a": [1], "b": [2]}) buffer.add(dummy_batch) buffer.sample(2) # Because elements can be sampled multiple times, we receive a concatenated version # of dummy_batch `{a: [1, 1], b: [2, 2,]}`. # __sphinx_doc_replay_buffer_basic_interaction__end__ # __sphinx_doc_replay_buffer_own_buffer__begin__ class LessSampledReplayBuffer(ReplayBuffer): @override(ReplayBuffer) def sample( self, num_items: int, evict_sampled_more_then: int = 30, **kwargs ) -> Optional[SampleBatchType]: """Evicts experiences that have been sampled > evict_sampled_more_then times.""" idxes = [random.randint(0, len(self) - 1) for _ in range(num_items)] often_sampled_idxes = list( filter(lambda x: self._hit_count[x] >= evict_sampled_more_then, set(idxes)) ) sample = self._encode_sample(idxes) self._num_timesteps_sampled += sample.count for idx in often_sampled_idxes: del self._storage[idx] self._hit_count = np.append( self._hit_count[:idx], self._hit_count[idx + 1 :] ) return sample config = ( DQNConfig() .training(replay_buffer_config={"type": LessSampledReplayBuffer}) .environment(env="CartPole-v0") ) tune.run( "DQN", config=config.to_dict(), stop={"training_iteration": 1}, ) # __sphinx_doc_replay_buffer_own_buffer__end__ # __sphinx_doc_replay_buffer_advanced_usage_storage_unit__begin__ # This line will make our buffer store only complete episodes found in a batch config.training(replay_buffer_config={"storage_unit": StorageUnit.EPISODES}) less_sampled_buffer = LessSampledReplayBuffer(**config.replay_buffer_config) # Gather some random experiences env = RandomEnv() done = False batch = SampleBatch({}) t = 0 while not done: obs, reward, done, info = env.step([0, 0]) # Note that in order for RLlib to find out about start and end of an episode, # "t" and "dones" have to properly mark an episode's trajectory one_step_batch = SampleBatch( {"obs": [obs], "t": [t], "reward": [reward], "dones": [done]} ) batch = SampleBatch.concat_samples([batch, one_step_batch]) t += 1 less_sampled_buffer.add(batch) for i in range(10): assert len(less_sampled_buffer._storage) == 1 less_sampled_buffer.sample(num_items=1, evict_sampled_more_then=9) assert len(less_sampled_buffer._storage) == 0 # __sphinx_doc_replay_buffer_advanced_usage_storage_unit__end__ # __sphinx_doc_replay_buffer_advanced_usage_underlying_buffers__begin__ config = { "env": "CartPole-v1", "replay_buffer_config": { "type": "MultiAgentReplayBuffer", "underlying_replay_buffer_config": { "type": LessSampledReplayBuffer, "evict_sampled_more_then": 20 # We can specify the default call argument # for the sample method of the underlying buffer method here }, }, } tune.run( "DQN", config=config, stop={"episode_reward_mean": 50, "training_iteration": 10} ) # __sphinx_doc_replay_buffer_advanced_usage_underlying_buffers__end__