mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00
142 lines
4.6 KiB
Python
142 lines
4.6 KiB
Python
# Demonstration of RLlib's ReplayBuffer workflow
|
|
|
|
from typing import Optional
|
|
import random
|
|
import numpy as np
|
|
|
|
from ray import air, tune
|
|
from ray.rllib.utils.replay_buffers import ReplayBuffer, StorageUnit
|
|
from ray.rllib.utils.annotations import override
|
|
from ray.rllib.utils.typing import SampleBatchType
|
|
from ray.rllib.utils.replay_buffers.utils import validate_buffer_config
|
|
from ray.rllib.examples.env.random_env import RandomEnv
|
|
from ray.rllib.policy.sample_batch import SampleBatch, concat_samples
|
|
from ray.rllib.algorithms.dqn.dqn import DQNConfig
|
|
|
|
|
|
# __sphinx_doc_replay_buffer_type_specification__begin__
|
|
config = DQNConfig().training(replay_buffer_config={"type": ReplayBuffer}).to_dict()
|
|
|
|
another_config = (
|
|
DQNConfig().training(replay_buffer_config={"type": "ReplayBuffer"}).to_dict()
|
|
)
|
|
|
|
|
|
yet_another_config = (
|
|
DQNConfig()
|
|
.training(
|
|
replay_buffer_config={"type": "ray.rllib.utils.replay_buffers.ReplayBuffer"}
|
|
)
|
|
.to_dict()
|
|
)
|
|
|
|
validate_buffer_config(config)
|
|
validate_buffer_config(another_config)
|
|
validate_buffer_config(yet_another_config)
|
|
|
|
# After validation, all three configs yield the same effective config
|
|
assert config == another_config == yet_another_config
|
|
# __sphinx_doc_replay_buffer_type_specification__end__
|
|
|
|
|
|
# __sphinx_doc_replay_buffer_basic_interaction__begin__
|
|
# We choose fragments because it does not impose restrictions on our batch to be added
|
|
buffer = ReplayBuffer(capacity=2, storage_unit=StorageUnit.FRAGMENTS)
|
|
dummy_batch = SampleBatch({"a": [1], "b": [2]})
|
|
buffer.add(dummy_batch)
|
|
buffer.sample(2)
|
|
# Because elements can be sampled multiple times, we receive a concatenated version
|
|
# of dummy_batch `{a: [1, 1], b: [2, 2,]}`.
|
|
# __sphinx_doc_replay_buffer_basic_interaction__end__
|
|
|
|
|
|
# __sphinx_doc_replay_buffer_own_buffer__begin__
|
|
class LessSampledReplayBuffer(ReplayBuffer):
|
|
@override(ReplayBuffer)
|
|
def sample(
|
|
self, num_items: int, evict_sampled_more_then: int = 30, **kwargs
|
|
) -> Optional[SampleBatchType]:
|
|
"""Evicts experiences that have been sampled > evict_sampled_more_then times."""
|
|
idxes = [random.randint(0, len(self) - 1) for _ in range(num_items)]
|
|
often_sampled_idxes = list(
|
|
filter(lambda x: self._hit_count[x] >= evict_sampled_more_then, set(idxes))
|
|
)
|
|
|
|
sample = self._encode_sample(idxes)
|
|
self._num_timesteps_sampled += sample.count
|
|
|
|
for idx in often_sampled_idxes:
|
|
del self._storage[idx]
|
|
self._hit_count = np.append(
|
|
self._hit_count[:idx], self._hit_count[idx + 1 :]
|
|
)
|
|
|
|
return sample
|
|
|
|
|
|
config = (
|
|
DQNConfig()
|
|
.training(replay_buffer_config={"type": LessSampledReplayBuffer})
|
|
.environment(env="CartPole-v0")
|
|
)
|
|
|
|
tune.Tuner(
|
|
"DQN",
|
|
param_space=config.to_dict(),
|
|
run_config=air.RunConfig(
|
|
stop={"training_iteration": 1},
|
|
),
|
|
).fit()
|
|
# __sphinx_doc_replay_buffer_own_buffer__end__
|
|
|
|
# __sphinx_doc_replay_buffer_advanced_usage_storage_unit__begin__
|
|
# This line will make our buffer store only complete episodes found in a batch
|
|
config.training(replay_buffer_config={"storage_unit": StorageUnit.EPISODES})
|
|
|
|
less_sampled_buffer = LessSampledReplayBuffer(**config.replay_buffer_config)
|
|
|
|
# Gather some random experiences
|
|
env = RandomEnv()
|
|
done = False
|
|
batch = SampleBatch({})
|
|
t = 0
|
|
while not done:
|
|
obs, reward, done, info = env.step([0, 0])
|
|
# Note that in order for RLlib to find out about start and end of an episode,
|
|
# "t" and "dones" have to properly mark an episode's trajectory
|
|
one_step_batch = SampleBatch(
|
|
{"obs": [obs], "t": [t], "reward": [reward], "dones": [done]}
|
|
)
|
|
batch = concat_samples([batch, one_step_batch])
|
|
t += 1
|
|
|
|
less_sampled_buffer.add(batch)
|
|
for i in range(10):
|
|
assert len(less_sampled_buffer._storage) == 1
|
|
less_sampled_buffer.sample(num_items=1, evict_sampled_more_then=9)
|
|
|
|
assert len(less_sampled_buffer._storage) == 0
|
|
# __sphinx_doc_replay_buffer_advanced_usage_storage_unit__end__
|
|
|
|
|
|
# __sphinx_doc_replay_buffer_advanced_usage_underlying_buffers__begin__
|
|
config = {
|
|
"env": "CartPole-v1",
|
|
"replay_buffer_config": {
|
|
"type": "MultiAgentReplayBuffer",
|
|
"underlying_replay_buffer_config": {
|
|
"type": LessSampledReplayBuffer,
|
|
"evict_sampled_more_then": 20 # We can specify the default call argument
|
|
# for the sample method of the underlying buffer method here
|
|
},
|
|
},
|
|
}
|
|
|
|
tune.Tuner(
|
|
"DQN",
|
|
param_space=config,
|
|
run_config=air.RunConfig(
|
|
stop={"episode_reward_mean": 50, "training_iteration": 10}
|
|
),
|
|
).fit()
|
|
# __sphinx_doc_replay_buffer_advanced_usage_underlying_buffers__end__
|