[RLlib] Replay Buffer API and Ape-X. (#24506)

This commit is contained in:
Artur Niederfahrenhorst 2022-05-17 13:43:49 +02:00 committed by GitHub
parent c74886a55e
commit fb2915d26a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
100 changed files with 653 additions and 1999 deletions

View file

@ -43,8 +43,10 @@ run_experiments(
"config": {
"num_workers": 3,
"num_gpus": 0,
"buffer_size": 10000,
"learning_starts": 0,
"replay_buffer_config": {
"capacity": 10000,
"learning_starts": 0,
},
"rollout_fragment_length": 1,
"train_batch_size": 1,
"min_iter_time_s": 10,

View file

@ -16,11 +16,12 @@ apex-breakoutnoframeskip-v4:
lr: .0001
adam_epsilon: .00015
hiddens: [512]
buffer_size: 1000000
replay_buffer_config:
capacity: 1000000
prioritized_replay_alpha: 0.5
exploration_config:
epsilon_timesteps: 200000
final_epsilon: 0.01
prioritized_replay_alpha: 0.5
num_gpus: 1
num_workers: 8
num_envs_per_worker: 8

View file

@ -24,11 +24,12 @@ cql-halfcheetahbulletenv-v0:
no_done_at_end: false
n_step: 3
rollout_fragment_length: 1
prioritized_replay: false
replay_buffer_config:
type: MultiAgentReplayBuffer
learning_starts: 256
train_batch_size: 256
target_network_update_freq: 0
min_train_timesteps_per_reporting: 1000
learning_starts: 256
optimization:
actor_learning_rate: 0.0001
critic_learning_rate: 0.0003

View file

@ -24,18 +24,19 @@ ddpg-hopperbulletenv-v0:
min_sample_timesteps_per_reporting: 1000
target_network_update_freq: 0
tau: 0.001
buffer_size: 10000
prioritized_replay: True
prioritized_replay_alpha: 0.6
prioritized_replay_beta: 0.4
prioritized_replay_eps: 0.000001
replay_buffer_config:
capacity: 10000
type: MultiAgentPrioritizedReplayBuffer
prioritized_replay_alpha: 0.6
prioritized_replay_beta: 0.4
prioritized_replay_eps: 0.000001
learning_starts: 500
clip_rewards: false
actor_lr: 0.001
critic_lr: 0.001
use_huber: true
huber_threshold: 1.0
l2_reg: 0.000001
learning_starts: 500
rollout_fragment_length: 1
train_batch_size: 48
num_gpus: 1

View file

@ -18,13 +18,14 @@ dqn-breakoutnoframeskip-v4:
lr: .0000625
adam_epsilon: .00015
hiddens: [512]
learning_starts: 20000
buffer_size: 1000000
replay_buffer_config:
capacity: 1000000
learning_starts: 20000
prioritized_replay_alpha: 0.5
rollout_fragment_length: 4
train_batch_size: 32
exploration_config:
epsilon_timesteps: 200000
final_epsilon: 0.01
prioritized_replay_alpha: 0.5
num_gpus: 0.5
min_sample_timesteps_per_reporting: 10000

View file

@ -21,11 +21,12 @@ sac-halfcheetahbulletenv-v0:
no_done_at_end: false
n_step: 3
rollout_fragment_length: 1
prioritized_replay: true
train_batch_size: 256
target_network_update_freq: 1
min_sample_timesteps_per_reporting: 1000
learning_starts: 10000
replay_buffer_config:
learning_starts: 10000
type: MultiAgentPrioritizedReplayBuffer
optimization:
actor_learning_rate: 0.0003
critic_learning_rate: 0.0003

View file

@ -9,6 +9,7 @@ td3-halfcheetahbulletenv-v0:
time_total_s: 7200
config:
num_gpus: 1
learning_starts: 10000
replay_buffer_config:
learning_starts: 10000
exploration_config:
random_timesteps: 10000

View file

@ -171,9 +171,9 @@ sac-repeat-after-me-env:
repeat_delay: 0
num_gpus: 2
num_workers: 0
replay_buffer_config:
type: MultiAgentPrioritizedReplayBuffer
initial_alpha: 0.001
prioritized_replay: true
# Double batch size (2 GPUs).
train_batch_size: 512
@ -191,11 +191,11 @@ sac-repeat-after-me-env-continuous:
config:
continuous: true
repeat_delay: 0
replay_buffer_config:
type: MultiAgentPrioritizedReplayBuffer
num_gpus: 2
num_workers: 0
initial_alpha: 0.001
prioritized_replay: true
# Double batch size (2 GPUs).
train_batch_size: 512

View file

@ -1030,7 +1030,7 @@ py_test(
"--env", "Pendulum-v1",
"--run", "APEX_DDPG",
"--stop", "'{\"training_iteration\": 1}'",
"--config", "'{\"framework\": \"tf\", \"num_workers\": 2, \"optimizer\": {\"num_replay_buffer_shards\": 1}, \"learning_starts\": 100, \"min_time_s_per_reporting\": 1, \"batch_mode\": \"complete_episodes\"}'",
"--config", "'{\"framework\": \"tf\", \"num_workers\": 2, \"optimizer\": {\"num_replay_buffer_shards\": 1}, \"replay_buffer_config\": {\"learning_starts\": 100}, \"min_time_s_per_reporting\": 1, \"batch_mode\": \"complete_episodes\"}'",
"--ray-num-cpus", "4",
]
)
@ -1087,7 +1087,7 @@ py_test(
"--env", "CartPole-v0",
"--run", "DQN",
"--stop", "'{\"training_iteration\": 1}'",
"--config", "'{\"framework\": \"tf\", \"input\": \"tests/data/cartpole\", \"learning_starts\": 0, \"input_evaluation\": [\"wis\", \"is\"], \"exploration_config\": {\"type\": \"SoftQ\"}}'"
"--config", "'{\"framework\": \"tf\", \"input\": \"tests/data/cartpole\", \"replay_buffer_config\": {\"learning_starts\": 0}, \"input_evaluation\": [\"wis\", \"is\"], \"exploration_config\": {\"type\": \"SoftQ\"}}'"
]
)
@ -1099,7 +1099,7 @@ py_test(
"--env", "PongDeterministic-v4",
"--run", "DQN",
"--stop", "'{\"training_iteration\": 1}'",
"--config", "'{\"framework\": \"tf\", \"lr\": 1e-4, \"exploration_config\": {\"epsilon_timesteps\": 200000, \"final_epsilon\": 0.01}, \"buffer_size\": 10000, \"rollout_fragment_length\": 4, \"learning_starts\": 10000, \"target_network_update_freq\": 1000, \"gamma\": 0.99, \"prioritized_replay\": true}'"
"--config", "'{\"framework\": \"tf\", \"lr\": 1e-4, \"exploration_config\": {\"epsilon_timesteps\": 200000, \"final_epsilon\": 0.01}, \"replay_buffer_config\": {\"capacity\": 10000, \"learning_starts\": 10000}, \"rollout_fragment_length\": 4, \"target_network_update_freq\": 1000, \"gamma\": 0.99}'"
]
)
@ -1385,27 +1385,6 @@ py_test(
srcs = ["evaluation/tests/test_episode.py"]
)
# --------------------------------------------------------------------
# Optimizers and Memories
# rllib/execution/
#
# Tag: execution
# --------------------------------------------------------------------
py_test(
name = "test_segment_tree",
tags = ["team:ml", "execution"],
size = "small",
srcs = ["execution/tests/test_segment_tree.py"]
)
py_test(
name = "test_prioritized_replay_buffer",
tags = ["team:ml", "execution"],
size = "small",
srcs = ["execution/tests/test_prioritized_replay_buffer.py"]
)
# --------------------------------------------------------------------
# Models and Distributions
# rllib/models/

View file

@ -4,6 +4,9 @@ from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import TrainerConfigDict
from ray.util.iter import LocalIterator
from ray.rllib.utils.typing import PartialTrainerConfigDict
from ray.rllib.utils.typing import ResultDict
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
APEX_DDPG_DEFAULT_CONFIG = DDPGTrainer.merge_trainer_configs(
DDPGConfig().to_dict(), # see also the options in ddpg.py, which are also supported
@ -17,42 +20,50 @@ APEX_DDPG_DEFAULT_CONFIG = DDPGTrainer.merge_trainer_configs(
"n_step": 3,
"num_gpus": 0,
"num_workers": 32,
"buffer_size": 2000000,
# TODO(jungong) : update once Apex supports replay_buffer_config.
"no_local_replay_buffer": True,
# Whether all shards of the replay buffer must be co-located
# with the learner process (running the execution plan).
# This is preferred b/c the learner process should have quick
# access to the data from the buffer shards, avoiding network
# traffic each time samples from the buffer(s) are drawn.
# Set this to False for relaxing this constraint and allowing
# replay shards to be created on node(s) other than the one
# on which the learner is located.
"replay_buffer_shards_colocated_with_driver": True,
"learning_starts": 50000,
"replay_buffer_config": {
"capacity": 2000000,
"no_local_replay_buffer": True,
# Specify prioritized replay by supplying a buffer type that supports
# prioritization, for example: MultiAgentPrioritizedReplayBuffer.
"prioritized_replay": DEPRECATED_VALUE,
"learning_starts": 50000,
# Whether all shards of the replay buffer must be co-located
# with the learner process (running the execution plan).
# This is preferred b/c the learner process should have quick
# access to the data from the buffer shards, avoiding network
# traffic each time samples from the buffer(s) are drawn.
# Set this to False for relaxing this constraint and allowing
# replay shards to be created on node(s) other than the one
# on which the learner is located.
"replay_buffer_shards_colocated_with_driver": True,
"worker_side_prioritization": True,
},
"train_batch_size": 512,
"rollout_fragment_length": 50,
# Update the target network every `target_network_update_freq` sample timesteps.
"target_network_update_freq": 500000,
"min_sample_timesteps_per_reporting": 25000,
"worker_side_prioritization": True,
"min_time_s_per_reporting": 30,
# Experimental flag.
# If True, the execution plan API will not be used. Instead,
# a Trainer's `training_iteration` method will be called as-is each
# training iteration.
"_disable_execution_plan_api": False,
},
_allow_unknown_configs=True,
)
class ApexDDPGTrainer(DDPGTrainer):
class ApexDDPGTrainer(DDPGTrainer, ApexTrainer):
@classmethod
@override(DDPGTrainer)
def get_default_config(cls) -> TrainerConfigDict:
return APEX_DDPG_DEFAULT_CONFIG
@override(DDPGTrainer)
def setup(self, config: PartialTrainerConfigDict):
return ApexTrainer.setup(self, config)
@override(DDPGTrainer)
def training_iteration(self) -> ResultDict:
"""Use APEX-DQN's training iteration function."""
return ApexTrainer.training_iteration(self)
@staticmethod
@override(DDPGTrainer)
def execution_plan(

View file

@ -7,6 +7,7 @@ from ray.rllib.agents.trainer_config import TrainerConfig
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import TrainerConfigDict
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.deprecation import Deprecated
logger = logging.getLogger(__name__)
@ -99,9 +100,11 @@ class DDPGConfig(SimpleQConfig):
# Common DDPG buffer parameters.
self.replay_buffer_config = {
"_enable_replay_buffer_api": True,
"type": "MultiAgentPrioritizedReplayBuffer",
"capacity": 50000,
# Specify prioritized replay by supplying a buffer type that supports
# prioritization, for example: MultiAgentPrioritizedReplayBuffer.
"prioritized_replay": DEPRECATED_VALUE,
# Alpha parameter for prioritized replay buffer.
"prioritized_replay_alpha": 0.6,
# Beta parameter for sampling from prioritized replay buffer.
@ -110,6 +113,8 @@ class DDPGConfig(SimpleQConfig):
"prioritized_replay_eps": 1e-6,
# How many steps of the model to sample before learning starts.
"learning_starts": 1500,
# Whether to compute priorities on workers.
"worker_side_prioritization": False,
}
# .training()

View file

@ -37,7 +37,6 @@ TD3_DEFAULT_CONFIG = DDPGTrainer.merge_trainer_configs(
},
# other changes & things we want to keep fixed:
# larger actor learning rate, no l2 regularisation, no Huber loss, etc.
"learning_starts": 10000,
"actor_hiddens": [400, 300],
"critic_hiddens": [400, 300],
"n_step": 1,
@ -52,15 +51,16 @@ TD3_DEFAULT_CONFIG = DDPGTrainer.merge_trainer_configs(
"target_network_update_freq": 0,
"num_workers": 0,
"num_gpus_per_worker": 0,
"worker_side_prioritization": False,
"clip_rewards": False,
"use_state_preprocessor": False,
# Size of the replay buffer (in time steps).
"buffer_size": DEPRECATED_VALUE,
"replay_buffer_config": {
"_enable_replay_buffer_api": True,
"type": "MultiAgentReplayBuffer",
# Specify prioritized replay by supplying a buffer type that supports
# prioritization, for example: MultiAgentPrioritizedReplayBuffer.
"prioritized_replay": DEPRECATED_VALUE,
"capacity": 1000000,
"learning_starts": 10000,
"worker_side_prioritization": False,
},
},
)

View file

@ -23,7 +23,7 @@ class TestApexDDPG(unittest.TestCase):
config = apex_ddpg.APEX_DDPG_DEFAULT_CONFIG.copy()
config["num_workers"] = 2
config["min_sample_timesteps_per_reporting"] = 100
config["learning_starts"] = 0
config["replay_buffer_config"]["learning_starts"] = 0
config["optimizer"]["num_replay_buffer_shards"] = 1
num_iterations = 1
for _ in framework_iterator(config, with_eager_tracing=True):

View file

@ -1,6 +1,7 @@
import numpy as np
import re
import unittest
from tempfile import TemporaryDirectory
import ray
import ray.rllib.agents.ddpg as ddpg
@ -61,6 +62,23 @@ class TestDDPG(unittest.TestCase):
check(a, 500)
trainer.stop()
def test_ddpg_checkpoint_save_and_restore(self):
"""Test whether a DDPGTrainer can save and load checkpoints."""
config = ddpg.DEFAULT_CONFIG.copy()
config["num_workers"] = 1
config["num_envs_per_worker"] = 2
config["replay_buffer_config"]["learning_starts"] = 0
config["exploration_config"]["random_timesteps"] = 100
# Test against all frameworks.
for _ in framework_iterator(config, with_eager_tracing=True):
trainer = ddpg.DDPGTrainer(config=config, env="Pendulum-v1")
trainer.train()
with TemporaryDirectory() as temp_dir:
checkpoint = trainer.save(temp_dir)
trainer.restore(checkpoint)
trainer.stop()
def test_ddpg_exploration_and_with_random_prerun(self):
"""Tests DDPG's Exploration (w/ random actions for n timesteps)."""
@ -131,7 +149,6 @@ class TestDDPG(unittest.TestCase):
# Run locally.
config.seed = 42
config.num_workers = 0
config.learning_starts = 0
config.twin_q = True
config.use_huber = True
config.huber_threshold = 1.0
@ -139,9 +156,9 @@ class TestDDPG(unittest.TestCase):
# Make this small (seems to introduce errors).
config.l2_reg = 1e-10
config.replay_buffer_config = {
"_enable_replay_buffer_api": True,
"type": "MultiAgentReplayBuffer",
"capacity": 50000,
"learning_starts": 0,
}
# Use very simple nets.
config.actor_hiddens = [10]

View file

@ -23,7 +23,6 @@ from ray.actor import ActorHandle
from ray.rllib import RolloutWorker
from ray.rllib.agents import Trainer
from ray.rllib.agents.dqn.dqn import (
calculate_rr_weights,
DEFAULT_CONFIG as DQN_DEFAULT_CONFIG,
DQNTrainer,
)
@ -35,16 +34,10 @@ from ray.rllib.execution.common import (
_get_global_vars,
_get_shared_metrics,
)
from ray.rllib.execution.concurrency_ops import Concurrently, Dequeue, Enqueue
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.buffers.multi_agent_replay_buffer import ReplayActor
from ray.rllib.execution.parallel_requests import (
asynchronous_parallel_requests,
wait_asynchronous_requests,
)
from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
from ray.rllib.execution.rollout_ops import ParallelRollouts
from ray.rllib.execution.train_ops import UpdateTargetNetwork
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.actors import create_colocated_actors
from ray.rllib.utils.annotations import override
@ -59,7 +52,6 @@ from ray.rllib.utils.metrics import (
SYNCH_WORKER_WEIGHTS_TIMER,
TARGET_NET_UPDATE_TIMER,
)
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
from ray.rllib.utils.typing import (
SampleBatchType,
TrainerConfigDict,
@ -69,7 +61,7 @@ from ray.rllib.utils.typing import (
)
from ray.tune.trainable import Trainable
from ray.tune.utils.placement_groups import PlacementGroupFactory
from ray.util.iter import LocalIterator
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
# fmt: off
# __sphinx_doc_begin__
@ -90,29 +82,32 @@ APEX_DEFAULT_CONFIG = merge_dicts(
# TODO(jungong) : add proper replay_buffer_config after
# DistributedReplayBuffer type is supported.
"replay_buffer_config": {
# For now we don't use the new ReplayBuffer API here
"_enable_replay_buffer_api": False,
"no_local_replay_buffer": True,
"type": "MultiAgentReplayBuffer",
# Specify prioritized replay by supplying a buffer type that supports
# prioritization
"prioritized_replay": DEPRECATED_VALUE,
"type": "MultiAgentPrioritizedReplayBuffer",
"capacity": 2000000,
"replay_batch_size": 32,
# Alpha parameter for prioritized replay buffer.
"prioritized_replay_alpha": 0.6,
# Beta parameter for sampling from prioritized replay buffer.
"prioritized_replay_beta": 0.4,
# Epsilon to add to the TD errors when updating priorities.
"prioritized_replay_eps": 1e-6,
"learning_starts": 50000,
# Whether all shards of the replay buffer must be co-located
# with the learner process (running the execution plan).
# This is preferred b/c the learner process should have quick
# access to the data from the buffer shards, avoiding network
# traffic each time samples from the buffer(s) are drawn.
# Set this to False for relaxing this constraint and allowing
# replay shards to be created on node(s) other than the one
# on which the learner is located.
"replay_buffer_shards_colocated_with_driver": True,
"worker_side_prioritization": True,
},
# Whether all shards of the replay buffer must be co-located
# with the learner process (running the execution plan).
# This is preferred b/c the learner process should have quick
# access to the data from the buffer shards, avoiding network
# traffic each time samples from the buffer(s) are drawn.
# Set this to False for relaxing this constraint and allowing
# replay shards to be created on node(s) other than the one
# on which the learner is located.
"replay_buffer_shards_colocated_with_driver": True,
"learning_starts": 50000,
"train_batch_size": 512,
"rollout_fragment_length": 50,
# Update the target network every `target_network_update_freq` sample timesteps.
@ -125,7 +120,6 @@ APEX_DEFAULT_CONFIG = merge_dicts(
# executed. Set to 0 for no minimum timesteps.
"min_sample_timesteps_per_reporting": 25000,
"exploration_config": {"type": "PerWorkerEpsilonGreedy"},
"worker_side_prioritization": True,
"min_time_s_per_reporting": 30,
# This will set the ratio of replayed from a buffer and learned
# on timesteps to sampled from an environment and stored in the replay
@ -188,26 +182,27 @@ class ApexTrainer(DQNTrainer):
]
num_replay_buffer_shards = self.config["optimizer"]["num_replay_buffer_shards"]
buffer_size = (
# Create copy here so that we can modify without breaking other logic
replay_actor_config = copy.deepcopy(self.config["replay_buffer_config"])
replay_actor_config["capacity"] = (
self.config["replay_buffer_config"]["capacity"] // num_replay_buffer_shards
)
replay_actor_args = [
num_replay_buffer_shards,
self.config["learning_starts"],
buffer_size,
self.config["train_batch_size"],
self.config["replay_buffer_config"]["prioritized_replay_alpha"],
self.config["replay_buffer_config"]["prioritized_replay_beta"],
self.config["replay_buffer_config"]["prioritized_replay_eps"],
self.config["multiagent"]["replay_mode"],
self.config["replay_buffer_config"].get("replay_sequence_length", 1),
]
ReplayActor = ray.remote(num_cpus=0)(replay_actor_config["type"])
# Place all replay buffer shards on the same node as the learner
# (driver process that runs this execution plan).
if self.config["replay_buffer_shards_colocated_with_driver"]:
if replay_actor_config["replay_buffer_shards_colocated_with_driver"]:
self.replay_actors = create_colocated_actors(
actor_specs=[ # (class, args, kwargs={}, count)
(ReplayActor, replay_actor_args, {}, num_replay_buffer_shards)
(
ReplayActor,
None,
replay_actor_config,
num_replay_buffer_shards,
)
],
node=platform.node(), # localhost
)[
@ -216,7 +211,7 @@ class ApexTrainer(DQNTrainer):
# Place replay buffer shards on any node(s).
else:
self.replay_actors = [
ReplayActor.remote(*replay_actor_args)
ReplayActor.remote(*replay_actor_config)
for _ in range(num_replay_buffer_shards)
]
self.learner_thread = LearnerThread(self.workers.local_worker())
@ -273,151 +268,6 @@ class ApexTrainer(DQNTrainer):
return copy.deepcopy(self.learner_thread.learner_info)
@staticmethod
@override(DQNTrainer)
def execution_plan(
workers: WorkerSet, config: dict, **kwargs
) -> LocalIterator[dict]:
assert (
len(kwargs) == 0
), "Apex execution_plan does NOT take any additional parameters"
# Create a number of replay buffer actors.
num_replay_buffer_shards = config["optimizer"]["num_replay_buffer_shards"]
buffer_size = (
config["replay_buffer_config"]["capacity"] // num_replay_buffer_shards
)
replay_actor_args = [
num_replay_buffer_shards,
config["learning_starts"],
buffer_size,
config["train_batch_size"],
config["replay_buffer_config"]["prioritized_replay_alpha"],
config["replay_buffer_config"]["prioritized_replay_beta"],
config["replay_buffer_config"]["prioritized_replay_eps"],
config["multiagent"]["replay_mode"],
config["replay_buffer_config"].get("replay_sequence_length", 1),
]
# Place all replay buffer shards on the same node as the learner
# (driver process that runs this execution plan).
if config["replay_buffer_shards_colocated_with_driver"]:
replay_actors = create_colocated_actors(
actor_specs=[
# (class, args, kwargs={}, count)
(ReplayActor, replay_actor_args, {}, num_replay_buffer_shards)
],
node=platform.node(), # localhost
)[
0
] # [0]=only one item in `actor_specs`.
# Place replay buffer shards on any node(s).
else:
replay_actors = [
ReplayActor(*replay_actor_args) for _ in range(num_replay_buffer_shards)
]
# Start the learner thread.
learner_thread = LearnerThread(workers.local_worker())
learner_thread.start()
# Update experience priorities post learning.
def update_prio_and_stats(item: Tuple[ActorHandle, dict, int, int]) -> None:
actor, prio_dict, env_count, agent_count = item
if config["replay_buffer_config"].get("prioritized_replay_alpha") > 0:
actor.update_priorities.remote(prio_dict)
metrics = _get_shared_metrics()
# Manually update the steps trained counter since the learner
# thread is executing outside the pipeline.
metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = env_count
metrics.counters[STEPS_TRAINED_COUNTER] += env_count
metrics.timers["learner_dequeue"] = learner_thread.queue_timer
metrics.timers["learner_grad"] = learner_thread.grad_timer
metrics.timers["learner_overall"] = learner_thread.overall_timer
# We execute the following steps concurrently:
# (1) Generate rollouts and store them in one of our replay buffer
# actors. Update the weights of the worker that generated the batch.
rollouts = ParallelRollouts(workers, mode="async", num_async=2)
store_op = rollouts.for_each(StoreToReplayBuffer(actors=replay_actors))
# Only need to update workers if there are remote workers.
if workers.remote_workers():
store_op = store_op.zip_with_source_actor().for_each(
UpdateWorkerWeights(
learner_thread,
workers,
max_weight_sync_delay=(
config["optimizer"]["max_weight_sync_delay"]
),
)
)
# (2) Read experiences from one of the replay buffer actors and send
# to the learner thread via its in-queue.
post_fn = config.get("before_learn_on_batch") or (lambda b, *a: b)
replay_op = (
Replay(actors=replay_actors, num_async=4)
.for_each(lambda x: post_fn(x, workers, config))
.zip_with_source_actor()
.for_each(Enqueue(learner_thread.inqueue))
)
# (3) Get priorities back from learner thread and apply them to the
# replay buffer actors.
update_op = (
Dequeue(learner_thread.outqueue, check=learner_thread.is_alive)
.for_each(update_prio_and_stats)
.for_each(
UpdateTargetNetwork(
workers, config["target_network_update_freq"], by_steps_trained=True
)
)
)
if config["training_intensity"]:
# Execute (1), (2) with a fixed intensity ratio.
rr_weights = calculate_rr_weights(config) + ["*"]
merged_op = Concurrently(
[store_op, replay_op, update_op],
mode="round_robin",
output_indexes=[2],
round_robin_weights=rr_weights,
)
else:
# Execute (1), (2), (3) asynchronously as fast as possible. Only
# output items from (3) since metrics aren't available before
# then.
merged_op = Concurrently(
[store_op, replay_op, update_op], mode="async", output_indexes=[2]
)
# Add in extra replay and learner metrics to the training result.
def add_apex_metrics(result: dict) -> dict:
replay_stats = ray.get(
replay_actors[0].stats.remote(config["optimizer"].get("debug"))
)
exploration_infos = workers.foreach_policy_to_train(
lambda p, _: p.get_exploration_state()
)
result["info"].update(
{
"exploration_infos": exploration_infos,
"learner_queue": learner_thread.learner_queue_size.stats(),
LEARNER_INFO: copy.deepcopy(learner_thread.learner_info),
"replay_shard_0": replay_stats,
}
)
return result
# Only report metrics from the workers with the lowest 1/3 of
# epsilons.
selected_workers = workers.remote_workers()[
-len(workers.remote_workers()) // 3 :
]
return StandardMetricsReporting(
merged_op, workers, config, selected_workers=selected_workers
).for_each(add_apex_metrics)
def get_samples_and_store_to_replay_buffers(self):
# in the case the num_workers = 0
if not self.workers.remote_workers():
@ -425,7 +275,7 @@ class ApexTrainer(DQNTrainer):
local_sampling_worker = self.workers.local_worker()
batch = local_sampling_worker.sample()
actor = random.choice(self.replay_actors)
ray.get(actor.add_batch.remote(batch))
ray.get(actor.add.remote(batch))
batch_statistics = {
local_sampling_worker: [
{
@ -437,7 +287,7 @@ class ApexTrainer(DQNTrainer):
return batch_statistics
def remote_worker_sample_and_store(
worker: RolloutWorker, replay_actors: List[ReplayActor]
worker: RolloutWorker, replay_actors: List[ActorHandle]
):
# This function is run as a remote function on sampling workers,
# and should only be used with the RolloutWorker's apply function ever.
@ -447,7 +297,7 @@ class ApexTrainer(DQNTrainer):
# operation on there.
_batch = worker.sample()
_actor = random.choice(replay_actors)
_actor.add_batch.remote(_batch)
_actor.add.remote(_batch)
_batch_statistics = {
"agent_steps": _batch.agent_steps(),
"env_steps": _batch.env_steps(),
@ -550,7 +400,8 @@ class ApexTrainer(DQNTrainer):
actors=[rand_actor],
ray_wait_timeout_s=0.1,
max_remote_requests_in_flight_per_actor=num_requests_to_launch,
remote_fn=lambda actor: actor.replay(),
remote_args=[[self.config["train_batch_size"]]],
remote_fn=lambda actor, num_items: actor.sample(num_items),
)
for replay_actor, sample_batches in replay_samples_ready.items():
for sample_batch in sample_batches:

View file

@ -46,6 +46,7 @@ from ray.rllib.execution.common import (
LAST_TARGET_UPDATE_TS,
NUM_TARGET_UPDATES,
)
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
logger = logging.getLogger(__name__)
@ -133,17 +134,26 @@ class DQNConfig(SimpleQConfig):
self.n_step = 1
self.before_learn_on_batch = None
self.training_intensity = None
self.worker_side_prioritization = False
# Changes to SimpleQConfig default
self.replay_buffer_config = {
"_enable_replay_buffer_api": True,
"type": "MultiAgentPrioritizedReplayBuffer",
# Specify prioritized replay by supplying a buffer type that supports
# prioritization, for example: MultiAgentPrioritizedReplayBuffer.
"prioritized_replay": DEPRECATED_VALUE,
# Size of the replay buffer. Note that if async_updates is set,
# then each worker will have a replay buffer of this size.
"capacity": 50000,
"prioritized_replay_alpha": 0.6,
# Beta parameter for sampling from prioritized replay buffer.
"prioritized_replay_beta": 0.4,
# Epsilon to add to the TD errors when updating priorities.
"prioritized_replay_eps": 1e-6,
# The number of continuous environment steps to replay at once. This may
# be set to greater than 1 to support recurrent models.
"replay_sequence_length": 1,
# Whether to compute priorities on workers.
"worker_side_prioritization": False,
}
# fmt: on
# __sphinx_doc_end__

View file

@ -442,7 +442,9 @@ def postprocess_nstep_and_prio(
batch[PRIO_WEIGHTS] = np.ones_like(batch[SampleBatch.REWARDS])
# Prioritize on the worker side.
if batch.count > 0 and policy.config["worker_side_prioritization"]:
if batch.count > 0 and policy.config["replay_buffer_config"].get(
"worker_side_prioritization", False
):
td_errors = policy.compute_td_error(
batch[SampleBatch.OBS],
batch[SampleBatch.ACTIONS],

View file

@ -32,8 +32,10 @@ R2D2_DEFAULT_CONFIG = Trainer.merge_trainer_configs(
# === Replay buffer ===
"replay_buffer_config": {
"_enable_replay_buffer_api": True,
"type": "MultiAgentReplayBuffer",
# Specify prioritized replay by supplying a buffer type that supports
# prioritization, for example: MultiAgentPrioritizedReplayBuffer.
"prioritized_replay": DEPRECATED_VALUE,
# Size of the replay buffer (in sequences, not timesteps).
"capacity": 100000,
"storage_unit": "sequences",

View file

@ -108,11 +108,13 @@ class SimpleQConfig(TrainerConfig):
# __sphinx_doc_begin__
self.target_network_update_freq = 500
self.replay_buffer_config = {
"_enable_replay_buffer_api": True,
# How many steps of the model to sample before learning starts.
"learning_starts": 1000,
"type": "MultiAgentReplayBuffer",
"capacity": 50000,
"replay_batch_size": 32,
# The number of contiguous environment steps to replay at once. This
# may be set to greater than 1 to support recurrent models.
"replay_sequence_length": 1,
}
self.store_buffer_in_checkpoints = False
@ -151,7 +153,8 @@ class SimpleQConfig(TrainerConfig):
self.prioritized_replay = DEPRECATED_VALUE
self.learning_starts = DEPRECATED_VALUE
self.replay_batch_size = DEPRECATED_VALUE
self.replay_sequence_length = DEPRECATED_VALUE
# Can not use DEPRECATED_VALUE here because -1 is a common config value
self.replay_sequence_length = None
self.prioritized_replay_alpha = DEPRECATED_VALUE
self.prioritized_replay_beta = DEPRECATED_VALUE
self.prioritized_replay_eps = DEPRECATED_VALUE

View file

@ -24,8 +24,9 @@ class TestApexDQN(unittest.TestCase):
config = apex.APEX_DEFAULT_CONFIG.copy()
config["num_workers"] = 0
config["num_gpus"] = 0
config["learning_starts"] = 1000
config["prioritized_replay"] = True
config["replay_buffer_config"] = {
"learning_starts": 1000,
}
config["min_sample_timesteps_per_reporting"] = 100
config["min_time_s_per_reporting"] = 1
config["optimizer"]["num_replay_buffer_shards"] = 1
@ -41,8 +42,9 @@ class TestApexDQN(unittest.TestCase):
config = apex.APEX_DEFAULT_CONFIG.copy()
config["num_workers"] = 3
config["num_gpus"] = 0
config["learning_starts"] = 1000
config["prioritized_replay"] = True
config["replay_buffer_config"] = {
"learning_starts": 1000,
}
config["min_sample_timesteps_per_reporting"] = 100
config["min_time_s_per_reporting"] = 1
config["optimizer"]["num_replay_buffer_shards"] = 1
@ -78,14 +80,12 @@ class TestApexDQN(unittest.TestCase):
config = apex.APEX_DEFAULT_CONFIG.copy()
config["num_workers"] = 1
config["num_gpus"] = 0
config["learning_starts"] = 10
config["train_batch_size"] = 10
config["rollout_fragment_length"] = 5
config["replay_buffer_config"] = {
# For now we don't use the new ReplayBuffer API here
"_enable_replay_buffer_api": False,
"no_local_replay_buffer": True,
"type": "MultiAgentReplayBuffer",
"type": "MultiAgentPrioritizedReplayBuffer",
"learning_starts": 10,
"capacity": 100,
"replay_batch_size": 10,
"prioritized_replay_alpha": 0.6,

View file

@ -3,6 +3,7 @@ import unittest
import ray
import ray.rllib.agents.dqn as dqn
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
from ray.rllib.utils.test_utils import (
check_compute_single_action,
check_train_results,
@ -13,6 +14,28 @@ tf1, tf, tfv = try_import_tf()
torch, nn = try_import_torch()
def check_batch_sizes(train_results):
"""Check if batch sizes are according to what we expect from config."""
info = train_results["info"]
learner_info = info[LEARNER_INFO]
for pid, policy_stats in learner_info.items():
if pid == "batch_count":
continue
# Expect td-errors to be per batch-item.
configured_b = train_results["config"]["train_batch_size"]
actual_b = policy_stats["td_error"].shape[0]
if (configured_b - actual_b) / actual_b > 0.1:
assert (
configured_b
/ (
train_results["config"]["model"]["max_seq_len"]
+ train_results["config"]["replay_buffer_config"]["replay_burn_in"]
)
== actual_b
)
class TestR2D2(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
@ -47,6 +70,7 @@ class TestR2D2(unittest.TestCase):
for i in range(num_iterations):
results = trainer.train()
check_train_results(results)
check_batch_sizes(results)
print(results)
check_compute_single_action(trainer, include_state=True)

View file

@ -19,8 +19,8 @@ from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.typing import TrainerConfigDict
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
@ -69,12 +69,14 @@ DEFAULT_CONFIG = with_common_config({
"adv_policy": "maddpg",
# === Replay buffer ===
# Size of the replay buffer. Note that if async_updates is set, then
# each worker will have a replay buffer of this size.
"buffer_size": DEPRECATED_VALUE,
"replay_buffer_config": {
"type": "MultiAgentReplayBuffer",
# Specify prioritized replay by supplying a buffer type that supports
# prioritization, for example: MultiAgentPrioritizedReplayBuffer.
"prioritized_replay": DEPRECATED_VALUE,
"capacity": int(1e6),
# How many steps of the model to sample before learning starts.
"learning_starts": 1024 * 25,
},
# Observation compression. Note that compression makes simulation slow in
# MPE.
@ -102,8 +104,6 @@ DEFAULT_CONFIG = with_common_config({
"actor_feature_reg": 0.001,
# If not None, clip gradients during optimization at this value
"grad_norm_clipping": 0.5,
# How many steps of the model to sample before learning starts.
"learning_starts": 1024 * 25,
# Update the replay buffer with this many samples at once. Note that this
# setting applies per-worker if num_workers > 1.
"rollout_fragment_length": 100,

View file

@ -21,6 +21,7 @@ from ray.rllib.utils.metrics import (
)
from ray.rllib.utils.replay_buffers.utils import sample_min_n_steps_from_buffer
from ray.rllib.utils.typing import ResultDict, TrainerConfigDict
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
class QMixConfig(SimpleQConfig):
@ -78,12 +79,15 @@ class QMixConfig(SimpleQConfig):
self.train_batch_size = 32
self.target_network_update_freq = 500
self.replay_buffer_config = {
# Use the new ReplayBuffer API here
"_enable_replay_buffer_api": True,
"type": "SimpleReplayBuffer",
# Specify prioritized replay by supplying a buffer type that supports
# prioritization, for example: MultiAgentPrioritizedReplayBuffer.
"prioritized_replay": DEPRECATED_VALUE,
# Size of the replay buffer in batches (not timesteps!).
"capacity": 1000,
"learning_starts": 1000,
# Whether to compute priorities on workers.
"worker_side_prioritization": False,
}
self.model = {
"lstm_cell_size": 64,

View file

@ -58,7 +58,8 @@ class RNNSACTrainer(SACTrainer):
)
# Check if user tries to set replay_sequence_length (to anything
# other than the proper value)
if config["replay_buffer_config"]["replay_sequence_length"] not in [
if config["replay_buffer_config"].get("replay_sequence_length", None) not in [
None,
-1,
replay_sequence_length,
]:

View file

@ -84,22 +84,23 @@ DEFAULT_CONFIG = with_common_config({
# === Replay buffer ===
"replay_buffer_config": {
# Enable the new ReplayBuffer API.
"_enable_replay_buffer_api": True,
"type": "MultiAgentPrioritizedReplayBuffer",
"type": "MultiAgentReplayBuffer",
# Specify prioritized replay by supplying a buffer type that supports
# prioritization, for example: MultiAgentPrioritizedReplayBuffer.
"prioritized_replay": DEPRECATED_VALUE,
"capacity": int(1e6),
# How many steps of the model to sample before learning starts.
"learning_starts": 1500,
# The number of continuous environment steps to replay at once. This may
# be set to greater than 1 to support recurrent models.
"replay_sequence_length": 1,
# If True prioritized replay buffer will be used.
"prioritized_replay": False,
"prioritized_replay_alpha": 0.6,
# Beta parameter for sampling from prioritized replay buffer.
"prioritized_replay_beta": 0.4,
# Epsilon to add to the TD errors when updating priorities.
"prioritized_replay_eps": 1e-6,
# Whether to compute priorities on workers.
"worker_side_prioritization": False,
},
# Set this to True, if you want the contents of your buffer(s) to be
# stored in any saved checkpoints as well.
@ -155,8 +156,6 @@ DEFAULT_CONFIG = with_common_config({
"num_gpus_per_worker": 0,
# Whether to allocate CPUs for workers (if > 0).
"num_cpus_per_worker": 1,
# Whether to compute priorities on workers.
"worker_side_prioritization": False,
# Prevent reporting frequency from going lower than this time span.
"min_time_s_per_reporting": 1,
@ -167,18 +166,6 @@ DEFAULT_CONFIG = with_common_config({
# Use a Beta-distribution instead of a SquashedGaussian for bounded,
# continuous action spaces (not recommended, for debugging only).
"_use_beta_distribution": False,
# Deprecated.
# The following values have moved because of the new ReplayBuffer API.
"prioritized_replay": DEPRECATED_VALUE,
"prioritized_replay_alpha": DEPRECATED_VALUE,
"prioritized_replay_beta": DEPRECATED_VALUE,
"prioritized_replay_eps": DEPRECATED_VALUE,
"learning_starts": DEPRECATED_VALUE,
"buffer_size": DEPRECATED_VALUE,
"replay_batch_size": DEPRECATED_VALUE,
"replay_sequence_length": DEPRECATED_VALUE,
})
# __sphinx_doc_end__
# fmt: on

View file

@ -42,11 +42,12 @@ class TestRNNSAC(unittest.TestCase):
"lstm_use_prev_reward": True,
}
# Test with PR activated.
config["prioritized_replay"] = True
config["burn_in"] = 20
config["zero_init_states"] = True
# Test with MultiAgentPrioritizedReplayBuffer
config["replay_buffer_config"] = {
"type": "MultiAgentPrioritizedReplayBuffer",
"replay_burn_in": 20,
"zero_init_states": True,
}
config["lr"] = 5e-4

View file

@ -523,7 +523,7 @@ class TestSAC(unittest.TestCase):
config = sac.DEFAULT_CONFIG.copy()
config["num_workers"] = 0 # Run locally.
config["learning_starts"] = 0
config["replay_buffer_config"]["learning_starts"] = 0
config["rollout_fragment_length"] = 5
config["train_batch_size"] = 5
config["replay_buffer_config"]["capacity"] = 10

View file

@ -77,10 +77,7 @@ class SlateQConfig(TrainerConfig):
self.rmsprop_epsilon = 1e-5
self.grad_clip = None
self.n_step = 1
self.worker_side_prioritization = False
self.replay_buffer_config = {
# Enable the new ReplayBuffer API.
"_enable_replay_buffer_api": True,
"type": "MultiAgentPrioritizedReplayBuffer",
"capacity": 100000,
"prioritized_replay_alpha": 0.6,
@ -91,6 +88,8 @@ class SlateQConfig(TrainerConfig):
# The number of continuous environment steps to replay at once. This may
# be set to greater than 1 to support recurrent models.
"replay_sequence_length": 1,
# Whether to compute priorities on workers.
"worker_side_prioritization": False,
# How many steps of the model to sample before learning starts.
"learning_starts": 20000,
}

View file

@ -40,9 +40,6 @@ from ray.rllib.evaluation.metrics import (
)
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.buffers.multi_agent_replay_buffer import (
MultiAgentReplayBuffer as Legacy_MultiAgentReplayBuffer,
)
from ray.rllib.utils.replay_buffers import MultiAgentReplayBuffer
from ray.rllib.execution.common import WORKER_UPDATE_TIMER
from ray.rllib.execution.rollout_ops import (
@ -2127,7 +2124,7 @@ class Trainer(Trainable):
@DeveloperAPI
def _create_local_replay_buffer_if_necessary(
self, config: PartialTrainerConfigDict
) -> Optional[Union[MultiAgentReplayBuffer, Legacy_MultiAgentReplayBuffer]]:
) -> Optional[MultiAgentReplayBuffer]:
"""Create a MultiAgentReplayBuffer instance if necessary.
Args:
@ -2138,7 +2135,7 @@ class Trainer(Trainable):
None, if local replay buffer is not needed.
"""
if not config.get("replay_buffer_config") or config["replay_buffer_config"].get(
"no_local_replay_buffer" or config.get("no_local_replay_buffer"), False
"no_local_replay_buffer" or config.get("no_local_replay_buffer")
):
return

View file

@ -226,6 +226,16 @@ class TrainerConfig:
self.timesteps_per_iteration = DEPRECATED_VALUE
self.min_iter_time_s = DEPRECATED_VALUE
self.collect_metrics_timeout = DEPRECATED_VALUE
# The following values have moved because of the new ReplayBuffer API
self.buffer_size = DEPRECATED_VALUE
self.prioritized_replay = DEPRECATED_VALUE
self.learning_starts = DEPRECATED_VALUE
self.replay_batch_size = DEPRECATED_VALUE
# -1 = DEPRECATED_VALUE is a valid value for replay_sequence_length
self.replay_sequence_length = None
self.prioritized_replay_alpha = DEPRECATED_VALUE
self.prioritized_replay_beta = DEPRECATED_VALUE
self.prioritized_replay_eps = DEPRECATED_VALUE
def to_dict(self) -> TrainerConfigDict:
"""Converts all settings into a legacy config dict for backward compatibility.

View file

@ -51,21 +51,12 @@ CQL_DEFAULT_CONFIG = merge_dicts(
"lagrangian_thresh": 5.0,
# Min Q weight multiplier.
"min_q_weight": 5.0,
"replay_buffer_config": {
"_enable_replay_buffer_api": True,
"type": "MultiAgentPrioritizedReplayBuffer",
# Replay buffer should be larger or equal the size of the offline
# dataset.
"capacity": int(1e6),
},
# Reporting: As CQL is offline (no sampling steps), we need to limit
# `self.train()` reporting by the number of steps trained (not sampled).
"min_sample_timesteps_per_reporting": 0,
"min_train_timesteps_per_reporting": 100,
# Deprecated keys.
# Use `replay_buffer_config.capacity` instead.
"buffer_size": DEPRECATED_VALUE,
# Use `min_sample_timesteps_per_reporting` and
# `min_train_timesteps_per_reporting` instead.
"timesteps_per_iteration": DEPRECATED_VALUE,

View file

@ -2,7 +2,7 @@ from typing import Type
from ray.rllib.agents.trainer import Trainer, with_common_config
from ray.rllib.algorithms.marwil.marwil_tf_policy import MARWILTFPolicy
from ray.rllib.execution.buffers.multi_agent_replay_buffer import MultiAgentReplayBuffer
from ray.rllib.utils.replay_buffers.utils import validate_buffer_config
from ray.rllib.execution.rollout_ops import (
synchronous_parallel_sample,
)
@ -18,10 +18,10 @@ from ray.rllib.utils.metrics import (
WORKER_UPDATE_TIMER,
)
from ray.rllib.utils.typing import (
PartialTrainerConfigDict,
ResultDict,
TrainerConfigDict,
)
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
# fmt: off
# __sphinx_doc_begin__
@ -66,14 +66,22 @@ DEFAULT_CONFIG = with_common_config({
# Number of (independent) timesteps pushed through the loss
# each SGD round.
"train_batch_size": 2000,
# Size of the replay buffer in (single and independent) timesteps.
# The buffer gets filled by reading from the input files line-by-line
# and adding all timesteps on one line at once. We then sample
# uniformly from the buffer (`train_batch_size` samples) for
# each training step.
"replay_buffer_size": 10000,
# Number of steps to read before learning starts.
"learning_starts": 0,
"replay_buffer_config": {
"type": "MultiAgentPrioritizedReplayBuffer",
# Size of the replay buffer in (single and independent) timesteps.
# The buffer gets filled by reading from the input files line-by-line
# and adding all timesteps on one line at once. We then sample
# uniformly from the buffer (`train_batch_size` samples) for
# each training step.
"capacity": 10000,
# Specify prioritized replay by supplying a buffer type that supports
# prioritization
"prioritized_replay": DEPRECATED_VALUE,
# Number of steps to read before learning starts.
"learning_starts": 0,
"replay_sequence_length": 1
},
# A coeff to encourage higher action distribution entropy for exploration.
"bc_logstd_coeff": 0.0,
@ -96,6 +104,8 @@ class MARWILTrainer(Trainer):
# Call super's validation method.
super().validate_config(config)
validate_buffer_config(config)
if config["num_gpus"] > 1:
raise ValueError("`num_gpus` > 1 not yet supported for MARWIL!")
@ -116,19 +126,6 @@ class MARWILTrainer(Trainer):
else:
return MARWILTFPolicy
@override(Trainer)
def setup(self, config: PartialTrainerConfigDict):
super().setup(config)
# `training_iteration` implementation: Setup buffer in `setup`, not
# in `execution_plan` (deprecated).
if self.config["_disable_execution_plan_api"] is True:
self.local_replay_buffer = MultiAgentReplayBuffer(
learning_starts=self.config["learning_starts"],
capacity=self.config["replay_buffer_size"],
replay_batch_size=self.config["train_batch_size"],
replay_sequence_length=1,
)
@override(Trainer)
def training_iteration(self) -> ResultDict:
# Collect SampleBatches from sample workers.
@ -137,10 +134,10 @@ class MARWILTrainer(Trainer):
self._counters[NUM_AGENT_STEPS_SAMPLED] += batch.agent_steps()
self._counters[NUM_ENV_STEPS_SAMPLED] += batch.env_steps()
# Add batch to replay buffer.
self.local_replay_buffer.add_batch(batch)
self.local_replay_buffer.add(batch)
# Pull batch from replay buffer and train on it.
train_batch = self.local_replay_buffer.replay()
train_batch = self.local_replay_buffer.sample(self.config["train_batch_size"])
# Train.
if self.config["simple_optimizer"]:
train_results = train_one_step(self, train_batch)

View file

@ -116,7 +116,7 @@ if __name__ == "__main__":
assert r["model"]["foo"] == 42, result
if args.run == "DQN":
extra_config = {"learning_starts": 0}
extra_config = {"replay_buffer_config": {"learning_starts": 0}}
else:
extra_config = {}

View file

@ -22,15 +22,17 @@ if __name__ == "__main__":
"num_gpus": 1,
"num_workers": 2,
"num_envs_per_worker": 8,
"learning_starts": 1000,
"buffer_size": int(1e5),
"replay_buffer_config": {
"learning_starts": 1000,
"capacity": int(1e5),
"prioritized_replay_alpha": 0.5,
},
"compress_observations": True,
"rollout_fragment_length": 20,
"train_batch_size": 512,
"gamma": 0.99,
"n_step": 3,
"lr": 0.0001,
"prioritized_replay_alpha": 0.5,
"target_network_update_freq": 50000,
"min_sample_timesteps_per_reporting": 25000,
# Method specific.

View file

@ -38,9 +38,11 @@ if __name__ == "__main__":
config["bc_iters"] = 0
config["clip_actions"] = False
config["normalize_actions"] = True
config["learning_starts"] = 256
config["replay_buffer_config"]["learning_starts"] = 256
config["rollout_fragment_length"] = 1
config["prioritized_replay"] = False
# Test without prioritized replay
config["replay_buffer_config"]["type"] = "MultiAgentReplayBuffer"
config["replay_buffer_config"]["capacity"] = int(1e6)
config["tau"] = 0.005
config["target_entropy"] = "auto"
config["Q_model"] = {

View file

@ -126,7 +126,9 @@ def main():
"num_gpus": args.num_gpus,
"num_workers": args.num_workers,
"env_config": env_config,
"learning_starts": args.learning_starts,
"replay_buffer_config": {
"learning_starts": args.learning_starts,
},
}
# Perform a test run on the env with a random agent to see, what

View file

@ -180,7 +180,7 @@ if __name__ == "__main__":
# Example of using DQN (supports off-policy actions).
config.update(
{
"learning_starts": 100,
"replay_buffer_config": {"learning_starts": 100},
"min_sample_timesteps_per_reporting": 200,
"n_step": 3,
"rollout_fragment_length": 4,

View file

@ -115,7 +115,7 @@ if __name__ == "__main__":
"env_config": {
"actions_are_logits": True,
},
"learning_starts": 100,
"replay_buffer_config": {"learning_starts": 100},
"multiagent": {
"policies": {
"pol1": PolicySpec(

View file

@ -21,7 +21,9 @@ from ray.rllib.agents.ppo.ppo_torch_policy import PPOTorchPolicy
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.execution.rollout_ops import synchronous_parallel_sample
from ray.rllib.execution.train_ops import train_one_step
from ray.rllib.execution.buffers.multi_agent_replay_buffer import MultiAgentReplayBuffer
from ray.rllib.utils.replay_buffers.multi_agent_replay_buffer import (
MultiAgentReplayBuffer,
)
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
from ray.rllib.utils.annotations import override
@ -82,7 +84,7 @@ class MyTrainer(Trainer):
super().setup(config)
# Create local replay buffer.
self.local_replay_buffer = MultiAgentReplayBuffer(
num_shards=1, learning_starts=1000, capacity=50000, replay_batch_size=64
num_shards=1, learning_starts=1000, capacity=50000
)
@override(Trainer)
@ -103,14 +105,14 @@ class MyTrainer(Trainer):
self._counters[NUM_AGENT_STEPS_SAMPLED] += ma_batch.agent_steps()
ppo_batch = ma_batch.policy_batches.pop("ppo_policy")
# Add collected batches (only for DQN policy) to replay buffer.
self.local_replay_buffer.add_batch(ma_batch)
self.local_replay_buffer.add(ma_batch)
ppo_batches.append(ppo_batch)
num_env_steps += ppo_batch.count
# DQN sub-flow.
dqn_train_results = {}
dqn_train_batch = self.local_replay_buffer.replay()
dqn_train_batch = self.local_replay_buffer.sample(num_items=64)
if dqn_train_batch is not None:
dqn_train_results = train_one_step(self, dqn_train_batch, ["dqn_policy"])
self._counters["agent_steps_trained_DQN"] += dqn_train_batch.agent_steps()

View file

@ -7,11 +7,7 @@ from ray.rllib.execution.metric_ops import (
OncePerTimestepsElapsed,
)
from ray.rllib.execution.multi_gpu_learner_thread import MultiGPULearnerThread
from ray.rllib.execution.buffers.replay_buffer import (
ReplayBuffer,
PrioritizedReplayBuffer,
)
from ray.rllib.execution.buffers.multi_agent_replay_buffer import MultiAgentReplayBuffer
from ray.rllib.execution.minibatch_buffer import MinibatchBuffer
from ray.rllib.execution.replay_ops import (
StoreToReplayBuffer,
Replay,
@ -50,14 +46,11 @@ __all__ = [
"Enqueue",
"LearnerThread",
"MixInReplay",
"MultiAgentReplayBuffer",
"MultiGPULearnerThread",
"OncePerTimeInterval",
"OncePerTimestepsElapsed",
"ParallelRollouts",
"PrioritizedReplayBuffer",
"Replay",
"ReplayBuffer",
"SelectExperiences",
"SimpleReplayBuffer",
"StandardMetricsReporting",
@ -66,4 +59,5 @@ __all__ = [
"TrainOneStep",
"MultiGPUTrainOneStep",
"UpdateTargetNetwork",
"MinibatchBuffer",
]

View file

@ -1,5 +1,9 @@
from ray.rllib.execution.buffers.minibatch_buffer import MinibatchBuffer
from ray.rllib.utils.deprecation import deprecation_warning
__all__ = [
"MinibatchBuffer",
]
deprecation_warning(
old="ray.rllib.execution.buffers",
new="ray.rllib.utils.replay_buffers",
help="RLlib's ReplayBuffer API has changed. Apart from the replay buffers moving, "
"some have altered behaviour. Please refer to the docs for more information.",
error=False,
)

View file

@ -1,303 +0,0 @@
import collections
import platform
from typing import Any, Dict, Optional, Callable
import numpy as np
import ray
from ray.rllib import SampleBatch
from ray.rllib.execution import PrioritizedReplayBuffer, ReplayBuffer
from ray.rllib.execution.buffers.replay_buffer import logger, _ALL_POLICIES
from ray.rllib.policy.rnn_sequencing import timeslice_along_seq_lens_with_overlap
from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.utils import deprecation_warning
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.timer import TimerStat
from ray.rllib.utils.typing import PolicyID, SampleBatchType, T
from ray.util.annotations import DeveloperAPI
from ray.util.iter import ParallelIteratorWorker
class MultiAgentReplayBuffer(ParallelIteratorWorker):
"""A replay buffer shard storing data for all policies (in multiagent setup).
Ray actors are single-threaded, so for scalability, multiple replay actors
may be created to increase parallelism."""
def __init__(
self,
num_shards: int = 1,
learning_starts: int = 1000,
capacity: int = 10000,
replay_batch_size: int = 1,
prioritized_replay_alpha: float = 0.6,
prioritized_replay_beta: float = 0.4,
prioritized_replay_eps: float = 1e-6,
replay_mode: str = "independent",
replay_sequence_length: int = 1,
replay_burn_in: int = 0,
replay_zero_init_states: bool = True,
buffer_size=DEPRECATED_VALUE,
):
"""Initializes a MultiAgentReplayBuffer instance.
Args:
num_shards: The number of buffer shards that exist in total
(including this one).
learning_starts: Number of timesteps after which a call to
`replay()` will yield samples (before that, `replay()` will
return None).
capacity: The capacity of the buffer. Note that when
`replay_sequence_length` > 1, this is the number of sequences
(not single timesteps) stored.
replay_batch_size: The batch size to be sampled (in timesteps).
Note that if `replay_sequence_length` > 1,
`self.replay_batch_size` will be set to the number of
sequences sampled (B).
prioritized_replay_alpha: Alpha parameter for a prioritized
replay buffer. Use 0.0 for no prioritization.
prioritized_replay_beta: Beta parameter for a prioritized
replay buffer.
prioritized_replay_eps: Epsilon parameter for a prioritized
replay buffer.
replay_mode: One of "independent" or "lockstep". Determined,
whether in the multiagent case, sampling is done across all
agents/policies equally.
replay_sequence_length: The sequence length (T) of a single
sample. If > 1, we will sample B x T from this buffer.
replay_burn_in: The burn-in length in case
`replay_sequence_length` > 0. This is the number of timesteps
each sequence overlaps with the previous one to generate a
better internal state (=state after the burn-in), instead of
starting from 0.0 each RNN rollout.
replay_zero_init_states: Whether the initial states in the
buffer (if replay_sequence_length > 0) are alwayas 0.0 or
should be updated with the previous train_batch state outputs.
"""
# Deprecated args.
if buffer_size != DEPRECATED_VALUE:
deprecation_warning(
"ReplayBuffer(size)", "ReplayBuffer(capacity)", error=False
)
capacity = buffer_size
self.replay_starts = learning_starts // num_shards
self.capacity = capacity // num_shards
self.replay_batch_size = replay_batch_size
self.prioritized_replay_beta = prioritized_replay_beta
self.prioritized_replay_eps = prioritized_replay_eps
self.replay_mode = replay_mode
self.replay_sequence_length = replay_sequence_length
self.replay_burn_in = replay_burn_in
self.replay_zero_init_states = replay_zero_init_states
if replay_sequence_length > 1:
self.replay_batch_size = int(
max(1, replay_batch_size // replay_sequence_length)
)
logger.info(
"Since replay_sequence_length={} and replay_batch_size={}, "
"we will replay {} sequences at a time.".format(
replay_sequence_length, replay_batch_size, self.replay_batch_size
)
)
if replay_mode not in ["lockstep", "independent"]:
raise ValueError("Unsupported replay mode: {}".format(replay_mode))
def gen_replay():
while True:
yield self.replay()
ParallelIteratorWorker.__init__(self, gen_replay, False)
def new_buffer():
if prioritized_replay_alpha == 0.0:
return ReplayBuffer(self.capacity)
else:
return PrioritizedReplayBuffer(
self.capacity, alpha=prioritized_replay_alpha
)
self.replay_buffers = collections.defaultdict(new_buffer)
# Metrics.
self.add_batch_timer = TimerStat()
self.replay_timer = TimerStat()
self.update_priorities_timer = TimerStat()
self.num_added = 0
# Make externally accessible for testing.
global _local_replay_buffer
_local_replay_buffer = self
# If set, return this instead of the usual data for testing.
self._fake_batch = None
@staticmethod
def get_instance_for_testing():
"""Return a MultiAgentReplayBuffer instance that has been previously
instantiated.
Returns:
_local_replay_buffer: The lastly instantiated
MultiAgentReplayBuffer.
"""
global _local_replay_buffer
return _local_replay_buffer
def get_host(self) -> str:
"""Returns the computer's network name.
Returns:
The computer's networks name or an empty string, if the network
name could not be determined.
"""
return platform.node()
def add_batch(self, batch: SampleBatchType) -> None:
"""Adds a batch to the appropriate policy's replay buffer.
Turns the batch into a MultiAgentBatch of the DEFAULT_POLICY_ID if
it is not a MultiAgentBatch.
Args:
batch (SampleBatchType): The batch to be added.
"""
# Make a copy so the replay buffer doesn't pin plasma memory.
batch = batch.copy()
# Handle everything as if multi-agent.
batch = batch.as_multi_agent()
with self.add_batch_timer:
# Lockstep mode: Store under _ALL_POLICIES key (we will always
# only sample from all policies at the same time).
if self.replay_mode == "lockstep":
# Note that prioritization is not supported in this mode.
for s in batch.timeslices(self.replay_sequence_length):
self.replay_buffers[_ALL_POLICIES].add(s, weight=None)
else:
for policy_id, sample_batch in batch.policy_batches.items():
if self.replay_sequence_length == 1:
timeslices = sample_batch.timeslices(1)
else:
timeslices = timeslice_along_seq_lens_with_overlap(
sample_batch=sample_batch,
zero_pad_max_seq_len=self.replay_sequence_length,
pre_overlap=self.replay_burn_in,
zero_init_states=self.replay_zero_init_states,
)
for time_slice in timeslices:
# If SampleBatch has prio-replay weights, average
# over these to use as a weight for the entire
# sequence.
if "weights" in time_slice and len(time_slice["weights"]):
weight = np.mean(time_slice["weights"])
else:
weight = None
self.replay_buffers[policy_id].add(time_slice, weight=weight)
self.num_added += batch.count
# TODO: This entire class will be removed soon. Leave this as a shim in case
# new `training_iteration` methods call the new replay buffer API's `sample()`
# method on this old buffer class here.
def sample(self, num_items=None):
return self.replay()
def replay(self, policy_id: Optional[PolicyID] = None) -> SampleBatchType:
"""If this buffer was given a fake batch, return it, otherwise return
a MultiAgentBatch with samples.
"""
if self._fake_batch:
if not isinstance(self._fake_batch, MultiAgentBatch):
self._fake_batch = SampleBatch(self._fake_batch).as_multi_agent()
return self._fake_batch
if self.num_added < self.replay_starts:
return None
with self.replay_timer:
# Lockstep mode: Sample from all policies at the same time an
# equal amount of steps.
if self.replay_mode == "lockstep":
assert (
policy_id is None
), "`policy_id` specifier not allowed in `locksetp` mode!"
return self.replay_buffers[_ALL_POLICIES].sample(
self.replay_batch_size, beta=self.prioritized_replay_beta
)
elif policy_id is not None:
return self.replay_buffers[policy_id].sample(
self.replay_batch_size, beta=self.prioritized_replay_beta
)
else:
samples = {}
for policy_id, replay_buffer in self.replay_buffers.items():
samples[policy_id] = replay_buffer.sample(
self.replay_batch_size, beta=self.prioritized_replay_beta
)
return MultiAgentBatch(samples, self.replay_batch_size)
def update_priorities(self, prio_dict: Dict) -> None:
"""Updates the priorities of underlying replay buffers.
Computes new priorities from td_errors and prioritized_replay_eps.
These priorities are used to update underlying replay buffers per
policy_id.
Args:
prio_dict (Dict): A dictionary containing td_errors for
batches saved in underlying replay buffers.
"""
with self.update_priorities_timer:
for policy_id, (batch_indexes, td_errors) in prio_dict.items():
new_priorities = np.abs(td_errors) + self.prioritized_replay_eps
self.replay_buffers[policy_id].update_priorities(
batch_indexes, new_priorities
)
def stats(self, debug: bool = False) -> Dict:
"""Returns the stats of this buffer and all underlying buffers.
Args:
debug (bool): If True, stats of underlying replay buffers will
be fetched with debug=True.
Returns:
stat: Dictionary of buffer stats.
"""
stat = {
"add_batch_time_ms": round(1000 * self.add_batch_timer.mean, 3),
"replay_time_ms": round(1000 * self.replay_timer.mean, 3),
"update_priorities_time_ms": round(
1000 * self.update_priorities_timer.mean, 3
),
}
for policy_id, replay_buffer in self.replay_buffers.items():
stat.update(
{"policy_{}".format(policy_id): replay_buffer.stats(debug=debug)}
)
return stat
def get_state(self) -> Dict[str, Any]:
state = {"num_added": self.num_added, "replay_buffers": {}}
for policy_id, replay_buffer in self.replay_buffers.items():
state["replay_buffers"][policy_id] = replay_buffer.get_state()
return state
def set_state(self, state: Dict[str, Any]) -> None:
self.num_added = state["num_added"]
buffer_states = state["replay_buffers"]
for policy_id in buffer_states.keys():
self.replay_buffers[policy_id].set_state(buffer_states[policy_id])
@DeveloperAPI
def apply(
self,
func: Callable[["MultiAgentReplayBuffer", Optional[Any], Optional[Any]], T],
*_args,
**kwargs,
) -> T:
"""Calls the given function with this MultiAgentReplayBuffer instance."""
return func(self, *_args, **kwargs)
ReplayActor = ray.remote(num_cpus=0)(MultiAgentReplayBuffer)

View file

@ -1,419 +0,0 @@
import logging
import numpy as np
import random
from typing import Any, Dict, List, Optional
# Import ray before psutil will make sure we use psutil's bundled version
import ray # noqa F401
import psutil # noqa E402
from ray.rllib.execution.segment_tree import SumSegmentTree, MinSegmentTree
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import DeveloperAPI, override
from ray.util.debug import log_once
from ray.rllib.utils.deprecation import (
Deprecated,
DEPRECATED_VALUE,
deprecation_warning,
)
from ray.rllib.utils.metrics.window_stat import WindowStat
from ray.rllib.utils.typing import SampleBatchType
# Constant that represents all policies in lockstep replay mode.
_ALL_POLICIES = "__all__"
logger = logging.getLogger(__name__)
def warn_replay_capacity(*, item: SampleBatchType, num_items: int) -> None:
"""Warn if the configured replay buffer capacity is too large."""
if log_once("replay_capacity"):
item_size = item.size_bytes()
psutil_mem = psutil.virtual_memory()
total_gb = psutil_mem.total / 1e9
mem_size = num_items * item_size / 1e9
msg = (
"Estimated max memory usage for replay buffer is {} GB "
"({} batches of size {}, {} bytes each), "
"available system memory is {} GB".format(
mem_size, num_items, item.count, item_size, total_gb
)
)
if mem_size > total_gb:
raise ValueError(msg)
elif mem_size > 0.2 * total_gb:
logger.warning(msg)
else:
logger.info(msg)
@Deprecated(new="warn_replay_capacity", error=False)
def warn_replay_buffer_size(*, item: SampleBatchType, num_items: int) -> None:
return warn_replay_capacity(item=item, num_items=num_items)
@DeveloperAPI
class ReplayBuffer:
@DeveloperAPI
def __init__(self, capacity: int = 10000, size: Optional[int] = DEPRECATED_VALUE):
"""Initializes a ReplayBuffer instance.
Args:
capacity: Max number of timesteps to store in the FIFO
buffer. After reaching this number, older samples will be
dropped to make space for new ones.
"""
# Deprecated args.
if size != DEPRECATED_VALUE:
deprecation_warning(
"ReplayBuffer(size)", "ReplayBuffer(capacity)", error=False
)
capacity = size
# The actual storage (list of SampleBatches).
self._storage = []
self.capacity = capacity
# The next index to override in the buffer.
self._next_idx = 0
self._hit_count = np.zeros(self.capacity)
# Whether we have already hit our capacity (and have therefore
# started to evict older samples).
self._eviction_started = False
# Number of (single) timesteps that have been added to the buffer
# over its lifetime. Note that each added item (batch) may contain
# more than one timestep.
self._num_timesteps_added = 0
self._num_timesteps_added_wrap = 0
# Number of (single) timesteps that have been sampled from the buffer
# over its lifetime.
self._num_timesteps_sampled = 0
self._evicted_hit_stats = WindowStat("evicted_hit", 1000)
self._est_size_bytes = 0
def __len__(self) -> int:
"""Returns the number of items currently stored in this buffer."""
return len(self._storage)
@DeveloperAPI
def add(self, item: SampleBatchType, weight: float) -> None:
"""Add a batch of experiences.
Args:
item: SampleBatch to add to this buffer's storage.
weight: The weight of the added sample used in subsequent
sampling steps. Only relevant if this ReplayBuffer is
a PrioritizedReplayBuffer.
"""
assert item.count > 0, item
warn_replay_capacity(item=item, num_items=self.capacity / item.count)
# Update our timesteps counts.
self._num_timesteps_added += item.count
self._num_timesteps_added_wrap += item.count
if self._next_idx >= len(self._storage):
self._storage.append(item)
self._est_size_bytes += item.size_bytes()
else:
self._storage[self._next_idx] = item
# Wrap around storage as a circular buffer once we hit capacity.
if self._num_timesteps_added_wrap >= self.capacity:
self._eviction_started = True
self._num_timesteps_added_wrap = 0
self._next_idx = 0
else:
self._next_idx += 1
# Eviction of older samples has already started (buffer is "full").
if self._eviction_started:
self._evicted_hit_stats.push(self._hit_count[self._next_idx])
self._hit_count[self._next_idx] = 0
@DeveloperAPI
def sample(self, num_items: int, beta: float = 0.0) -> SampleBatchType:
"""Sample a batch of size `num_items` from this buffer.
If less than `num_items` records are in this buffer, some samples in
the results may be repeated to fulfil the batch size (`num_items`)
request.
Args:
num_items: Number of items to sample from this buffer.
beta: The prioritized replay beta value. Only relevant if this
ReplayBuffer is a PrioritizedReplayBuffer.
Returns:
Concatenated batch of items.
"""
# If we don't have any samples yet in this buffer, return None.
if len(self) == 0:
return None
idxes = [random.randint(0, len(self) - 1) for _ in range(num_items)]
sample = self._encode_sample(idxes)
# Update our timesteps counters.
self._num_timesteps_sampled += len(sample)
return sample
@DeveloperAPI
def stats(self, debug: bool = False) -> dict:
"""Returns the stats of this buffer.
Args:
debug: If True, adds sample eviction statistics to the returned
stats dict.
Returns:
A dictionary of stats about this buffer.
"""
data = {
"added_count": self._num_timesteps_added,
"added_count_wrapped": self._num_timesteps_added_wrap,
"eviction_started": self._eviction_started,
"sampled_count": self._num_timesteps_sampled,
"est_size_bytes": self._est_size_bytes,
"num_entries": len(self._storage),
}
if debug:
data.update(self._evicted_hit_stats.stats())
return data
@DeveloperAPI
def get_state(self) -> Dict[str, Any]:
"""Returns all local state.
Returns:
The serializable local state.
"""
state = {"_storage": self._storage, "_next_idx": self._next_idx}
state.update(self.stats(debug=False))
return state
@DeveloperAPI
def set_state(self, state: Dict[str, Any]) -> None:
"""Restores all local state to the provided `state`.
Args:
state: The new state to set this buffer. Can be
obtained by calling `self.get_state()`.
"""
# The actual storage.
self._storage = state["_storage"]
self._next_idx = state["_next_idx"]
# Stats and counts.
self._num_timesteps_added = state["added_count"]
self._num_timesteps_added_wrap = state["added_count_wrapped"]
self._eviction_started = state["eviction_started"]
self._num_timesteps_sampled = state["sampled_count"]
self._est_size_bytes = state["est_size_bytes"]
def _encode_sample(self, idxes: List[int]) -> SampleBatchType:
out = SampleBatch.concat_samples([self._storage[i] for i in idxes])
out.decompress_if_needed()
return out
@DeveloperAPI
class PrioritizedReplayBuffer(ReplayBuffer):
@DeveloperAPI
def __init__(
self,
capacity: int = 10000,
alpha: float = 1.0,
size: Optional[int] = DEPRECATED_VALUE,
):
"""Initializes a PrioritizedReplayBuffer instance.
Args:
capacity: Max number of timesteps to store in the FIFO
buffer. After reaching this number, older samples will be
dropped to make space for new ones.
alpha: How much prioritization is used
(0.0=no prioritization, 1.0=full prioritization).
"""
super(PrioritizedReplayBuffer, self).__init__(capacity, size)
assert alpha > 0
self._alpha = alpha
it_capacity = 1
while it_capacity < self.capacity:
it_capacity *= 2
self._it_sum = SumSegmentTree(it_capacity)
self._it_min = MinSegmentTree(it_capacity)
self._max_priority = 1.0
self._prio_change_stats = WindowStat("reprio", 1000)
@DeveloperAPI
@override(ReplayBuffer)
def add(self, item: SampleBatchType, weight: float) -> None:
"""Add a batch of experiences.
Args:
item: SampleBatch to add to this buffer's storage.
weight: The weight of the added sample used in subsequent sampling
steps.
"""
idx = self._next_idx
super(PrioritizedReplayBuffer, self).add(item, weight)
if weight is None:
weight = self._max_priority
self._it_sum[idx] = weight ** self._alpha
self._it_min[idx] = weight ** self._alpha
def _sample_proportional(self, num_items: int) -> List[int]:
res = []
for _ in range(num_items):
# TODO(szymon): should we ensure no repeats?
mass = random.random() * self._it_sum.sum(0, len(self._storage))
idx = self._it_sum.find_prefixsum_idx(mass)
res.append(idx)
return res
@DeveloperAPI
@override(ReplayBuffer)
def sample(self, num_items: int, beta: float) -> SampleBatchType:
"""Sample `num_items` items from this buffer, including prio. weights.
If less than `num_items` records are in this buffer, some samples in
the results may be repeated to fulfil the batch size (`num_items`)
request.
Args:
num_items: Number of items to sample from this buffer.
beta: To what degree to use importance weights
(0 - no corrections, 1 - full correction).
Returns:
Concatenated batch of items including "weights" and
"batch_indexes" fields denoting IS of each sampled
transition and original idxes in buffer of sampled experiences.
"""
# If we don't have any samples yet in this buffer, return None.
if len(self) == 0:
return None
assert beta >= 0.0
idxes = self._sample_proportional(num_items)
weights = []
batch_indexes = []
p_min = self._it_min.min() / self._it_sum.sum()
max_weight = (p_min * len(self)) ** (-beta)
for idx in idxes:
p_sample = self._it_sum[idx] / self._it_sum.sum()
weight = (p_sample * len(self)) ** (-beta)
count = self._storage[idx].count
# If zero-padded, count will not be the actual batch size of the
# data.
if (
isinstance(self._storage[idx], SampleBatch)
and self._storage[idx].zero_padded
):
actual_size = self._storage[idx].max_seq_len
else:
actual_size = count
weights.extend([weight / max_weight] * actual_size)
batch_indexes.extend([idx] * actual_size)
self._num_timesteps_sampled += count
batch = self._encode_sample(idxes)
# Note: prioritization is not supported in lockstep replay mode.
if isinstance(batch, SampleBatch):
batch["weights"] = np.array(weights)
batch["batch_indexes"] = np.array(batch_indexes)
return batch
@DeveloperAPI
def update_priorities(self, idxes: List[int], priorities: List[float]) -> None:
"""Update priorities of sampled transitions.
Sets priority of transition at index idxes[i] in buffer
to priorities[i].
Args:
idxes: List of indices of sampled transitions
priorities: List of updated priorities corresponding to
transitions at the sampled idxes denoted by
variable `idxes`.
"""
# Making sure we don't pass in e.g. a torch tensor.
assert isinstance(
idxes, (list, np.ndarray)
), "ERROR: `idxes` is not a list or np.ndarray, but {}!".format(
type(idxes).__name__
)
assert len(idxes) == len(priorities)
for idx, priority in zip(idxes, priorities):
assert priority > 0
assert 0 <= idx < len(self._storage)
delta = priority ** self._alpha - self._it_sum[idx]
self._prio_change_stats.push(delta)
self._it_sum[idx] = priority ** self._alpha
self._it_min[idx] = priority ** self._alpha
self._max_priority = max(self._max_priority, priority)
@DeveloperAPI
@override(ReplayBuffer)
def stats(self, debug: bool = False) -> Dict:
"""Returns the stats of this buffer.
Args:
debug: If true, adds sample eviction statistics to the
returned stats dict.
Returns:
A dictionary of stats about this buffer.
"""
parent = ReplayBuffer.stats(self, debug)
if debug:
parent.update(self._prio_change_stats.stats())
return parent
@DeveloperAPI
@override(ReplayBuffer)
def get_state(self) -> Dict[str, Any]:
"""Returns all local state.
Returns:
The serializable local state.
"""
# Get parent state.
state = super().get_state()
# Add prio weights.
state.update(
{
"sum_segment_tree": self._it_sum.get_state(),
"min_segment_tree": self._it_min.get_state(),
"max_priority": self._max_priority,
}
)
return state
@DeveloperAPI
@override(ReplayBuffer)
def set_state(self, state: Dict[str, Any]) -> None:
"""Restores all local state to the provided `state`.
Args:
state: The new state to set this buffer. Can be obtained by
calling `self.get_state()`.
"""
super().set_state(state)
self._it_sum.set_state(state["sum_segment_tree"])
self._it_min.set_state(state["min_segment_tree"])
self._max_priority = state["max_priority"]
# Visible for testing.
_local_replay_buffer = None

View file

@ -4,7 +4,7 @@ import threading
from typing import Dict, Optional
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.execution.buffers.minibatch_buffer import MinibatchBuffer
from ray.rllib.execution.minibatch_buffer import MinibatchBuffer
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder, LEARNER_INFO
from ray.rllib.utils.timer import TimerStat

View file

@ -3,7 +3,7 @@ from six.moves import queue
import threading
from ray.rllib.execution.learner_thread import LearnerThread
from ray.rllib.execution.buffers.minibatch_buffer import MinibatchBuffer
from ray.rllib.execution.minibatch_buffer import MinibatchBuffer
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import deprecation_warning

View file

@ -4,8 +4,10 @@ import random
from ray.actor import ActorHandle
from ray.util.iter import from_actors, LocalIterator, _NextValueNotReady
from ray.util.iter_metrics import SharedMetrics
from ray.rllib.execution.buffers.replay_buffer import warn_replay_capacity
from ray.rllib.execution.buffers.multi_agent_replay_buffer import MultiAgentReplayBuffer
from ray.rllib.utils.replay_buffers.replay_buffer import warn_replay_capacity
from ray.rllib.utils.replay_buffers.multi_agent_replay_buffer import (
MultiAgentReplayBuffer,
)
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, _get_shared_metrics
from ray.rllib.utils.typing import SampleBatchType
@ -21,7 +23,7 @@ class StoreToReplayBuffer:
The batch that was stored is returned.
Examples:
>>> from ray.rllib.execution.buffers import multi_agent_replay_buffer
>>> from ray.rllib.utils.replay_buffers import multi_agent_replay_buffer
>>> from ray.rllib.execution.replay_ops import StoreToReplayBuffer
>>> from ray.rllib.execution import ParallelRollouts
>>> actors = [ # doctest: +SKIP
@ -59,10 +61,10 @@ class StoreToReplayBuffer:
def __call__(self, batch: SampleBatchType):
if self.local_actor is not None:
self.local_actor.add_batch(batch)
self.local_actor.add(batch)
else:
actor = random.choice(self.replay_actors)
actor.add_batch.remote(batch)
actor.add.remote(batch)
return batch
@ -86,7 +88,7 @@ def Replay(
per actor.
Examples:
>>> from ray.rllib.execution.buffers import multi_agent_replay_buffer
>>> from ray.rllib.utils.replay_buffers import multi_agent_replay_buffer
>>> actors = [ # doctest: +SKIP
... multi_agent_replay_buffer.ReplayActor.remote() for _ in range(4)]
>>> replay_op = Replay(actors=actors) # doctest: +SKIP

View file

@ -1,147 +0,0 @@
import numpy as np
import unittest
from ray.rllib.execution.buffers.mixin_replay_buffer import MixInMultiAgentReplayBuffer
from ray.rllib.policy.sample_batch import SampleBatch
class TestMixInMultiAgentReplayBuffer(unittest.TestCase):
"""Tests insertion and mixed sampling of the MixInMultiAgentReplayBuffer."""
capacity = 10
def _generate_data(self):
return SampleBatch(
{
"obs": [np.random.random((4,))],
"action": [np.random.choice([0, 1])],
"reward": [np.random.rand()],
"new_obs": [np.random.random((4,))],
"done": [np.random.choice([False, True])],
}
)
def test_mixin_sampling(self):
# 50% replay ratio.
buffer = MixInMultiAgentReplayBuffer(capacity=self.capacity, replay_ratio=0.5)
# Add a new batch.
batch = self._generate_data()
buffer.add_batch(batch)
# Expect at least 1 sample to be returned.
sample = buffer.replay()
self.assertTrue(len(sample) >= 1)
# If we insert and replay n times, expect roughly return batches of
# len 2 (replay_ratio=0.5 -> 50% replayed samples -> 1 new and 1 old sample
# on average in each returned value).
results = []
for _ in range(100):
buffer.add_batch(batch)
sample = buffer.replay()
results.append(len(sample))
self.assertAlmostEqual(np.mean(results), 2.0)
# 33% replay ratio.
buffer = MixInMultiAgentReplayBuffer(capacity=self.capacity, replay_ratio=0.333)
# Expect exactly 0 samples to be returned (buffer empty).
sample = buffer.replay()
self.assertTrue(sample is None)
# Add a new batch.
batch = self._generate_data()
buffer.add_batch(batch)
# Expect at least 1 sample to be returned.
sample = buffer.replay()
self.assertTrue(len(sample) >= 1)
# If we insert-2x and replay n times, expect roughly return batches of
# len 3 (replay_ratio=0.33 -> 33% replayed samples -> 2 new and 1 old sample
# on average in each returned value).
results = []
for _ in range(100):
buffer.add_batch(batch)
buffer.add_batch(batch)
sample = buffer.replay()
results.append(len(sample))
self.assertAlmostEqual(np.mean(results), 3.0, delta=0.1)
# If we insert-1x and replay n times, expect roughly return batches of
# len 1.5 (replay_ratio=0.33 -> 33% replayed samples -> 1 new and 0.5 old
# samples on average in each returned value).
results = []
for _ in range(100):
buffer.add_batch(batch)
sample = buffer.replay()
results.append(len(sample))
self.assertAlmostEqual(np.mean(results), 1.5, delta=0.1)
# 90% replay ratio.
buffer = MixInMultiAgentReplayBuffer(capacity=self.capacity, replay_ratio=0.9)
# Expect exactly 0 samples to be returned (buffer empty).
sample = buffer.replay()
self.assertTrue(sample is None)
# Add a new batch.
batch = self._generate_data()
buffer.add_batch(batch)
# Expect at least 2 samples to be returned (new one plus at least one
# replay sample).
sample = buffer.replay()
self.assertTrue(len(sample) >= 2)
# If we insert and replay n times, expect roughly return batches of
# len 10 (replay_ratio=0.9 -> 90% replayed samples -> 1 new and 9 old
# samples on average in each returned value).
results = []
for _ in range(100):
buffer.add_batch(batch)
sample = buffer.replay()
results.append(len(sample))
self.assertAlmostEqual(np.mean(results), 10.0, delta=0.1)
# 0% replay ratio -> Only new samples.
buffer = MixInMultiAgentReplayBuffer(capacity=self.capacity, replay_ratio=0.0)
# Add a new batch.
batch = self._generate_data()
buffer.add_batch(batch)
# Expect exactly 1 sample to be returned.
sample = buffer.replay()
self.assertTrue(len(sample) == 1)
# Expect exactly 0 sample to be returned (nothing new to be returned;
# no replay allowed (replay_ratio=0.0)).
sample = buffer.replay()
self.assertTrue(sample is None)
# If we insert and replay n times, expect roughly return batches of
# len 1 (replay_ratio=0.0 -> 0% replayed samples -> 1 new and 0 old samples
# on average in each returned value).
results = []
for _ in range(100):
buffer.add_batch(batch)
sample = buffer.replay()
results.append(len(sample))
self.assertAlmostEqual(np.mean(results), 1.0)
# 100% replay ratio -> Only new samples.
buffer = MixInMultiAgentReplayBuffer(capacity=self.capacity, replay_ratio=1.0)
# Expect exactly 0 samples to be returned (buffer empty).
sample = buffer.replay()
self.assertTrue(sample is None)
# Add a new batch.
batch = self._generate_data()
buffer.add_batch(batch)
# Expect exactly 1 sample to be returned (the new batch).
sample = buffer.replay()
self.assertTrue(len(sample) == 1)
# Another replay -> Expect exactly 1 sample to be returned.
sample = buffer.replay()
self.assertTrue(len(sample) == 1)
# If we replay n times, expect roughly return batches of
# len 1 (replay_ratio=1.0 -> 100% replayed samples -> 0 new and 1 old samples
# on average in each returned value).
results = []
for _ in range(100):
sample = buffer.replay()
results.append(len(sample))
self.assertAlmostEqual(np.mean(results), 1.0)
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -1,308 +0,0 @@
from collections import Counter
import numpy as np
import unittest
from ray.rllib.execution.buffers.replay_buffer import PrioritizedReplayBuffer
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.test_utils import check
class TestPrioritizedReplayBuffer(unittest.TestCase):
"""
Tests insertion and (weighted) sampling of the PrioritizedReplayBuffer.
"""
capacity = 10
alpha = 1.0
beta = 1.0
max_priority = 1.0
def _generate_data(self):
return SampleBatch(
{
"obs_t": [np.random.random((4,))],
"action": [np.random.choice([0, 1])],
"reward": [np.random.rand()],
"obs_tp1": [np.random.random((4,))],
"done": [np.random.choice([False, True])],
}
)
def test_sequence_size(self):
# Seq-len=1.
memory = PrioritizedReplayBuffer(capacity=100, alpha=0.1)
for _ in range(200):
memory.add(self._generate_data(), weight=None)
assert len(memory._storage) == 100, len(memory._storage)
assert memory.stats()["added_count"] == 200, memory.stats()
# Test get_state/set_state.
state = memory.get_state()
new_memory = PrioritizedReplayBuffer(capacity=100, alpha=0.1)
new_memory.set_state(state)
assert len(new_memory._storage) == 100, len(new_memory._storage)
assert new_memory.stats()["added_count"] == 200, new_memory.stats()
# Seq-len=5.
memory = PrioritizedReplayBuffer(capacity=100, alpha=0.1)
for _ in range(40):
memory.add(
SampleBatch.concat_samples([self._generate_data() for _ in range(5)]),
weight=None,
)
assert len(memory._storage) == 20, len(memory._storage)
assert memory.stats()["added_count"] == 200, memory.stats()
# Test get_state/set_state.
state = memory.get_state()
new_memory = PrioritizedReplayBuffer(capacity=100, alpha=0.1)
new_memory.set_state(state)
assert len(new_memory._storage) == 20, len(new_memory._storage)
assert new_memory.stats()["added_count"] == 200, new_memory.stats()
def test_add(self):
memory = PrioritizedReplayBuffer(capacity=2, alpha=self.alpha)
# Assert indices 0 before insert.
self.assertEqual(len(memory), 0)
self.assertEqual(memory._next_idx, 0)
# Insert single record.
data = self._generate_data()
memory.add(data, weight=0.5)
self.assertTrue(len(memory) == 1)
self.assertTrue(memory._next_idx == 1)
# Insert single record.
data = self._generate_data()
memory.add(data, weight=0.1)
self.assertTrue(len(memory) == 2)
self.assertTrue(memory._next_idx == 0)
# Insert over capacity.
data = self._generate_data()
memory.add(data, weight=1.0)
self.assertTrue(len(memory) == 2)
self.assertTrue(memory._next_idx == 1)
# Test get_state/set_state.
state = memory.get_state()
new_memory = PrioritizedReplayBuffer(capacity=2, alpha=self.alpha)
new_memory.set_state(state)
self.assertTrue(len(new_memory) == 2)
self.assertTrue(new_memory._next_idx == 1)
def test_update_priorities(self):
memory = PrioritizedReplayBuffer(self.capacity, alpha=self.alpha)
# Insert n samples.
num_records = 5
for i in range(num_records):
data = self._generate_data()
memory.add(data, weight=1.0)
self.assertTrue(len(memory) == i + 1)
self.assertTrue(memory._next_idx == i + 1)
# Test get_state/set_state.
state = memory.get_state()
new_memory = PrioritizedReplayBuffer(self.capacity, alpha=self.alpha)
new_memory.set_state(state)
self.assertTrue(len(new_memory) == num_records)
self.assertTrue(new_memory._next_idx == num_records)
# Fetch records, their indices and weights.
batch = memory.sample(3, beta=self.beta)
weights = batch["weights"]
indices = batch["batch_indexes"]
check(weights, np.ones(shape=(3,)))
self.assertEqual(3, len(indices))
self.assertTrue(len(memory) == num_records)
self.assertTrue(memory._next_idx == num_records)
# Update weight of indices 0, 2, 3, 4 to very small.
memory.update_priorities(
np.array([0, 2, 3, 4]), np.array([0.01, 0.01, 0.01, 0.01])
)
# Expect to sample almost only index 1
# (which still has a weight of 1.0).
for _ in range(10):
batch = memory.sample(1000, beta=self.beta)
indices = batch["batch_indexes"]
self.assertTrue(970 < np.sum(indices) < 1100)
# Test get_state/set_state.
state = memory.get_state()
new_memory = PrioritizedReplayBuffer(self.capacity, alpha=self.alpha)
new_memory.set_state(state)
batch = new_memory.sample(1000, beta=self.beta)
indices = batch["batch_indexes"]
self.assertTrue(970 < np.sum(indices) < 1100)
# Update weight of indices 0 and 1 to >> 0.01.
# Expect to sample 0 and 1 equally (and some 2s, 3s, and 4s).
for _ in range(10):
rand = np.random.random() + 0.2
memory.update_priorities(np.array([0, 1]), np.array([rand, rand]))
batch = memory.sample(1000, beta=self.beta)
indices = batch["batch_indexes"]
# Expect biased to higher values due to some 2s, 3s, and 4s.
self.assertTrue(400 < np.sum(indices) < 800)
# Test get_state/set_state.
state = memory.get_state()
new_memory = PrioritizedReplayBuffer(self.capacity, alpha=self.alpha)
new_memory.set_state(state)
batch = new_memory.sample(1000, beta=self.beta)
indices = batch["batch_indexes"]
self.assertTrue(400 < np.sum(indices) < 800)
# Update weights to be 1:2.
# Expect to sample double as often index 1 over index 0
# plus very few times indices 2, 3, or 4.
for _ in range(10):
rand = np.random.random() + 0.2
memory.update_priorities(np.array([0, 1]), np.array([rand, rand * 2]))
batch = memory.sample(1000, beta=self.beta)
indices = batch["batch_indexes"]
# print(np.sum(indices))
self.assertTrue(600 < np.sum(indices) < 850)
# Test get_state/set_state.
state = memory.get_state()
new_memory = PrioritizedReplayBuffer(self.capacity, alpha=self.alpha)
new_memory.set_state(state)
batch = new_memory.sample(1000, beta=self.beta)
indices = batch["batch_indexes"]
self.assertTrue(600 < np.sum(indices) < 850)
# Update weights to be 1:4.
# Expect to sample quadruple as often index 1 over index 0
# plus very few times indices 2, 3, or 4.
for _ in range(10):
rand = np.random.random() + 0.2
memory.update_priorities(np.array([0, 1]), np.array([rand, rand * 4]))
batch = memory.sample(1000, beta=self.beta)
indices = batch["batch_indexes"]
self.assertTrue(750 < np.sum(indices) < 950)
# Test get_state/set_state.
state = memory.get_state()
new_memory = PrioritizedReplayBuffer(self.capacity, alpha=self.alpha)
new_memory.set_state(state)
batch = new_memory.sample(1000, beta=self.beta)
indices = batch["batch_indexes"]
self.assertTrue(750 < np.sum(indices) < 950)
# Update weights to be 1:9.
# Expect to sample 9 times as often index 1 over index 0.
# plus very few times indices 2, 3, or 4.
for _ in range(10):
rand = np.random.random() + 0.2
memory.update_priorities(np.array([0, 1]), np.array([rand, rand * 9]))
batch = memory.sample(1000, beta=self.beta)
indices = batch["batch_indexes"]
self.assertTrue(850 < np.sum(indices) < 1100)
# Test get_state/set_state.
state = memory.get_state()
new_memory = PrioritizedReplayBuffer(self.capacity, alpha=self.alpha)
new_memory.set_state(state)
batch = new_memory.sample(1000, beta=self.beta)
indices = batch["batch_indexes"]
self.assertTrue(850 < np.sum(indices) < 1100)
# Insert n more samples.
num_records = 5
for i in range(num_records):
data = self._generate_data()
memory.add(data, weight=1.0)
self.assertTrue(len(memory) == i + 6)
self.assertTrue(memory._next_idx == (i + 6) % self.capacity)
# Update all weights to be 1.0 to 10.0 and sample a >100 batch.
memory.update_priorities(
np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
np.array([0.001, 0.1, 2.0, 8.0, 16.0, 32.0, 64.0, 128.0, 256.0, 512.0]),
)
counts = Counter()
for _ in range(10):
batch = memory.sample(np.random.randint(100, 600), beta=self.beta)
indices = batch["batch_indexes"]
for i in indices:
counts[i] += 1
print(counts)
# Expect an approximately correct distribution of indices.
self.assertTrue(
counts[9]
>= counts[8]
>= counts[7]
>= counts[6]
>= counts[5]
>= counts[4]
>= counts[3]
>= counts[2]
>= counts[1]
>= counts[0]
)
# Test get_state/set_state.
state = memory.get_state()
new_memory = PrioritizedReplayBuffer(self.capacity, alpha=self.alpha)
new_memory.set_state(state)
counts = Counter()
for _ in range(10):
batch = new_memory.sample(np.random.randint(100, 600), beta=self.beta)
indices = batch["batch_indexes"]
for i in indices:
counts[i] += 1
print(counts)
self.assertTrue(
counts[9]
>= counts[8]
>= counts[7]
>= counts[6]
>= counts[5]
>= counts[4]
>= counts[3]
>= counts[2]
>= counts[1]
>= counts[0]
)
def test_alpha_parameter(self):
# Test sampling from a PR with a very small alpha (should behave just
# like a regular ReplayBuffer).
memory = PrioritizedReplayBuffer(self.capacity, alpha=0.01)
# Insert n samples.
num_records = 5
for i in range(num_records):
data = self._generate_data()
memory.add(data, weight=np.random.rand())
self.assertTrue(len(memory) == i + 1)
self.assertTrue(memory._next_idx == i + 1)
# Test get_state/set_state.
state = memory.get_state()
new_memory = PrioritizedReplayBuffer(self.capacity, alpha=0.01)
new_memory.set_state(state)
self.assertTrue(len(new_memory) == num_records)
self.assertTrue(new_memory._next_idx == num_records)
# Fetch records, their indices and weights.
batch = memory.sample(1000, beta=self.beta)
indices = batch["batch_indexes"]
counts = Counter()
for i in indices:
counts[i] += 1
print(counts)
# Expect an approximately uniform distribution of indices.
self.assertTrue(any(100 < i < 300 for i in counts.values()))
# Test get_state/set_state.
state = memory.get_state()
new_memory = PrioritizedReplayBuffer(self.capacity, alpha=0.01)
new_memory.set_state(state)
batch = new_memory.sample(1000, beta=self.beta)
indices = batch["batch_indexes"]
counts = Counter()
for i in indices:
counts[i] += 1
self.assertTrue(any(100 < i < 300 for i in counts.values()))
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -1,102 +0,0 @@
import numpy as np
import unittest
from ray.rllib.execution.segment_tree import SumSegmentTree, MinSegmentTree
class TestSegmentTree(unittest.TestCase):
def test_tree_set(self):
tree = SumSegmentTree(4)
tree[2] = 1.0
tree[3] = 3.0
assert np.isclose(tree.sum(), 4.0)
assert np.isclose(tree.sum(0, 2), 0.0)
assert np.isclose(tree.sum(0, 3), 1.0)
assert np.isclose(tree.sum(2, 3), 1.0)
assert np.isclose(tree.sum(2, -1), 1.0)
assert np.isclose(tree.sum(2, 4), 4.0)
assert np.isclose(tree.sum(2), 4.0)
def test_tree_set_overlap(self):
tree = SumSegmentTree(4)
tree[2] = 1.0
tree[2] = 3.0
assert np.isclose(tree.sum(), 3.0)
assert np.isclose(tree.sum(2, 3), 3.0)
assert np.isclose(tree.sum(2, -1), 3.0)
assert np.isclose(tree.sum(2, 4), 3.0)
assert np.isclose(tree.sum(2), 3.0)
assert np.isclose(tree.sum(1, 2), 0.0)
def test_prefixsum_idx(self):
tree = SumSegmentTree(4)
tree[2] = 1.0
tree[3] = 3.0
assert tree.find_prefixsum_idx(0.0) == 2
assert tree.find_prefixsum_idx(0.5) == 2
assert tree.find_prefixsum_idx(0.99) == 2
assert tree.find_prefixsum_idx(1.01) == 3
assert tree.find_prefixsum_idx(3.00) == 3
assert tree.find_prefixsum_idx(4.00) == 3
def test_prefixsum_idx2(self):
tree = SumSegmentTree(4)
tree[0] = 0.5
tree[1] = 1.0
tree[2] = 1.0
tree[3] = 3.0
assert tree.find_prefixsum_idx(0.00) == 0
assert tree.find_prefixsum_idx(0.55) == 1
assert tree.find_prefixsum_idx(0.99) == 1
assert tree.find_prefixsum_idx(1.51) == 2
assert tree.find_prefixsum_idx(3.00) == 3
assert tree.find_prefixsum_idx(5.50) == 3
def test_max_interval_tree(self):
tree = MinSegmentTree(4)
tree[0] = 1.0
tree[2] = 0.5
tree[3] = 3.0
assert np.isclose(tree.min(), 0.5)
assert np.isclose(tree.min(0, 2), 1.0)
assert np.isclose(tree.min(0, 3), 0.5)
assert np.isclose(tree.min(0, -1), 0.5)
assert np.isclose(tree.min(2, 4), 0.5)
assert np.isclose(tree.min(3, 4), 3.0)
tree[2] = 0.7
assert np.isclose(tree.min(), 0.7)
assert np.isclose(tree.min(0, 2), 1.0)
assert np.isclose(tree.min(0, 3), 0.7)
assert np.isclose(tree.min(0, -1), 0.7)
assert np.isclose(tree.min(2, 4), 0.7)
assert np.isclose(tree.min(3, 4), 3.0)
tree[2] = 4.0
assert np.isclose(tree.min(), 1.0)
assert np.isclose(tree.min(0, 2), 1.0)
assert np.isclose(tree.min(0, 3), 1.0)
assert np.isclose(tree.min(0, -1), 1.0)
assert np.isclose(tree.min(2, 4), 3.0)
assert np.isclose(tree.min(2, 3), 4.0)
assert np.isclose(tree.min(2, -1), 4.0)
assert np.isclose(tree.min(3, 4), 3.0)
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -42,10 +42,15 @@ class TestEagerSupportPG(unittest.TestCase):
ray.shutdown()
def test_simple_q(self):
check_support("SimpleQ", {"num_workers": 0, "learning_starts": 0})
check_support(
"SimpleQ",
{"num_workers": 0, "replay_buffer_config": {"learning_starts": 0}},
)
def test_dqn(self):
check_support("DQN", {"num_workers": 0, "learning_starts": 0})
check_support(
"DQN", {"num_workers": 0, "replay_buffer_config": {"learning_starts": 0}}
)
def test_ddpg(self):
check_support("DDPG", {"num_workers": 0})
@ -84,10 +89,15 @@ class TestEagerSupportOffPolicy(unittest.TestCase):
ray.shutdown()
def test_simple_q(self):
check_support("SimpleQ", {"num_workers": 0, "learning_starts": 0})
check_support(
"SimpleQ",
{"num_workers": 0, "replay_buffer_config": {"learning_starts": 0}},
)
def test_dqn(self):
check_support("DQN", {"num_workers": 0, "learning_starts": 0})
check_support(
"DQN", {"num_workers": 0, "replay_buffer_config": {"learning_starts": 0}}
)
def test_ddpg(self):
check_support("DDPG", {"num_workers": 0})
@ -103,7 +113,7 @@ class TestEagerSupportOffPolicy(unittest.TestCase):
"APEX",
{
"num_workers": 2,
"learning_starts": 0,
"replay_buffer_config": {"learning_starts": 0},
"num_gpus": 0,
"min_time_s_per_reporting": 1,
"min_sample_timesteps_per_reporting": 100,
@ -114,7 +124,9 @@ class TestEagerSupportOffPolicy(unittest.TestCase):
)
def test_sac(self):
check_support("SAC", {"num_workers": 0, "learning_starts": 0})
check_support(
"SAC", {"num_workers": 0, "replay_buffer_config": {"learning_starts": 0}}
)
if __name__ == "__main__":

View file

@ -23,9 +23,8 @@ from ray.rllib.execution.train_ops import (
ComputeGradients,
AverageGradients,
)
from ray.rllib.execution.buffers.multi_agent_replay_buffer import (
from ray.rllib.utils.replay_buffers.multi_agent_replay_buffer import (
MultiAgentReplayBuffer,
ReplayActor,
)
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
from ray.util.iter import LocalIterator, from_range
@ -216,40 +215,41 @@ class TestExecution(unittest.TestCase):
prioritized_replay_beta=0.4,
prioritized_replay_eps=0.0001,
)
assert buf.replay() is None
assert len(buf.sample(100)) == 0
workers = make_workers(0)
a = ParallelRollouts(workers, mode="bulk_sync")
b = a.for_each(StoreToReplayBuffer(local_buffer=buf))
next(b)
assert buf.replay() is None # learning hasn't started yet
assert len(buf.sample(100)) == 0 # learning hasn't started yet
next(b)
assert buf.replay().count == 100
assert buf.sample(100).count == 100
replay_op = Replay(local_buffer=buf)
assert next(replay_op).count == 100
def test_store_to_replay_actor(self):
ReplayActor = ray.remote(num_cpus=0)(MultiAgentReplayBuffer)
actor = ReplayActor.remote(
num_shards=1,
learning_starts=200,
buffer_size=1000,
capacity=1000,
replay_batch_size=100,
prioritized_replay_alpha=0.6,
prioritized_replay_beta=0.4,
prioritized_replay_eps=0.0001,
)
assert ray.get(actor.replay.remote()) is None
assert len(ray.get(actor.sample.remote(100))) == 0
workers = make_workers(0)
a = ParallelRollouts(workers, mode="bulk_sync")
b = a.for_each(StoreToReplayBuffer(actors=[actor]))
next(b)
assert ray.get(actor.replay.remote()) is None # learning hasn't started
assert len(ray.get(actor.sample.remote(100))) == 0 # learning hasn't started
next(b)
assert ray.get(actor.replay.remote()).count == 100
assert ray.get(actor.sample.remote(100)).count == 100
replay_op = Replay(actors=[actor])
assert next(replay_op).count == 100

View file

@ -95,9 +95,11 @@ class TestSupportedMultiAgentOffPolicy(unittest.TestCase):
"num_workers": 2,
"min_sample_timesteps_per_reporting": 100,
"num_gpus": 0,
"replay_buffer_config": {"capacity": 1000},
"replay_buffer_config": {
"capacity": 1000,
"learning_starts": 10,
},
"min_time_s_per_reporting": 1,
"learning_starts": 10,
"target_network_update_freq": 100,
"optimizer": {
"num_replay_buffer_shards": 1,
@ -111,10 +113,12 @@ class TestSupportedMultiAgentOffPolicy(unittest.TestCase):
{
"num_workers": 2,
"min_sample_timesteps_per_reporting": 100,
"replay_buffer_config": {"capacity": 1000},
"replay_buffer_config": {
"capacity": 1000,
"learning_starts": 10,
},
"num_gpus": 0,
"min_time_s_per_reporting": 1,
"learning_starts": 10,
"target_network_update_freq": 100,
"use_state_preprocessor": True,
},
@ -125,9 +129,11 @@ class TestSupportedMultiAgentOffPolicy(unittest.TestCase):
"DDPG",
{
"min_sample_timesteps_per_reporting": 1,
"replay_buffer_config": {"capacity": 1000},
"replay_buffer_config": {
"capacity": 1000,
"learning_starts": 500,
},
"use_state_preprocessor": True,
"learning_starts": 500,
},
)
@ -136,7 +142,9 @@ class TestSupportedMultiAgentOffPolicy(unittest.TestCase):
"DQN",
{
"min_sample_timesteps_per_reporting": 1,
"replay_buffer_config": {"capacity": 1000},
"replay_buffer_config": {
"capacity": 1000,
},
},
)
@ -145,7 +153,9 @@ class TestSupportedMultiAgentOffPolicy(unittest.TestCase):
"SAC",
{
"num_workers": 0,
"replay_buffer_config": {"capacity": 1000},
"replay_buffer_config": {
"capacity": 1000,
},
"normalize_actions": False,
},
)

View file

@ -182,18 +182,27 @@ class TestSupportedSpacesOffPolicy(unittest.TestCase):
{
"exploration_config": {"ou_base_scale": 100.0},
"min_sample_timesteps_per_reporting": 1,
"buffer_size": 1000,
"replay_buffer_config": {
"capacity": 1000,
},
"use_state_preprocessor": True,
},
check_bounds=True,
)
def test_dqn(self):
config = {"min_sample_timesteps_per_reporting": 1, "buffer_size": 1000}
config = {
"min_sample_timesteps_per_reporting": 1,
"replay_buffer_config": {
"capacity": 1000,
},
}
check_support("DQN", config, tfe=True)
def test_sac(self):
check_support("SAC", {"buffer_size": 1000}, check_bounds=True)
check_support(
"SAC", {"replay_buffer_config": {"capacity": 1000}}, check_bounds=True
)
class TestSupportedSpacesEvolutionAlgos(unittest.TestCase):

View file

@ -86,11 +86,13 @@ apex:
lr: .0001
adam_epsilon: .00015
hiddens: [512]
buffer_size: 1000000
exploration_config:
epsilon_timesteps: 200000
final_epsilon: 0.01
prioritized_replay_alpha: 0.5
replay_buffer_config:
type: MultiAgentPrioritizedReplayBuffer
prioritized_replay_alpha: 0.5
capacity: 1000000
num_gpus: 1
num_workers: 8
num_envs_per_worker: 8
@ -125,19 +127,19 @@ atari-basic-dqn:
dueling: false
num_atoms: 1
noisy: false
prioritized_replay: false
replay_buffer_config:
type: MultiAgentReplayBuffer
learning_starts: 20000
capacity: 1000000
n_step: 1
target_network_update_freq: 8000
lr: .0000625
adam_epsilon: .00015
hiddens: [512]
learning_starts: 20000
buffer_size: 1000000
rollout_fragment_length: 4
train_batch_size: 32
exploration_config:
epsilon_timesteps: 200000
final_epsilon: 0.01
prioritized_replay_alpha: 0.5
num_gpus: 0.2
min_sample_timesteps_per_reporting: 10000

View file

@ -26,11 +26,12 @@ halfcheetah_bc:
no_done_at_end: false
n_step: 1
rollout_fragment_length: 1
prioritized_replay: false
replay_buffer_config:
type: MultiAgentReplayBuffer
learning_starts: 10
train_batch_size: 256
target_network_update_freq: 0
min_train_timesteps_per_reporting: 1000
learning_starts: 10
optimization:
actor_learning_rate: 0.0001
critic_learning_rate: 0.0003

View file

@ -28,11 +28,12 @@ halfcheetah_cql:
no_done_at_end: false
n_step: 3
rollout_fragment_length: 1
prioritized_replay: false
replay_buffer_config:
type: MultiAgentReplayBuffer
learning_starts: 256
train_batch_size: 256
target_network_update_freq: 0
min_train_timesteps_per_reporting: 1000
learning_starts: 256
optimization:
actor_learning_rate: 0.0001
critic_learning_rate: 0.0003

View file

@ -26,11 +26,12 @@ hopper_bc:
no_done_at_end: false
n_step: 1
rollout_fragment_length: 1
prioritized_replay: false
replay_buffer_config:
type: MultiAgentReplayBuffer
learning_starts: 10
train_batch_size: 256
target_network_update_freq: 0
min_train_timesteps_per_reporting: 1000
learning_starts: 10
optimization:
actor_learning_rate: 0.0001
critic_learning_rate: 0.0003

View file

@ -26,11 +26,12 @@ hopper_cql:
no_done_at_end: false
n_step: 1
rollout_fragment_length: 1
prioritized_replay: false
replay_buffer_config:
type: MultiAgentReplayBuffer
learning_starts: 10
train_batch_size: 256
target_network_update_freq: 0
min_train_timesteps_per_reporting: 1000
learning_starts: 10
optimization:
actor_learning_rate: 0.0001
critic_learning_rate: 0.0003

View file

@ -22,6 +22,7 @@ pendulum-cql:
twin_q: true
train_batch_size: 2000
replay_buffer_config:
type: MultiAgentReplayBuffer
learning_starts: 0
bc_iters: 100

View file

@ -31,10 +31,12 @@ halfcheetah-ddpg:
# === Replay buffer ===
replay_buffer_config:
type: MultiAgentPrioritizedReplayBuffer
capacity: 10000
prioritized_replay_alpha: 0.6
prioritized_replay_beta: 0.4
prioritized_replay_eps: 0.000001
worker_side_prioritization: false
clip_rewards: False
# === Optimization ===
@ -50,7 +52,6 @@ halfcheetah-ddpg:
# === Parallelism ===
num_workers: 0
num_gpus_per_worker: 0
worker_side_prioritization: false
# === Evaluation ===
evaluation_interval: 5

View file

@ -23,10 +23,12 @@ ddpg-halfcheetahbulletenv-v0:
target_network_update_freq: 0
tau: 0.001
replay_buffer_config:
type: MultiAgentPrioritizedReplayBuffer
capacity: 15000
prioritized_replay_alpha: 0.6
prioritized_replay_beta: 0.4
prioritized_replay_eps: 0.000001
worker_side_prioritization: false
clip_rewards: false
actor_lr: 0.001
critic_lr: 0.001
@ -39,4 +41,3 @@ ddpg-halfcheetahbulletenv-v0:
num_workers: 0
num_gpus: 1
num_gpus_per_worker: 0
worker_side_prioritization: false

View file

@ -26,19 +26,20 @@ ddpg-hopperbulletenv-v0:
target_network_update_freq: 0
tau: 0.001
replay_buffer_config:
type: MultiAgentPrioritizedReplayBuffer
capacity: 10000
prioritized_replay_alpha: 0.6
prioritized_replay_beta: 0.4
prioritized_replay_eps: 0.000001
worker_side_prioritization: False
learning_starts: 500
clip_rewards: False
actor_lr: 0.001
critic_lr: 0.001
use_huber: False
huber_threshold: 1.0
l2_reg: 0.000001
learning_starts: 500
rollout_fragment_length: 1
train_batch_size: 48
num_workers: 0
num_gpus_per_worker: 0
worker_side_prioritization: False
num_gpus_per_worker: 0

View file

@ -16,7 +16,8 @@ invertedpendulum-td3:
critic_hiddens: [32, 32]
# === Exploration ===
learning_starts: 1000
replay_buffer_config:
learning_starts: 1000
exploration_config:
random_timesteps: 1000

View file

@ -9,4 +9,5 @@ memory-leak-test-ddpg:
env_config:
config:
static_samples: true
buffer_size: 500 # use small buffer to catch memory leaks
replay_buffer_config:
capacity: 500 # use small buffer to catch memory leaks

View file

@ -32,10 +32,12 @@ mountaincarcontinuous-ddpg:
# === Replay buffer ===
replay_buffer_config:
type: MultiAgentPrioritizedReplayBuffer
capacity: 50000
prioritized_replay_alpha: 0.6
prioritized_replay_beta: 0.4
prioritized_replay_eps: 0.000001
worker_side_prioritization: False
clip_rewards: False
# === Optimization ===
@ -51,7 +53,6 @@ mountaincarcontinuous-ddpg:
# === Parallelism ===
num_workers: 0
num_gpus_per_worker: 0
worker_side_prioritization: False
# === Evaluation ===
evaluation_interval: 5

View file

@ -18,10 +18,11 @@ mujoco-td3:
# Works for both torch and tf.
framework: tf
# === Exploration ===
learning_starts: 10000
exploration_config:
random_timesteps: 10000
replay_buffer_config:
type: MultiAgentReplayBuffer
learning_starts: 10000
# === Evaluation ===
evaluation_interval: 10
evaluation_num_episodes: 10

View file

@ -17,13 +17,14 @@ pendulum-ddpg-fake-gpus:
final_scale: 0.02
min_sample_timesteps_per_reporting: 600
replay_buffer_config:
type: MultiAgentPrioritizedReplayBuffer
capacity: 10000
worker_side_prioritization: false
learning_starts: 500
clip_rewards: false
use_huber: true
learning_starts: 500
train_batch_size: 64
num_workers: 0
worker_side_prioritization: false
actor_lr: 0.0001
critic_lr: 0.0001

View file

@ -34,7 +34,9 @@ pendulum-ddpg:
# === Replay buffer ===
replay_buffer_config:
type: MultiAgentPrioritizedReplayBuffer
capacity: 10000
worker_side_prioritization: False
clip_rewards: False
# === Optimization ===
@ -49,4 +51,3 @@ pendulum-ddpg:
# === Parallelism ===
num_workers: 0
worker_side_prioritization: False

View file

@ -9,7 +9,10 @@ pendulum-td3-fake-gpus:
framework: tf
actor_hiddens: [64, 64]
critic_hiddens: [64, 64]
learning_starts: 5000
replay_buffer_config:
type: MultiAgentReplayBuffer
learning_starts: 5000
exploration_config:
random_timesteps: 5000
evaluation_interval: 10

View file

@ -12,6 +12,8 @@ pendulum-td3:
actor_hiddens: [64, 64]
critic_hiddens: [64, 64]
# === Exploration ===
learning_starts: 5000
replay_buffer_config:
type: MultiAgentReplayBuffer
learning_starts: 5000
exploration_config:
random_timesteps: 5000

View file

@ -15,11 +15,12 @@ apex-breakoutnoframeskip-v4:
lr: .0001
adam_epsilon: .00015
hiddens: [512]
buffer_size: 1000000
replay_buffer_config:
capacity: 1000000
prioritized_replay_alpha: 0.5
exploration_config:
epsilon_timesteps: 200000
final_epsilon: 0.01
prioritized_replay_alpha: 0.5
num_gpus: 1
num_workers: 8
num_envs_per_worker: 8

View file

@ -11,19 +11,19 @@ atari-dist-dqn:
dueling: false
num_atoms: 51
noisy: false
prioritized_replay: false
replay_buffer_config:
type: MultiAgentReplayBuffer
capacity: 1000000
learning_starts: 20000
n_step: 1
target_network_update_freq: 8000
lr: .0000625
adam_epsilon: .00015
hiddens: [512]
learning_starts: 20000
buffer_size: 1000000
rollout_fragment_length: 4
train_batch_size: 32
exploration_config:
epsilon_timesteps: 200000
final_epsilon: 0.01
prioritized_replay_alpha: 0.5
num_gpus: 0.2
min_sample_timesteps_per_reporting: 10000

View file

@ -15,19 +15,19 @@ atari-basic-dqn:
dueling: false
num_atoms: 1
noisy: false
prioritized_replay: false
replay_buffer_config:
type: MultiAgentReplayBuffer
learning_starts: 20000
capacity: 1000000
n_step: 1
target_network_update_freq: 8000
lr: .0000625
adam_epsilon: .00015
hiddens: [512]
learning_starts: 20000
buffer_size: 1000000
rollout_fragment_length: 4
train_batch_size: 32
exploration_config:
epsilon_timesteps: 200000
final_epsilon: 0.01
prioritized_replay_alpha: 0.5
num_gpus: 0.2
min_sample_timesteps_per_reporting: 10000

View file

@ -15,19 +15,19 @@ dueling-ddqn:
dueling: true
num_atoms: 1
noisy: false
prioritized_replay: false
replay_buffer_config:
type: MultiAgentReplayBuffer
learning_starts: 20000
capacity: 1000000
n_step: 1
target_network_update_freq: 8000
lr: .0000625
adam_epsilon: .00015
hiddens: [512]
learning_starts: 20000
buffer_size: 1000000
rollout_fragment_length: 4
train_batch_size: 32
exploration_config:
epsilon_timesteps: 200000
final_epsilon: 0.01
prioritized_replay_alpha: 0.5
num_gpus: 0.2
min_sample_timesteps_per_reporting: 10000

View file

@ -17,13 +17,15 @@ cartpole-apex-dqn-training-itr:
# Make this work with only 5 CPUs and 0 GPUs:
num_workers: 3
optimizer:
num_replay_buffer_shards: 2
num_replay_buffer_shards: 2
replay_buffer_config:
type: MultiAgentPrioritizedReplayBuffer
capacity: 20000
learning_starts: 1000
num_gpus: 0
min_time_s_per_reporting: 5
target_network_update_freq: 500
learning_starts: 1000
min_sample_timesteps_per_reporting: 1000
buffer_size: 20000
training_intensity: 4

View file

@ -9,4 +9,5 @@ memory-leak-test-dqn:
env_config:
config:
static_samples: true
buffer_size: 500 # use small buffer to catch memory leaks
replay_buffer_config:
capacity: 500 # use small buffer to catch memory leaks

View file

@ -15,6 +15,8 @@ pong-apex:
num_envs_per_worker: 8
lr: .00005
train_batch_size: 64
buffer_size: 1000000
replay_buffer_config:
type: MultiAgentPrioritizedReplayBuffer
capacity: 1000000
gamma: 0.99
training_intensity: 16

View file

@ -11,8 +11,10 @@ pong-deterministic-dqn:
num_gpus: 1
gamma: 0.99
lr: .0001
learning_starts: 10000
buffer_size: 50000
replay_buffer_config:
type: MultiAgentPrioritizedReplayBuffer
capacity: 50000
learning_starts: 10000
rollout_fragment_length: 4
train_batch_size: 32
exploration_config:

View file

@ -9,16 +9,17 @@ pong-deterministic-rainbow:
gamma: 0.99
lr: .0001
hiddens: [512]
learning_starts: 10000
buffer_size: 50000
rollout_fragment_length: 4
train_batch_size: 32
exploration_config:
epsilon_timesteps: 2
final_epsilon: 0.0
target_network_update_freq: 500
prioritized_replay: True
prioritized_replay_alpha: 0.5
replay_buffer_config:
type: MultiAgentPrioritizedReplayBuffer
prioritized_replay_alpha: 0.5
learning_starts: 10000
capacity: 50000
n_step: 3
gpu: True
model:

View file

@ -10,6 +10,7 @@ stateless-cartpole-r2d2:
num_workers: 0
# R2D2 settings.
replay_buffer_config:
type: MultiAgentReplayBuffer
replay_burn_in: 20
zero_init_states: true
#dueling: false

View file

@ -10,6 +10,7 @@ stateless-cartpole-r2d2:
num_workers: 0
# R2D2 settings.
replay_buffer_config:
type: MultiAgentReplayBuffer
replay_burn_in: 20
zero_init_states: true
#dueling: false

View file

@ -31,23 +31,19 @@ atari-sac-tf-and-torch:
n_step: 1
rollout_fragment_length: 1
replay_buffer_config:
_enable_replay_buffer_api: False
type: MultiAgentReplayBuffer
type: MultiAgentPrioritizedReplayBuffer
capacity: 1000000
# How many steps of the model to sample before learning starts.
learning_starts: 1500
learning_starts: 100000
# If True prioritized replay buffer will be used.
prioritized_replay: True
prioritized_replay_alpha: 0.6
prioritized_replay_beta: 0.4
prioritized_replay_eps: 1e-6
train_batch_size: 64
min_sample_timesteps_per_reporting: 4
# Paper uses 20k random timesteps, which is not exactly the same, but
# seems to work nevertheless. We use 100k here for the longer Atari
# runs (DQN style: filling up the buffer a bit before learning).
learning_starts: 100000
optimization:
actor_learning_rate: 0.0003
critic_learning_rate: 0.0003

View file

@ -12,9 +12,10 @@ cartpole-sac:
horizon: 200
soft_horizon: true
n_step: 3
prioritized_replay: true
replay_buffer_config:
type: MultiAgentPrioritizedReplayBuffer
learning_starts: 256
initial_alpha: 0.2
learning_starts: 256
clip_actions: false
min_sample_timesteps_per_reporting: 1000
optimization:

View file

@ -19,11 +19,12 @@ halfcheetah-pybullet-sac:
no_done_at_end: false
n_step: 3
rollout_fragment_length: 1
prioritized_replay: true
train_batch_size: 256
target_network_update_freq: 1
min_sample_timesteps_per_reporting: 1000
learning_starts: 10000
replay_buffer_config:
type: MultiAgentPrioritizedReplayBuffer
learning_starts: 10000
optimization:
actor_learning_rate: 0.0003
critic_learning_rate: 0.0003

View file

@ -20,11 +20,12 @@ halfcheetah_sac:
no_done_at_end: true
n_step: 1
rollout_fragment_length: 1
prioritized_replay: true
train_batch_size: 256
target_network_update_freq: 1
min_sample_timesteps_per_reporting: 1000
learning_starts: 10000
replay_buffer_config:
type: MultiAgentPrioritizedReplayBuffer
learning_starts: 10000
optimization:
actor_learning_rate: 0.0003
critic_learning_rate: 0.0003

View file

@ -9,4 +9,5 @@ memory-leak-test-sac:
env_config:
config:
static_samples: true
buffer_size: 500 # use small buffer to catch memory leaks
replay_buffer_config:
capacity: 500 # use small buffer to catch memory leaks

View file

@ -27,12 +27,13 @@ mspacman-sac-tf:
no_done_at_end: False
n_step: 1
rollout_fragment_length: 1
prioritized_replay: true
train_batch_size: 64
min_sample_timesteps_per_reporting: 4
# Paper uses 20k random timesteps, which is not exactly the same, but
# seems to work nevertheless.
learning_starts: 20000
replay_buffer_config:
type: MultiAgentPrioritizedReplayBuffer
learning_starts: 20000
optimization:
actor_learning_rate: 0.0003
critic_learning_rate: 0.0003

View file

@ -21,11 +21,12 @@ pendulum-sac-fake-gpus:
no_done_at_end: true
n_step: 1
rollout_fragment_length: 1
prioritized_replay: true
train_batch_size: 256
target_network_update_freq: 1
min_sample_timesteps_per_reporting: 1000
learning_starts: 256
replay_buffer_config:
type: MultiAgentPrioritizedReplayBuffer
learning_starts: 256
num_workers: 0
metrics_smoothing_episodes: 5

View file

@ -23,11 +23,12 @@ pendulum-sac:
no_done_at_end: true
n_step: 1
rollout_fragment_length: 1
prioritized_replay: true
train_batch_size: 256
target_network_update_freq: 1
min_sample_timesteps_per_reporting: 1000
learning_starts: 256
replay_buffer_config:
type: MultiAgentPrioritizedReplayBuffer
learning_starts: 256
optimization:
actor_learning_rate: 0.0003
critic_learning_rate: 0.0003

View file

@ -30,11 +30,12 @@ transformed-actions-pendulum-sac-dummy-torch:
no_done_at_end: true
n_step: 1
rollout_fragment_length: 1
prioritized_replay: true
train_batch_size: 256
target_network_update_freq: 1
min_sample_timesteps_per_reporting: 1000
learning_starts: 256
replay_buffer_config:
type: MultiAgentPrioritizedReplayBuffer
learning_starts: 256
optimization:
actor_learning_rate: 0.0003
critic_learning_rate: 0.0003

View file

@ -9,7 +9,7 @@ from ray.rllib.policy.sample_batch import (
SampleBatch,
MultiAgentBatch,
)
from ray.rllib.utils.annotations import override, ExperimentalAPI
from ray.rllib.utils.annotations import override
from ray.rllib.utils.replay_buffers.multi_agent_prioritized_replay_buffer import (
MultiAgentPrioritizedReplayBuffer,
)
@ -21,13 +21,14 @@ from ray.rllib.utils.replay_buffers.multi_agent_replay_buffer import (
ReplayMode,
)
from ray.rllib.utils.typing import PolicyID, SampleBatchType
from ray.rllib.execution.buffers.replay_buffer import _ALL_POLICIES
from ray.rllib.utils.replay_buffers.replay_buffer import _ALL_POLICIES
from ray.util.debug import log_once
from ray.util.annotations import DeveloperAPI
logger = logging.getLogger(__name__)
@ExperimentalAPI
@DeveloperAPI
class MultiAgentMixInReplayBuffer(MultiAgentPrioritizedReplayBuffer):
"""This buffer adds replayed samples to a stream of new experiences.
@ -168,7 +169,7 @@ class MultiAgentMixInReplayBuffer(MultiAgentPrioritizedReplayBuffer):
self.last_added_batches = collections.defaultdict(list)
@ExperimentalAPI
@DeveloperAPI
@override(MultiAgentPrioritizedReplayBuffer)
def add(self, batch: SampleBatchType, **kwargs) -> None:
"""Adds a batch to the appropriate policy's replay buffer.
@ -244,7 +245,7 @@ class MultiAgentMixInReplayBuffer(MultiAgentPrioritizedReplayBuffer):
self._num_added += batch.count
@ExperimentalAPI
@DeveloperAPI
@override(MultiAgentReplayBuffer)
def sample(
self, num_items: int, policy_id: PolicyID = DEFAULT_POLICY_ID, **kwargs
@ -353,7 +354,7 @@ class MultiAgentMixInReplayBuffer(MultiAgentPrioritizedReplayBuffer):
return MultiAgentBatch.concat_samples(samples)
@ExperimentalAPI
@DeveloperAPI
@override(MultiAgentPrioritizedReplayBuffer)
def get_state(self) -> Dict[str, Any]:
"""Returns all local state.
@ -368,7 +369,7 @@ class MultiAgentMixInReplayBuffer(MultiAgentPrioritizedReplayBuffer):
parent.update(data)
return parent
@ExperimentalAPI
@DeveloperAPI
@override(MultiAgentPrioritizedReplayBuffer)
def set_state(self, state: Dict[str, Any]) -> None:
"""Restores all local state to the provided `state`.

View file

@ -2,9 +2,8 @@ from typing import Dict
import logging
import numpy as np
import ray
from ray.rllib.policy.rnn_sequencing import timeslice_along_seq_lens_with_overlap
from ray.rllib.utils.annotations import override, ExperimentalAPI
from ray.rllib.utils.annotations import override
from ray.rllib.utils.replay_buffers.multi_agent_replay_buffer import (
MultiAgentReplayBuffer,
ReplayMode,
@ -16,12 +15,15 @@ from ray.rllib.utils.replay_buffers.replay_buffer import StorageUnit
from ray.rllib.utils.typing import PolicyID, SampleBatchType
from ray.rllib.utils.timer import TimerStat
from ray.util.debug import log_once
from ray.util.annotations import DeveloperAPI
logger = logging.getLogger(__name__)
@ExperimentalAPI
class MultiAgentPrioritizedReplayBuffer(MultiAgentReplayBuffer):
@DeveloperAPI
class MultiAgentPrioritizedReplayBuffer(
MultiAgentReplayBuffer, PrioritizedReplayBuffer
):
"""A prioritized replay buffer shard for multiagent setups.
This buffer is meant to be run in parallel to distribute experiences
@ -141,7 +143,7 @@ class MultiAgentPrioritizedReplayBuffer(MultiAgentReplayBuffer):
self.prioritized_replay_eps = prioritized_replay_eps
self.update_priorities_timer = TimerStat()
@ExperimentalAPI
@DeveloperAPI
@override(MultiAgentReplayBuffer)
def _add_to_underlying_buffer(
self, policy_id: PolicyID, batch: SampleBatchType, **kwargs
@ -206,7 +208,8 @@ class MultiAgentPrioritizedReplayBuffer(MultiAgentReplayBuffer):
else:
self.replay_buffers[policy_id].add(batch, **kwargs)
@ExperimentalAPI
@DeveloperAPI
@override(PrioritizedReplayBuffer)
def update_priorities(self, prio_dict: Dict) -> None:
"""Updates the priorities of underlying replay buffers.
@ -225,7 +228,7 @@ class MultiAgentPrioritizedReplayBuffer(MultiAgentReplayBuffer):
batch_indexes, new_priorities
)
@ExperimentalAPI
@DeveloperAPI
@override(MultiAgentReplayBuffer)
def stats(self, debug: bool = False) -> Dict:
"""Returns the stats of this buffer and all underlying buffers.
@ -249,6 +252,3 @@ class MultiAgentPrioritizedReplayBuffer(MultiAgentReplayBuffer):
{"policy_{}".format(policy_id): replay_buffer.stats(debug=debug)}
)
return stat
ReplayActor = ray.remote(num_cpus=0)(MultiAgentPrioritizedReplayBuffer)

View file

@ -3,11 +3,10 @@ import collections
from typing import Any, Dict, Optional
from enum import Enum
import ray
from ray.rllib.utils.replay_buffers.replay_buffer import _ALL_POLICIES
from ray.rllib.policy.rnn_sequencing import timeslice_along_seq_lens_with_overlap
from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.utils.annotations import override, ExperimentalAPI
from ray.rllib.utils.annotations import override
from ray.rllib.utils.replay_buffers.replay_buffer import ReplayBuffer
from ray.rllib.utils.timer import TimerStat
from ray.rllib.utils.typing import PolicyID, SampleBatchType
@ -15,17 +14,18 @@ from ray.rllib.utils.replay_buffers.replay_buffer import StorageUnit
from ray.rllib.utils.from_config import from_config
from ray.util.debug import log_once
from ray.rllib.utils.deprecation import Deprecated
from ray.util.annotations import DeveloperAPI
logger = logging.getLogger(__name__)
@ExperimentalAPI
@DeveloperAPI
class ReplayMode(Enum):
LOCKSTEP = "lockstep"
INDEPENDENT = "independent"
@ExperimentalAPI
@DeveloperAPI
def merge_dicts_with_warning(args_on_init, args_on_call):
"""Merge argument dicts, overwriting args_on_call with warning.
@ -50,7 +50,7 @@ def merge_dicts_with_warning(args_on_init, args_on_call):
return {**args_on_init, **args_on_call}
@ExperimentalAPI
@DeveloperAPI
class MultiAgentReplayBuffer(ReplayBuffer):
"""A replay buffer shard for multiagent setups.
@ -183,7 +183,7 @@ class MultiAgentReplayBuffer(ReplayBuffer):
"""Returns the number of items currently stored in this buffer."""
return sum(len(buffer._storage) for buffer in self.replay_buffers.values())
@ExperimentalAPI
@DeveloperAPI
@Deprecated(old="replay", new="sample", error=False)
def replay(self, num_items: int = None, **kwargs) -> Optional[SampleBatchType]:
"""Deprecated in favor of new ReplayBuffer API."""
@ -191,7 +191,7 @@ class MultiAgentReplayBuffer(ReplayBuffer):
num_items = self.replay_batch_size
return self.sample(num_items, **kwargs)
@ExperimentalAPI
@DeveloperAPI
@override(ReplayBuffer)
def add(self, batch: SampleBatchType, **kwargs) -> None:
"""Adds a batch to the appropriate policy's replay buffer.
@ -229,7 +229,7 @@ class MultiAgentReplayBuffer(ReplayBuffer):
self._add_to_underlying_buffer(policy_id, sample_batch, **kwargs)
self._num_added += batch.count
@ExperimentalAPI
@DeveloperAPI
def _add_to_underlying_buffer(
self, policy_id: PolicyID, batch: SampleBatchType, **kwargs
) -> None:
@ -266,7 +266,7 @@ class MultiAgentReplayBuffer(ReplayBuffer):
else:
self.replay_buffers[policy_id].add(batch, **kwargs)
@ExperimentalAPI
@DeveloperAPI
@override(ReplayBuffer)
def sample(
self, num_items: int, policy_id: Optional[PolicyID] = None, **kwargs
@ -310,7 +310,7 @@ class MultiAgentReplayBuffer(ReplayBuffer):
samples[policy_id] = replay_buffer.sample(num_items, **kwargs)
return MultiAgentBatch(samples, sum(s.count for s in samples.values()))
@ExperimentalAPI
@DeveloperAPI
@override(ReplayBuffer)
def stats(self, debug: bool = False) -> Dict:
"""Returns the stats of this buffer and all underlying buffers.
@ -332,7 +332,7 @@ class MultiAgentReplayBuffer(ReplayBuffer):
)
return stat
@ExperimentalAPI
@DeveloperAPI
@override(ReplayBuffer)
def get_state(self) -> Dict[str, Any]:
"""Returns all local state.
@ -345,7 +345,7 @@ class MultiAgentReplayBuffer(ReplayBuffer):
state["replay_buffers"][policy_id] = replay_buffer.get_state()
return state
@ExperimentalAPI
@DeveloperAPI
@override(ReplayBuffer)
def set_state(self, state: Dict[str, Any]) -> None:
"""Restores all local state to the provided `state`.
@ -358,6 +358,3 @@ class MultiAgentReplayBuffer(ReplayBuffer):
buffer_states = state["replay_buffers"]
for policy_id in buffer_states.keys():
self.replay_buffers[policy_id].set_state(buffer_states[policy_id])
ReplayActor = ray.remote(num_cpus=0)(MultiAgentReplayBuffer)

View file

@ -8,13 +8,14 @@ import psutil # noqa E402
from ray.rllib.execution.segment_tree import SumSegmentTree, MinSegmentTree
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import override, ExperimentalAPI
from ray.rllib.utils.annotations import override
from ray.rllib.utils.metrics.window_stat import WindowStat
from ray.rllib.utils.replay_buffers.replay_buffer import ReplayBuffer
from ray.rllib.utils.typing import SampleBatchType
from ray.util.annotations import DeveloperAPI
@ExperimentalAPI
@DeveloperAPI
class PrioritizedReplayBuffer(ReplayBuffer):
"""This buffer implements Prioritized Experience Replay
@ -23,7 +24,6 @@ class PrioritizedReplayBuffer(ReplayBuffer):
the full paper.
"""
@ExperimentalAPI
def __init__(
self,
capacity: int = 10000,
@ -58,7 +58,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
self._max_priority = 1.0
self._prio_change_stats = WindowStat("reprio", 1000)
@ExperimentalAPI
@DeveloperAPI
@override(ReplayBuffer)
def _add_single_batch(self, item: SampleBatchType, **kwargs) -> None:
"""Add a batch of experiences to self._storage with weight.
@ -90,7 +90,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
res.append(idx)
return res
@ExperimentalAPI
@DeveloperAPI
@override(ReplayBuffer)
def sample(
self, num_items: int, beta: float, **kwargs
@ -157,7 +157,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
return batch
@ExperimentalAPI
@DeveloperAPI
def update_priorities(self, idxes: List[int], priorities: List[float]) -> None:
"""Update priorities of items at given indices.
@ -186,7 +186,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
self._max_priority = max(self._max_priority, priority)
@ExperimentalAPI
@DeveloperAPI
@override(ReplayBuffer)
def stats(self, debug: bool = False) -> Dict:
"""Returns the stats of this buffer.
@ -203,7 +203,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
parent.update(self._prio_change_stats.stats())
return parent
@ExperimentalAPI
@DeveloperAPI
@override(ReplayBuffer)
def get_state(self) -> Dict[str, Any]:
"""Returns all local state.
@ -223,7 +223,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
)
return state
@ExperimentalAPI
@DeveloperAPI
@override(ReplayBuffer)
def set_state(self, state: Dict[str, Any]) -> None:
"""Restores all local state to the provided `state`.

View file

@ -1,6 +1,6 @@
import logging
import platform
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Callable
import numpy as np
import random
@ -12,11 +12,11 @@ import psutil # noqa E402
from ray.util.debug import log_once
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
from ray.rllib.utils.annotations import ExperimentalAPI
from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.metrics.window_stat import WindowStat
from ray.rllib.utils.typing import SampleBatchType
from ray.rllib.execution.buffers.replay_buffer import warn_replay_capacity
from ray.rllib.utils.typing import SampleBatchType, T
from ray.util.annotations import DeveloperAPI
from ray.util.iter import ParallelIteratorWorker
# Constant that represents all policies in lockstep replay mode.
_ALL_POLICIES = "__all__"
@ -24,7 +24,7 @@ _ALL_POLICIES = "__all__"
logger = logging.getLogger(__name__)
@ExperimentalAPI
@DeveloperAPI
class StorageUnit(Enum):
TIMESTEPS = "timesteps"
SEQUENCES = "sequences"
@ -32,8 +32,31 @@ class StorageUnit(Enum):
FRAGMENTS = "fragments"
@ExperimentalAPI
class ReplayBuffer:
@DeveloperAPI
def warn_replay_capacity(*, item: SampleBatchType, num_items: int) -> None:
"""Warn if the configured replay buffer capacity is too large."""
if log_once("replay_capacity"):
item_size = item.size_bytes()
psutil_mem = psutil.virtual_memory()
total_gb = psutil_mem.total / 1e9
mem_size = num_items * item_size / 1e9
msg = (
"Estimated max memory usage for replay buffer is {} GB "
"({} batches of size {}, {} bytes each), "
"available system memory is {} GB".format(
mem_size, num_items, item.count, item_size, total_gb
)
)
if mem_size > total_gb:
raise ValueError(msg)
elif mem_size > 0.2 * total_gb:
logger.warning(msg)
else:
logger.info(msg)
@DeveloperAPI
class ReplayBuffer(ParallelIteratorWorker):
def __init__(
self, capacity: int = 10000, storage_unit: str = "timesteps", **kwargs
):
@ -96,11 +119,17 @@ class ReplayBuffer:
self.batch_size = None
def gen_replay():
while True:
yield self.replay()
ParallelIteratorWorker.__init__(self, gen_replay, False)
def __len__(self) -> int:
"""Returns the number of items currently stored in this buffer."""
return len(self._storage)
@ExperimentalAPI
@DeveloperAPI
def add(self, batch: SampleBatchType, **kwargs) -> None:
"""Adds a batch of experiences to this buffer.
@ -155,7 +184,7 @@ class ReplayBuffer:
elif self._storage_unit == StorageUnit.FRAGMENTS:
self._add_single_batch(batch, **kwargs)
@ExperimentalAPI
@DeveloperAPI
def _add_single_batch(self, item: SampleBatchType, **kwargs) -> None:
"""Add a SampleBatch of experiences to self._storage.
@ -189,7 +218,7 @@ class ReplayBuffer:
else:
self._next_idx += 1
@ExperimentalAPI
@DeveloperAPI
def sample(self, num_items: int, **kwargs) -> Optional[SampleBatchType]:
"""Samples `num_items` items from this buffer.
@ -220,7 +249,7 @@ class ReplayBuffer:
self._num_timesteps_sampled += sample.count
return sample
@ExperimentalAPI
@DeveloperAPI
def stats(self, debug: bool = False) -> dict:
"""Returns the stats of this buffer.
@ -243,7 +272,7 @@ class ReplayBuffer:
data.update(self._evicted_hit_stats.stats())
return data
@ExperimentalAPI
@DeveloperAPI
def get_state(self) -> Dict[str, Any]:
"""Returns all local state.
@ -254,7 +283,7 @@ class ReplayBuffer:
state.update(self.stats(debug=False))
return state
@ExperimentalAPI
@DeveloperAPI
def set_state(self, state: Dict[str, Any]) -> None:
"""Restores all local state to the provided `state`.
@ -272,6 +301,7 @@ class ReplayBuffer:
self._num_timesteps_sampled = state["sampled_count"]
self._est_size_bytes = state["est_size_bytes"]
@DeveloperAPI
def _encode_sample(self, idxes: List[int]) -> SampleBatchType:
"""Fetches concatenated samples at given indeces from the storage."""
samples = []
@ -288,6 +318,7 @@ class ReplayBuffer:
out.decompress_if_needed()
return out
@DeveloperAPI
def get_host(self) -> str:
"""Returns the computer's network name.
@ -297,10 +328,35 @@ class ReplayBuffer:
"""
return platform.node()
@DeveloperAPI
def apply(
self,
func: Callable[["ReplayBuffer", Optional[Any], Optional[Any]], T],
*_args,
**kwargs,
) -> T:
"""Calls the given function with this ReplayBuffer instance.
This is useful if we want to apply a function to a set of remote actors.
Args:
func: A callable that accepts the replay buffer itself, args and kwargs
*_arkgs: Any args to pass to func
**kwargs: Any kwargs to pass to func
Returns:
Return value of the induced function call
"""
return func(self, *_args, **kwargs)
@Deprecated(old="ReplayBuffer.add_batch()", new="RepayBuffer.add()", error=False)
def add_batch(self, *args, **kwargs):
return self.add(*args, **kwargs)
@Deprecated(old="RepayBuffer.replay()", new="RepayBuffer.sample()", error=False)
def replay(self, *args, **kwargs):
return self.sample(*args, **kwargs)
@Deprecated(
old="RepayBuffer.replay(num_items)",
new="RepayBuffer.sample(" "num_items)",
error=False,
)
def replay(self, num_items):
return self.sample(num_items)

View file

@ -6,8 +6,10 @@ import ray # noqa F401
import psutil # noqa E402
from ray.rllib.utils.annotations import ExperimentalAPI, override
from ray.rllib.utils.replay_buffers.replay_buffer import ReplayBuffer
from ray.rllib.execution.buffers.replay_buffer import warn_replay_capacity
from ray.rllib.utils.replay_buffers.replay_buffer import (
ReplayBuffer,
warn_replay_capacity,
)
from ray.rllib.utils.typing import SampleBatchType

View file

@ -4,8 +4,10 @@ from ray.rllib.utils.annotations import override
from ray.rllib.utils.replay_buffers.replay_buffer import ReplayBuffer
from ray.rllib.utils.replay_buffers.utils import warn_replay_buffer_capacity
from ray.rllib.utils.typing import SampleBatchType
from ray.util.annotations import DeveloperAPI
@DeveloperAPI
class SimpleReplayBuffer(ReplayBuffer):
"""Simple replay buffer that operates over entire batches."""
@ -15,6 +17,7 @@ class SimpleReplayBuffer(ReplayBuffer):
self.replay_batches = []
self.replay_index = 0
@DeveloperAPI
@override(ReplayBuffer)
def add(self, batch: SampleBatchType, **kwargs) -> None:
warn_replay_buffer_capacity(item=batch, capacity=self.capacity)
@ -26,10 +29,12 @@ class SimpleReplayBuffer(ReplayBuffer):
self.replay_index += 1
self.replay_index %= self.capacity
@DeveloperAPI
@override(ReplayBuffer)
def sample(self, num_items: int, **kwargs) -> SampleBatchType:
return random.choice(self.replay_batches)
@DeveloperAPI
@override(ReplayBuffer)
def __len__(self):
return len(self.replay_batches)

View file

@ -2,10 +2,6 @@ import logging
import psutil
from typing import Optional, Any
from ray.rllib.execution import MultiAgentReplayBuffer as Legacy_MultiAgentReplayBuffer
from ray.rllib.execution.buffers.multi_agent_replay_buffer import (
MultiAgentReplayBuffer as LegacyMultiAgentReplayBuffer,
)
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils import deprecation_warning
from ray.rllib.utils.annotations import ExperimentalAPI
@ -47,10 +43,7 @@ def update_priorities_in_replay_buffer(
utility.
"""
# Only update priorities if buffer supports them.
if (
type(replay_buffer) is LegacyMultiAgentReplayBuffer
and config["replay_buffer_config"].get("prioritized_replay_alpha", 0.0) > 0.0
) or isinstance(replay_buffer, MultiAgentPrioritizedReplayBuffer):
if isinstance(replay_buffer, MultiAgentPrioritizedReplayBuffer):
# Go through training results for the different policies (maybe multi-agent).
prio_dict = {}
for policy_id, info in train_results.items():
@ -145,46 +138,57 @@ def validate_buffer_config(config: dict):
if config.get("replay_buffer_config", None) is None:
config["replay_buffer_config"] = {}
prioritized_replay = config.get("prioritized_replay")
prioritized_replay = config.get("prioritized_replay", DEPRECATED_VALUE)
if prioritized_replay != DEPRECATED_VALUE:
deprecation_warning(
old="config['prioritized_replay']",
help="Replay prioritization specified at new location config["
"'replay_buffer_config']["
"'prioritized_replay'] will be overwritten.",
error=False,
old="config['prioritized_replay'] or config['replay_buffer_config']["
"'prioritized_replay']",
help="Replay prioritization specified by config key. RLlib's new replay "
"buffer API requires setting `config["
"'replay_buffer_config']['type']`, e.g. `config["
"'replay_buffer_config']['type'] = "
"'MultiAgentPrioritizedReplayBuffer'` to change the default "
"behaviour.",
error=True,
)
config["replay_buffer_config"]["prioritized_replay"] = prioritized_replay
capacity = config.get("buffer_size", DEPRECATED_VALUE)
if capacity == DEPRECATED_VALUE:
capacity = config["replay_buffer_config"].get("buffer_size", DEPRECATED_VALUE)
if capacity != DEPRECATED_VALUE:
deprecation_warning(
old="config['buffer_size']",
help="Buffer size specified at new location config["
"'replay_buffer_config']["
"'capacity'] will be overwritten.",
error=False,
old="config['buffer_size'] or config['replay_buffer_config']["
"'buffer_size']",
new="config['replay_buffer_config']['capacity']",
error=True,
)
replay_burn_in = config.get("burn_in", DEPRECATED_VALUE)
if replay_burn_in != DEPRECATED_VALUE:
config["replay_buffer_config"]["replay_burn_in"] = replay_burn_in
deprecation_warning(
old="config['burn_in']",
help="config['replay_buffer_config']['replay_burn_in']",
)
config["replay_buffer_config"]["capacity"] = capacity
# Deprecation of old-style replay buffer args
# Warnings before checking of we need local buffer so that algorithms
# Without local buffer also get warned
deprecated_replay_buffer_keys = [
keys_with_deprecated_positions = [
"prioritized_replay_alpha",
"prioritized_replay_beta",
"prioritized_replay_eps",
"no_local_replay_buffer",
"replay_batch_size",
"replay_zero_init_states",
"learning_starts",
"replay_buffer_shards_colocated_with_driver",
]
for k in deprecated_replay_buffer_keys:
for k in keys_with_deprecated_positions:
if config.get(k, DEPRECATED_VALUE) != DEPRECATED_VALUE:
deprecation_warning(
old="config[{}]".format(k),
help="config['replay_buffer_config'][{}] should be used "
"for Q-Learning algorithms. Ignore this warning if "
"you are not using a Q-Learning algorithm and still "
"provide {}."
"".format(k, k),
old="config['{}']".format(k),
help="config['replay_buffer_config']['{}']" "".format(k),
error=False,
)
# Copy values over to new location in config to support new
@ -192,143 +196,61 @@ def validate_buffer_config(config: dict):
if config.get("replay_buffer_config") is not None:
config["replay_buffer_config"][k] = config[k]
# Old Ape-X configs may contain no_local_replay_buffer
no_local_replay_buffer = config.get("no_local_replay_buffer", False)
if no_local_replay_buffer:
replay_mode = config["multiagent"].get("replay_mode", DEPRECATED_VALUE)
if replay_mode != DEPRECATED_VALUE:
deprecation_warning(
old="config['no_local_replay_buffer']",
help="no_local_replay_buffer specified at new location config["
"'replay_buffer_config']["
"'capacity'] will be overwritten.",
old="config['multiagent']['replay_mode']",
help="config['replay_buffer_config']['replay_mode']",
error=False,
)
config["replay_buffer_config"][
"no_local_replay_buffer"
] = no_local_replay_buffer
config["replay_buffer_config"]["replay_mode"] = replay_mode
# TODO (Artur):
if config["replay_buffer_config"].get("no_local_replay_buffer", False):
return
# Can't use DEPRECATED_VALUE here because this is also a deliberate
# value set for some algorithms
# TODO: (Artur): Compare to DEPRECATED_VALUE on deprecation
replay_sequence_length = config.get("replay_sequence_length", None)
if replay_sequence_length is not None:
config["replay_buffer_config"][
"replay_sequence_length"
] = replay_sequence_length
deprecation_warning(
old="config['replay_sequence_length']",
help="Replay sequence length specified at new "
"location config['replay_buffer_config']["
"'replay_sequence_length'] will be overwritten.",
error=False,
)
replay_buffer_config = config["replay_buffer_config"]
assert (
"type" in replay_buffer_config
), "Can not instantiate ReplayBuffer from config without 'type' key."
replay_burn_in = config.get("burn_in", DEPRECATED_VALUE)
if replay_burn_in != DEPRECATED_VALUE:
config["replay_buffer_config"]["replay_burn_in"] = replay_burn_in
deprecation_warning(
old="config['burn_in']",
help="Burn in specified at new location config["
"'replay_buffer_config']["
"'replay_burn_in'] will be overwritten.",
)
# Check if old replay buffer should be instantiated
buffer_type = config["replay_buffer_config"]["type"]
if not config["replay_buffer_config"].get("_enable_replay_buffer_api", False):
if isinstance(buffer_type, str) and buffer_type.find(".") == -1:
# Prepend old-style buffers' path
assert buffer_type == "MultiAgentReplayBuffer", (
"Without "
"ReplayBuffer "
"API, only "
"MultiAgentReplayBuffer "
"is supported!"
)
# Create valid full [module].[class] string for from_config
buffer_type = "ray.rllib.execution.MultiAgentReplayBuffer"
else:
assert buffer_type in [
"ray.rllib.execution.MultiAgentReplayBuffer",
Legacy_MultiAgentReplayBuffer,
], (
"Without ReplayBuffer API, only " "MultiAgentReplayBuffer is supported!"
)
config["replay_buffer_config"]["type"] = buffer_type
# Remove from config, so it's not passed into the buffer c'tor
config["replay_buffer_config"].pop("_enable_replay_buffer_api", None)
# We need to deprecate the old-style location of the following
# buffer arguments and make users put them into the
# "replay_buffer_config" field of their config.
replay_batch_size = config.get("replay_batch_size", DEPRECATED_VALUE)
if replay_batch_size != DEPRECATED_VALUE:
config["replay_buffer_config"]["replay_batch_size"] = replay_batch_size
deprecation_warning(
old="config['replay_batch_size']",
help="Replay batch size specified at new "
"location config['replay_buffer_config']["
"'replay_batch_size'] will be overwritten.",
error=False,
)
replay_mode = config.get("replay_mode", DEPRECATED_VALUE)
if replay_mode != DEPRECATED_VALUE:
config["replay_buffer_config"]["replay_mode"] = replay_mode
deprecation_warning(
old="config['multiagent']['replay_mode']",
help="Replay sequence length specified at new "
"location config['replay_buffer_config']["
"'replay_mode'] will be overwritten.",
error=False,
)
# Can't use DEPRECATED_VALUE here because this is also a deliberate
# value set for some algorithms
# TODO: (Artur): Compare to DEPRECATED_VALUE on deprecation
replay_sequence_length = config.get("replay_sequence_length", None)
if replay_sequence_length is not None:
config["replay_buffer_config"][
"replay_sequence_length"
] = replay_sequence_length
deprecation_warning(
old="config['replay_sequence_length']",
help="Replay sequence length specified at new "
"location config['replay_buffer_config']["
"'replay_sequence_length'] will be overwritten.",
error=False,
)
replay_zero_init_states = config.get(
"replay_zero_init_states", DEPRECATED_VALUE
if isinstance(buffer_type, str) and buffer_type.find(".") == -1:
# Create valid full [module].[class] string for from_config
config["replay_buffer_config"]["type"] = (
"ray.rllib.utils.replay_buffers." + buffer_type
)
if replay_zero_init_states != DEPRECATED_VALUE:
config["replay_buffer_config"][
"replay_zero_init_states"
] = replay_zero_init_states
deprecation_warning(
old="config['replay_zero_init_states']",
help="Replay zero init states specified at new location "
"config["
"'replay_buffer_config']["
"'replay_zero_init_states'] will be overwritten.",
error=False,
)
# TODO (Artur): Move this logic into config objects
if config["replay_buffer_config"].get("prioritized_replay", False):
is_prioritized_buffer = True
else:
is_prioritized_buffer = False
# This triggers non-prioritization in old-style replay buffer
config["replay_buffer_config"]["prioritized_replay_alpha"] = 0.0
else:
if isinstance(buffer_type, str) and buffer_type.find(".") == -1:
# Create valid full [module].[class] string for from_config
config["replay_buffer_config"]["type"] = (
"ray.rllib.utils.replay_buffers." + buffer_type
)
test_buffer = from_config(buffer_type, config["replay_buffer_config"])
if hasattr(test_buffer, "update_priorities"):
is_prioritized_buffer = True
else:
is_prioritized_buffer = False
if config["replay_buffer_config"].get("replay_batch_size", None) is None:
# Fall back to train batch size if no replay batch size was provided
logger.info(
"No value for key `replay_batch_size` in replay_buffer_config. "
"config['replay_buffer_config']['replay_batch_size'] will be "
"automatically set to config['train_batch_size']"
)
config["replay_buffer_config"]["replay_batch_size"] = config["train_batch_size"]
if is_prioritized_buffer:
# Instantiate a dummy buffer to fail early on misconfiguration and find out about
# inferred buffer class
dummy_buffer = from_config(buffer_type, config["replay_buffer_config"])
config["replay_buffer_config"]["type"] = type(dummy_buffer)
if hasattr(dummy_buffer, "update_priorities"):
if config["multiagent"]["replay_mode"] == "lockstep":
raise ValueError(
"Prioritized replay is not supported when replay_mode=lockstep."
@ -339,20 +261,12 @@ def validate_buffer_config(config: dict):
"replay_sequence_length > 1."
)
else:
if config.get("worker_side_prioritization"):
if config["replay_buffer_config"].get("worker_side_prioritization"):
raise ValueError(
"Worker side prioritization is not supported when "
"prioritized_replay=False."
)
if config["replay_buffer_config"].get("replay_batch_size", None) is None:
# Fall back to train batch size if no replay batch size was provided
config["replay_buffer_config"]["replay_batch_size"] = config["train_batch_size"]
# Pop prioritized replay because it's not a valid parameter for older
# replay buffers
config["replay_buffer_config"].pop("prioritized_replay", None)
def warn_replay_buffer_capacity(*, item: SampleBatchType, capacity: int) -> None:
"""Warn if the configured replay buffer capacity is too large for machine's memory.

View file

@ -561,22 +561,6 @@ def check_train_results(train_results):
for pid, policy_stats in learner_info.items():
if pid == "batch_count":
continue
# Expect td-errors to be per batch-item.
if "td_error" in policy_stats:
configured_b = train_results["config"]["train_batch_size"]
actual_b = policy_stats["td_error"].shape[0]
# R2D2 case.
if (configured_b - actual_b) / actual_b > 0.1:
assert (
configured_b
/ (
train_results["config"]["model"]["max_seq_len"]
+ train_results["config"]["replay_buffer_config"][
"replay_burn_in"
]
)
== actual_b
)
# Make sure each policy has the LEARNER_STATS_KEY under it.
assert LEARNER_STATS_KEY in policy_stats