mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[RLlib] Replay Buffer API and Ape-X. (#24506)
This commit is contained in:
parent
c74886a55e
commit
fb2915d26a
100 changed files with 653 additions and 1999 deletions
|
@ -43,8 +43,10 @@ run_experiments(
|
|||
"config": {
|
||||
"num_workers": 3,
|
||||
"num_gpus": 0,
|
||||
"buffer_size": 10000,
|
||||
"learning_starts": 0,
|
||||
"replay_buffer_config": {
|
||||
"capacity": 10000,
|
||||
"learning_starts": 0,
|
||||
},
|
||||
"rollout_fragment_length": 1,
|
||||
"train_batch_size": 1,
|
||||
"min_iter_time_s": 10,
|
||||
|
|
|
@ -16,11 +16,12 @@ apex-breakoutnoframeskip-v4:
|
|||
lr: .0001
|
||||
adam_epsilon: .00015
|
||||
hiddens: [512]
|
||||
buffer_size: 1000000
|
||||
replay_buffer_config:
|
||||
capacity: 1000000
|
||||
prioritized_replay_alpha: 0.5
|
||||
exploration_config:
|
||||
epsilon_timesteps: 200000
|
||||
final_epsilon: 0.01
|
||||
prioritized_replay_alpha: 0.5
|
||||
num_gpus: 1
|
||||
num_workers: 8
|
||||
num_envs_per_worker: 8
|
||||
|
|
|
@ -24,11 +24,12 @@ cql-halfcheetahbulletenv-v0:
|
|||
no_done_at_end: false
|
||||
n_step: 3
|
||||
rollout_fragment_length: 1
|
||||
prioritized_replay: false
|
||||
replay_buffer_config:
|
||||
type: MultiAgentReplayBuffer
|
||||
learning_starts: 256
|
||||
train_batch_size: 256
|
||||
target_network_update_freq: 0
|
||||
min_train_timesteps_per_reporting: 1000
|
||||
learning_starts: 256
|
||||
optimization:
|
||||
actor_learning_rate: 0.0001
|
||||
critic_learning_rate: 0.0003
|
||||
|
|
|
@ -24,18 +24,19 @@ ddpg-hopperbulletenv-v0:
|
|||
min_sample_timesteps_per_reporting: 1000
|
||||
target_network_update_freq: 0
|
||||
tau: 0.001
|
||||
buffer_size: 10000
|
||||
prioritized_replay: True
|
||||
prioritized_replay_alpha: 0.6
|
||||
prioritized_replay_beta: 0.4
|
||||
prioritized_replay_eps: 0.000001
|
||||
replay_buffer_config:
|
||||
capacity: 10000
|
||||
type: MultiAgentPrioritizedReplayBuffer
|
||||
prioritized_replay_alpha: 0.6
|
||||
prioritized_replay_beta: 0.4
|
||||
prioritized_replay_eps: 0.000001
|
||||
learning_starts: 500
|
||||
clip_rewards: false
|
||||
actor_lr: 0.001
|
||||
critic_lr: 0.001
|
||||
use_huber: true
|
||||
huber_threshold: 1.0
|
||||
l2_reg: 0.000001
|
||||
learning_starts: 500
|
||||
rollout_fragment_length: 1
|
||||
train_batch_size: 48
|
||||
num_gpus: 1
|
||||
|
|
|
@ -18,13 +18,14 @@ dqn-breakoutnoframeskip-v4:
|
|||
lr: .0000625
|
||||
adam_epsilon: .00015
|
||||
hiddens: [512]
|
||||
learning_starts: 20000
|
||||
buffer_size: 1000000
|
||||
replay_buffer_config:
|
||||
capacity: 1000000
|
||||
learning_starts: 20000
|
||||
prioritized_replay_alpha: 0.5
|
||||
rollout_fragment_length: 4
|
||||
train_batch_size: 32
|
||||
exploration_config:
|
||||
epsilon_timesteps: 200000
|
||||
final_epsilon: 0.01
|
||||
prioritized_replay_alpha: 0.5
|
||||
num_gpus: 0.5
|
||||
min_sample_timesteps_per_reporting: 10000
|
||||
|
|
|
@ -21,11 +21,12 @@ sac-halfcheetahbulletenv-v0:
|
|||
no_done_at_end: false
|
||||
n_step: 3
|
||||
rollout_fragment_length: 1
|
||||
prioritized_replay: true
|
||||
train_batch_size: 256
|
||||
target_network_update_freq: 1
|
||||
min_sample_timesteps_per_reporting: 1000
|
||||
learning_starts: 10000
|
||||
replay_buffer_config:
|
||||
learning_starts: 10000
|
||||
type: MultiAgentPrioritizedReplayBuffer
|
||||
optimization:
|
||||
actor_learning_rate: 0.0003
|
||||
critic_learning_rate: 0.0003
|
||||
|
|
|
@ -9,6 +9,7 @@ td3-halfcheetahbulletenv-v0:
|
|||
time_total_s: 7200
|
||||
config:
|
||||
num_gpus: 1
|
||||
learning_starts: 10000
|
||||
replay_buffer_config:
|
||||
learning_starts: 10000
|
||||
exploration_config:
|
||||
random_timesteps: 10000
|
||||
|
|
|
@ -171,9 +171,9 @@ sac-repeat-after-me-env:
|
|||
repeat_delay: 0
|
||||
num_gpus: 2
|
||||
num_workers: 0
|
||||
replay_buffer_config:
|
||||
type: MultiAgentPrioritizedReplayBuffer
|
||||
initial_alpha: 0.001
|
||||
prioritized_replay: true
|
||||
|
||||
# Double batch size (2 GPUs).
|
||||
train_batch_size: 512
|
||||
|
||||
|
@ -191,11 +191,11 @@ sac-repeat-after-me-env-continuous:
|
|||
config:
|
||||
continuous: true
|
||||
repeat_delay: 0
|
||||
replay_buffer_config:
|
||||
type: MultiAgentPrioritizedReplayBuffer
|
||||
num_gpus: 2
|
||||
num_workers: 0
|
||||
initial_alpha: 0.001
|
||||
prioritized_replay: true
|
||||
|
||||
# Double batch size (2 GPUs).
|
||||
train_batch_size: 512
|
||||
|
||||
|
|
27
rllib/BUILD
27
rllib/BUILD
|
@ -1030,7 +1030,7 @@ py_test(
|
|||
"--env", "Pendulum-v1",
|
||||
"--run", "APEX_DDPG",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"framework\": \"tf\", \"num_workers\": 2, \"optimizer\": {\"num_replay_buffer_shards\": 1}, \"learning_starts\": 100, \"min_time_s_per_reporting\": 1, \"batch_mode\": \"complete_episodes\"}'",
|
||||
"--config", "'{\"framework\": \"tf\", \"num_workers\": 2, \"optimizer\": {\"num_replay_buffer_shards\": 1}, \"replay_buffer_config\": {\"learning_starts\": 100}, \"min_time_s_per_reporting\": 1, \"batch_mode\": \"complete_episodes\"}'",
|
||||
"--ray-num-cpus", "4",
|
||||
]
|
||||
)
|
||||
|
@ -1087,7 +1087,7 @@ py_test(
|
|||
"--env", "CartPole-v0",
|
||||
"--run", "DQN",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"framework\": \"tf\", \"input\": \"tests/data/cartpole\", \"learning_starts\": 0, \"input_evaluation\": [\"wis\", \"is\"], \"exploration_config\": {\"type\": \"SoftQ\"}}'"
|
||||
"--config", "'{\"framework\": \"tf\", \"input\": \"tests/data/cartpole\", \"replay_buffer_config\": {\"learning_starts\": 0}, \"input_evaluation\": [\"wis\", \"is\"], \"exploration_config\": {\"type\": \"SoftQ\"}}'"
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -1099,7 +1099,7 @@ py_test(
|
|||
"--env", "PongDeterministic-v4",
|
||||
"--run", "DQN",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"framework\": \"tf\", \"lr\": 1e-4, \"exploration_config\": {\"epsilon_timesteps\": 200000, \"final_epsilon\": 0.01}, \"buffer_size\": 10000, \"rollout_fragment_length\": 4, \"learning_starts\": 10000, \"target_network_update_freq\": 1000, \"gamma\": 0.99, \"prioritized_replay\": true}'"
|
||||
"--config", "'{\"framework\": \"tf\", \"lr\": 1e-4, \"exploration_config\": {\"epsilon_timesteps\": 200000, \"final_epsilon\": 0.01}, \"replay_buffer_config\": {\"capacity\": 10000, \"learning_starts\": 10000}, \"rollout_fragment_length\": 4, \"target_network_update_freq\": 1000, \"gamma\": 0.99}'"
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -1385,27 +1385,6 @@ py_test(
|
|||
srcs = ["evaluation/tests/test_episode.py"]
|
||||
)
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
# Optimizers and Memories
|
||||
# rllib/execution/
|
||||
#
|
||||
# Tag: execution
|
||||
# --------------------------------------------------------------------
|
||||
|
||||
py_test(
|
||||
name = "test_segment_tree",
|
||||
tags = ["team:ml", "execution"],
|
||||
size = "small",
|
||||
srcs = ["execution/tests/test_segment_tree.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_prioritized_replay_buffer",
|
||||
tags = ["team:ml", "execution"],
|
||||
size = "small",
|
||||
srcs = ["execution/tests/test_prioritized_replay_buffer.py"]
|
||||
)
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
# Models and Distributions
|
||||
# rllib/models/
|
||||
|
|
|
@ -4,6 +4,9 @@ from ray.rllib.evaluation.worker_set import WorkerSet
|
|||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
from ray.util.iter import LocalIterator
|
||||
from ray.rllib.utils.typing import PartialTrainerConfigDict
|
||||
from ray.rllib.utils.typing import ResultDict
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
||||
|
||||
APEX_DDPG_DEFAULT_CONFIG = DDPGTrainer.merge_trainer_configs(
|
||||
DDPGConfig().to_dict(), # see also the options in ddpg.py, which are also supported
|
||||
|
@ -17,42 +20,50 @@ APEX_DDPG_DEFAULT_CONFIG = DDPGTrainer.merge_trainer_configs(
|
|||
"n_step": 3,
|
||||
"num_gpus": 0,
|
||||
"num_workers": 32,
|
||||
"buffer_size": 2000000,
|
||||
# TODO(jungong) : update once Apex supports replay_buffer_config.
|
||||
"no_local_replay_buffer": True,
|
||||
# Whether all shards of the replay buffer must be co-located
|
||||
# with the learner process (running the execution plan).
|
||||
# This is preferred b/c the learner process should have quick
|
||||
# access to the data from the buffer shards, avoiding network
|
||||
# traffic each time samples from the buffer(s) are drawn.
|
||||
# Set this to False for relaxing this constraint and allowing
|
||||
# replay shards to be created on node(s) other than the one
|
||||
# on which the learner is located.
|
||||
"replay_buffer_shards_colocated_with_driver": True,
|
||||
"learning_starts": 50000,
|
||||
"replay_buffer_config": {
|
||||
"capacity": 2000000,
|
||||
"no_local_replay_buffer": True,
|
||||
# Specify prioritized replay by supplying a buffer type that supports
|
||||
# prioritization, for example: MultiAgentPrioritizedReplayBuffer.
|
||||
"prioritized_replay": DEPRECATED_VALUE,
|
||||
"learning_starts": 50000,
|
||||
# Whether all shards of the replay buffer must be co-located
|
||||
# with the learner process (running the execution plan).
|
||||
# This is preferred b/c the learner process should have quick
|
||||
# access to the data from the buffer shards, avoiding network
|
||||
# traffic each time samples from the buffer(s) are drawn.
|
||||
# Set this to False for relaxing this constraint and allowing
|
||||
# replay shards to be created on node(s) other than the one
|
||||
# on which the learner is located.
|
||||
"replay_buffer_shards_colocated_with_driver": True,
|
||||
"worker_side_prioritization": True,
|
||||
},
|
||||
"train_batch_size": 512,
|
||||
"rollout_fragment_length": 50,
|
||||
# Update the target network every `target_network_update_freq` sample timesteps.
|
||||
"target_network_update_freq": 500000,
|
||||
"min_sample_timesteps_per_reporting": 25000,
|
||||
"worker_side_prioritization": True,
|
||||
"min_time_s_per_reporting": 30,
|
||||
# Experimental flag.
|
||||
# If True, the execution plan API will not be used. Instead,
|
||||
# a Trainer's `training_iteration` method will be called as-is each
|
||||
# training iteration.
|
||||
"_disable_execution_plan_api": False,
|
||||
},
|
||||
_allow_unknown_configs=True,
|
||||
)
|
||||
|
||||
|
||||
class ApexDDPGTrainer(DDPGTrainer):
|
||||
class ApexDDPGTrainer(DDPGTrainer, ApexTrainer):
|
||||
@classmethod
|
||||
@override(DDPGTrainer)
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
return APEX_DDPG_DEFAULT_CONFIG
|
||||
|
||||
@override(DDPGTrainer)
|
||||
def setup(self, config: PartialTrainerConfigDict):
|
||||
return ApexTrainer.setup(self, config)
|
||||
|
||||
@override(DDPGTrainer)
|
||||
def training_iteration(self) -> ResultDict:
|
||||
"""Use APEX-DQN's training iteration function."""
|
||||
return ApexTrainer.training_iteration(self)
|
||||
|
||||
@staticmethod
|
||||
@override(DDPGTrainer)
|
||||
def execution_plan(
|
||||
|
|
|
@ -7,6 +7,7 @@ from ray.rllib.agents.trainer_config import TrainerConfig
|
|||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
||||
from ray.rllib.utils.deprecation import Deprecated
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -99,9 +100,11 @@ class DDPGConfig(SimpleQConfig):
|
|||
|
||||
# Common DDPG buffer parameters.
|
||||
self.replay_buffer_config = {
|
||||
"_enable_replay_buffer_api": True,
|
||||
"type": "MultiAgentPrioritizedReplayBuffer",
|
||||
"capacity": 50000,
|
||||
# Specify prioritized replay by supplying a buffer type that supports
|
||||
# prioritization, for example: MultiAgentPrioritizedReplayBuffer.
|
||||
"prioritized_replay": DEPRECATED_VALUE,
|
||||
# Alpha parameter for prioritized replay buffer.
|
||||
"prioritized_replay_alpha": 0.6,
|
||||
# Beta parameter for sampling from prioritized replay buffer.
|
||||
|
@ -110,6 +113,8 @@ class DDPGConfig(SimpleQConfig):
|
|||
"prioritized_replay_eps": 1e-6,
|
||||
# How many steps of the model to sample before learning starts.
|
||||
"learning_starts": 1500,
|
||||
# Whether to compute priorities on workers.
|
||||
"worker_side_prioritization": False,
|
||||
}
|
||||
|
||||
# .training()
|
||||
|
|
|
@ -37,7 +37,6 @@ TD3_DEFAULT_CONFIG = DDPGTrainer.merge_trainer_configs(
|
|||
},
|
||||
# other changes & things we want to keep fixed:
|
||||
# larger actor learning rate, no l2 regularisation, no Huber loss, etc.
|
||||
"learning_starts": 10000,
|
||||
"actor_hiddens": [400, 300],
|
||||
"critic_hiddens": [400, 300],
|
||||
"n_step": 1,
|
||||
|
@ -52,15 +51,16 @@ TD3_DEFAULT_CONFIG = DDPGTrainer.merge_trainer_configs(
|
|||
"target_network_update_freq": 0,
|
||||
"num_workers": 0,
|
||||
"num_gpus_per_worker": 0,
|
||||
"worker_side_prioritization": False,
|
||||
"clip_rewards": False,
|
||||
"use_state_preprocessor": False,
|
||||
# Size of the replay buffer (in time steps).
|
||||
"buffer_size": DEPRECATED_VALUE,
|
||||
"replay_buffer_config": {
|
||||
"_enable_replay_buffer_api": True,
|
||||
"type": "MultiAgentReplayBuffer",
|
||||
# Specify prioritized replay by supplying a buffer type that supports
|
||||
# prioritization, for example: MultiAgentPrioritizedReplayBuffer.
|
||||
"prioritized_replay": DEPRECATED_VALUE,
|
||||
"capacity": 1000000,
|
||||
"learning_starts": 10000,
|
||||
"worker_side_prioritization": False,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
|
|
@ -23,7 +23,7 @@ class TestApexDDPG(unittest.TestCase):
|
|||
config = apex_ddpg.APEX_DDPG_DEFAULT_CONFIG.copy()
|
||||
config["num_workers"] = 2
|
||||
config["min_sample_timesteps_per_reporting"] = 100
|
||||
config["learning_starts"] = 0
|
||||
config["replay_buffer_config"]["learning_starts"] = 0
|
||||
config["optimizer"]["num_replay_buffer_shards"] = 1
|
||||
num_iterations = 1
|
||||
for _ in framework_iterator(config, with_eager_tracing=True):
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import numpy as np
|
||||
import re
|
||||
import unittest
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import ray
|
||||
import ray.rllib.agents.ddpg as ddpg
|
||||
|
@ -61,6 +62,23 @@ class TestDDPG(unittest.TestCase):
|
|||
check(a, 500)
|
||||
trainer.stop()
|
||||
|
||||
def test_ddpg_checkpoint_save_and_restore(self):
|
||||
"""Test whether a DDPGTrainer can save and load checkpoints."""
|
||||
config = ddpg.DEFAULT_CONFIG.copy()
|
||||
config["num_workers"] = 1
|
||||
config["num_envs_per_worker"] = 2
|
||||
config["replay_buffer_config"]["learning_starts"] = 0
|
||||
config["exploration_config"]["random_timesteps"] = 100
|
||||
|
||||
# Test against all frameworks.
|
||||
for _ in framework_iterator(config, with_eager_tracing=True):
|
||||
trainer = ddpg.DDPGTrainer(config=config, env="Pendulum-v1")
|
||||
trainer.train()
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
checkpoint = trainer.save(temp_dir)
|
||||
trainer.restore(checkpoint)
|
||||
trainer.stop()
|
||||
|
||||
def test_ddpg_exploration_and_with_random_prerun(self):
|
||||
"""Tests DDPG's Exploration (w/ random actions for n timesteps)."""
|
||||
|
||||
|
@ -131,7 +149,6 @@ class TestDDPG(unittest.TestCase):
|
|||
# Run locally.
|
||||
config.seed = 42
|
||||
config.num_workers = 0
|
||||
config.learning_starts = 0
|
||||
config.twin_q = True
|
||||
config.use_huber = True
|
||||
config.huber_threshold = 1.0
|
||||
|
@ -139,9 +156,9 @@ class TestDDPG(unittest.TestCase):
|
|||
# Make this small (seems to introduce errors).
|
||||
config.l2_reg = 1e-10
|
||||
config.replay_buffer_config = {
|
||||
"_enable_replay_buffer_api": True,
|
||||
"type": "MultiAgentReplayBuffer",
|
||||
"capacity": 50000,
|
||||
"learning_starts": 0,
|
||||
}
|
||||
# Use very simple nets.
|
||||
config.actor_hiddens = [10]
|
||||
|
|
|
@ -23,7 +23,6 @@ from ray.actor import ActorHandle
|
|||
from ray.rllib import RolloutWorker
|
||||
from ray.rllib.agents import Trainer
|
||||
from ray.rllib.agents.dqn.dqn import (
|
||||
calculate_rr_weights,
|
||||
DEFAULT_CONFIG as DQN_DEFAULT_CONFIG,
|
||||
DQNTrainer,
|
||||
)
|
||||
|
@ -35,16 +34,10 @@ from ray.rllib.execution.common import (
|
|||
_get_global_vars,
|
||||
_get_shared_metrics,
|
||||
)
|
||||
from ray.rllib.execution.concurrency_ops import Concurrently, Dequeue, Enqueue
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.execution.buffers.multi_agent_replay_buffer import ReplayActor
|
||||
from ray.rllib.execution.parallel_requests import (
|
||||
asynchronous_parallel_requests,
|
||||
wait_asynchronous_requests,
|
||||
)
|
||||
from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
|
||||
from ray.rllib.execution.rollout_ops import ParallelRollouts
|
||||
from ray.rllib.execution.train_ops import UpdateTargetNetwork
|
||||
from ray.rllib.utils import merge_dicts
|
||||
from ray.rllib.utils.actors import create_colocated_actors
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
@ -59,7 +52,6 @@ from ray.rllib.utils.metrics import (
|
|||
SYNCH_WORKER_WEIGHTS_TIMER,
|
||||
TARGET_NET_UPDATE_TIMER,
|
||||
)
|
||||
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
|
||||
from ray.rllib.utils.typing import (
|
||||
SampleBatchType,
|
||||
TrainerConfigDict,
|
||||
|
@ -69,7 +61,7 @@ from ray.rllib.utils.typing import (
|
|||
)
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.utils.placement_groups import PlacementGroupFactory
|
||||
from ray.util.iter import LocalIterator
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
|
@ -90,29 +82,32 @@ APEX_DEFAULT_CONFIG = merge_dicts(
|
|||
# TODO(jungong) : add proper replay_buffer_config after
|
||||
# DistributedReplayBuffer type is supported.
|
||||
"replay_buffer_config": {
|
||||
# For now we don't use the new ReplayBuffer API here
|
||||
"_enable_replay_buffer_api": False,
|
||||
"no_local_replay_buffer": True,
|
||||
"type": "MultiAgentReplayBuffer",
|
||||
# Specify prioritized replay by supplying a buffer type that supports
|
||||
# prioritization
|
||||
"prioritized_replay": DEPRECATED_VALUE,
|
||||
"type": "MultiAgentPrioritizedReplayBuffer",
|
||||
"capacity": 2000000,
|
||||
"replay_batch_size": 32,
|
||||
# Alpha parameter for prioritized replay buffer.
|
||||
"prioritized_replay_alpha": 0.6,
|
||||
# Beta parameter for sampling from prioritized replay buffer.
|
||||
"prioritized_replay_beta": 0.4,
|
||||
# Epsilon to add to the TD errors when updating priorities.
|
||||
"prioritized_replay_eps": 1e-6,
|
||||
"learning_starts": 50000,
|
||||
# Whether all shards of the replay buffer must be co-located
|
||||
# with the learner process (running the execution plan).
|
||||
# This is preferred b/c the learner process should have quick
|
||||
# access to the data from the buffer shards, avoiding network
|
||||
# traffic each time samples from the buffer(s) are drawn.
|
||||
# Set this to False for relaxing this constraint and allowing
|
||||
# replay shards to be created on node(s) other than the one
|
||||
# on which the learner is located.
|
||||
"replay_buffer_shards_colocated_with_driver": True,
|
||||
"worker_side_prioritization": True,
|
||||
},
|
||||
# Whether all shards of the replay buffer must be co-located
|
||||
# with the learner process (running the execution plan).
|
||||
# This is preferred b/c the learner process should have quick
|
||||
# access to the data from the buffer shards, avoiding network
|
||||
# traffic each time samples from the buffer(s) are drawn.
|
||||
# Set this to False for relaxing this constraint and allowing
|
||||
# replay shards to be created on node(s) other than the one
|
||||
# on which the learner is located.
|
||||
"replay_buffer_shards_colocated_with_driver": True,
|
||||
|
||||
"learning_starts": 50000,
|
||||
"train_batch_size": 512,
|
||||
"rollout_fragment_length": 50,
|
||||
# Update the target network every `target_network_update_freq` sample timesteps.
|
||||
|
@ -125,7 +120,6 @@ APEX_DEFAULT_CONFIG = merge_dicts(
|
|||
# executed. Set to 0 for no minimum timesteps.
|
||||
"min_sample_timesteps_per_reporting": 25000,
|
||||
"exploration_config": {"type": "PerWorkerEpsilonGreedy"},
|
||||
"worker_side_prioritization": True,
|
||||
"min_time_s_per_reporting": 30,
|
||||
# This will set the ratio of replayed from a buffer and learned
|
||||
# on timesteps to sampled from an environment and stored in the replay
|
||||
|
@ -188,26 +182,27 @@ class ApexTrainer(DQNTrainer):
|
|||
]
|
||||
|
||||
num_replay_buffer_shards = self.config["optimizer"]["num_replay_buffer_shards"]
|
||||
buffer_size = (
|
||||
|
||||
# Create copy here so that we can modify without breaking other logic
|
||||
replay_actor_config = copy.deepcopy(self.config["replay_buffer_config"])
|
||||
|
||||
replay_actor_config["capacity"] = (
|
||||
self.config["replay_buffer_config"]["capacity"] // num_replay_buffer_shards
|
||||
)
|
||||
replay_actor_args = [
|
||||
num_replay_buffer_shards,
|
||||
self.config["learning_starts"],
|
||||
buffer_size,
|
||||
self.config["train_batch_size"],
|
||||
self.config["replay_buffer_config"]["prioritized_replay_alpha"],
|
||||
self.config["replay_buffer_config"]["prioritized_replay_beta"],
|
||||
self.config["replay_buffer_config"]["prioritized_replay_eps"],
|
||||
self.config["multiagent"]["replay_mode"],
|
||||
self.config["replay_buffer_config"].get("replay_sequence_length", 1),
|
||||
]
|
||||
|
||||
ReplayActor = ray.remote(num_cpus=0)(replay_actor_config["type"])
|
||||
|
||||
# Place all replay buffer shards on the same node as the learner
|
||||
# (driver process that runs this execution plan).
|
||||
if self.config["replay_buffer_shards_colocated_with_driver"]:
|
||||
if replay_actor_config["replay_buffer_shards_colocated_with_driver"]:
|
||||
self.replay_actors = create_colocated_actors(
|
||||
actor_specs=[ # (class, args, kwargs={}, count)
|
||||
(ReplayActor, replay_actor_args, {}, num_replay_buffer_shards)
|
||||
(
|
||||
ReplayActor,
|
||||
None,
|
||||
replay_actor_config,
|
||||
num_replay_buffer_shards,
|
||||
)
|
||||
],
|
||||
node=platform.node(), # localhost
|
||||
)[
|
||||
|
@ -216,7 +211,7 @@ class ApexTrainer(DQNTrainer):
|
|||
# Place replay buffer shards on any node(s).
|
||||
else:
|
||||
self.replay_actors = [
|
||||
ReplayActor.remote(*replay_actor_args)
|
||||
ReplayActor.remote(*replay_actor_config)
|
||||
for _ in range(num_replay_buffer_shards)
|
||||
]
|
||||
self.learner_thread = LearnerThread(self.workers.local_worker())
|
||||
|
@ -273,151 +268,6 @@ class ApexTrainer(DQNTrainer):
|
|||
|
||||
return copy.deepcopy(self.learner_thread.learner_info)
|
||||
|
||||
@staticmethod
|
||||
@override(DQNTrainer)
|
||||
def execution_plan(
|
||||
workers: WorkerSet, config: dict, **kwargs
|
||||
) -> LocalIterator[dict]:
|
||||
assert (
|
||||
len(kwargs) == 0
|
||||
), "Apex execution_plan does NOT take any additional parameters"
|
||||
|
||||
# Create a number of replay buffer actors.
|
||||
num_replay_buffer_shards = config["optimizer"]["num_replay_buffer_shards"]
|
||||
buffer_size = (
|
||||
config["replay_buffer_config"]["capacity"] // num_replay_buffer_shards
|
||||
)
|
||||
replay_actor_args = [
|
||||
num_replay_buffer_shards,
|
||||
config["learning_starts"],
|
||||
buffer_size,
|
||||
config["train_batch_size"],
|
||||
config["replay_buffer_config"]["prioritized_replay_alpha"],
|
||||
config["replay_buffer_config"]["prioritized_replay_beta"],
|
||||
config["replay_buffer_config"]["prioritized_replay_eps"],
|
||||
config["multiagent"]["replay_mode"],
|
||||
config["replay_buffer_config"].get("replay_sequence_length", 1),
|
||||
]
|
||||
# Place all replay buffer shards on the same node as the learner
|
||||
# (driver process that runs this execution plan).
|
||||
if config["replay_buffer_shards_colocated_with_driver"]:
|
||||
replay_actors = create_colocated_actors(
|
||||
actor_specs=[
|
||||
# (class, args, kwargs={}, count)
|
||||
(ReplayActor, replay_actor_args, {}, num_replay_buffer_shards)
|
||||
],
|
||||
node=platform.node(), # localhost
|
||||
)[
|
||||
0
|
||||
] # [0]=only one item in `actor_specs`.
|
||||
# Place replay buffer shards on any node(s).
|
||||
else:
|
||||
replay_actors = [
|
||||
ReplayActor(*replay_actor_args) for _ in range(num_replay_buffer_shards)
|
||||
]
|
||||
|
||||
# Start the learner thread.
|
||||
learner_thread = LearnerThread(workers.local_worker())
|
||||
learner_thread.start()
|
||||
|
||||
# Update experience priorities post learning.
|
||||
def update_prio_and_stats(item: Tuple[ActorHandle, dict, int, int]) -> None:
|
||||
actor, prio_dict, env_count, agent_count = item
|
||||
if config["replay_buffer_config"].get("prioritized_replay_alpha") > 0:
|
||||
actor.update_priorities.remote(prio_dict)
|
||||
metrics = _get_shared_metrics()
|
||||
# Manually update the steps trained counter since the learner
|
||||
# thread is executing outside the pipeline.
|
||||
metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = env_count
|
||||
metrics.counters[STEPS_TRAINED_COUNTER] += env_count
|
||||
metrics.timers["learner_dequeue"] = learner_thread.queue_timer
|
||||
metrics.timers["learner_grad"] = learner_thread.grad_timer
|
||||
metrics.timers["learner_overall"] = learner_thread.overall_timer
|
||||
|
||||
# We execute the following steps concurrently:
|
||||
# (1) Generate rollouts and store them in one of our replay buffer
|
||||
# actors. Update the weights of the worker that generated the batch.
|
||||
rollouts = ParallelRollouts(workers, mode="async", num_async=2)
|
||||
store_op = rollouts.for_each(StoreToReplayBuffer(actors=replay_actors))
|
||||
# Only need to update workers if there are remote workers.
|
||||
if workers.remote_workers():
|
||||
store_op = store_op.zip_with_source_actor().for_each(
|
||||
UpdateWorkerWeights(
|
||||
learner_thread,
|
||||
workers,
|
||||
max_weight_sync_delay=(
|
||||
config["optimizer"]["max_weight_sync_delay"]
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# (2) Read experiences from one of the replay buffer actors and send
|
||||
# to the learner thread via its in-queue.
|
||||
post_fn = config.get("before_learn_on_batch") or (lambda b, *a: b)
|
||||
replay_op = (
|
||||
Replay(actors=replay_actors, num_async=4)
|
||||
.for_each(lambda x: post_fn(x, workers, config))
|
||||
.zip_with_source_actor()
|
||||
.for_each(Enqueue(learner_thread.inqueue))
|
||||
)
|
||||
|
||||
# (3) Get priorities back from learner thread and apply them to the
|
||||
# replay buffer actors.
|
||||
update_op = (
|
||||
Dequeue(learner_thread.outqueue, check=learner_thread.is_alive)
|
||||
.for_each(update_prio_and_stats)
|
||||
.for_each(
|
||||
UpdateTargetNetwork(
|
||||
workers, config["target_network_update_freq"], by_steps_trained=True
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if config["training_intensity"]:
|
||||
# Execute (1), (2) with a fixed intensity ratio.
|
||||
rr_weights = calculate_rr_weights(config) + ["*"]
|
||||
merged_op = Concurrently(
|
||||
[store_op, replay_op, update_op],
|
||||
mode="round_robin",
|
||||
output_indexes=[2],
|
||||
round_robin_weights=rr_weights,
|
||||
)
|
||||
else:
|
||||
# Execute (1), (2), (3) asynchronously as fast as possible. Only
|
||||
# output items from (3) since metrics aren't available before
|
||||
# then.
|
||||
merged_op = Concurrently(
|
||||
[store_op, replay_op, update_op], mode="async", output_indexes=[2]
|
||||
)
|
||||
|
||||
# Add in extra replay and learner metrics to the training result.
|
||||
def add_apex_metrics(result: dict) -> dict:
|
||||
replay_stats = ray.get(
|
||||
replay_actors[0].stats.remote(config["optimizer"].get("debug"))
|
||||
)
|
||||
exploration_infos = workers.foreach_policy_to_train(
|
||||
lambda p, _: p.get_exploration_state()
|
||||
)
|
||||
result["info"].update(
|
||||
{
|
||||
"exploration_infos": exploration_infos,
|
||||
"learner_queue": learner_thread.learner_queue_size.stats(),
|
||||
LEARNER_INFO: copy.deepcopy(learner_thread.learner_info),
|
||||
"replay_shard_0": replay_stats,
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
# Only report metrics from the workers with the lowest 1/3 of
|
||||
# epsilons.
|
||||
selected_workers = workers.remote_workers()[
|
||||
-len(workers.remote_workers()) // 3 :
|
||||
]
|
||||
|
||||
return StandardMetricsReporting(
|
||||
merged_op, workers, config, selected_workers=selected_workers
|
||||
).for_each(add_apex_metrics)
|
||||
|
||||
def get_samples_and_store_to_replay_buffers(self):
|
||||
# in the case the num_workers = 0
|
||||
if not self.workers.remote_workers():
|
||||
|
@ -425,7 +275,7 @@ class ApexTrainer(DQNTrainer):
|
|||
local_sampling_worker = self.workers.local_worker()
|
||||
batch = local_sampling_worker.sample()
|
||||
actor = random.choice(self.replay_actors)
|
||||
ray.get(actor.add_batch.remote(batch))
|
||||
ray.get(actor.add.remote(batch))
|
||||
batch_statistics = {
|
||||
local_sampling_worker: [
|
||||
{
|
||||
|
@ -437,7 +287,7 @@ class ApexTrainer(DQNTrainer):
|
|||
return batch_statistics
|
||||
|
||||
def remote_worker_sample_and_store(
|
||||
worker: RolloutWorker, replay_actors: List[ReplayActor]
|
||||
worker: RolloutWorker, replay_actors: List[ActorHandle]
|
||||
):
|
||||
# This function is run as a remote function on sampling workers,
|
||||
# and should only be used with the RolloutWorker's apply function ever.
|
||||
|
@ -447,7 +297,7 @@ class ApexTrainer(DQNTrainer):
|
|||
# operation on there.
|
||||
_batch = worker.sample()
|
||||
_actor = random.choice(replay_actors)
|
||||
_actor.add_batch.remote(_batch)
|
||||
_actor.add.remote(_batch)
|
||||
_batch_statistics = {
|
||||
"agent_steps": _batch.agent_steps(),
|
||||
"env_steps": _batch.env_steps(),
|
||||
|
@ -550,7 +400,8 @@ class ApexTrainer(DQNTrainer):
|
|||
actors=[rand_actor],
|
||||
ray_wait_timeout_s=0.1,
|
||||
max_remote_requests_in_flight_per_actor=num_requests_to_launch,
|
||||
remote_fn=lambda actor: actor.replay(),
|
||||
remote_args=[[self.config["train_batch_size"]]],
|
||||
remote_fn=lambda actor, num_items: actor.sample(num_items),
|
||||
)
|
||||
for replay_actor, sample_batches in replay_samples_ready.items():
|
||||
for sample_batch in sample_batches:
|
||||
|
|
|
@ -46,6 +46,7 @@ from ray.rllib.execution.common import (
|
|||
LAST_TARGET_UPDATE_TS,
|
||||
NUM_TARGET_UPDATES,
|
||||
)
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -133,17 +134,26 @@ class DQNConfig(SimpleQConfig):
|
|||
self.n_step = 1
|
||||
self.before_learn_on_batch = None
|
||||
self.training_intensity = None
|
||||
self.worker_side_prioritization = False
|
||||
|
||||
# Changes to SimpleQConfig default
|
||||
self.replay_buffer_config = {
|
||||
"_enable_replay_buffer_api": True,
|
||||
"type": "MultiAgentPrioritizedReplayBuffer",
|
||||
# Specify prioritized replay by supplying a buffer type that supports
|
||||
# prioritization, for example: MultiAgentPrioritizedReplayBuffer.
|
||||
"prioritized_replay": DEPRECATED_VALUE,
|
||||
# Size of the replay buffer. Note that if async_updates is set,
|
||||
# then each worker will have a replay buffer of this size.
|
||||
"capacity": 50000,
|
||||
"prioritized_replay_alpha": 0.6,
|
||||
# Beta parameter for sampling from prioritized replay buffer.
|
||||
"prioritized_replay_beta": 0.4,
|
||||
# Epsilon to add to the TD errors when updating priorities.
|
||||
"prioritized_replay_eps": 1e-6,
|
||||
# The number of continuous environment steps to replay at once. This may
|
||||
# be set to greater than 1 to support recurrent models.
|
||||
"replay_sequence_length": 1,
|
||||
# Whether to compute priorities on workers.
|
||||
"worker_side_prioritization": False,
|
||||
}
|
||||
# fmt: on
|
||||
# __sphinx_doc_end__
|
||||
|
|
|
@ -442,7 +442,9 @@ def postprocess_nstep_and_prio(
|
|||
batch[PRIO_WEIGHTS] = np.ones_like(batch[SampleBatch.REWARDS])
|
||||
|
||||
# Prioritize on the worker side.
|
||||
if batch.count > 0 and policy.config["worker_side_prioritization"]:
|
||||
if batch.count > 0 and policy.config["replay_buffer_config"].get(
|
||||
"worker_side_prioritization", False
|
||||
):
|
||||
td_errors = policy.compute_td_error(
|
||||
batch[SampleBatch.OBS],
|
||||
batch[SampleBatch.ACTIONS],
|
||||
|
|
|
@ -32,8 +32,10 @@ R2D2_DEFAULT_CONFIG = Trainer.merge_trainer_configs(
|
|||
|
||||
# === Replay buffer ===
|
||||
"replay_buffer_config": {
|
||||
"_enable_replay_buffer_api": True,
|
||||
"type": "MultiAgentReplayBuffer",
|
||||
# Specify prioritized replay by supplying a buffer type that supports
|
||||
# prioritization, for example: MultiAgentPrioritizedReplayBuffer.
|
||||
"prioritized_replay": DEPRECATED_VALUE,
|
||||
# Size of the replay buffer (in sequences, not timesteps).
|
||||
"capacity": 100000,
|
||||
"storage_unit": "sequences",
|
||||
|
|
|
@ -108,11 +108,13 @@ class SimpleQConfig(TrainerConfig):
|
|||
# __sphinx_doc_begin__
|
||||
self.target_network_update_freq = 500
|
||||
self.replay_buffer_config = {
|
||||
"_enable_replay_buffer_api": True,
|
||||
# How many steps of the model to sample before learning starts.
|
||||
"learning_starts": 1000,
|
||||
"type": "MultiAgentReplayBuffer",
|
||||
"capacity": 50000,
|
||||
"replay_batch_size": 32,
|
||||
# The number of contiguous environment steps to replay at once. This
|
||||
# may be set to greater than 1 to support recurrent models.
|
||||
"replay_sequence_length": 1,
|
||||
}
|
||||
self.store_buffer_in_checkpoints = False
|
||||
|
@ -151,7 +153,8 @@ class SimpleQConfig(TrainerConfig):
|
|||
self.prioritized_replay = DEPRECATED_VALUE
|
||||
self.learning_starts = DEPRECATED_VALUE
|
||||
self.replay_batch_size = DEPRECATED_VALUE
|
||||
self.replay_sequence_length = DEPRECATED_VALUE
|
||||
# Can not use DEPRECATED_VALUE here because -1 is a common config value
|
||||
self.replay_sequence_length = None
|
||||
self.prioritized_replay_alpha = DEPRECATED_VALUE
|
||||
self.prioritized_replay_beta = DEPRECATED_VALUE
|
||||
self.prioritized_replay_eps = DEPRECATED_VALUE
|
||||
|
|
|
@ -24,8 +24,9 @@ class TestApexDQN(unittest.TestCase):
|
|||
config = apex.APEX_DEFAULT_CONFIG.copy()
|
||||
config["num_workers"] = 0
|
||||
config["num_gpus"] = 0
|
||||
config["learning_starts"] = 1000
|
||||
config["prioritized_replay"] = True
|
||||
config["replay_buffer_config"] = {
|
||||
"learning_starts": 1000,
|
||||
}
|
||||
config["min_sample_timesteps_per_reporting"] = 100
|
||||
config["min_time_s_per_reporting"] = 1
|
||||
config["optimizer"]["num_replay_buffer_shards"] = 1
|
||||
|
@ -41,8 +42,9 @@ class TestApexDQN(unittest.TestCase):
|
|||
config = apex.APEX_DEFAULT_CONFIG.copy()
|
||||
config["num_workers"] = 3
|
||||
config["num_gpus"] = 0
|
||||
config["learning_starts"] = 1000
|
||||
config["prioritized_replay"] = True
|
||||
config["replay_buffer_config"] = {
|
||||
"learning_starts": 1000,
|
||||
}
|
||||
config["min_sample_timesteps_per_reporting"] = 100
|
||||
config["min_time_s_per_reporting"] = 1
|
||||
config["optimizer"]["num_replay_buffer_shards"] = 1
|
||||
|
@ -78,14 +80,12 @@ class TestApexDQN(unittest.TestCase):
|
|||
config = apex.APEX_DEFAULT_CONFIG.copy()
|
||||
config["num_workers"] = 1
|
||||
config["num_gpus"] = 0
|
||||
config["learning_starts"] = 10
|
||||
config["train_batch_size"] = 10
|
||||
config["rollout_fragment_length"] = 5
|
||||
config["replay_buffer_config"] = {
|
||||
# For now we don't use the new ReplayBuffer API here
|
||||
"_enable_replay_buffer_api": False,
|
||||
"no_local_replay_buffer": True,
|
||||
"type": "MultiAgentReplayBuffer",
|
||||
"type": "MultiAgentPrioritizedReplayBuffer",
|
||||
"learning_starts": 10,
|
||||
"capacity": 100,
|
||||
"replay_batch_size": 10,
|
||||
"prioritized_replay_alpha": 0.6,
|
||||
|
|
|
@ -3,6 +3,7 @@ import unittest
|
|||
import ray
|
||||
import ray.rllib.agents.dqn as dqn
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
|
||||
from ray.rllib.utils.test_utils import (
|
||||
check_compute_single_action,
|
||||
check_train_results,
|
||||
|
@ -13,6 +14,28 @@ tf1, tf, tfv = try_import_tf()
|
|||
torch, nn = try_import_torch()
|
||||
|
||||
|
||||
def check_batch_sizes(train_results):
|
||||
"""Check if batch sizes are according to what we expect from config."""
|
||||
info = train_results["info"]
|
||||
learner_info = info[LEARNER_INFO]
|
||||
|
||||
for pid, policy_stats in learner_info.items():
|
||||
if pid == "batch_count":
|
||||
continue
|
||||
# Expect td-errors to be per batch-item.
|
||||
configured_b = train_results["config"]["train_batch_size"]
|
||||
actual_b = policy_stats["td_error"].shape[0]
|
||||
if (configured_b - actual_b) / actual_b > 0.1:
|
||||
assert (
|
||||
configured_b
|
||||
/ (
|
||||
train_results["config"]["model"]["max_seq_len"]
|
||||
+ train_results["config"]["replay_buffer_config"]["replay_burn_in"]
|
||||
)
|
||||
== actual_b
|
||||
)
|
||||
|
||||
|
||||
class TestR2D2(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
|
@ -47,6 +70,7 @@ class TestR2D2(unittest.TestCase):
|
|||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
check_train_results(results)
|
||||
check_batch_sizes(results)
|
||||
print(results)
|
||||
|
||||
check_compute_single_action(trainer, include_state=True)
|
||||
|
|
|
@ -19,8 +19,8 @@ from ray.rllib.policy.policy import Policy
|
|||
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
|
||||
from ray.rllib.utils import merge_dicts
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
@ -69,12 +69,14 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"adv_policy": "maddpg",
|
||||
|
||||
# === Replay buffer ===
|
||||
# Size of the replay buffer. Note that if async_updates is set, then
|
||||
# each worker will have a replay buffer of this size.
|
||||
"buffer_size": DEPRECATED_VALUE,
|
||||
"replay_buffer_config": {
|
||||
"type": "MultiAgentReplayBuffer",
|
||||
# Specify prioritized replay by supplying a buffer type that supports
|
||||
# prioritization, for example: MultiAgentPrioritizedReplayBuffer.
|
||||
"prioritized_replay": DEPRECATED_VALUE,
|
||||
"capacity": int(1e6),
|
||||
# How many steps of the model to sample before learning starts.
|
||||
"learning_starts": 1024 * 25,
|
||||
},
|
||||
# Observation compression. Note that compression makes simulation slow in
|
||||
# MPE.
|
||||
|
@ -102,8 +104,6 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"actor_feature_reg": 0.001,
|
||||
# If not None, clip gradients during optimization at this value
|
||||
"grad_norm_clipping": 0.5,
|
||||
# How many steps of the model to sample before learning starts.
|
||||
"learning_starts": 1024 * 25,
|
||||
# Update the replay buffer with this many samples at once. Note that this
|
||||
# setting applies per-worker if num_workers > 1.
|
||||
"rollout_fragment_length": 100,
|
||||
|
|
|
@ -21,6 +21,7 @@ from ray.rllib.utils.metrics import (
|
|||
)
|
||||
from ray.rllib.utils.replay_buffers.utils import sample_min_n_steps_from_buffer
|
||||
from ray.rllib.utils.typing import ResultDict, TrainerConfigDict
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
||||
|
||||
|
||||
class QMixConfig(SimpleQConfig):
|
||||
|
@ -78,12 +79,15 @@ class QMixConfig(SimpleQConfig):
|
|||
self.train_batch_size = 32
|
||||
self.target_network_update_freq = 500
|
||||
self.replay_buffer_config = {
|
||||
# Use the new ReplayBuffer API here
|
||||
"_enable_replay_buffer_api": True,
|
||||
"type": "SimpleReplayBuffer",
|
||||
# Specify prioritized replay by supplying a buffer type that supports
|
||||
# prioritization, for example: MultiAgentPrioritizedReplayBuffer.
|
||||
"prioritized_replay": DEPRECATED_VALUE,
|
||||
# Size of the replay buffer in batches (not timesteps!).
|
||||
"capacity": 1000,
|
||||
"learning_starts": 1000,
|
||||
# Whether to compute priorities on workers.
|
||||
"worker_side_prioritization": False,
|
||||
}
|
||||
self.model = {
|
||||
"lstm_cell_size": 64,
|
||||
|
|
|
@ -58,7 +58,8 @@ class RNNSACTrainer(SACTrainer):
|
|||
)
|
||||
# Check if user tries to set replay_sequence_length (to anything
|
||||
# other than the proper value)
|
||||
if config["replay_buffer_config"]["replay_sequence_length"] not in [
|
||||
if config["replay_buffer_config"].get("replay_sequence_length", None) not in [
|
||||
None,
|
||||
-1,
|
||||
replay_sequence_length,
|
||||
]:
|
||||
|
|
|
@ -84,22 +84,23 @@ DEFAULT_CONFIG = with_common_config({
|
|||
|
||||
# === Replay buffer ===
|
||||
"replay_buffer_config": {
|
||||
# Enable the new ReplayBuffer API.
|
||||
"_enable_replay_buffer_api": True,
|
||||
"type": "MultiAgentPrioritizedReplayBuffer",
|
||||
"type": "MultiAgentReplayBuffer",
|
||||
# Specify prioritized replay by supplying a buffer type that supports
|
||||
# prioritization, for example: MultiAgentPrioritizedReplayBuffer.
|
||||
"prioritized_replay": DEPRECATED_VALUE,
|
||||
"capacity": int(1e6),
|
||||
# How many steps of the model to sample before learning starts.
|
||||
"learning_starts": 1500,
|
||||
# The number of continuous environment steps to replay at once. This may
|
||||
# be set to greater than 1 to support recurrent models.
|
||||
"replay_sequence_length": 1,
|
||||
# If True prioritized replay buffer will be used.
|
||||
"prioritized_replay": False,
|
||||
"prioritized_replay_alpha": 0.6,
|
||||
# Beta parameter for sampling from prioritized replay buffer.
|
||||
"prioritized_replay_beta": 0.4,
|
||||
# Epsilon to add to the TD errors when updating priorities.
|
||||
"prioritized_replay_eps": 1e-6,
|
||||
# Whether to compute priorities on workers.
|
||||
"worker_side_prioritization": False,
|
||||
},
|
||||
# Set this to True, if you want the contents of your buffer(s) to be
|
||||
# stored in any saved checkpoints as well.
|
||||
|
@ -155,8 +156,6 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"num_gpus_per_worker": 0,
|
||||
# Whether to allocate CPUs for workers (if > 0).
|
||||
"num_cpus_per_worker": 1,
|
||||
# Whether to compute priorities on workers.
|
||||
"worker_side_prioritization": False,
|
||||
# Prevent reporting frequency from going lower than this time span.
|
||||
"min_time_s_per_reporting": 1,
|
||||
|
||||
|
@ -167,18 +166,6 @@ DEFAULT_CONFIG = with_common_config({
|
|||
# Use a Beta-distribution instead of a SquashedGaussian for bounded,
|
||||
# continuous action spaces (not recommended, for debugging only).
|
||||
"_use_beta_distribution": False,
|
||||
|
||||
# Deprecated.
|
||||
# The following values have moved because of the new ReplayBuffer API.
|
||||
"prioritized_replay": DEPRECATED_VALUE,
|
||||
"prioritized_replay_alpha": DEPRECATED_VALUE,
|
||||
"prioritized_replay_beta": DEPRECATED_VALUE,
|
||||
"prioritized_replay_eps": DEPRECATED_VALUE,
|
||||
"learning_starts": DEPRECATED_VALUE,
|
||||
"buffer_size": DEPRECATED_VALUE,
|
||||
"replay_batch_size": DEPRECATED_VALUE,
|
||||
"replay_sequence_length": DEPRECATED_VALUE,
|
||||
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
|
|
|
@ -42,11 +42,12 @@ class TestRNNSAC(unittest.TestCase):
|
|||
"lstm_use_prev_reward": True,
|
||||
}
|
||||
|
||||
# Test with PR activated.
|
||||
config["prioritized_replay"] = True
|
||||
|
||||
config["burn_in"] = 20
|
||||
config["zero_init_states"] = True
|
||||
# Test with MultiAgentPrioritizedReplayBuffer
|
||||
config["replay_buffer_config"] = {
|
||||
"type": "MultiAgentPrioritizedReplayBuffer",
|
||||
"replay_burn_in": 20,
|
||||
"zero_init_states": True,
|
||||
}
|
||||
|
||||
config["lr"] = 5e-4
|
||||
|
||||
|
|
|
@ -523,7 +523,7 @@ class TestSAC(unittest.TestCase):
|
|||
|
||||
config = sac.DEFAULT_CONFIG.copy()
|
||||
config["num_workers"] = 0 # Run locally.
|
||||
config["learning_starts"] = 0
|
||||
config["replay_buffer_config"]["learning_starts"] = 0
|
||||
config["rollout_fragment_length"] = 5
|
||||
config["train_batch_size"] = 5
|
||||
config["replay_buffer_config"]["capacity"] = 10
|
||||
|
|
|
@ -77,10 +77,7 @@ class SlateQConfig(TrainerConfig):
|
|||
self.rmsprop_epsilon = 1e-5
|
||||
self.grad_clip = None
|
||||
self.n_step = 1
|
||||
self.worker_side_prioritization = False
|
||||
self.replay_buffer_config = {
|
||||
# Enable the new ReplayBuffer API.
|
||||
"_enable_replay_buffer_api": True,
|
||||
"type": "MultiAgentPrioritizedReplayBuffer",
|
||||
"capacity": 100000,
|
||||
"prioritized_replay_alpha": 0.6,
|
||||
|
@ -91,6 +88,8 @@ class SlateQConfig(TrainerConfig):
|
|||
# The number of continuous environment steps to replay at once. This may
|
||||
# be set to greater than 1 to support recurrent models.
|
||||
"replay_sequence_length": 1,
|
||||
# Whether to compute priorities on workers.
|
||||
"worker_side_prioritization": False,
|
||||
# How many steps of the model to sample before learning starts.
|
||||
"learning_starts": 20000,
|
||||
}
|
||||
|
|
|
@ -40,9 +40,6 @@ from ray.rllib.evaluation.metrics import (
|
|||
)
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.execution.buffers.multi_agent_replay_buffer import (
|
||||
MultiAgentReplayBuffer as Legacy_MultiAgentReplayBuffer,
|
||||
)
|
||||
from ray.rllib.utils.replay_buffers import MultiAgentReplayBuffer
|
||||
from ray.rllib.execution.common import WORKER_UPDATE_TIMER
|
||||
from ray.rllib.execution.rollout_ops import (
|
||||
|
@ -2127,7 +2124,7 @@ class Trainer(Trainable):
|
|||
@DeveloperAPI
|
||||
def _create_local_replay_buffer_if_necessary(
|
||||
self, config: PartialTrainerConfigDict
|
||||
) -> Optional[Union[MultiAgentReplayBuffer, Legacy_MultiAgentReplayBuffer]]:
|
||||
) -> Optional[MultiAgentReplayBuffer]:
|
||||
"""Create a MultiAgentReplayBuffer instance if necessary.
|
||||
|
||||
Args:
|
||||
|
@ -2138,7 +2135,7 @@ class Trainer(Trainable):
|
|||
None, if local replay buffer is not needed.
|
||||
"""
|
||||
if not config.get("replay_buffer_config") or config["replay_buffer_config"].get(
|
||||
"no_local_replay_buffer" or config.get("no_local_replay_buffer"), False
|
||||
"no_local_replay_buffer" or config.get("no_local_replay_buffer")
|
||||
):
|
||||
return
|
||||
|
||||
|
|
|
@ -226,6 +226,16 @@ class TrainerConfig:
|
|||
self.timesteps_per_iteration = DEPRECATED_VALUE
|
||||
self.min_iter_time_s = DEPRECATED_VALUE
|
||||
self.collect_metrics_timeout = DEPRECATED_VALUE
|
||||
# The following values have moved because of the new ReplayBuffer API
|
||||
self.buffer_size = DEPRECATED_VALUE
|
||||
self.prioritized_replay = DEPRECATED_VALUE
|
||||
self.learning_starts = DEPRECATED_VALUE
|
||||
self.replay_batch_size = DEPRECATED_VALUE
|
||||
# -1 = DEPRECATED_VALUE is a valid value for replay_sequence_length
|
||||
self.replay_sequence_length = None
|
||||
self.prioritized_replay_alpha = DEPRECATED_VALUE
|
||||
self.prioritized_replay_beta = DEPRECATED_VALUE
|
||||
self.prioritized_replay_eps = DEPRECATED_VALUE
|
||||
|
||||
def to_dict(self) -> TrainerConfigDict:
|
||||
"""Converts all settings into a legacy config dict for backward compatibility.
|
||||
|
|
|
@ -51,21 +51,12 @@ CQL_DEFAULT_CONFIG = merge_dicts(
|
|||
"lagrangian_thresh": 5.0,
|
||||
# Min Q weight multiplier.
|
||||
"min_q_weight": 5.0,
|
||||
"replay_buffer_config": {
|
||||
"_enable_replay_buffer_api": True,
|
||||
"type": "MultiAgentPrioritizedReplayBuffer",
|
||||
# Replay buffer should be larger or equal the size of the offline
|
||||
# dataset.
|
||||
"capacity": int(1e6),
|
||||
},
|
||||
# Reporting: As CQL is offline (no sampling steps), we need to limit
|
||||
# `self.train()` reporting by the number of steps trained (not sampled).
|
||||
"min_sample_timesteps_per_reporting": 0,
|
||||
"min_train_timesteps_per_reporting": 100,
|
||||
|
||||
# Deprecated keys.
|
||||
# Use `replay_buffer_config.capacity` instead.
|
||||
"buffer_size": DEPRECATED_VALUE,
|
||||
# Use `min_sample_timesteps_per_reporting` and
|
||||
# `min_train_timesteps_per_reporting` instead.
|
||||
"timesteps_per_iteration": DEPRECATED_VALUE,
|
||||
|
|
|
@ -2,7 +2,7 @@ from typing import Type
|
|||
|
||||
from ray.rllib.agents.trainer import Trainer, with_common_config
|
||||
from ray.rllib.algorithms.marwil.marwil_tf_policy import MARWILTFPolicy
|
||||
from ray.rllib.execution.buffers.multi_agent_replay_buffer import MultiAgentReplayBuffer
|
||||
from ray.rllib.utils.replay_buffers.utils import validate_buffer_config
|
||||
from ray.rllib.execution.rollout_ops import (
|
||||
synchronous_parallel_sample,
|
||||
)
|
||||
|
@ -18,10 +18,10 @@ from ray.rllib.utils.metrics import (
|
|||
WORKER_UPDATE_TIMER,
|
||||
)
|
||||
from ray.rllib.utils.typing import (
|
||||
PartialTrainerConfigDict,
|
||||
ResultDict,
|
||||
TrainerConfigDict,
|
||||
)
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
|
@ -66,14 +66,22 @@ DEFAULT_CONFIG = with_common_config({
|
|||
# Number of (independent) timesteps pushed through the loss
|
||||
# each SGD round.
|
||||
"train_batch_size": 2000,
|
||||
# Size of the replay buffer in (single and independent) timesteps.
|
||||
# The buffer gets filled by reading from the input files line-by-line
|
||||
# and adding all timesteps on one line at once. We then sample
|
||||
# uniformly from the buffer (`train_batch_size` samples) for
|
||||
# each training step.
|
||||
"replay_buffer_size": 10000,
|
||||
# Number of steps to read before learning starts.
|
||||
"learning_starts": 0,
|
||||
|
||||
"replay_buffer_config": {
|
||||
"type": "MultiAgentPrioritizedReplayBuffer",
|
||||
# Size of the replay buffer in (single and independent) timesteps.
|
||||
# The buffer gets filled by reading from the input files line-by-line
|
||||
# and adding all timesteps on one line at once. We then sample
|
||||
# uniformly from the buffer (`train_batch_size` samples) for
|
||||
# each training step.
|
||||
"capacity": 10000,
|
||||
# Specify prioritized replay by supplying a buffer type that supports
|
||||
# prioritization
|
||||
"prioritized_replay": DEPRECATED_VALUE,
|
||||
# Number of steps to read before learning starts.
|
||||
"learning_starts": 0,
|
||||
"replay_sequence_length": 1
|
||||
},
|
||||
|
||||
# A coeff to encourage higher action distribution entropy for exploration.
|
||||
"bc_logstd_coeff": 0.0,
|
||||
|
@ -96,6 +104,8 @@ class MARWILTrainer(Trainer):
|
|||
# Call super's validation method.
|
||||
super().validate_config(config)
|
||||
|
||||
validate_buffer_config(config)
|
||||
|
||||
if config["num_gpus"] > 1:
|
||||
raise ValueError("`num_gpus` > 1 not yet supported for MARWIL!")
|
||||
|
||||
|
@ -116,19 +126,6 @@ class MARWILTrainer(Trainer):
|
|||
else:
|
||||
return MARWILTFPolicy
|
||||
|
||||
@override(Trainer)
|
||||
def setup(self, config: PartialTrainerConfigDict):
|
||||
super().setup(config)
|
||||
# `training_iteration` implementation: Setup buffer in `setup`, not
|
||||
# in `execution_plan` (deprecated).
|
||||
if self.config["_disable_execution_plan_api"] is True:
|
||||
self.local_replay_buffer = MultiAgentReplayBuffer(
|
||||
learning_starts=self.config["learning_starts"],
|
||||
capacity=self.config["replay_buffer_size"],
|
||||
replay_batch_size=self.config["train_batch_size"],
|
||||
replay_sequence_length=1,
|
||||
)
|
||||
|
||||
@override(Trainer)
|
||||
def training_iteration(self) -> ResultDict:
|
||||
# Collect SampleBatches from sample workers.
|
||||
|
@ -137,10 +134,10 @@ class MARWILTrainer(Trainer):
|
|||
self._counters[NUM_AGENT_STEPS_SAMPLED] += batch.agent_steps()
|
||||
self._counters[NUM_ENV_STEPS_SAMPLED] += batch.env_steps()
|
||||
# Add batch to replay buffer.
|
||||
self.local_replay_buffer.add_batch(batch)
|
||||
self.local_replay_buffer.add(batch)
|
||||
|
||||
# Pull batch from replay buffer and train on it.
|
||||
train_batch = self.local_replay_buffer.replay()
|
||||
train_batch = self.local_replay_buffer.sample(self.config["train_batch_size"])
|
||||
# Train.
|
||||
if self.config["simple_optimizer"]:
|
||||
train_results = train_one_step(self, train_batch)
|
||||
|
|
|
@ -116,7 +116,7 @@ if __name__ == "__main__":
|
|||
assert r["model"]["foo"] == 42, result
|
||||
|
||||
if args.run == "DQN":
|
||||
extra_config = {"learning_starts": 0}
|
||||
extra_config = {"replay_buffer_config": {"learning_starts": 0}}
|
||||
else:
|
||||
extra_config = {}
|
||||
|
||||
|
|
|
@ -22,15 +22,17 @@ if __name__ == "__main__":
|
|||
"num_gpus": 1,
|
||||
"num_workers": 2,
|
||||
"num_envs_per_worker": 8,
|
||||
"learning_starts": 1000,
|
||||
"buffer_size": int(1e5),
|
||||
"replay_buffer_config": {
|
||||
"learning_starts": 1000,
|
||||
"capacity": int(1e5),
|
||||
"prioritized_replay_alpha": 0.5,
|
||||
},
|
||||
"compress_observations": True,
|
||||
"rollout_fragment_length": 20,
|
||||
"train_batch_size": 512,
|
||||
"gamma": 0.99,
|
||||
"n_step": 3,
|
||||
"lr": 0.0001,
|
||||
"prioritized_replay_alpha": 0.5,
|
||||
"target_network_update_freq": 50000,
|
||||
"min_sample_timesteps_per_reporting": 25000,
|
||||
# Method specific.
|
||||
|
|
|
@ -38,9 +38,11 @@ if __name__ == "__main__":
|
|||
config["bc_iters"] = 0
|
||||
config["clip_actions"] = False
|
||||
config["normalize_actions"] = True
|
||||
config["learning_starts"] = 256
|
||||
config["replay_buffer_config"]["learning_starts"] = 256
|
||||
config["rollout_fragment_length"] = 1
|
||||
config["prioritized_replay"] = False
|
||||
# Test without prioritized replay
|
||||
config["replay_buffer_config"]["type"] = "MultiAgentReplayBuffer"
|
||||
config["replay_buffer_config"]["capacity"] = int(1e6)
|
||||
config["tau"] = 0.005
|
||||
config["target_entropy"] = "auto"
|
||||
config["Q_model"] = {
|
||||
|
|
|
@ -126,7 +126,9 @@ def main():
|
|||
"num_gpus": args.num_gpus,
|
||||
"num_workers": args.num_workers,
|
||||
"env_config": env_config,
|
||||
"learning_starts": args.learning_starts,
|
||||
"replay_buffer_config": {
|
||||
"learning_starts": args.learning_starts,
|
||||
},
|
||||
}
|
||||
|
||||
# Perform a test run on the env with a random agent to see, what
|
||||
|
|
|
@ -180,7 +180,7 @@ if __name__ == "__main__":
|
|||
# Example of using DQN (supports off-policy actions).
|
||||
config.update(
|
||||
{
|
||||
"learning_starts": 100,
|
||||
"replay_buffer_config": {"learning_starts": 100},
|
||||
"min_sample_timesteps_per_reporting": 200,
|
||||
"n_step": 3,
|
||||
"rollout_fragment_length": 4,
|
||||
|
|
|
@ -115,7 +115,7 @@ if __name__ == "__main__":
|
|||
"env_config": {
|
||||
"actions_are_logits": True,
|
||||
},
|
||||
"learning_starts": 100,
|
||||
"replay_buffer_config": {"learning_starts": 100},
|
||||
"multiagent": {
|
||||
"policies": {
|
||||
"pol1": PolicySpec(
|
||||
|
|
|
@ -21,7 +21,9 @@ from ray.rllib.agents.ppo.ppo_torch_policy import PPOTorchPolicy
|
|||
from ray.rllib.evaluation.postprocessing import Postprocessing
|
||||
from ray.rllib.execution.rollout_ops import synchronous_parallel_sample
|
||||
from ray.rllib.execution.train_ops import train_one_step
|
||||
from ray.rllib.execution.buffers.multi_agent_replay_buffer import MultiAgentReplayBuffer
|
||||
from ray.rllib.utils.replay_buffers.multi_agent_replay_buffer import (
|
||||
MultiAgentReplayBuffer,
|
||||
)
|
||||
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
|
||||
from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
@ -82,7 +84,7 @@ class MyTrainer(Trainer):
|
|||
super().setup(config)
|
||||
# Create local replay buffer.
|
||||
self.local_replay_buffer = MultiAgentReplayBuffer(
|
||||
num_shards=1, learning_starts=1000, capacity=50000, replay_batch_size=64
|
||||
num_shards=1, learning_starts=1000, capacity=50000
|
||||
)
|
||||
|
||||
@override(Trainer)
|
||||
|
@ -103,14 +105,14 @@ class MyTrainer(Trainer):
|
|||
self._counters[NUM_AGENT_STEPS_SAMPLED] += ma_batch.agent_steps()
|
||||
ppo_batch = ma_batch.policy_batches.pop("ppo_policy")
|
||||
# Add collected batches (only for DQN policy) to replay buffer.
|
||||
self.local_replay_buffer.add_batch(ma_batch)
|
||||
self.local_replay_buffer.add(ma_batch)
|
||||
|
||||
ppo_batches.append(ppo_batch)
|
||||
num_env_steps += ppo_batch.count
|
||||
|
||||
# DQN sub-flow.
|
||||
dqn_train_results = {}
|
||||
dqn_train_batch = self.local_replay_buffer.replay()
|
||||
dqn_train_batch = self.local_replay_buffer.sample(num_items=64)
|
||||
if dqn_train_batch is not None:
|
||||
dqn_train_results = train_one_step(self, dqn_train_batch, ["dqn_policy"])
|
||||
self._counters["agent_steps_trained_DQN"] += dqn_train_batch.agent_steps()
|
||||
|
|
|
@ -7,11 +7,7 @@ from ray.rllib.execution.metric_ops import (
|
|||
OncePerTimestepsElapsed,
|
||||
)
|
||||
from ray.rllib.execution.multi_gpu_learner_thread import MultiGPULearnerThread
|
||||
from ray.rllib.execution.buffers.replay_buffer import (
|
||||
ReplayBuffer,
|
||||
PrioritizedReplayBuffer,
|
||||
)
|
||||
from ray.rllib.execution.buffers.multi_agent_replay_buffer import MultiAgentReplayBuffer
|
||||
from ray.rllib.execution.minibatch_buffer import MinibatchBuffer
|
||||
from ray.rllib.execution.replay_ops import (
|
||||
StoreToReplayBuffer,
|
||||
Replay,
|
||||
|
@ -50,14 +46,11 @@ __all__ = [
|
|||
"Enqueue",
|
||||
"LearnerThread",
|
||||
"MixInReplay",
|
||||
"MultiAgentReplayBuffer",
|
||||
"MultiGPULearnerThread",
|
||||
"OncePerTimeInterval",
|
||||
"OncePerTimestepsElapsed",
|
||||
"ParallelRollouts",
|
||||
"PrioritizedReplayBuffer",
|
||||
"Replay",
|
||||
"ReplayBuffer",
|
||||
"SelectExperiences",
|
||||
"SimpleReplayBuffer",
|
||||
"StandardMetricsReporting",
|
||||
|
@ -66,4 +59,5 @@ __all__ = [
|
|||
"TrainOneStep",
|
||||
"MultiGPUTrainOneStep",
|
||||
"UpdateTargetNetwork",
|
||||
"MinibatchBuffer",
|
||||
]
|
||||
|
|
|
@ -1,5 +1,9 @@
|
|||
from ray.rllib.execution.buffers.minibatch_buffer import MinibatchBuffer
|
||||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
|
||||
__all__ = [
|
||||
"MinibatchBuffer",
|
||||
]
|
||||
deprecation_warning(
|
||||
old="ray.rllib.execution.buffers",
|
||||
new="ray.rllib.utils.replay_buffers",
|
||||
help="RLlib's ReplayBuffer API has changed. Apart from the replay buffers moving, "
|
||||
"some have altered behaviour. Please refer to the docs for more information.",
|
||||
error=False,
|
||||
)
|
||||
|
|
|
@ -1,303 +0,0 @@
|
|||
import collections
|
||||
import platform
|
||||
from typing import Any, Dict, Optional, Callable
|
||||
|
||||
import numpy as np
|
||||
import ray
|
||||
from ray.rllib import SampleBatch
|
||||
from ray.rllib.execution import PrioritizedReplayBuffer, ReplayBuffer
|
||||
from ray.rllib.execution.buffers.replay_buffer import logger, _ALL_POLICIES
|
||||
from ray.rllib.policy.rnn_sequencing import timeslice_along_seq_lens_with_overlap
|
||||
from ray.rllib.policy.sample_batch import MultiAgentBatch
|
||||
from ray.rllib.utils import deprecation_warning
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
from ray.rllib.utils.typing import PolicyID, SampleBatchType, T
|
||||
from ray.util.annotations import DeveloperAPI
|
||||
from ray.util.iter import ParallelIteratorWorker
|
||||
|
||||
|
||||
class MultiAgentReplayBuffer(ParallelIteratorWorker):
|
||||
"""A replay buffer shard storing data for all policies (in multiagent setup).
|
||||
|
||||
Ray actors are single-threaded, so for scalability, multiple replay actors
|
||||
may be created to increase parallelism."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_shards: int = 1,
|
||||
learning_starts: int = 1000,
|
||||
capacity: int = 10000,
|
||||
replay_batch_size: int = 1,
|
||||
prioritized_replay_alpha: float = 0.6,
|
||||
prioritized_replay_beta: float = 0.4,
|
||||
prioritized_replay_eps: float = 1e-6,
|
||||
replay_mode: str = "independent",
|
||||
replay_sequence_length: int = 1,
|
||||
replay_burn_in: int = 0,
|
||||
replay_zero_init_states: bool = True,
|
||||
buffer_size=DEPRECATED_VALUE,
|
||||
):
|
||||
"""Initializes a MultiAgentReplayBuffer instance.
|
||||
|
||||
Args:
|
||||
num_shards: The number of buffer shards that exist in total
|
||||
(including this one).
|
||||
learning_starts: Number of timesteps after which a call to
|
||||
`replay()` will yield samples (before that, `replay()` will
|
||||
return None).
|
||||
capacity: The capacity of the buffer. Note that when
|
||||
`replay_sequence_length` > 1, this is the number of sequences
|
||||
(not single timesteps) stored.
|
||||
replay_batch_size: The batch size to be sampled (in timesteps).
|
||||
Note that if `replay_sequence_length` > 1,
|
||||
`self.replay_batch_size` will be set to the number of
|
||||
sequences sampled (B).
|
||||
prioritized_replay_alpha: Alpha parameter for a prioritized
|
||||
replay buffer. Use 0.0 for no prioritization.
|
||||
prioritized_replay_beta: Beta parameter for a prioritized
|
||||
replay buffer.
|
||||
prioritized_replay_eps: Epsilon parameter for a prioritized
|
||||
replay buffer.
|
||||
replay_mode: One of "independent" or "lockstep". Determined,
|
||||
whether in the multiagent case, sampling is done across all
|
||||
agents/policies equally.
|
||||
replay_sequence_length: The sequence length (T) of a single
|
||||
sample. If > 1, we will sample B x T from this buffer.
|
||||
replay_burn_in: The burn-in length in case
|
||||
`replay_sequence_length` > 0. This is the number of timesteps
|
||||
each sequence overlaps with the previous one to generate a
|
||||
better internal state (=state after the burn-in), instead of
|
||||
starting from 0.0 each RNN rollout.
|
||||
replay_zero_init_states: Whether the initial states in the
|
||||
buffer (if replay_sequence_length > 0) are alwayas 0.0 or
|
||||
should be updated with the previous train_batch state outputs.
|
||||
"""
|
||||
# Deprecated args.
|
||||
if buffer_size != DEPRECATED_VALUE:
|
||||
deprecation_warning(
|
||||
"ReplayBuffer(size)", "ReplayBuffer(capacity)", error=False
|
||||
)
|
||||
capacity = buffer_size
|
||||
|
||||
self.replay_starts = learning_starts // num_shards
|
||||
self.capacity = capacity // num_shards
|
||||
self.replay_batch_size = replay_batch_size
|
||||
self.prioritized_replay_beta = prioritized_replay_beta
|
||||
self.prioritized_replay_eps = prioritized_replay_eps
|
||||
self.replay_mode = replay_mode
|
||||
self.replay_sequence_length = replay_sequence_length
|
||||
self.replay_burn_in = replay_burn_in
|
||||
self.replay_zero_init_states = replay_zero_init_states
|
||||
|
||||
if replay_sequence_length > 1:
|
||||
self.replay_batch_size = int(
|
||||
max(1, replay_batch_size // replay_sequence_length)
|
||||
)
|
||||
logger.info(
|
||||
"Since replay_sequence_length={} and replay_batch_size={}, "
|
||||
"we will replay {} sequences at a time.".format(
|
||||
replay_sequence_length, replay_batch_size, self.replay_batch_size
|
||||
)
|
||||
)
|
||||
|
||||
if replay_mode not in ["lockstep", "independent"]:
|
||||
raise ValueError("Unsupported replay mode: {}".format(replay_mode))
|
||||
|
||||
def gen_replay():
|
||||
while True:
|
||||
yield self.replay()
|
||||
|
||||
ParallelIteratorWorker.__init__(self, gen_replay, False)
|
||||
|
||||
def new_buffer():
|
||||
if prioritized_replay_alpha == 0.0:
|
||||
return ReplayBuffer(self.capacity)
|
||||
else:
|
||||
return PrioritizedReplayBuffer(
|
||||
self.capacity, alpha=prioritized_replay_alpha
|
||||
)
|
||||
|
||||
self.replay_buffers = collections.defaultdict(new_buffer)
|
||||
|
||||
# Metrics.
|
||||
self.add_batch_timer = TimerStat()
|
||||
self.replay_timer = TimerStat()
|
||||
self.update_priorities_timer = TimerStat()
|
||||
self.num_added = 0
|
||||
|
||||
# Make externally accessible for testing.
|
||||
global _local_replay_buffer
|
||||
_local_replay_buffer = self
|
||||
# If set, return this instead of the usual data for testing.
|
||||
self._fake_batch = None
|
||||
|
||||
@staticmethod
|
||||
def get_instance_for_testing():
|
||||
"""Return a MultiAgentReplayBuffer instance that has been previously
|
||||
instantiated.
|
||||
|
||||
Returns:
|
||||
_local_replay_buffer: The lastly instantiated
|
||||
MultiAgentReplayBuffer.
|
||||
|
||||
"""
|
||||
global _local_replay_buffer
|
||||
return _local_replay_buffer
|
||||
|
||||
def get_host(self) -> str:
|
||||
"""Returns the computer's network name.
|
||||
|
||||
Returns:
|
||||
The computer's networks name or an empty string, if the network
|
||||
name could not be determined.
|
||||
"""
|
||||
return platform.node()
|
||||
|
||||
def add_batch(self, batch: SampleBatchType) -> None:
|
||||
"""Adds a batch to the appropriate policy's replay buffer.
|
||||
|
||||
Turns the batch into a MultiAgentBatch of the DEFAULT_POLICY_ID if
|
||||
it is not a MultiAgentBatch.
|
||||
|
||||
Args:
|
||||
batch (SampleBatchType): The batch to be added.
|
||||
"""
|
||||
# Make a copy so the replay buffer doesn't pin plasma memory.
|
||||
batch = batch.copy()
|
||||
# Handle everything as if multi-agent.
|
||||
batch = batch.as_multi_agent()
|
||||
|
||||
with self.add_batch_timer:
|
||||
# Lockstep mode: Store under _ALL_POLICIES key (we will always
|
||||
# only sample from all policies at the same time).
|
||||
if self.replay_mode == "lockstep":
|
||||
# Note that prioritization is not supported in this mode.
|
||||
for s in batch.timeslices(self.replay_sequence_length):
|
||||
self.replay_buffers[_ALL_POLICIES].add(s, weight=None)
|
||||
else:
|
||||
for policy_id, sample_batch in batch.policy_batches.items():
|
||||
if self.replay_sequence_length == 1:
|
||||
timeslices = sample_batch.timeslices(1)
|
||||
else:
|
||||
timeslices = timeslice_along_seq_lens_with_overlap(
|
||||
sample_batch=sample_batch,
|
||||
zero_pad_max_seq_len=self.replay_sequence_length,
|
||||
pre_overlap=self.replay_burn_in,
|
||||
zero_init_states=self.replay_zero_init_states,
|
||||
)
|
||||
for time_slice in timeslices:
|
||||
# If SampleBatch has prio-replay weights, average
|
||||
# over these to use as a weight for the entire
|
||||
# sequence.
|
||||
if "weights" in time_slice and len(time_slice["weights"]):
|
||||
weight = np.mean(time_slice["weights"])
|
||||
else:
|
||||
weight = None
|
||||
self.replay_buffers[policy_id].add(time_slice, weight=weight)
|
||||
self.num_added += batch.count
|
||||
|
||||
# TODO: This entire class will be removed soon. Leave this as a shim in case
|
||||
# new `training_iteration` methods call the new replay buffer API's `sample()`
|
||||
# method on this old buffer class here.
|
||||
def sample(self, num_items=None):
|
||||
return self.replay()
|
||||
|
||||
def replay(self, policy_id: Optional[PolicyID] = None) -> SampleBatchType:
|
||||
"""If this buffer was given a fake batch, return it, otherwise return
|
||||
a MultiAgentBatch with samples.
|
||||
"""
|
||||
if self._fake_batch:
|
||||
if not isinstance(self._fake_batch, MultiAgentBatch):
|
||||
self._fake_batch = SampleBatch(self._fake_batch).as_multi_agent()
|
||||
return self._fake_batch
|
||||
|
||||
if self.num_added < self.replay_starts:
|
||||
return None
|
||||
with self.replay_timer:
|
||||
# Lockstep mode: Sample from all policies at the same time an
|
||||
# equal amount of steps.
|
||||
if self.replay_mode == "lockstep":
|
||||
assert (
|
||||
policy_id is None
|
||||
), "`policy_id` specifier not allowed in `locksetp` mode!"
|
||||
return self.replay_buffers[_ALL_POLICIES].sample(
|
||||
self.replay_batch_size, beta=self.prioritized_replay_beta
|
||||
)
|
||||
elif policy_id is not None:
|
||||
return self.replay_buffers[policy_id].sample(
|
||||
self.replay_batch_size, beta=self.prioritized_replay_beta
|
||||
)
|
||||
else:
|
||||
samples = {}
|
||||
for policy_id, replay_buffer in self.replay_buffers.items():
|
||||
samples[policy_id] = replay_buffer.sample(
|
||||
self.replay_batch_size, beta=self.prioritized_replay_beta
|
||||
)
|
||||
return MultiAgentBatch(samples, self.replay_batch_size)
|
||||
|
||||
def update_priorities(self, prio_dict: Dict) -> None:
|
||||
"""Updates the priorities of underlying replay buffers.
|
||||
|
||||
Computes new priorities from td_errors and prioritized_replay_eps.
|
||||
These priorities are used to update underlying replay buffers per
|
||||
policy_id.
|
||||
|
||||
Args:
|
||||
prio_dict (Dict): A dictionary containing td_errors for
|
||||
batches saved in underlying replay buffers.
|
||||
"""
|
||||
with self.update_priorities_timer:
|
||||
for policy_id, (batch_indexes, td_errors) in prio_dict.items():
|
||||
new_priorities = np.abs(td_errors) + self.prioritized_replay_eps
|
||||
self.replay_buffers[policy_id].update_priorities(
|
||||
batch_indexes, new_priorities
|
||||
)
|
||||
|
||||
def stats(self, debug: bool = False) -> Dict:
|
||||
"""Returns the stats of this buffer and all underlying buffers.
|
||||
|
||||
Args:
|
||||
debug (bool): If True, stats of underlying replay buffers will
|
||||
be fetched with debug=True.
|
||||
|
||||
Returns:
|
||||
stat: Dictionary of buffer stats.
|
||||
"""
|
||||
stat = {
|
||||
"add_batch_time_ms": round(1000 * self.add_batch_timer.mean, 3),
|
||||
"replay_time_ms": round(1000 * self.replay_timer.mean, 3),
|
||||
"update_priorities_time_ms": round(
|
||||
1000 * self.update_priorities_timer.mean, 3
|
||||
),
|
||||
}
|
||||
for policy_id, replay_buffer in self.replay_buffers.items():
|
||||
stat.update(
|
||||
{"policy_{}".format(policy_id): replay_buffer.stats(debug=debug)}
|
||||
)
|
||||
return stat
|
||||
|
||||
def get_state(self) -> Dict[str, Any]:
|
||||
state = {"num_added": self.num_added, "replay_buffers": {}}
|
||||
for policy_id, replay_buffer in self.replay_buffers.items():
|
||||
state["replay_buffers"][policy_id] = replay_buffer.get_state()
|
||||
return state
|
||||
|
||||
def set_state(self, state: Dict[str, Any]) -> None:
|
||||
self.num_added = state["num_added"]
|
||||
buffer_states = state["replay_buffers"]
|
||||
for policy_id in buffer_states.keys():
|
||||
self.replay_buffers[policy_id].set_state(buffer_states[policy_id])
|
||||
|
||||
@DeveloperAPI
|
||||
def apply(
|
||||
self,
|
||||
func: Callable[["MultiAgentReplayBuffer", Optional[Any], Optional[Any]], T],
|
||||
*_args,
|
||||
**kwargs,
|
||||
) -> T:
|
||||
"""Calls the given function with this MultiAgentReplayBuffer instance."""
|
||||
return func(self, *_args, **kwargs)
|
||||
|
||||
|
||||
ReplayActor = ray.remote(num_cpus=0)(MultiAgentReplayBuffer)
|
|
@ -1,419 +0,0 @@
|
|||
import logging
|
||||
import numpy as np
|
||||
import random
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
# Import ray before psutil will make sure we use psutil's bundled version
|
||||
import ray # noqa F401
|
||||
import psutil # noqa E402
|
||||
|
||||
from ray.rllib.execution.segment_tree import SumSegmentTree, MinSegmentTree
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.annotations import DeveloperAPI, override
|
||||
from ray.util.debug import log_once
|
||||
from ray.rllib.utils.deprecation import (
|
||||
Deprecated,
|
||||
DEPRECATED_VALUE,
|
||||
deprecation_warning,
|
||||
)
|
||||
from ray.rllib.utils.metrics.window_stat import WindowStat
|
||||
from ray.rllib.utils.typing import SampleBatchType
|
||||
|
||||
# Constant that represents all policies in lockstep replay mode.
|
||||
_ALL_POLICIES = "__all__"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def warn_replay_capacity(*, item: SampleBatchType, num_items: int) -> None:
|
||||
"""Warn if the configured replay buffer capacity is too large."""
|
||||
if log_once("replay_capacity"):
|
||||
item_size = item.size_bytes()
|
||||
psutil_mem = psutil.virtual_memory()
|
||||
total_gb = psutil_mem.total / 1e9
|
||||
mem_size = num_items * item_size / 1e9
|
||||
msg = (
|
||||
"Estimated max memory usage for replay buffer is {} GB "
|
||||
"({} batches of size {}, {} bytes each), "
|
||||
"available system memory is {} GB".format(
|
||||
mem_size, num_items, item.count, item_size, total_gb
|
||||
)
|
||||
)
|
||||
if mem_size > total_gb:
|
||||
raise ValueError(msg)
|
||||
elif mem_size > 0.2 * total_gb:
|
||||
logger.warning(msg)
|
||||
else:
|
||||
logger.info(msg)
|
||||
|
||||
|
||||
@Deprecated(new="warn_replay_capacity", error=False)
|
||||
def warn_replay_buffer_size(*, item: SampleBatchType, num_items: int) -> None:
|
||||
return warn_replay_capacity(item=item, num_items=num_items)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class ReplayBuffer:
|
||||
@DeveloperAPI
|
||||
def __init__(self, capacity: int = 10000, size: Optional[int] = DEPRECATED_VALUE):
|
||||
"""Initializes a ReplayBuffer instance.
|
||||
|
||||
Args:
|
||||
capacity: Max number of timesteps to store in the FIFO
|
||||
buffer. After reaching this number, older samples will be
|
||||
dropped to make space for new ones.
|
||||
"""
|
||||
# Deprecated args.
|
||||
if size != DEPRECATED_VALUE:
|
||||
deprecation_warning(
|
||||
"ReplayBuffer(size)", "ReplayBuffer(capacity)", error=False
|
||||
)
|
||||
capacity = size
|
||||
|
||||
# The actual storage (list of SampleBatches).
|
||||
self._storage = []
|
||||
|
||||
self.capacity = capacity
|
||||
# The next index to override in the buffer.
|
||||
self._next_idx = 0
|
||||
self._hit_count = np.zeros(self.capacity)
|
||||
|
||||
# Whether we have already hit our capacity (and have therefore
|
||||
# started to evict older samples).
|
||||
self._eviction_started = False
|
||||
|
||||
# Number of (single) timesteps that have been added to the buffer
|
||||
# over its lifetime. Note that each added item (batch) may contain
|
||||
# more than one timestep.
|
||||
self._num_timesteps_added = 0
|
||||
self._num_timesteps_added_wrap = 0
|
||||
|
||||
# Number of (single) timesteps that have been sampled from the buffer
|
||||
# over its lifetime.
|
||||
self._num_timesteps_sampled = 0
|
||||
|
||||
self._evicted_hit_stats = WindowStat("evicted_hit", 1000)
|
||||
self._est_size_bytes = 0
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Returns the number of items currently stored in this buffer."""
|
||||
return len(self._storage)
|
||||
|
||||
@DeveloperAPI
|
||||
def add(self, item: SampleBatchType, weight: float) -> None:
|
||||
"""Add a batch of experiences.
|
||||
|
||||
Args:
|
||||
item: SampleBatch to add to this buffer's storage.
|
||||
weight: The weight of the added sample used in subsequent
|
||||
sampling steps. Only relevant if this ReplayBuffer is
|
||||
a PrioritizedReplayBuffer.
|
||||
"""
|
||||
assert item.count > 0, item
|
||||
warn_replay_capacity(item=item, num_items=self.capacity / item.count)
|
||||
|
||||
# Update our timesteps counts.
|
||||
self._num_timesteps_added += item.count
|
||||
self._num_timesteps_added_wrap += item.count
|
||||
|
||||
if self._next_idx >= len(self._storage):
|
||||
self._storage.append(item)
|
||||
self._est_size_bytes += item.size_bytes()
|
||||
else:
|
||||
self._storage[self._next_idx] = item
|
||||
|
||||
# Wrap around storage as a circular buffer once we hit capacity.
|
||||
if self._num_timesteps_added_wrap >= self.capacity:
|
||||
self._eviction_started = True
|
||||
self._num_timesteps_added_wrap = 0
|
||||
self._next_idx = 0
|
||||
else:
|
||||
self._next_idx += 1
|
||||
|
||||
# Eviction of older samples has already started (buffer is "full").
|
||||
if self._eviction_started:
|
||||
self._evicted_hit_stats.push(self._hit_count[self._next_idx])
|
||||
self._hit_count[self._next_idx] = 0
|
||||
|
||||
@DeveloperAPI
|
||||
def sample(self, num_items: int, beta: float = 0.0) -> SampleBatchType:
|
||||
"""Sample a batch of size `num_items` from this buffer.
|
||||
|
||||
If less than `num_items` records are in this buffer, some samples in
|
||||
the results may be repeated to fulfil the batch size (`num_items`)
|
||||
request.
|
||||
|
||||
Args:
|
||||
num_items: Number of items to sample from this buffer.
|
||||
beta: The prioritized replay beta value. Only relevant if this
|
||||
ReplayBuffer is a PrioritizedReplayBuffer.
|
||||
|
||||
Returns:
|
||||
Concatenated batch of items.
|
||||
"""
|
||||
# If we don't have any samples yet in this buffer, return None.
|
||||
if len(self) == 0:
|
||||
return None
|
||||
|
||||
idxes = [random.randint(0, len(self) - 1) for _ in range(num_items)]
|
||||
sample = self._encode_sample(idxes)
|
||||
# Update our timesteps counters.
|
||||
self._num_timesteps_sampled += len(sample)
|
||||
return sample
|
||||
|
||||
@DeveloperAPI
|
||||
def stats(self, debug: bool = False) -> dict:
|
||||
"""Returns the stats of this buffer.
|
||||
|
||||
Args:
|
||||
debug: If True, adds sample eviction statistics to the returned
|
||||
stats dict.
|
||||
|
||||
Returns:
|
||||
A dictionary of stats about this buffer.
|
||||
"""
|
||||
data = {
|
||||
"added_count": self._num_timesteps_added,
|
||||
"added_count_wrapped": self._num_timesteps_added_wrap,
|
||||
"eviction_started": self._eviction_started,
|
||||
"sampled_count": self._num_timesteps_sampled,
|
||||
"est_size_bytes": self._est_size_bytes,
|
||||
"num_entries": len(self._storage),
|
||||
}
|
||||
if debug:
|
||||
data.update(self._evicted_hit_stats.stats())
|
||||
return data
|
||||
|
||||
@DeveloperAPI
|
||||
def get_state(self) -> Dict[str, Any]:
|
||||
"""Returns all local state.
|
||||
|
||||
Returns:
|
||||
The serializable local state.
|
||||
"""
|
||||
state = {"_storage": self._storage, "_next_idx": self._next_idx}
|
||||
state.update(self.stats(debug=False))
|
||||
return state
|
||||
|
||||
@DeveloperAPI
|
||||
def set_state(self, state: Dict[str, Any]) -> None:
|
||||
"""Restores all local state to the provided `state`.
|
||||
|
||||
Args:
|
||||
state: The new state to set this buffer. Can be
|
||||
obtained by calling `self.get_state()`.
|
||||
"""
|
||||
# The actual storage.
|
||||
self._storage = state["_storage"]
|
||||
self._next_idx = state["_next_idx"]
|
||||
# Stats and counts.
|
||||
self._num_timesteps_added = state["added_count"]
|
||||
self._num_timesteps_added_wrap = state["added_count_wrapped"]
|
||||
self._eviction_started = state["eviction_started"]
|
||||
self._num_timesteps_sampled = state["sampled_count"]
|
||||
self._est_size_bytes = state["est_size_bytes"]
|
||||
|
||||
def _encode_sample(self, idxes: List[int]) -> SampleBatchType:
|
||||
out = SampleBatch.concat_samples([self._storage[i] for i in idxes])
|
||||
out.decompress_if_needed()
|
||||
return out
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
@DeveloperAPI
|
||||
def __init__(
|
||||
self,
|
||||
capacity: int = 10000,
|
||||
alpha: float = 1.0,
|
||||
size: Optional[int] = DEPRECATED_VALUE,
|
||||
):
|
||||
"""Initializes a PrioritizedReplayBuffer instance.
|
||||
|
||||
Args:
|
||||
capacity: Max number of timesteps to store in the FIFO
|
||||
buffer. After reaching this number, older samples will be
|
||||
dropped to make space for new ones.
|
||||
alpha: How much prioritization is used
|
||||
(0.0=no prioritization, 1.0=full prioritization).
|
||||
"""
|
||||
super(PrioritizedReplayBuffer, self).__init__(capacity, size)
|
||||
assert alpha > 0
|
||||
self._alpha = alpha
|
||||
|
||||
it_capacity = 1
|
||||
while it_capacity < self.capacity:
|
||||
it_capacity *= 2
|
||||
|
||||
self._it_sum = SumSegmentTree(it_capacity)
|
||||
self._it_min = MinSegmentTree(it_capacity)
|
||||
self._max_priority = 1.0
|
||||
self._prio_change_stats = WindowStat("reprio", 1000)
|
||||
|
||||
@DeveloperAPI
|
||||
@override(ReplayBuffer)
|
||||
def add(self, item: SampleBatchType, weight: float) -> None:
|
||||
"""Add a batch of experiences.
|
||||
|
||||
Args:
|
||||
item: SampleBatch to add to this buffer's storage.
|
||||
weight: The weight of the added sample used in subsequent sampling
|
||||
steps.
|
||||
"""
|
||||
idx = self._next_idx
|
||||
super(PrioritizedReplayBuffer, self).add(item, weight)
|
||||
if weight is None:
|
||||
weight = self._max_priority
|
||||
self._it_sum[idx] = weight ** self._alpha
|
||||
self._it_min[idx] = weight ** self._alpha
|
||||
|
||||
def _sample_proportional(self, num_items: int) -> List[int]:
|
||||
res = []
|
||||
for _ in range(num_items):
|
||||
# TODO(szymon): should we ensure no repeats?
|
||||
mass = random.random() * self._it_sum.sum(0, len(self._storage))
|
||||
idx = self._it_sum.find_prefixsum_idx(mass)
|
||||
res.append(idx)
|
||||
return res
|
||||
|
||||
@DeveloperAPI
|
||||
@override(ReplayBuffer)
|
||||
def sample(self, num_items: int, beta: float) -> SampleBatchType:
|
||||
"""Sample `num_items` items from this buffer, including prio. weights.
|
||||
|
||||
If less than `num_items` records are in this buffer, some samples in
|
||||
the results may be repeated to fulfil the batch size (`num_items`)
|
||||
request.
|
||||
|
||||
Args:
|
||||
num_items: Number of items to sample from this buffer.
|
||||
beta: To what degree to use importance weights
|
||||
(0 - no corrections, 1 - full correction).
|
||||
|
||||
Returns:
|
||||
Concatenated batch of items including "weights" and
|
||||
"batch_indexes" fields denoting IS of each sampled
|
||||
transition and original idxes in buffer of sampled experiences.
|
||||
"""
|
||||
# If we don't have any samples yet in this buffer, return None.
|
||||
if len(self) == 0:
|
||||
return None
|
||||
|
||||
assert beta >= 0.0
|
||||
|
||||
idxes = self._sample_proportional(num_items)
|
||||
|
||||
weights = []
|
||||
batch_indexes = []
|
||||
p_min = self._it_min.min() / self._it_sum.sum()
|
||||
max_weight = (p_min * len(self)) ** (-beta)
|
||||
|
||||
for idx in idxes:
|
||||
p_sample = self._it_sum[idx] / self._it_sum.sum()
|
||||
weight = (p_sample * len(self)) ** (-beta)
|
||||
count = self._storage[idx].count
|
||||
# If zero-padded, count will not be the actual batch size of the
|
||||
# data.
|
||||
if (
|
||||
isinstance(self._storage[idx], SampleBatch)
|
||||
and self._storage[idx].zero_padded
|
||||
):
|
||||
actual_size = self._storage[idx].max_seq_len
|
||||
else:
|
||||
actual_size = count
|
||||
weights.extend([weight / max_weight] * actual_size)
|
||||
batch_indexes.extend([idx] * actual_size)
|
||||
self._num_timesteps_sampled += count
|
||||
batch = self._encode_sample(idxes)
|
||||
|
||||
# Note: prioritization is not supported in lockstep replay mode.
|
||||
if isinstance(batch, SampleBatch):
|
||||
batch["weights"] = np.array(weights)
|
||||
batch["batch_indexes"] = np.array(batch_indexes)
|
||||
|
||||
return batch
|
||||
|
||||
@DeveloperAPI
|
||||
def update_priorities(self, idxes: List[int], priorities: List[float]) -> None:
|
||||
"""Update priorities of sampled transitions.
|
||||
|
||||
Sets priority of transition at index idxes[i] in buffer
|
||||
to priorities[i].
|
||||
|
||||
Args:
|
||||
idxes: List of indices of sampled transitions
|
||||
priorities: List of updated priorities corresponding to
|
||||
transitions at the sampled idxes denoted by
|
||||
variable `idxes`.
|
||||
"""
|
||||
# Making sure we don't pass in e.g. a torch tensor.
|
||||
assert isinstance(
|
||||
idxes, (list, np.ndarray)
|
||||
), "ERROR: `idxes` is not a list or np.ndarray, but {}!".format(
|
||||
type(idxes).__name__
|
||||
)
|
||||
assert len(idxes) == len(priorities)
|
||||
for idx, priority in zip(idxes, priorities):
|
||||
assert priority > 0
|
||||
assert 0 <= idx < len(self._storage)
|
||||
delta = priority ** self._alpha - self._it_sum[idx]
|
||||
self._prio_change_stats.push(delta)
|
||||
self._it_sum[idx] = priority ** self._alpha
|
||||
self._it_min[idx] = priority ** self._alpha
|
||||
|
||||
self._max_priority = max(self._max_priority, priority)
|
||||
|
||||
@DeveloperAPI
|
||||
@override(ReplayBuffer)
|
||||
def stats(self, debug: bool = False) -> Dict:
|
||||
"""Returns the stats of this buffer.
|
||||
|
||||
Args:
|
||||
debug: If true, adds sample eviction statistics to the
|
||||
returned stats dict.
|
||||
|
||||
Returns:
|
||||
A dictionary of stats about this buffer.
|
||||
"""
|
||||
parent = ReplayBuffer.stats(self, debug)
|
||||
if debug:
|
||||
parent.update(self._prio_change_stats.stats())
|
||||
return parent
|
||||
|
||||
@DeveloperAPI
|
||||
@override(ReplayBuffer)
|
||||
def get_state(self) -> Dict[str, Any]:
|
||||
"""Returns all local state.
|
||||
|
||||
Returns:
|
||||
The serializable local state.
|
||||
"""
|
||||
# Get parent state.
|
||||
state = super().get_state()
|
||||
# Add prio weights.
|
||||
state.update(
|
||||
{
|
||||
"sum_segment_tree": self._it_sum.get_state(),
|
||||
"min_segment_tree": self._it_min.get_state(),
|
||||
"max_priority": self._max_priority,
|
||||
}
|
||||
)
|
||||
return state
|
||||
|
||||
@DeveloperAPI
|
||||
@override(ReplayBuffer)
|
||||
def set_state(self, state: Dict[str, Any]) -> None:
|
||||
"""Restores all local state to the provided `state`.
|
||||
|
||||
Args:
|
||||
state: The new state to set this buffer. Can be obtained by
|
||||
calling `self.get_state()`.
|
||||
"""
|
||||
super().set_state(state)
|
||||
self._it_sum.set_state(state["sum_segment_tree"])
|
||||
self._it_min.set_state(state["min_segment_tree"])
|
||||
self._max_priority = state["max_priority"]
|
||||
|
||||
|
||||
# Visible for testing.
|
||||
_local_replay_buffer = None
|
|
@ -4,7 +4,7 @@ import threading
|
|||
from typing import Dict, Optional
|
||||
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.execution.buffers.minibatch_buffer import MinibatchBuffer
|
||||
from ray.rllib.execution.minibatch_buffer import MinibatchBuffer
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder, LEARNER_INFO
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
|
|
|
@ -3,7 +3,7 @@ from six.moves import queue
|
|||
import threading
|
||||
|
||||
from ray.rllib.execution.learner_thread import LearnerThread
|
||||
from ray.rllib.execution.buffers.minibatch_buffer import MinibatchBuffer
|
||||
from ray.rllib.execution.minibatch_buffer import MinibatchBuffer
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
|
|
|
@ -4,8 +4,10 @@ import random
|
|||
from ray.actor import ActorHandle
|
||||
from ray.util.iter import from_actors, LocalIterator, _NextValueNotReady
|
||||
from ray.util.iter_metrics import SharedMetrics
|
||||
from ray.rllib.execution.buffers.replay_buffer import warn_replay_capacity
|
||||
from ray.rllib.execution.buffers.multi_agent_replay_buffer import MultiAgentReplayBuffer
|
||||
from ray.rllib.utils.replay_buffers.replay_buffer import warn_replay_capacity
|
||||
from ray.rllib.utils.replay_buffers.multi_agent_replay_buffer import (
|
||||
MultiAgentReplayBuffer,
|
||||
)
|
||||
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, _get_shared_metrics
|
||||
from ray.rllib.utils.typing import SampleBatchType
|
||||
|
||||
|
@ -21,7 +23,7 @@ class StoreToReplayBuffer:
|
|||
The batch that was stored is returned.
|
||||
|
||||
Examples:
|
||||
>>> from ray.rllib.execution.buffers import multi_agent_replay_buffer
|
||||
>>> from ray.rllib.utils.replay_buffers import multi_agent_replay_buffer
|
||||
>>> from ray.rllib.execution.replay_ops import StoreToReplayBuffer
|
||||
>>> from ray.rllib.execution import ParallelRollouts
|
||||
>>> actors = [ # doctest: +SKIP
|
||||
|
@ -59,10 +61,10 @@ class StoreToReplayBuffer:
|
|||
|
||||
def __call__(self, batch: SampleBatchType):
|
||||
if self.local_actor is not None:
|
||||
self.local_actor.add_batch(batch)
|
||||
self.local_actor.add(batch)
|
||||
else:
|
||||
actor = random.choice(self.replay_actors)
|
||||
actor.add_batch.remote(batch)
|
||||
actor.add.remote(batch)
|
||||
return batch
|
||||
|
||||
|
||||
|
@ -86,7 +88,7 @@ def Replay(
|
|||
per actor.
|
||||
|
||||
Examples:
|
||||
>>> from ray.rllib.execution.buffers import multi_agent_replay_buffer
|
||||
>>> from ray.rllib.utils.replay_buffers import multi_agent_replay_buffer
|
||||
>>> actors = [ # doctest: +SKIP
|
||||
... multi_agent_replay_buffer.ReplayActor.remote() for _ in range(4)]
|
||||
>>> replay_op = Replay(actors=actors) # doctest: +SKIP
|
||||
|
|
|
@ -1,147 +0,0 @@
|
|||
import numpy as np
|
||||
import unittest
|
||||
|
||||
from ray.rllib.execution.buffers.mixin_replay_buffer import MixInMultiAgentReplayBuffer
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
|
||||
|
||||
class TestMixInMultiAgentReplayBuffer(unittest.TestCase):
|
||||
"""Tests insertion and mixed sampling of the MixInMultiAgentReplayBuffer."""
|
||||
|
||||
capacity = 10
|
||||
|
||||
def _generate_data(self):
|
||||
return SampleBatch(
|
||||
{
|
||||
"obs": [np.random.random((4,))],
|
||||
"action": [np.random.choice([0, 1])],
|
||||
"reward": [np.random.rand()],
|
||||
"new_obs": [np.random.random((4,))],
|
||||
"done": [np.random.choice([False, True])],
|
||||
}
|
||||
)
|
||||
|
||||
def test_mixin_sampling(self):
|
||||
# 50% replay ratio.
|
||||
buffer = MixInMultiAgentReplayBuffer(capacity=self.capacity, replay_ratio=0.5)
|
||||
# Add a new batch.
|
||||
batch = self._generate_data()
|
||||
buffer.add_batch(batch)
|
||||
# Expect at least 1 sample to be returned.
|
||||
sample = buffer.replay()
|
||||
self.assertTrue(len(sample) >= 1)
|
||||
# If we insert and replay n times, expect roughly return batches of
|
||||
# len 2 (replay_ratio=0.5 -> 50% replayed samples -> 1 new and 1 old sample
|
||||
# on average in each returned value).
|
||||
results = []
|
||||
for _ in range(100):
|
||||
buffer.add_batch(batch)
|
||||
sample = buffer.replay()
|
||||
results.append(len(sample))
|
||||
self.assertAlmostEqual(np.mean(results), 2.0)
|
||||
|
||||
# 33% replay ratio.
|
||||
buffer = MixInMultiAgentReplayBuffer(capacity=self.capacity, replay_ratio=0.333)
|
||||
# Expect exactly 0 samples to be returned (buffer empty).
|
||||
sample = buffer.replay()
|
||||
self.assertTrue(sample is None)
|
||||
# Add a new batch.
|
||||
batch = self._generate_data()
|
||||
buffer.add_batch(batch)
|
||||
# Expect at least 1 sample to be returned.
|
||||
sample = buffer.replay()
|
||||
self.assertTrue(len(sample) >= 1)
|
||||
# If we insert-2x and replay n times, expect roughly return batches of
|
||||
# len 3 (replay_ratio=0.33 -> 33% replayed samples -> 2 new and 1 old sample
|
||||
# on average in each returned value).
|
||||
results = []
|
||||
for _ in range(100):
|
||||
buffer.add_batch(batch)
|
||||
buffer.add_batch(batch)
|
||||
sample = buffer.replay()
|
||||
results.append(len(sample))
|
||||
self.assertAlmostEqual(np.mean(results), 3.0, delta=0.1)
|
||||
|
||||
# If we insert-1x and replay n times, expect roughly return batches of
|
||||
# len 1.5 (replay_ratio=0.33 -> 33% replayed samples -> 1 new and 0.5 old
|
||||
# samples on average in each returned value).
|
||||
results = []
|
||||
for _ in range(100):
|
||||
buffer.add_batch(batch)
|
||||
sample = buffer.replay()
|
||||
results.append(len(sample))
|
||||
self.assertAlmostEqual(np.mean(results), 1.5, delta=0.1)
|
||||
|
||||
# 90% replay ratio.
|
||||
buffer = MixInMultiAgentReplayBuffer(capacity=self.capacity, replay_ratio=0.9)
|
||||
# Expect exactly 0 samples to be returned (buffer empty).
|
||||
sample = buffer.replay()
|
||||
self.assertTrue(sample is None)
|
||||
# Add a new batch.
|
||||
batch = self._generate_data()
|
||||
buffer.add_batch(batch)
|
||||
# Expect at least 2 samples to be returned (new one plus at least one
|
||||
# replay sample).
|
||||
sample = buffer.replay()
|
||||
self.assertTrue(len(sample) >= 2)
|
||||
# If we insert and replay n times, expect roughly return batches of
|
||||
# len 10 (replay_ratio=0.9 -> 90% replayed samples -> 1 new and 9 old
|
||||
# samples on average in each returned value).
|
||||
results = []
|
||||
for _ in range(100):
|
||||
buffer.add_batch(batch)
|
||||
sample = buffer.replay()
|
||||
results.append(len(sample))
|
||||
self.assertAlmostEqual(np.mean(results), 10.0, delta=0.1)
|
||||
|
||||
# 0% replay ratio -> Only new samples.
|
||||
buffer = MixInMultiAgentReplayBuffer(capacity=self.capacity, replay_ratio=0.0)
|
||||
# Add a new batch.
|
||||
batch = self._generate_data()
|
||||
buffer.add_batch(batch)
|
||||
# Expect exactly 1 sample to be returned.
|
||||
sample = buffer.replay()
|
||||
self.assertTrue(len(sample) == 1)
|
||||
# Expect exactly 0 sample to be returned (nothing new to be returned;
|
||||
# no replay allowed (replay_ratio=0.0)).
|
||||
sample = buffer.replay()
|
||||
self.assertTrue(sample is None)
|
||||
# If we insert and replay n times, expect roughly return batches of
|
||||
# len 1 (replay_ratio=0.0 -> 0% replayed samples -> 1 new and 0 old samples
|
||||
# on average in each returned value).
|
||||
results = []
|
||||
for _ in range(100):
|
||||
buffer.add_batch(batch)
|
||||
sample = buffer.replay()
|
||||
results.append(len(sample))
|
||||
self.assertAlmostEqual(np.mean(results), 1.0)
|
||||
|
||||
# 100% replay ratio -> Only new samples.
|
||||
buffer = MixInMultiAgentReplayBuffer(capacity=self.capacity, replay_ratio=1.0)
|
||||
# Expect exactly 0 samples to be returned (buffer empty).
|
||||
sample = buffer.replay()
|
||||
self.assertTrue(sample is None)
|
||||
# Add a new batch.
|
||||
batch = self._generate_data()
|
||||
buffer.add_batch(batch)
|
||||
# Expect exactly 1 sample to be returned (the new batch).
|
||||
sample = buffer.replay()
|
||||
self.assertTrue(len(sample) == 1)
|
||||
# Another replay -> Expect exactly 1 sample to be returned.
|
||||
sample = buffer.replay()
|
||||
self.assertTrue(len(sample) == 1)
|
||||
# If we replay n times, expect roughly return batches of
|
||||
# len 1 (replay_ratio=1.0 -> 100% replayed samples -> 0 new and 1 old samples
|
||||
# on average in each returned value).
|
||||
results = []
|
||||
for _ in range(100):
|
||||
sample = buffer.replay()
|
||||
results.append(len(sample))
|
||||
self.assertAlmostEqual(np.mean(results), 1.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
|
@ -1,308 +0,0 @@
|
|||
from collections import Counter
|
||||
import numpy as np
|
||||
import unittest
|
||||
|
||||
from ray.rllib.execution.buffers.replay_buffer import PrioritizedReplayBuffer
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.test_utils import check
|
||||
|
||||
|
||||
class TestPrioritizedReplayBuffer(unittest.TestCase):
|
||||
"""
|
||||
Tests insertion and (weighted) sampling of the PrioritizedReplayBuffer.
|
||||
"""
|
||||
|
||||
capacity = 10
|
||||
alpha = 1.0
|
||||
beta = 1.0
|
||||
max_priority = 1.0
|
||||
|
||||
def _generate_data(self):
|
||||
return SampleBatch(
|
||||
{
|
||||
"obs_t": [np.random.random((4,))],
|
||||
"action": [np.random.choice([0, 1])],
|
||||
"reward": [np.random.rand()],
|
||||
"obs_tp1": [np.random.random((4,))],
|
||||
"done": [np.random.choice([False, True])],
|
||||
}
|
||||
)
|
||||
|
||||
def test_sequence_size(self):
|
||||
# Seq-len=1.
|
||||
memory = PrioritizedReplayBuffer(capacity=100, alpha=0.1)
|
||||
for _ in range(200):
|
||||
memory.add(self._generate_data(), weight=None)
|
||||
assert len(memory._storage) == 100, len(memory._storage)
|
||||
assert memory.stats()["added_count"] == 200, memory.stats()
|
||||
# Test get_state/set_state.
|
||||
state = memory.get_state()
|
||||
new_memory = PrioritizedReplayBuffer(capacity=100, alpha=0.1)
|
||||
new_memory.set_state(state)
|
||||
assert len(new_memory._storage) == 100, len(new_memory._storage)
|
||||
assert new_memory.stats()["added_count"] == 200, new_memory.stats()
|
||||
|
||||
# Seq-len=5.
|
||||
memory = PrioritizedReplayBuffer(capacity=100, alpha=0.1)
|
||||
for _ in range(40):
|
||||
memory.add(
|
||||
SampleBatch.concat_samples([self._generate_data() for _ in range(5)]),
|
||||
weight=None,
|
||||
)
|
||||
assert len(memory._storage) == 20, len(memory._storage)
|
||||
assert memory.stats()["added_count"] == 200, memory.stats()
|
||||
# Test get_state/set_state.
|
||||
state = memory.get_state()
|
||||
new_memory = PrioritizedReplayBuffer(capacity=100, alpha=0.1)
|
||||
new_memory.set_state(state)
|
||||
assert len(new_memory._storage) == 20, len(new_memory._storage)
|
||||
assert new_memory.stats()["added_count"] == 200, new_memory.stats()
|
||||
|
||||
def test_add(self):
|
||||
memory = PrioritizedReplayBuffer(capacity=2, alpha=self.alpha)
|
||||
|
||||
# Assert indices 0 before insert.
|
||||
self.assertEqual(len(memory), 0)
|
||||
self.assertEqual(memory._next_idx, 0)
|
||||
|
||||
# Insert single record.
|
||||
data = self._generate_data()
|
||||
memory.add(data, weight=0.5)
|
||||
self.assertTrue(len(memory) == 1)
|
||||
self.assertTrue(memory._next_idx == 1)
|
||||
|
||||
# Insert single record.
|
||||
data = self._generate_data()
|
||||
memory.add(data, weight=0.1)
|
||||
self.assertTrue(len(memory) == 2)
|
||||
self.assertTrue(memory._next_idx == 0)
|
||||
|
||||
# Insert over capacity.
|
||||
data = self._generate_data()
|
||||
memory.add(data, weight=1.0)
|
||||
self.assertTrue(len(memory) == 2)
|
||||
self.assertTrue(memory._next_idx == 1)
|
||||
|
||||
# Test get_state/set_state.
|
||||
state = memory.get_state()
|
||||
new_memory = PrioritizedReplayBuffer(capacity=2, alpha=self.alpha)
|
||||
new_memory.set_state(state)
|
||||
self.assertTrue(len(new_memory) == 2)
|
||||
self.assertTrue(new_memory._next_idx == 1)
|
||||
|
||||
def test_update_priorities(self):
|
||||
memory = PrioritizedReplayBuffer(self.capacity, alpha=self.alpha)
|
||||
|
||||
# Insert n samples.
|
||||
num_records = 5
|
||||
for i in range(num_records):
|
||||
data = self._generate_data()
|
||||
memory.add(data, weight=1.0)
|
||||
self.assertTrue(len(memory) == i + 1)
|
||||
self.assertTrue(memory._next_idx == i + 1)
|
||||
|
||||
# Test get_state/set_state.
|
||||
state = memory.get_state()
|
||||
new_memory = PrioritizedReplayBuffer(self.capacity, alpha=self.alpha)
|
||||
new_memory.set_state(state)
|
||||
self.assertTrue(len(new_memory) == num_records)
|
||||
self.assertTrue(new_memory._next_idx == num_records)
|
||||
|
||||
# Fetch records, their indices and weights.
|
||||
batch = memory.sample(3, beta=self.beta)
|
||||
weights = batch["weights"]
|
||||
indices = batch["batch_indexes"]
|
||||
check(weights, np.ones(shape=(3,)))
|
||||
self.assertEqual(3, len(indices))
|
||||
self.assertTrue(len(memory) == num_records)
|
||||
self.assertTrue(memory._next_idx == num_records)
|
||||
|
||||
# Update weight of indices 0, 2, 3, 4 to very small.
|
||||
memory.update_priorities(
|
||||
np.array([0, 2, 3, 4]), np.array([0.01, 0.01, 0.01, 0.01])
|
||||
)
|
||||
# Expect to sample almost only index 1
|
||||
# (which still has a weight of 1.0).
|
||||
for _ in range(10):
|
||||
batch = memory.sample(1000, beta=self.beta)
|
||||
indices = batch["batch_indexes"]
|
||||
self.assertTrue(970 < np.sum(indices) < 1100)
|
||||
# Test get_state/set_state.
|
||||
state = memory.get_state()
|
||||
new_memory = PrioritizedReplayBuffer(self.capacity, alpha=self.alpha)
|
||||
new_memory.set_state(state)
|
||||
batch = new_memory.sample(1000, beta=self.beta)
|
||||
indices = batch["batch_indexes"]
|
||||
self.assertTrue(970 < np.sum(indices) < 1100)
|
||||
|
||||
# Update weight of indices 0 and 1 to >> 0.01.
|
||||
# Expect to sample 0 and 1 equally (and some 2s, 3s, and 4s).
|
||||
for _ in range(10):
|
||||
rand = np.random.random() + 0.2
|
||||
memory.update_priorities(np.array([0, 1]), np.array([rand, rand]))
|
||||
batch = memory.sample(1000, beta=self.beta)
|
||||
indices = batch["batch_indexes"]
|
||||
# Expect biased to higher values due to some 2s, 3s, and 4s.
|
||||
self.assertTrue(400 < np.sum(indices) < 800)
|
||||
# Test get_state/set_state.
|
||||
state = memory.get_state()
|
||||
new_memory = PrioritizedReplayBuffer(self.capacity, alpha=self.alpha)
|
||||
new_memory.set_state(state)
|
||||
batch = new_memory.sample(1000, beta=self.beta)
|
||||
indices = batch["batch_indexes"]
|
||||
self.assertTrue(400 < np.sum(indices) < 800)
|
||||
|
||||
# Update weights to be 1:2.
|
||||
# Expect to sample double as often index 1 over index 0
|
||||
# plus very few times indices 2, 3, or 4.
|
||||
for _ in range(10):
|
||||
rand = np.random.random() + 0.2
|
||||
memory.update_priorities(np.array([0, 1]), np.array([rand, rand * 2]))
|
||||
batch = memory.sample(1000, beta=self.beta)
|
||||
indices = batch["batch_indexes"]
|
||||
# print(np.sum(indices))
|
||||
self.assertTrue(600 < np.sum(indices) < 850)
|
||||
# Test get_state/set_state.
|
||||
state = memory.get_state()
|
||||
new_memory = PrioritizedReplayBuffer(self.capacity, alpha=self.alpha)
|
||||
new_memory.set_state(state)
|
||||
batch = new_memory.sample(1000, beta=self.beta)
|
||||
indices = batch["batch_indexes"]
|
||||
self.assertTrue(600 < np.sum(indices) < 850)
|
||||
|
||||
# Update weights to be 1:4.
|
||||
# Expect to sample quadruple as often index 1 over index 0
|
||||
# plus very few times indices 2, 3, or 4.
|
||||
for _ in range(10):
|
||||
rand = np.random.random() + 0.2
|
||||
memory.update_priorities(np.array([0, 1]), np.array([rand, rand * 4]))
|
||||
batch = memory.sample(1000, beta=self.beta)
|
||||
indices = batch["batch_indexes"]
|
||||
self.assertTrue(750 < np.sum(indices) < 950)
|
||||
# Test get_state/set_state.
|
||||
state = memory.get_state()
|
||||
new_memory = PrioritizedReplayBuffer(self.capacity, alpha=self.alpha)
|
||||
new_memory.set_state(state)
|
||||
batch = new_memory.sample(1000, beta=self.beta)
|
||||
indices = batch["batch_indexes"]
|
||||
self.assertTrue(750 < np.sum(indices) < 950)
|
||||
|
||||
# Update weights to be 1:9.
|
||||
# Expect to sample 9 times as often index 1 over index 0.
|
||||
# plus very few times indices 2, 3, or 4.
|
||||
for _ in range(10):
|
||||
rand = np.random.random() + 0.2
|
||||
memory.update_priorities(np.array([0, 1]), np.array([rand, rand * 9]))
|
||||
batch = memory.sample(1000, beta=self.beta)
|
||||
indices = batch["batch_indexes"]
|
||||
self.assertTrue(850 < np.sum(indices) < 1100)
|
||||
# Test get_state/set_state.
|
||||
state = memory.get_state()
|
||||
new_memory = PrioritizedReplayBuffer(self.capacity, alpha=self.alpha)
|
||||
new_memory.set_state(state)
|
||||
batch = new_memory.sample(1000, beta=self.beta)
|
||||
indices = batch["batch_indexes"]
|
||||
self.assertTrue(850 < np.sum(indices) < 1100)
|
||||
|
||||
# Insert n more samples.
|
||||
num_records = 5
|
||||
for i in range(num_records):
|
||||
data = self._generate_data()
|
||||
memory.add(data, weight=1.0)
|
||||
self.assertTrue(len(memory) == i + 6)
|
||||
self.assertTrue(memory._next_idx == (i + 6) % self.capacity)
|
||||
|
||||
# Update all weights to be 1.0 to 10.0 and sample a >100 batch.
|
||||
memory.update_priorities(
|
||||
np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
|
||||
np.array([0.001, 0.1, 2.0, 8.0, 16.0, 32.0, 64.0, 128.0, 256.0, 512.0]),
|
||||
)
|
||||
counts = Counter()
|
||||
for _ in range(10):
|
||||
batch = memory.sample(np.random.randint(100, 600), beta=self.beta)
|
||||
indices = batch["batch_indexes"]
|
||||
for i in indices:
|
||||
counts[i] += 1
|
||||
print(counts)
|
||||
# Expect an approximately correct distribution of indices.
|
||||
self.assertTrue(
|
||||
counts[9]
|
||||
>= counts[8]
|
||||
>= counts[7]
|
||||
>= counts[6]
|
||||
>= counts[5]
|
||||
>= counts[4]
|
||||
>= counts[3]
|
||||
>= counts[2]
|
||||
>= counts[1]
|
||||
>= counts[0]
|
||||
)
|
||||
# Test get_state/set_state.
|
||||
state = memory.get_state()
|
||||
new_memory = PrioritizedReplayBuffer(self.capacity, alpha=self.alpha)
|
||||
new_memory.set_state(state)
|
||||
counts = Counter()
|
||||
for _ in range(10):
|
||||
batch = new_memory.sample(np.random.randint(100, 600), beta=self.beta)
|
||||
indices = batch["batch_indexes"]
|
||||
for i in indices:
|
||||
counts[i] += 1
|
||||
print(counts)
|
||||
self.assertTrue(
|
||||
counts[9]
|
||||
>= counts[8]
|
||||
>= counts[7]
|
||||
>= counts[6]
|
||||
>= counts[5]
|
||||
>= counts[4]
|
||||
>= counts[3]
|
||||
>= counts[2]
|
||||
>= counts[1]
|
||||
>= counts[0]
|
||||
)
|
||||
|
||||
def test_alpha_parameter(self):
|
||||
# Test sampling from a PR with a very small alpha (should behave just
|
||||
# like a regular ReplayBuffer).
|
||||
memory = PrioritizedReplayBuffer(self.capacity, alpha=0.01)
|
||||
|
||||
# Insert n samples.
|
||||
num_records = 5
|
||||
for i in range(num_records):
|
||||
data = self._generate_data()
|
||||
memory.add(data, weight=np.random.rand())
|
||||
self.assertTrue(len(memory) == i + 1)
|
||||
self.assertTrue(memory._next_idx == i + 1)
|
||||
# Test get_state/set_state.
|
||||
state = memory.get_state()
|
||||
new_memory = PrioritizedReplayBuffer(self.capacity, alpha=0.01)
|
||||
new_memory.set_state(state)
|
||||
self.assertTrue(len(new_memory) == num_records)
|
||||
self.assertTrue(new_memory._next_idx == num_records)
|
||||
|
||||
# Fetch records, their indices and weights.
|
||||
batch = memory.sample(1000, beta=self.beta)
|
||||
indices = batch["batch_indexes"]
|
||||
counts = Counter()
|
||||
for i in indices:
|
||||
counts[i] += 1
|
||||
print(counts)
|
||||
# Expect an approximately uniform distribution of indices.
|
||||
self.assertTrue(any(100 < i < 300 for i in counts.values()))
|
||||
# Test get_state/set_state.
|
||||
state = memory.get_state()
|
||||
new_memory = PrioritizedReplayBuffer(self.capacity, alpha=0.01)
|
||||
new_memory.set_state(state)
|
||||
batch = new_memory.sample(1000, beta=self.beta)
|
||||
indices = batch["batch_indexes"]
|
||||
counts = Counter()
|
||||
for i in indices:
|
||||
counts[i] += 1
|
||||
self.assertTrue(any(100 < i < 300 for i in counts.values()))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
|
@ -1,102 +0,0 @@
|
|||
import numpy as np
|
||||
import unittest
|
||||
|
||||
from ray.rllib.execution.segment_tree import SumSegmentTree, MinSegmentTree
|
||||
|
||||
|
||||
class TestSegmentTree(unittest.TestCase):
|
||||
def test_tree_set(self):
|
||||
tree = SumSegmentTree(4)
|
||||
|
||||
tree[2] = 1.0
|
||||
tree[3] = 3.0
|
||||
|
||||
assert np.isclose(tree.sum(), 4.0)
|
||||
assert np.isclose(tree.sum(0, 2), 0.0)
|
||||
assert np.isclose(tree.sum(0, 3), 1.0)
|
||||
assert np.isclose(tree.sum(2, 3), 1.0)
|
||||
assert np.isclose(tree.sum(2, -1), 1.0)
|
||||
assert np.isclose(tree.sum(2, 4), 4.0)
|
||||
assert np.isclose(tree.sum(2), 4.0)
|
||||
|
||||
def test_tree_set_overlap(self):
|
||||
tree = SumSegmentTree(4)
|
||||
|
||||
tree[2] = 1.0
|
||||
tree[2] = 3.0
|
||||
|
||||
assert np.isclose(tree.sum(), 3.0)
|
||||
assert np.isclose(tree.sum(2, 3), 3.0)
|
||||
assert np.isclose(tree.sum(2, -1), 3.0)
|
||||
assert np.isclose(tree.sum(2, 4), 3.0)
|
||||
assert np.isclose(tree.sum(2), 3.0)
|
||||
assert np.isclose(tree.sum(1, 2), 0.0)
|
||||
|
||||
def test_prefixsum_idx(self):
|
||||
tree = SumSegmentTree(4)
|
||||
|
||||
tree[2] = 1.0
|
||||
tree[3] = 3.0
|
||||
|
||||
assert tree.find_prefixsum_idx(0.0) == 2
|
||||
assert tree.find_prefixsum_idx(0.5) == 2
|
||||
assert tree.find_prefixsum_idx(0.99) == 2
|
||||
assert tree.find_prefixsum_idx(1.01) == 3
|
||||
assert tree.find_prefixsum_idx(3.00) == 3
|
||||
assert tree.find_prefixsum_idx(4.00) == 3
|
||||
|
||||
def test_prefixsum_idx2(self):
|
||||
tree = SumSegmentTree(4)
|
||||
|
||||
tree[0] = 0.5
|
||||
tree[1] = 1.0
|
||||
tree[2] = 1.0
|
||||
tree[3] = 3.0
|
||||
|
||||
assert tree.find_prefixsum_idx(0.00) == 0
|
||||
assert tree.find_prefixsum_idx(0.55) == 1
|
||||
assert tree.find_prefixsum_idx(0.99) == 1
|
||||
assert tree.find_prefixsum_idx(1.51) == 2
|
||||
assert tree.find_prefixsum_idx(3.00) == 3
|
||||
assert tree.find_prefixsum_idx(5.50) == 3
|
||||
|
||||
def test_max_interval_tree(self):
|
||||
tree = MinSegmentTree(4)
|
||||
|
||||
tree[0] = 1.0
|
||||
tree[2] = 0.5
|
||||
tree[3] = 3.0
|
||||
|
||||
assert np.isclose(tree.min(), 0.5)
|
||||
assert np.isclose(tree.min(0, 2), 1.0)
|
||||
assert np.isclose(tree.min(0, 3), 0.5)
|
||||
assert np.isclose(tree.min(0, -1), 0.5)
|
||||
assert np.isclose(tree.min(2, 4), 0.5)
|
||||
assert np.isclose(tree.min(3, 4), 3.0)
|
||||
|
||||
tree[2] = 0.7
|
||||
|
||||
assert np.isclose(tree.min(), 0.7)
|
||||
assert np.isclose(tree.min(0, 2), 1.0)
|
||||
assert np.isclose(tree.min(0, 3), 0.7)
|
||||
assert np.isclose(tree.min(0, -1), 0.7)
|
||||
assert np.isclose(tree.min(2, 4), 0.7)
|
||||
assert np.isclose(tree.min(3, 4), 3.0)
|
||||
|
||||
tree[2] = 4.0
|
||||
|
||||
assert np.isclose(tree.min(), 1.0)
|
||||
assert np.isclose(tree.min(0, 2), 1.0)
|
||||
assert np.isclose(tree.min(0, 3), 1.0)
|
||||
assert np.isclose(tree.min(0, -1), 1.0)
|
||||
assert np.isclose(tree.min(2, 4), 3.0)
|
||||
assert np.isclose(tree.min(2, 3), 4.0)
|
||||
assert np.isclose(tree.min(2, -1), 4.0)
|
||||
assert np.isclose(tree.min(3, 4), 3.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
|
@ -42,10 +42,15 @@ class TestEagerSupportPG(unittest.TestCase):
|
|||
ray.shutdown()
|
||||
|
||||
def test_simple_q(self):
|
||||
check_support("SimpleQ", {"num_workers": 0, "learning_starts": 0})
|
||||
check_support(
|
||||
"SimpleQ",
|
||||
{"num_workers": 0, "replay_buffer_config": {"learning_starts": 0}},
|
||||
)
|
||||
|
||||
def test_dqn(self):
|
||||
check_support("DQN", {"num_workers": 0, "learning_starts": 0})
|
||||
check_support(
|
||||
"DQN", {"num_workers": 0, "replay_buffer_config": {"learning_starts": 0}}
|
||||
)
|
||||
|
||||
def test_ddpg(self):
|
||||
check_support("DDPG", {"num_workers": 0})
|
||||
|
@ -84,10 +89,15 @@ class TestEagerSupportOffPolicy(unittest.TestCase):
|
|||
ray.shutdown()
|
||||
|
||||
def test_simple_q(self):
|
||||
check_support("SimpleQ", {"num_workers": 0, "learning_starts": 0})
|
||||
check_support(
|
||||
"SimpleQ",
|
||||
{"num_workers": 0, "replay_buffer_config": {"learning_starts": 0}},
|
||||
)
|
||||
|
||||
def test_dqn(self):
|
||||
check_support("DQN", {"num_workers": 0, "learning_starts": 0})
|
||||
check_support(
|
||||
"DQN", {"num_workers": 0, "replay_buffer_config": {"learning_starts": 0}}
|
||||
)
|
||||
|
||||
def test_ddpg(self):
|
||||
check_support("DDPG", {"num_workers": 0})
|
||||
|
@ -103,7 +113,7 @@ class TestEagerSupportOffPolicy(unittest.TestCase):
|
|||
"APEX",
|
||||
{
|
||||
"num_workers": 2,
|
||||
"learning_starts": 0,
|
||||
"replay_buffer_config": {"learning_starts": 0},
|
||||
"num_gpus": 0,
|
||||
"min_time_s_per_reporting": 1,
|
||||
"min_sample_timesteps_per_reporting": 100,
|
||||
|
@ -114,7 +124,9 @@ class TestEagerSupportOffPolicy(unittest.TestCase):
|
|||
)
|
||||
|
||||
def test_sac(self):
|
||||
check_support("SAC", {"num_workers": 0, "learning_starts": 0})
|
||||
check_support(
|
||||
"SAC", {"num_workers": 0, "replay_buffer_config": {"learning_starts": 0}}
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -23,9 +23,8 @@ from ray.rllib.execution.train_ops import (
|
|||
ComputeGradients,
|
||||
AverageGradients,
|
||||
)
|
||||
from ray.rllib.execution.buffers.multi_agent_replay_buffer import (
|
||||
from ray.rllib.utils.replay_buffers.multi_agent_replay_buffer import (
|
||||
MultiAgentReplayBuffer,
|
||||
ReplayActor,
|
||||
)
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
|
||||
from ray.util.iter import LocalIterator, from_range
|
||||
|
@ -216,40 +215,41 @@ class TestExecution(unittest.TestCase):
|
|||
prioritized_replay_beta=0.4,
|
||||
prioritized_replay_eps=0.0001,
|
||||
)
|
||||
assert buf.replay() is None
|
||||
assert len(buf.sample(100)) == 0
|
||||
|
||||
workers = make_workers(0)
|
||||
a = ParallelRollouts(workers, mode="bulk_sync")
|
||||
b = a.for_each(StoreToReplayBuffer(local_buffer=buf))
|
||||
|
||||
next(b)
|
||||
assert buf.replay() is None # learning hasn't started yet
|
||||
assert len(buf.sample(100)) == 0 # learning hasn't started yet
|
||||
next(b)
|
||||
assert buf.replay().count == 100
|
||||
assert buf.sample(100).count == 100
|
||||
|
||||
replay_op = Replay(local_buffer=buf)
|
||||
assert next(replay_op).count == 100
|
||||
|
||||
def test_store_to_replay_actor(self):
|
||||
ReplayActor = ray.remote(num_cpus=0)(MultiAgentReplayBuffer)
|
||||
actor = ReplayActor.remote(
|
||||
num_shards=1,
|
||||
learning_starts=200,
|
||||
buffer_size=1000,
|
||||
capacity=1000,
|
||||
replay_batch_size=100,
|
||||
prioritized_replay_alpha=0.6,
|
||||
prioritized_replay_beta=0.4,
|
||||
prioritized_replay_eps=0.0001,
|
||||
)
|
||||
assert ray.get(actor.replay.remote()) is None
|
||||
assert len(ray.get(actor.sample.remote(100))) == 0
|
||||
|
||||
workers = make_workers(0)
|
||||
a = ParallelRollouts(workers, mode="bulk_sync")
|
||||
b = a.for_each(StoreToReplayBuffer(actors=[actor]))
|
||||
|
||||
next(b)
|
||||
assert ray.get(actor.replay.remote()) is None # learning hasn't started
|
||||
assert len(ray.get(actor.sample.remote(100))) == 0 # learning hasn't started
|
||||
next(b)
|
||||
assert ray.get(actor.replay.remote()).count == 100
|
||||
assert ray.get(actor.sample.remote(100)).count == 100
|
||||
|
||||
replay_op = Replay(actors=[actor])
|
||||
assert next(replay_op).count == 100
|
||||
|
|
|
@ -95,9 +95,11 @@ class TestSupportedMultiAgentOffPolicy(unittest.TestCase):
|
|||
"num_workers": 2,
|
||||
"min_sample_timesteps_per_reporting": 100,
|
||||
"num_gpus": 0,
|
||||
"replay_buffer_config": {"capacity": 1000},
|
||||
"replay_buffer_config": {
|
||||
"capacity": 1000,
|
||||
"learning_starts": 10,
|
||||
},
|
||||
"min_time_s_per_reporting": 1,
|
||||
"learning_starts": 10,
|
||||
"target_network_update_freq": 100,
|
||||
"optimizer": {
|
||||
"num_replay_buffer_shards": 1,
|
||||
|
@ -111,10 +113,12 @@ class TestSupportedMultiAgentOffPolicy(unittest.TestCase):
|
|||
{
|
||||
"num_workers": 2,
|
||||
"min_sample_timesteps_per_reporting": 100,
|
||||
"replay_buffer_config": {"capacity": 1000},
|
||||
"replay_buffer_config": {
|
||||
"capacity": 1000,
|
||||
"learning_starts": 10,
|
||||
},
|
||||
"num_gpus": 0,
|
||||
"min_time_s_per_reporting": 1,
|
||||
"learning_starts": 10,
|
||||
"target_network_update_freq": 100,
|
||||
"use_state_preprocessor": True,
|
||||
},
|
||||
|
@ -125,9 +129,11 @@ class TestSupportedMultiAgentOffPolicy(unittest.TestCase):
|
|||
"DDPG",
|
||||
{
|
||||
"min_sample_timesteps_per_reporting": 1,
|
||||
"replay_buffer_config": {"capacity": 1000},
|
||||
"replay_buffer_config": {
|
||||
"capacity": 1000,
|
||||
"learning_starts": 500,
|
||||
},
|
||||
"use_state_preprocessor": True,
|
||||
"learning_starts": 500,
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -136,7 +142,9 @@ class TestSupportedMultiAgentOffPolicy(unittest.TestCase):
|
|||
"DQN",
|
||||
{
|
||||
"min_sample_timesteps_per_reporting": 1,
|
||||
"replay_buffer_config": {"capacity": 1000},
|
||||
"replay_buffer_config": {
|
||||
"capacity": 1000,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -145,7 +153,9 @@ class TestSupportedMultiAgentOffPolicy(unittest.TestCase):
|
|||
"SAC",
|
||||
{
|
||||
"num_workers": 0,
|
||||
"replay_buffer_config": {"capacity": 1000},
|
||||
"replay_buffer_config": {
|
||||
"capacity": 1000,
|
||||
},
|
||||
"normalize_actions": False,
|
||||
},
|
||||
)
|
||||
|
|
|
@ -182,18 +182,27 @@ class TestSupportedSpacesOffPolicy(unittest.TestCase):
|
|||
{
|
||||
"exploration_config": {"ou_base_scale": 100.0},
|
||||
"min_sample_timesteps_per_reporting": 1,
|
||||
"buffer_size": 1000,
|
||||
"replay_buffer_config": {
|
||||
"capacity": 1000,
|
||||
},
|
||||
"use_state_preprocessor": True,
|
||||
},
|
||||
check_bounds=True,
|
||||
)
|
||||
|
||||
def test_dqn(self):
|
||||
config = {"min_sample_timesteps_per_reporting": 1, "buffer_size": 1000}
|
||||
config = {
|
||||
"min_sample_timesteps_per_reporting": 1,
|
||||
"replay_buffer_config": {
|
||||
"capacity": 1000,
|
||||
},
|
||||
}
|
||||
check_support("DQN", config, tfe=True)
|
||||
|
||||
def test_sac(self):
|
||||
check_support("SAC", {"buffer_size": 1000}, check_bounds=True)
|
||||
check_support(
|
||||
"SAC", {"replay_buffer_config": {"capacity": 1000}}, check_bounds=True
|
||||
)
|
||||
|
||||
|
||||
class TestSupportedSpacesEvolutionAlgos(unittest.TestCase):
|
||||
|
|
|
@ -86,11 +86,13 @@ apex:
|
|||
lr: .0001
|
||||
adam_epsilon: .00015
|
||||
hiddens: [512]
|
||||
buffer_size: 1000000
|
||||
exploration_config:
|
||||
epsilon_timesteps: 200000
|
||||
final_epsilon: 0.01
|
||||
prioritized_replay_alpha: 0.5
|
||||
replay_buffer_config:
|
||||
type: MultiAgentPrioritizedReplayBuffer
|
||||
prioritized_replay_alpha: 0.5
|
||||
capacity: 1000000
|
||||
num_gpus: 1
|
||||
num_workers: 8
|
||||
num_envs_per_worker: 8
|
||||
|
@ -125,19 +127,19 @@ atari-basic-dqn:
|
|||
dueling: false
|
||||
num_atoms: 1
|
||||
noisy: false
|
||||
prioritized_replay: false
|
||||
replay_buffer_config:
|
||||
type: MultiAgentReplayBuffer
|
||||
learning_starts: 20000
|
||||
capacity: 1000000
|
||||
n_step: 1
|
||||
target_network_update_freq: 8000
|
||||
lr: .0000625
|
||||
adam_epsilon: .00015
|
||||
hiddens: [512]
|
||||
learning_starts: 20000
|
||||
buffer_size: 1000000
|
||||
rollout_fragment_length: 4
|
||||
train_batch_size: 32
|
||||
exploration_config:
|
||||
epsilon_timesteps: 200000
|
||||
final_epsilon: 0.01
|
||||
prioritized_replay_alpha: 0.5
|
||||
num_gpus: 0.2
|
||||
min_sample_timesteps_per_reporting: 10000
|
||||
|
|
|
@ -26,11 +26,12 @@ halfcheetah_bc:
|
|||
no_done_at_end: false
|
||||
n_step: 1
|
||||
rollout_fragment_length: 1
|
||||
prioritized_replay: false
|
||||
replay_buffer_config:
|
||||
type: MultiAgentReplayBuffer
|
||||
learning_starts: 10
|
||||
train_batch_size: 256
|
||||
target_network_update_freq: 0
|
||||
min_train_timesteps_per_reporting: 1000
|
||||
learning_starts: 10
|
||||
optimization:
|
||||
actor_learning_rate: 0.0001
|
||||
critic_learning_rate: 0.0003
|
||||
|
|
|
@ -28,11 +28,12 @@ halfcheetah_cql:
|
|||
no_done_at_end: false
|
||||
n_step: 3
|
||||
rollout_fragment_length: 1
|
||||
prioritized_replay: false
|
||||
replay_buffer_config:
|
||||
type: MultiAgentReplayBuffer
|
||||
learning_starts: 256
|
||||
train_batch_size: 256
|
||||
target_network_update_freq: 0
|
||||
min_train_timesteps_per_reporting: 1000
|
||||
learning_starts: 256
|
||||
optimization:
|
||||
actor_learning_rate: 0.0001
|
||||
critic_learning_rate: 0.0003
|
||||
|
|
|
@ -26,11 +26,12 @@ hopper_bc:
|
|||
no_done_at_end: false
|
||||
n_step: 1
|
||||
rollout_fragment_length: 1
|
||||
prioritized_replay: false
|
||||
replay_buffer_config:
|
||||
type: MultiAgentReplayBuffer
|
||||
learning_starts: 10
|
||||
train_batch_size: 256
|
||||
target_network_update_freq: 0
|
||||
min_train_timesteps_per_reporting: 1000
|
||||
learning_starts: 10
|
||||
optimization:
|
||||
actor_learning_rate: 0.0001
|
||||
critic_learning_rate: 0.0003
|
||||
|
|
|
@ -26,11 +26,12 @@ hopper_cql:
|
|||
no_done_at_end: false
|
||||
n_step: 1
|
||||
rollout_fragment_length: 1
|
||||
prioritized_replay: false
|
||||
replay_buffer_config:
|
||||
type: MultiAgentReplayBuffer
|
||||
learning_starts: 10
|
||||
train_batch_size: 256
|
||||
target_network_update_freq: 0
|
||||
min_train_timesteps_per_reporting: 1000
|
||||
learning_starts: 10
|
||||
optimization:
|
||||
actor_learning_rate: 0.0001
|
||||
critic_learning_rate: 0.0003
|
||||
|
|
|
@ -22,6 +22,7 @@ pendulum-cql:
|
|||
twin_q: true
|
||||
train_batch_size: 2000
|
||||
replay_buffer_config:
|
||||
type: MultiAgentReplayBuffer
|
||||
learning_starts: 0
|
||||
bc_iters: 100
|
||||
|
||||
|
|
|
@ -31,10 +31,12 @@ halfcheetah-ddpg:
|
|||
|
||||
# === Replay buffer ===
|
||||
replay_buffer_config:
|
||||
type: MultiAgentPrioritizedReplayBuffer
|
||||
capacity: 10000
|
||||
prioritized_replay_alpha: 0.6
|
||||
prioritized_replay_beta: 0.4
|
||||
prioritized_replay_eps: 0.000001
|
||||
worker_side_prioritization: false
|
||||
clip_rewards: False
|
||||
|
||||
# === Optimization ===
|
||||
|
@ -50,7 +52,6 @@ halfcheetah-ddpg:
|
|||
# === Parallelism ===
|
||||
num_workers: 0
|
||||
num_gpus_per_worker: 0
|
||||
worker_side_prioritization: false
|
||||
|
||||
# === Evaluation ===
|
||||
evaluation_interval: 5
|
||||
|
|
|
@ -23,10 +23,12 @@ ddpg-halfcheetahbulletenv-v0:
|
|||
target_network_update_freq: 0
|
||||
tau: 0.001
|
||||
replay_buffer_config:
|
||||
type: MultiAgentPrioritizedReplayBuffer
|
||||
capacity: 15000
|
||||
prioritized_replay_alpha: 0.6
|
||||
prioritized_replay_beta: 0.4
|
||||
prioritized_replay_eps: 0.000001
|
||||
worker_side_prioritization: false
|
||||
clip_rewards: false
|
||||
actor_lr: 0.001
|
||||
critic_lr: 0.001
|
||||
|
@ -39,4 +41,3 @@ ddpg-halfcheetahbulletenv-v0:
|
|||
num_workers: 0
|
||||
num_gpus: 1
|
||||
num_gpus_per_worker: 0
|
||||
worker_side_prioritization: false
|
||||
|
|
|
@ -26,19 +26,20 @@ ddpg-hopperbulletenv-v0:
|
|||
target_network_update_freq: 0
|
||||
tau: 0.001
|
||||
replay_buffer_config:
|
||||
type: MultiAgentPrioritizedReplayBuffer
|
||||
capacity: 10000
|
||||
prioritized_replay_alpha: 0.6
|
||||
prioritized_replay_beta: 0.4
|
||||
prioritized_replay_eps: 0.000001
|
||||
worker_side_prioritization: False
|
||||
learning_starts: 500
|
||||
clip_rewards: False
|
||||
actor_lr: 0.001
|
||||
critic_lr: 0.001
|
||||
use_huber: False
|
||||
huber_threshold: 1.0
|
||||
l2_reg: 0.000001
|
||||
learning_starts: 500
|
||||
rollout_fragment_length: 1
|
||||
train_batch_size: 48
|
||||
num_workers: 0
|
||||
num_gpus_per_worker: 0
|
||||
worker_side_prioritization: False
|
||||
num_gpus_per_worker: 0
|
|
@ -16,7 +16,8 @@ invertedpendulum-td3:
|
|||
critic_hiddens: [32, 32]
|
||||
|
||||
# === Exploration ===
|
||||
learning_starts: 1000
|
||||
replay_buffer_config:
|
||||
learning_starts: 1000
|
||||
exploration_config:
|
||||
random_timesteps: 1000
|
||||
|
||||
|
|
|
@ -9,4 +9,5 @@ memory-leak-test-ddpg:
|
|||
env_config:
|
||||
config:
|
||||
static_samples: true
|
||||
buffer_size: 500 # use small buffer to catch memory leaks
|
||||
replay_buffer_config:
|
||||
capacity: 500 # use small buffer to catch memory leaks
|
||||
|
|
|
@ -32,10 +32,12 @@ mountaincarcontinuous-ddpg:
|
|||
|
||||
# === Replay buffer ===
|
||||
replay_buffer_config:
|
||||
type: MultiAgentPrioritizedReplayBuffer
|
||||
capacity: 50000
|
||||
prioritized_replay_alpha: 0.6
|
||||
prioritized_replay_beta: 0.4
|
||||
prioritized_replay_eps: 0.000001
|
||||
worker_side_prioritization: False
|
||||
clip_rewards: False
|
||||
|
||||
# === Optimization ===
|
||||
|
@ -51,7 +53,6 @@ mountaincarcontinuous-ddpg:
|
|||
# === Parallelism ===
|
||||
num_workers: 0
|
||||
num_gpus_per_worker: 0
|
||||
worker_side_prioritization: False
|
||||
|
||||
# === Evaluation ===
|
||||
evaluation_interval: 5
|
||||
|
|
|
@ -18,10 +18,11 @@ mujoco-td3:
|
|||
# Works for both torch and tf.
|
||||
framework: tf
|
||||
# === Exploration ===
|
||||
learning_starts: 10000
|
||||
exploration_config:
|
||||
random_timesteps: 10000
|
||||
|
||||
replay_buffer_config:
|
||||
type: MultiAgentReplayBuffer
|
||||
learning_starts: 10000
|
||||
# === Evaluation ===
|
||||
evaluation_interval: 10
|
||||
evaluation_num_episodes: 10
|
||||
|
|
|
@ -17,13 +17,14 @@ pendulum-ddpg-fake-gpus:
|
|||
final_scale: 0.02
|
||||
min_sample_timesteps_per_reporting: 600
|
||||
replay_buffer_config:
|
||||
type: MultiAgentPrioritizedReplayBuffer
|
||||
capacity: 10000
|
||||
worker_side_prioritization: false
|
||||
learning_starts: 500
|
||||
clip_rewards: false
|
||||
use_huber: true
|
||||
learning_starts: 500
|
||||
train_batch_size: 64
|
||||
num_workers: 0
|
||||
worker_side_prioritization: false
|
||||
actor_lr: 0.0001
|
||||
critic_lr: 0.0001
|
||||
|
||||
|
|
|
@ -34,7 +34,9 @@ pendulum-ddpg:
|
|||
|
||||
# === Replay buffer ===
|
||||
replay_buffer_config:
|
||||
type: MultiAgentPrioritizedReplayBuffer
|
||||
capacity: 10000
|
||||
worker_side_prioritization: False
|
||||
clip_rewards: False
|
||||
|
||||
# === Optimization ===
|
||||
|
@ -49,4 +51,3 @@ pendulum-ddpg:
|
|||
|
||||
# === Parallelism ===
|
||||
num_workers: 0
|
||||
worker_side_prioritization: False
|
||||
|
|
|
@ -9,7 +9,10 @@ pendulum-td3-fake-gpus:
|
|||
framework: tf
|
||||
actor_hiddens: [64, 64]
|
||||
critic_hiddens: [64, 64]
|
||||
learning_starts: 5000
|
||||
|
||||
replay_buffer_config:
|
||||
type: MultiAgentReplayBuffer
|
||||
learning_starts: 5000
|
||||
exploration_config:
|
||||
random_timesteps: 5000
|
||||
evaluation_interval: 10
|
||||
|
|
|
@ -12,6 +12,8 @@ pendulum-td3:
|
|||
actor_hiddens: [64, 64]
|
||||
critic_hiddens: [64, 64]
|
||||
# === Exploration ===
|
||||
learning_starts: 5000
|
||||
replay_buffer_config:
|
||||
type: MultiAgentReplayBuffer
|
||||
learning_starts: 5000
|
||||
exploration_config:
|
||||
random_timesteps: 5000
|
||||
|
|
|
@ -15,11 +15,12 @@ apex-breakoutnoframeskip-v4:
|
|||
lr: .0001
|
||||
adam_epsilon: .00015
|
||||
hiddens: [512]
|
||||
buffer_size: 1000000
|
||||
replay_buffer_config:
|
||||
capacity: 1000000
|
||||
prioritized_replay_alpha: 0.5
|
||||
exploration_config:
|
||||
epsilon_timesteps: 200000
|
||||
final_epsilon: 0.01
|
||||
prioritized_replay_alpha: 0.5
|
||||
num_gpus: 1
|
||||
num_workers: 8
|
||||
num_envs_per_worker: 8
|
||||
|
|
|
@ -11,19 +11,19 @@ atari-dist-dqn:
|
|||
dueling: false
|
||||
num_atoms: 51
|
||||
noisy: false
|
||||
prioritized_replay: false
|
||||
replay_buffer_config:
|
||||
type: MultiAgentReplayBuffer
|
||||
capacity: 1000000
|
||||
learning_starts: 20000
|
||||
n_step: 1
|
||||
target_network_update_freq: 8000
|
||||
lr: .0000625
|
||||
adam_epsilon: .00015
|
||||
hiddens: [512]
|
||||
learning_starts: 20000
|
||||
buffer_size: 1000000
|
||||
rollout_fragment_length: 4
|
||||
train_batch_size: 32
|
||||
exploration_config:
|
||||
epsilon_timesteps: 200000
|
||||
final_epsilon: 0.01
|
||||
prioritized_replay_alpha: 0.5
|
||||
num_gpus: 0.2
|
||||
min_sample_timesteps_per_reporting: 10000
|
||||
|
|
|
@ -15,19 +15,19 @@ atari-basic-dqn:
|
|||
dueling: false
|
||||
num_atoms: 1
|
||||
noisy: false
|
||||
prioritized_replay: false
|
||||
replay_buffer_config:
|
||||
type: MultiAgentReplayBuffer
|
||||
learning_starts: 20000
|
||||
capacity: 1000000
|
||||
n_step: 1
|
||||
target_network_update_freq: 8000
|
||||
lr: .0000625
|
||||
adam_epsilon: .00015
|
||||
hiddens: [512]
|
||||
learning_starts: 20000
|
||||
buffer_size: 1000000
|
||||
rollout_fragment_length: 4
|
||||
train_batch_size: 32
|
||||
exploration_config:
|
||||
epsilon_timesteps: 200000
|
||||
final_epsilon: 0.01
|
||||
prioritized_replay_alpha: 0.5
|
||||
num_gpus: 0.2
|
||||
min_sample_timesteps_per_reporting: 10000
|
||||
|
|
|
@ -15,19 +15,19 @@ dueling-ddqn:
|
|||
dueling: true
|
||||
num_atoms: 1
|
||||
noisy: false
|
||||
prioritized_replay: false
|
||||
replay_buffer_config:
|
||||
type: MultiAgentReplayBuffer
|
||||
learning_starts: 20000
|
||||
capacity: 1000000
|
||||
n_step: 1
|
||||
target_network_update_freq: 8000
|
||||
lr: .0000625
|
||||
adam_epsilon: .00015
|
||||
hiddens: [512]
|
||||
learning_starts: 20000
|
||||
buffer_size: 1000000
|
||||
rollout_fragment_length: 4
|
||||
train_batch_size: 32
|
||||
exploration_config:
|
||||
epsilon_timesteps: 200000
|
||||
final_epsilon: 0.01
|
||||
prioritized_replay_alpha: 0.5
|
||||
num_gpus: 0.2
|
||||
min_sample_timesteps_per_reporting: 10000
|
||||
|
|
|
@ -17,13 +17,15 @@ cartpole-apex-dqn-training-itr:
|
|||
# Make this work with only 5 CPUs and 0 GPUs:
|
||||
num_workers: 3
|
||||
optimizer:
|
||||
num_replay_buffer_shards: 2
|
||||
num_replay_buffer_shards: 2
|
||||
replay_buffer_config:
|
||||
type: MultiAgentPrioritizedReplayBuffer
|
||||
capacity: 20000
|
||||
learning_starts: 1000
|
||||
|
||||
num_gpus: 0
|
||||
|
||||
min_time_s_per_reporting: 5
|
||||
target_network_update_freq: 500
|
||||
learning_starts: 1000
|
||||
min_sample_timesteps_per_reporting: 1000
|
||||
buffer_size: 20000
|
||||
training_intensity: 4
|
|
@ -9,4 +9,5 @@ memory-leak-test-dqn:
|
|||
env_config:
|
||||
config:
|
||||
static_samples: true
|
||||
buffer_size: 500 # use small buffer to catch memory leaks
|
||||
replay_buffer_config:
|
||||
capacity: 500 # use small buffer to catch memory leaks
|
||||
|
|
|
@ -15,6 +15,8 @@ pong-apex:
|
|||
num_envs_per_worker: 8
|
||||
lr: .00005
|
||||
train_batch_size: 64
|
||||
buffer_size: 1000000
|
||||
replay_buffer_config:
|
||||
type: MultiAgentPrioritizedReplayBuffer
|
||||
capacity: 1000000
|
||||
gamma: 0.99
|
||||
training_intensity: 16
|
||||
|
|
|
@ -11,8 +11,10 @@ pong-deterministic-dqn:
|
|||
num_gpus: 1
|
||||
gamma: 0.99
|
||||
lr: .0001
|
||||
learning_starts: 10000
|
||||
buffer_size: 50000
|
||||
replay_buffer_config:
|
||||
type: MultiAgentPrioritizedReplayBuffer
|
||||
capacity: 50000
|
||||
learning_starts: 10000
|
||||
rollout_fragment_length: 4
|
||||
train_batch_size: 32
|
||||
exploration_config:
|
||||
|
|
|
@ -9,16 +9,17 @@ pong-deterministic-rainbow:
|
|||
gamma: 0.99
|
||||
lr: .0001
|
||||
hiddens: [512]
|
||||
learning_starts: 10000
|
||||
buffer_size: 50000
|
||||
rollout_fragment_length: 4
|
||||
train_batch_size: 32
|
||||
exploration_config:
|
||||
epsilon_timesteps: 2
|
||||
final_epsilon: 0.0
|
||||
target_network_update_freq: 500
|
||||
prioritized_replay: True
|
||||
prioritized_replay_alpha: 0.5
|
||||
replay_buffer_config:
|
||||
type: MultiAgentPrioritizedReplayBuffer
|
||||
prioritized_replay_alpha: 0.5
|
||||
learning_starts: 10000
|
||||
capacity: 50000
|
||||
n_step: 3
|
||||
gpu: True
|
||||
model:
|
||||
|
|
|
@ -10,6 +10,7 @@ stateless-cartpole-r2d2:
|
|||
num_workers: 0
|
||||
# R2D2 settings.
|
||||
replay_buffer_config:
|
||||
type: MultiAgentReplayBuffer
|
||||
replay_burn_in: 20
|
||||
zero_init_states: true
|
||||
#dueling: false
|
||||
|
|
|
@ -10,6 +10,7 @@ stateless-cartpole-r2d2:
|
|||
num_workers: 0
|
||||
# R2D2 settings.
|
||||
replay_buffer_config:
|
||||
type: MultiAgentReplayBuffer
|
||||
replay_burn_in: 20
|
||||
zero_init_states: true
|
||||
#dueling: false
|
||||
|
|
|
@ -31,23 +31,19 @@ atari-sac-tf-and-torch:
|
|||
n_step: 1
|
||||
rollout_fragment_length: 1
|
||||
replay_buffer_config:
|
||||
_enable_replay_buffer_api: False
|
||||
type: MultiAgentReplayBuffer
|
||||
type: MultiAgentPrioritizedReplayBuffer
|
||||
capacity: 1000000
|
||||
# How many steps of the model to sample before learning starts.
|
||||
learning_starts: 1500
|
||||
learning_starts: 100000
|
||||
# If True prioritized replay buffer will be used.
|
||||
prioritized_replay: True
|
||||
prioritized_replay_alpha: 0.6
|
||||
prioritized_replay_beta: 0.4
|
||||
prioritized_replay_eps: 1e-6
|
||||
|
||||
train_batch_size: 64
|
||||
min_sample_timesteps_per_reporting: 4
|
||||
# Paper uses 20k random timesteps, which is not exactly the same, but
|
||||
# seems to work nevertheless. We use 100k here for the longer Atari
|
||||
# runs (DQN style: filling up the buffer a bit before learning).
|
||||
learning_starts: 100000
|
||||
optimization:
|
||||
actor_learning_rate: 0.0003
|
||||
critic_learning_rate: 0.0003
|
||||
|
|
|
@ -12,9 +12,10 @@ cartpole-sac:
|
|||
horizon: 200
|
||||
soft_horizon: true
|
||||
n_step: 3
|
||||
prioritized_replay: true
|
||||
replay_buffer_config:
|
||||
type: MultiAgentPrioritizedReplayBuffer
|
||||
learning_starts: 256
|
||||
initial_alpha: 0.2
|
||||
learning_starts: 256
|
||||
clip_actions: false
|
||||
min_sample_timesteps_per_reporting: 1000
|
||||
optimization:
|
||||
|
|
|
@ -19,11 +19,12 @@ halfcheetah-pybullet-sac:
|
|||
no_done_at_end: false
|
||||
n_step: 3
|
||||
rollout_fragment_length: 1
|
||||
prioritized_replay: true
|
||||
train_batch_size: 256
|
||||
target_network_update_freq: 1
|
||||
min_sample_timesteps_per_reporting: 1000
|
||||
learning_starts: 10000
|
||||
replay_buffer_config:
|
||||
type: MultiAgentPrioritizedReplayBuffer
|
||||
learning_starts: 10000
|
||||
optimization:
|
||||
actor_learning_rate: 0.0003
|
||||
critic_learning_rate: 0.0003
|
||||
|
|
|
@ -20,11 +20,12 @@ halfcheetah_sac:
|
|||
no_done_at_end: true
|
||||
n_step: 1
|
||||
rollout_fragment_length: 1
|
||||
prioritized_replay: true
|
||||
train_batch_size: 256
|
||||
target_network_update_freq: 1
|
||||
min_sample_timesteps_per_reporting: 1000
|
||||
learning_starts: 10000
|
||||
replay_buffer_config:
|
||||
type: MultiAgentPrioritizedReplayBuffer
|
||||
learning_starts: 10000
|
||||
optimization:
|
||||
actor_learning_rate: 0.0003
|
||||
critic_learning_rate: 0.0003
|
||||
|
|
|
@ -9,4 +9,5 @@ memory-leak-test-sac:
|
|||
env_config:
|
||||
config:
|
||||
static_samples: true
|
||||
buffer_size: 500 # use small buffer to catch memory leaks
|
||||
replay_buffer_config:
|
||||
capacity: 500 # use small buffer to catch memory leaks
|
||||
|
|
|
@ -27,12 +27,13 @@ mspacman-sac-tf:
|
|||
no_done_at_end: False
|
||||
n_step: 1
|
||||
rollout_fragment_length: 1
|
||||
prioritized_replay: true
|
||||
train_batch_size: 64
|
||||
min_sample_timesteps_per_reporting: 4
|
||||
# Paper uses 20k random timesteps, which is not exactly the same, but
|
||||
# seems to work nevertheless.
|
||||
learning_starts: 20000
|
||||
replay_buffer_config:
|
||||
type: MultiAgentPrioritizedReplayBuffer
|
||||
learning_starts: 20000
|
||||
optimization:
|
||||
actor_learning_rate: 0.0003
|
||||
critic_learning_rate: 0.0003
|
||||
|
|
|
@ -21,11 +21,12 @@ pendulum-sac-fake-gpus:
|
|||
no_done_at_end: true
|
||||
n_step: 1
|
||||
rollout_fragment_length: 1
|
||||
prioritized_replay: true
|
||||
train_batch_size: 256
|
||||
target_network_update_freq: 1
|
||||
min_sample_timesteps_per_reporting: 1000
|
||||
learning_starts: 256
|
||||
replay_buffer_config:
|
||||
type: MultiAgentPrioritizedReplayBuffer
|
||||
learning_starts: 256
|
||||
num_workers: 0
|
||||
metrics_smoothing_episodes: 5
|
||||
|
||||
|
|
|
@ -23,11 +23,12 @@ pendulum-sac:
|
|||
no_done_at_end: true
|
||||
n_step: 1
|
||||
rollout_fragment_length: 1
|
||||
prioritized_replay: true
|
||||
train_batch_size: 256
|
||||
target_network_update_freq: 1
|
||||
min_sample_timesteps_per_reporting: 1000
|
||||
learning_starts: 256
|
||||
replay_buffer_config:
|
||||
type: MultiAgentPrioritizedReplayBuffer
|
||||
learning_starts: 256
|
||||
optimization:
|
||||
actor_learning_rate: 0.0003
|
||||
critic_learning_rate: 0.0003
|
||||
|
|
|
@ -30,11 +30,12 @@ transformed-actions-pendulum-sac-dummy-torch:
|
|||
no_done_at_end: true
|
||||
n_step: 1
|
||||
rollout_fragment_length: 1
|
||||
prioritized_replay: true
|
||||
train_batch_size: 256
|
||||
target_network_update_freq: 1
|
||||
min_sample_timesteps_per_reporting: 1000
|
||||
learning_starts: 256
|
||||
replay_buffer_config:
|
||||
type: MultiAgentPrioritizedReplayBuffer
|
||||
learning_starts: 256
|
||||
optimization:
|
||||
actor_learning_rate: 0.0003
|
||||
critic_learning_rate: 0.0003
|
||||
|
|
|
@ -9,7 +9,7 @@ from ray.rllib.policy.sample_batch import (
|
|||
SampleBatch,
|
||||
MultiAgentBatch,
|
||||
)
|
||||
from ray.rllib.utils.annotations import override, ExperimentalAPI
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.replay_buffers.multi_agent_prioritized_replay_buffer import (
|
||||
MultiAgentPrioritizedReplayBuffer,
|
||||
)
|
||||
|
@ -21,13 +21,14 @@ from ray.rllib.utils.replay_buffers.multi_agent_replay_buffer import (
|
|||
ReplayMode,
|
||||
)
|
||||
from ray.rllib.utils.typing import PolicyID, SampleBatchType
|
||||
from ray.rllib.execution.buffers.replay_buffer import _ALL_POLICIES
|
||||
from ray.rllib.utils.replay_buffers.replay_buffer import _ALL_POLICIES
|
||||
from ray.util.debug import log_once
|
||||
from ray.util.annotations import DeveloperAPI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ExperimentalAPI
|
||||
@DeveloperAPI
|
||||
class MultiAgentMixInReplayBuffer(MultiAgentPrioritizedReplayBuffer):
|
||||
"""This buffer adds replayed samples to a stream of new experiences.
|
||||
|
||||
|
@ -168,7 +169,7 @@ class MultiAgentMixInReplayBuffer(MultiAgentPrioritizedReplayBuffer):
|
|||
|
||||
self.last_added_batches = collections.defaultdict(list)
|
||||
|
||||
@ExperimentalAPI
|
||||
@DeveloperAPI
|
||||
@override(MultiAgentPrioritizedReplayBuffer)
|
||||
def add(self, batch: SampleBatchType, **kwargs) -> None:
|
||||
"""Adds a batch to the appropriate policy's replay buffer.
|
||||
|
@ -244,7 +245,7 @@ class MultiAgentMixInReplayBuffer(MultiAgentPrioritizedReplayBuffer):
|
|||
|
||||
self._num_added += batch.count
|
||||
|
||||
@ExperimentalAPI
|
||||
@DeveloperAPI
|
||||
@override(MultiAgentReplayBuffer)
|
||||
def sample(
|
||||
self, num_items: int, policy_id: PolicyID = DEFAULT_POLICY_ID, **kwargs
|
||||
|
@ -353,7 +354,7 @@ class MultiAgentMixInReplayBuffer(MultiAgentPrioritizedReplayBuffer):
|
|||
|
||||
return MultiAgentBatch.concat_samples(samples)
|
||||
|
||||
@ExperimentalAPI
|
||||
@DeveloperAPI
|
||||
@override(MultiAgentPrioritizedReplayBuffer)
|
||||
def get_state(self) -> Dict[str, Any]:
|
||||
"""Returns all local state.
|
||||
|
@ -368,7 +369,7 @@ class MultiAgentMixInReplayBuffer(MultiAgentPrioritizedReplayBuffer):
|
|||
parent.update(data)
|
||||
return parent
|
||||
|
||||
@ExperimentalAPI
|
||||
@DeveloperAPI
|
||||
@override(MultiAgentPrioritizedReplayBuffer)
|
||||
def set_state(self, state: Dict[str, Any]) -> None:
|
||||
"""Restores all local state to the provided `state`.
|
||||
|
|
|
@ -2,9 +2,8 @@ from typing import Dict
|
|||
import logging
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray.rllib.policy.rnn_sequencing import timeslice_along_seq_lens_with_overlap
|
||||
from ray.rllib.utils.annotations import override, ExperimentalAPI
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.replay_buffers.multi_agent_replay_buffer import (
|
||||
MultiAgentReplayBuffer,
|
||||
ReplayMode,
|
||||
|
@ -16,12 +15,15 @@ from ray.rllib.utils.replay_buffers.replay_buffer import StorageUnit
|
|||
from ray.rllib.utils.typing import PolicyID, SampleBatchType
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
from ray.util.debug import log_once
|
||||
from ray.util.annotations import DeveloperAPI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ExperimentalAPI
|
||||
class MultiAgentPrioritizedReplayBuffer(MultiAgentReplayBuffer):
|
||||
@DeveloperAPI
|
||||
class MultiAgentPrioritizedReplayBuffer(
|
||||
MultiAgentReplayBuffer, PrioritizedReplayBuffer
|
||||
):
|
||||
"""A prioritized replay buffer shard for multiagent setups.
|
||||
|
||||
This buffer is meant to be run in parallel to distribute experiences
|
||||
|
@ -141,7 +143,7 @@ class MultiAgentPrioritizedReplayBuffer(MultiAgentReplayBuffer):
|
|||
self.prioritized_replay_eps = prioritized_replay_eps
|
||||
self.update_priorities_timer = TimerStat()
|
||||
|
||||
@ExperimentalAPI
|
||||
@DeveloperAPI
|
||||
@override(MultiAgentReplayBuffer)
|
||||
def _add_to_underlying_buffer(
|
||||
self, policy_id: PolicyID, batch: SampleBatchType, **kwargs
|
||||
|
@ -206,7 +208,8 @@ class MultiAgentPrioritizedReplayBuffer(MultiAgentReplayBuffer):
|
|||
else:
|
||||
self.replay_buffers[policy_id].add(batch, **kwargs)
|
||||
|
||||
@ExperimentalAPI
|
||||
@DeveloperAPI
|
||||
@override(PrioritizedReplayBuffer)
|
||||
def update_priorities(self, prio_dict: Dict) -> None:
|
||||
"""Updates the priorities of underlying replay buffers.
|
||||
|
||||
|
@ -225,7 +228,7 @@ class MultiAgentPrioritizedReplayBuffer(MultiAgentReplayBuffer):
|
|||
batch_indexes, new_priorities
|
||||
)
|
||||
|
||||
@ExperimentalAPI
|
||||
@DeveloperAPI
|
||||
@override(MultiAgentReplayBuffer)
|
||||
def stats(self, debug: bool = False) -> Dict:
|
||||
"""Returns the stats of this buffer and all underlying buffers.
|
||||
|
@ -249,6 +252,3 @@ class MultiAgentPrioritizedReplayBuffer(MultiAgentReplayBuffer):
|
|||
{"policy_{}".format(policy_id): replay_buffer.stats(debug=debug)}
|
||||
)
|
||||
return stat
|
||||
|
||||
|
||||
ReplayActor = ray.remote(num_cpus=0)(MultiAgentPrioritizedReplayBuffer)
|
||||
|
|
|
@ -3,11 +3,10 @@ import collections
|
|||
from typing import Any, Dict, Optional
|
||||
from enum import Enum
|
||||
|
||||
import ray
|
||||
from ray.rllib.utils.replay_buffers.replay_buffer import _ALL_POLICIES
|
||||
from ray.rllib.policy.rnn_sequencing import timeslice_along_seq_lens_with_overlap
|
||||
from ray.rllib.policy.sample_batch import MultiAgentBatch
|
||||
from ray.rllib.utils.annotations import override, ExperimentalAPI
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.replay_buffers.replay_buffer import ReplayBuffer
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
from ray.rllib.utils.typing import PolicyID, SampleBatchType
|
||||
|
@ -15,17 +14,18 @@ from ray.rllib.utils.replay_buffers.replay_buffer import StorageUnit
|
|||
from ray.rllib.utils.from_config import from_config
|
||||
from ray.util.debug import log_once
|
||||
from ray.rllib.utils.deprecation import Deprecated
|
||||
from ray.util.annotations import DeveloperAPI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ExperimentalAPI
|
||||
@DeveloperAPI
|
||||
class ReplayMode(Enum):
|
||||
LOCKSTEP = "lockstep"
|
||||
INDEPENDENT = "independent"
|
||||
|
||||
|
||||
@ExperimentalAPI
|
||||
@DeveloperAPI
|
||||
def merge_dicts_with_warning(args_on_init, args_on_call):
|
||||
"""Merge argument dicts, overwriting args_on_call with warning.
|
||||
|
||||
|
@ -50,7 +50,7 @@ def merge_dicts_with_warning(args_on_init, args_on_call):
|
|||
return {**args_on_init, **args_on_call}
|
||||
|
||||
|
||||
@ExperimentalAPI
|
||||
@DeveloperAPI
|
||||
class MultiAgentReplayBuffer(ReplayBuffer):
|
||||
"""A replay buffer shard for multiagent setups.
|
||||
|
||||
|
@ -183,7 +183,7 @@ class MultiAgentReplayBuffer(ReplayBuffer):
|
|||
"""Returns the number of items currently stored in this buffer."""
|
||||
return sum(len(buffer._storage) for buffer in self.replay_buffers.values())
|
||||
|
||||
@ExperimentalAPI
|
||||
@DeveloperAPI
|
||||
@Deprecated(old="replay", new="sample", error=False)
|
||||
def replay(self, num_items: int = None, **kwargs) -> Optional[SampleBatchType]:
|
||||
"""Deprecated in favor of new ReplayBuffer API."""
|
||||
|
@ -191,7 +191,7 @@ class MultiAgentReplayBuffer(ReplayBuffer):
|
|||
num_items = self.replay_batch_size
|
||||
return self.sample(num_items, **kwargs)
|
||||
|
||||
@ExperimentalAPI
|
||||
@DeveloperAPI
|
||||
@override(ReplayBuffer)
|
||||
def add(self, batch: SampleBatchType, **kwargs) -> None:
|
||||
"""Adds a batch to the appropriate policy's replay buffer.
|
||||
|
@ -229,7 +229,7 @@ class MultiAgentReplayBuffer(ReplayBuffer):
|
|||
self._add_to_underlying_buffer(policy_id, sample_batch, **kwargs)
|
||||
self._num_added += batch.count
|
||||
|
||||
@ExperimentalAPI
|
||||
@DeveloperAPI
|
||||
def _add_to_underlying_buffer(
|
||||
self, policy_id: PolicyID, batch: SampleBatchType, **kwargs
|
||||
) -> None:
|
||||
|
@ -266,7 +266,7 @@ class MultiAgentReplayBuffer(ReplayBuffer):
|
|||
else:
|
||||
self.replay_buffers[policy_id].add(batch, **kwargs)
|
||||
|
||||
@ExperimentalAPI
|
||||
@DeveloperAPI
|
||||
@override(ReplayBuffer)
|
||||
def sample(
|
||||
self, num_items: int, policy_id: Optional[PolicyID] = None, **kwargs
|
||||
|
@ -310,7 +310,7 @@ class MultiAgentReplayBuffer(ReplayBuffer):
|
|||
samples[policy_id] = replay_buffer.sample(num_items, **kwargs)
|
||||
return MultiAgentBatch(samples, sum(s.count for s in samples.values()))
|
||||
|
||||
@ExperimentalAPI
|
||||
@DeveloperAPI
|
||||
@override(ReplayBuffer)
|
||||
def stats(self, debug: bool = False) -> Dict:
|
||||
"""Returns the stats of this buffer and all underlying buffers.
|
||||
|
@ -332,7 +332,7 @@ class MultiAgentReplayBuffer(ReplayBuffer):
|
|||
)
|
||||
return stat
|
||||
|
||||
@ExperimentalAPI
|
||||
@DeveloperAPI
|
||||
@override(ReplayBuffer)
|
||||
def get_state(self) -> Dict[str, Any]:
|
||||
"""Returns all local state.
|
||||
|
@ -345,7 +345,7 @@ class MultiAgentReplayBuffer(ReplayBuffer):
|
|||
state["replay_buffers"][policy_id] = replay_buffer.get_state()
|
||||
return state
|
||||
|
||||
@ExperimentalAPI
|
||||
@DeveloperAPI
|
||||
@override(ReplayBuffer)
|
||||
def set_state(self, state: Dict[str, Any]) -> None:
|
||||
"""Restores all local state to the provided `state`.
|
||||
|
@ -358,6 +358,3 @@ class MultiAgentReplayBuffer(ReplayBuffer):
|
|||
buffer_states = state["replay_buffers"]
|
||||
for policy_id in buffer_states.keys():
|
||||
self.replay_buffers[policy_id].set_state(buffer_states[policy_id])
|
||||
|
||||
|
||||
ReplayActor = ray.remote(num_cpus=0)(MultiAgentReplayBuffer)
|
||||
|
|
|
@ -8,13 +8,14 @@ import psutil # noqa E402
|
|||
|
||||
from ray.rllib.execution.segment_tree import SumSegmentTree, MinSegmentTree
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.annotations import override, ExperimentalAPI
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.metrics.window_stat import WindowStat
|
||||
from ray.rllib.utils.replay_buffers.replay_buffer import ReplayBuffer
|
||||
from ray.rllib.utils.typing import SampleBatchType
|
||||
from ray.util.annotations import DeveloperAPI
|
||||
|
||||
|
||||
@ExperimentalAPI
|
||||
@DeveloperAPI
|
||||
class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
"""This buffer implements Prioritized Experience Replay
|
||||
|
||||
|
@ -23,7 +24,6 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|||
the full paper.
|
||||
"""
|
||||
|
||||
@ExperimentalAPI
|
||||
def __init__(
|
||||
self,
|
||||
capacity: int = 10000,
|
||||
|
@ -58,7 +58,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|||
self._max_priority = 1.0
|
||||
self._prio_change_stats = WindowStat("reprio", 1000)
|
||||
|
||||
@ExperimentalAPI
|
||||
@DeveloperAPI
|
||||
@override(ReplayBuffer)
|
||||
def _add_single_batch(self, item: SampleBatchType, **kwargs) -> None:
|
||||
"""Add a batch of experiences to self._storage with weight.
|
||||
|
@ -90,7 +90,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|||
res.append(idx)
|
||||
return res
|
||||
|
||||
@ExperimentalAPI
|
||||
@DeveloperAPI
|
||||
@override(ReplayBuffer)
|
||||
def sample(
|
||||
self, num_items: int, beta: float, **kwargs
|
||||
|
@ -157,7 +157,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|||
|
||||
return batch
|
||||
|
||||
@ExperimentalAPI
|
||||
@DeveloperAPI
|
||||
def update_priorities(self, idxes: List[int], priorities: List[float]) -> None:
|
||||
"""Update priorities of items at given indices.
|
||||
|
||||
|
@ -186,7 +186,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|||
|
||||
self._max_priority = max(self._max_priority, priority)
|
||||
|
||||
@ExperimentalAPI
|
||||
@DeveloperAPI
|
||||
@override(ReplayBuffer)
|
||||
def stats(self, debug: bool = False) -> Dict:
|
||||
"""Returns the stats of this buffer.
|
||||
|
@ -203,7 +203,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|||
parent.update(self._prio_change_stats.stats())
|
||||
return parent
|
||||
|
||||
@ExperimentalAPI
|
||||
@DeveloperAPI
|
||||
@override(ReplayBuffer)
|
||||
def get_state(self) -> Dict[str, Any]:
|
||||
"""Returns all local state.
|
||||
|
@ -223,7 +223,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|||
)
|
||||
return state
|
||||
|
||||
@ExperimentalAPI
|
||||
@DeveloperAPI
|
||||
@override(ReplayBuffer)
|
||||
def set_state(self, state: Dict[str, Any]) -> None:
|
||||
"""Restores all local state to the provided `state`.
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import logging
|
||||
import platform
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Callable
|
||||
|
||||
import numpy as np
|
||||
import random
|
||||
|
@ -12,11 +12,11 @@ import psutil # noqa E402
|
|||
|
||||
from ray.util.debug import log_once
|
||||
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
|
||||
from ray.rllib.utils.annotations import ExperimentalAPI
|
||||
from ray.rllib.utils.deprecation import Deprecated
|
||||
from ray.rllib.utils.metrics.window_stat import WindowStat
|
||||
from ray.rllib.utils.typing import SampleBatchType
|
||||
from ray.rllib.execution.buffers.replay_buffer import warn_replay_capacity
|
||||
from ray.rllib.utils.typing import SampleBatchType, T
|
||||
from ray.util.annotations import DeveloperAPI
|
||||
from ray.util.iter import ParallelIteratorWorker
|
||||
|
||||
# Constant that represents all policies in lockstep replay mode.
|
||||
_ALL_POLICIES = "__all__"
|
||||
|
@ -24,7 +24,7 @@ _ALL_POLICIES = "__all__"
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ExperimentalAPI
|
||||
@DeveloperAPI
|
||||
class StorageUnit(Enum):
|
||||
TIMESTEPS = "timesteps"
|
||||
SEQUENCES = "sequences"
|
||||
|
@ -32,8 +32,31 @@ class StorageUnit(Enum):
|
|||
FRAGMENTS = "fragments"
|
||||
|
||||
|
||||
@ExperimentalAPI
|
||||
class ReplayBuffer:
|
||||
@DeveloperAPI
|
||||
def warn_replay_capacity(*, item: SampleBatchType, num_items: int) -> None:
|
||||
"""Warn if the configured replay buffer capacity is too large."""
|
||||
if log_once("replay_capacity"):
|
||||
item_size = item.size_bytes()
|
||||
psutil_mem = psutil.virtual_memory()
|
||||
total_gb = psutil_mem.total / 1e9
|
||||
mem_size = num_items * item_size / 1e9
|
||||
msg = (
|
||||
"Estimated max memory usage for replay buffer is {} GB "
|
||||
"({} batches of size {}, {} bytes each), "
|
||||
"available system memory is {} GB".format(
|
||||
mem_size, num_items, item.count, item_size, total_gb
|
||||
)
|
||||
)
|
||||
if mem_size > total_gb:
|
||||
raise ValueError(msg)
|
||||
elif mem_size > 0.2 * total_gb:
|
||||
logger.warning(msg)
|
||||
else:
|
||||
logger.info(msg)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class ReplayBuffer(ParallelIteratorWorker):
|
||||
def __init__(
|
||||
self, capacity: int = 10000, storage_unit: str = "timesteps", **kwargs
|
||||
):
|
||||
|
@ -96,11 +119,17 @@ class ReplayBuffer:
|
|||
|
||||
self.batch_size = None
|
||||
|
||||
def gen_replay():
|
||||
while True:
|
||||
yield self.replay()
|
||||
|
||||
ParallelIteratorWorker.__init__(self, gen_replay, False)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Returns the number of items currently stored in this buffer."""
|
||||
return len(self._storage)
|
||||
|
||||
@ExperimentalAPI
|
||||
@DeveloperAPI
|
||||
def add(self, batch: SampleBatchType, **kwargs) -> None:
|
||||
"""Adds a batch of experiences to this buffer.
|
||||
|
||||
|
@ -155,7 +184,7 @@ class ReplayBuffer:
|
|||
elif self._storage_unit == StorageUnit.FRAGMENTS:
|
||||
self._add_single_batch(batch, **kwargs)
|
||||
|
||||
@ExperimentalAPI
|
||||
@DeveloperAPI
|
||||
def _add_single_batch(self, item: SampleBatchType, **kwargs) -> None:
|
||||
"""Add a SampleBatch of experiences to self._storage.
|
||||
|
||||
|
@ -189,7 +218,7 @@ class ReplayBuffer:
|
|||
else:
|
||||
self._next_idx += 1
|
||||
|
||||
@ExperimentalAPI
|
||||
@DeveloperAPI
|
||||
def sample(self, num_items: int, **kwargs) -> Optional[SampleBatchType]:
|
||||
"""Samples `num_items` items from this buffer.
|
||||
|
||||
|
@ -220,7 +249,7 @@ class ReplayBuffer:
|
|||
self._num_timesteps_sampled += sample.count
|
||||
return sample
|
||||
|
||||
@ExperimentalAPI
|
||||
@DeveloperAPI
|
||||
def stats(self, debug: bool = False) -> dict:
|
||||
"""Returns the stats of this buffer.
|
||||
|
||||
|
@ -243,7 +272,7 @@ class ReplayBuffer:
|
|||
data.update(self._evicted_hit_stats.stats())
|
||||
return data
|
||||
|
||||
@ExperimentalAPI
|
||||
@DeveloperAPI
|
||||
def get_state(self) -> Dict[str, Any]:
|
||||
"""Returns all local state.
|
||||
|
||||
|
@ -254,7 +283,7 @@ class ReplayBuffer:
|
|||
state.update(self.stats(debug=False))
|
||||
return state
|
||||
|
||||
@ExperimentalAPI
|
||||
@DeveloperAPI
|
||||
def set_state(self, state: Dict[str, Any]) -> None:
|
||||
"""Restores all local state to the provided `state`.
|
||||
|
||||
|
@ -272,6 +301,7 @@ class ReplayBuffer:
|
|||
self._num_timesteps_sampled = state["sampled_count"]
|
||||
self._est_size_bytes = state["est_size_bytes"]
|
||||
|
||||
@DeveloperAPI
|
||||
def _encode_sample(self, idxes: List[int]) -> SampleBatchType:
|
||||
"""Fetches concatenated samples at given indeces from the storage."""
|
||||
samples = []
|
||||
|
@ -288,6 +318,7 @@ class ReplayBuffer:
|
|||
out.decompress_if_needed()
|
||||
return out
|
||||
|
||||
@DeveloperAPI
|
||||
def get_host(self) -> str:
|
||||
"""Returns the computer's network name.
|
||||
|
||||
|
@ -297,10 +328,35 @@ class ReplayBuffer:
|
|||
"""
|
||||
return platform.node()
|
||||
|
||||
@DeveloperAPI
|
||||
def apply(
|
||||
self,
|
||||
func: Callable[["ReplayBuffer", Optional[Any], Optional[Any]], T],
|
||||
*_args,
|
||||
**kwargs,
|
||||
) -> T:
|
||||
"""Calls the given function with this ReplayBuffer instance.
|
||||
|
||||
This is useful if we want to apply a function to a set of remote actors.
|
||||
|
||||
Args:
|
||||
func: A callable that accepts the replay buffer itself, args and kwargs
|
||||
*_arkgs: Any args to pass to func
|
||||
**kwargs: Any kwargs to pass to func
|
||||
|
||||
Returns:
|
||||
Return value of the induced function call
|
||||
"""
|
||||
return func(self, *_args, **kwargs)
|
||||
|
||||
@Deprecated(old="ReplayBuffer.add_batch()", new="RepayBuffer.add()", error=False)
|
||||
def add_batch(self, *args, **kwargs):
|
||||
return self.add(*args, **kwargs)
|
||||
|
||||
@Deprecated(old="RepayBuffer.replay()", new="RepayBuffer.sample()", error=False)
|
||||
def replay(self, *args, **kwargs):
|
||||
return self.sample(*args, **kwargs)
|
||||
@Deprecated(
|
||||
old="RepayBuffer.replay(num_items)",
|
||||
new="RepayBuffer.sample(" "num_items)",
|
||||
error=False,
|
||||
)
|
||||
def replay(self, num_items):
|
||||
return self.sample(num_items)
|
||||
|
|
|
@ -6,8 +6,10 @@ import ray # noqa F401
|
|||
import psutil # noqa E402
|
||||
|
||||
from ray.rllib.utils.annotations import ExperimentalAPI, override
|
||||
from ray.rllib.utils.replay_buffers.replay_buffer import ReplayBuffer
|
||||
from ray.rllib.execution.buffers.replay_buffer import warn_replay_capacity
|
||||
from ray.rllib.utils.replay_buffers.replay_buffer import (
|
||||
ReplayBuffer,
|
||||
warn_replay_capacity,
|
||||
)
|
||||
from ray.rllib.utils.typing import SampleBatchType
|
||||
|
||||
|
||||
|
|
|
@ -4,8 +4,10 @@ from ray.rllib.utils.annotations import override
|
|||
from ray.rllib.utils.replay_buffers.replay_buffer import ReplayBuffer
|
||||
from ray.rllib.utils.replay_buffers.utils import warn_replay_buffer_capacity
|
||||
from ray.rllib.utils.typing import SampleBatchType
|
||||
from ray.util.annotations import DeveloperAPI
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class SimpleReplayBuffer(ReplayBuffer):
|
||||
"""Simple replay buffer that operates over entire batches."""
|
||||
|
||||
|
@ -15,6 +17,7 @@ class SimpleReplayBuffer(ReplayBuffer):
|
|||
self.replay_batches = []
|
||||
self.replay_index = 0
|
||||
|
||||
@DeveloperAPI
|
||||
@override(ReplayBuffer)
|
||||
def add(self, batch: SampleBatchType, **kwargs) -> None:
|
||||
warn_replay_buffer_capacity(item=batch, capacity=self.capacity)
|
||||
|
@ -26,10 +29,12 @@ class SimpleReplayBuffer(ReplayBuffer):
|
|||
self.replay_index += 1
|
||||
self.replay_index %= self.capacity
|
||||
|
||||
@DeveloperAPI
|
||||
@override(ReplayBuffer)
|
||||
def sample(self, num_items: int, **kwargs) -> SampleBatchType:
|
||||
return random.choice(self.replay_batches)
|
||||
|
||||
@DeveloperAPI
|
||||
@override(ReplayBuffer)
|
||||
def __len__(self):
|
||||
return len(self.replay_batches)
|
||||
|
|
|
@ -2,10 +2,6 @@ import logging
|
|||
import psutil
|
||||
from typing import Optional, Any
|
||||
|
||||
from ray.rllib.execution import MultiAgentReplayBuffer as Legacy_MultiAgentReplayBuffer
|
||||
from ray.rllib.execution.buffers.multi_agent_replay_buffer import (
|
||||
MultiAgentReplayBuffer as LegacyMultiAgentReplayBuffer,
|
||||
)
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils import deprecation_warning
|
||||
from ray.rllib.utils.annotations import ExperimentalAPI
|
||||
|
@ -47,10 +43,7 @@ def update_priorities_in_replay_buffer(
|
|||
utility.
|
||||
"""
|
||||
# Only update priorities if buffer supports them.
|
||||
if (
|
||||
type(replay_buffer) is LegacyMultiAgentReplayBuffer
|
||||
and config["replay_buffer_config"].get("prioritized_replay_alpha", 0.0) > 0.0
|
||||
) or isinstance(replay_buffer, MultiAgentPrioritizedReplayBuffer):
|
||||
if isinstance(replay_buffer, MultiAgentPrioritizedReplayBuffer):
|
||||
# Go through training results for the different policies (maybe multi-agent).
|
||||
prio_dict = {}
|
||||
for policy_id, info in train_results.items():
|
||||
|
@ -145,46 +138,57 @@ def validate_buffer_config(config: dict):
|
|||
if config.get("replay_buffer_config", None) is None:
|
||||
config["replay_buffer_config"] = {}
|
||||
|
||||
prioritized_replay = config.get("prioritized_replay")
|
||||
prioritized_replay = config.get("prioritized_replay", DEPRECATED_VALUE)
|
||||
if prioritized_replay != DEPRECATED_VALUE:
|
||||
deprecation_warning(
|
||||
old="config['prioritized_replay']",
|
||||
help="Replay prioritization specified at new location config["
|
||||
"'replay_buffer_config']["
|
||||
"'prioritized_replay'] will be overwritten.",
|
||||
error=False,
|
||||
old="config['prioritized_replay'] or config['replay_buffer_config']["
|
||||
"'prioritized_replay']",
|
||||
help="Replay prioritization specified by config key. RLlib's new replay "
|
||||
"buffer API requires setting `config["
|
||||
"'replay_buffer_config']['type']`, e.g. `config["
|
||||
"'replay_buffer_config']['type'] = "
|
||||
"'MultiAgentPrioritizedReplayBuffer'` to change the default "
|
||||
"behaviour.",
|
||||
error=True,
|
||||
)
|
||||
config["replay_buffer_config"]["prioritized_replay"] = prioritized_replay
|
||||
|
||||
capacity = config.get("buffer_size", DEPRECATED_VALUE)
|
||||
if capacity == DEPRECATED_VALUE:
|
||||
capacity = config["replay_buffer_config"].get("buffer_size", DEPRECATED_VALUE)
|
||||
if capacity != DEPRECATED_VALUE:
|
||||
deprecation_warning(
|
||||
old="config['buffer_size']",
|
||||
help="Buffer size specified at new location config["
|
||||
"'replay_buffer_config']["
|
||||
"'capacity'] will be overwritten.",
|
||||
error=False,
|
||||
old="config['buffer_size'] or config['replay_buffer_config']["
|
||||
"'buffer_size']",
|
||||
new="config['replay_buffer_config']['capacity']",
|
||||
error=True,
|
||||
)
|
||||
|
||||
replay_burn_in = config.get("burn_in", DEPRECATED_VALUE)
|
||||
if replay_burn_in != DEPRECATED_VALUE:
|
||||
config["replay_buffer_config"]["replay_burn_in"] = replay_burn_in
|
||||
deprecation_warning(
|
||||
old="config['burn_in']",
|
||||
help="config['replay_buffer_config']['replay_burn_in']",
|
||||
)
|
||||
config["replay_buffer_config"]["capacity"] = capacity
|
||||
|
||||
# Deprecation of old-style replay buffer args
|
||||
# Warnings before checking of we need local buffer so that algorithms
|
||||
# Without local buffer also get warned
|
||||
deprecated_replay_buffer_keys = [
|
||||
keys_with_deprecated_positions = [
|
||||
"prioritized_replay_alpha",
|
||||
"prioritized_replay_beta",
|
||||
"prioritized_replay_eps",
|
||||
"no_local_replay_buffer",
|
||||
"replay_batch_size",
|
||||
"replay_zero_init_states",
|
||||
"learning_starts",
|
||||
"replay_buffer_shards_colocated_with_driver",
|
||||
]
|
||||
for k in deprecated_replay_buffer_keys:
|
||||
for k in keys_with_deprecated_positions:
|
||||
if config.get(k, DEPRECATED_VALUE) != DEPRECATED_VALUE:
|
||||
deprecation_warning(
|
||||
old="config[{}]".format(k),
|
||||
help="config['replay_buffer_config'][{}] should be used "
|
||||
"for Q-Learning algorithms. Ignore this warning if "
|
||||
"you are not using a Q-Learning algorithm and still "
|
||||
"provide {}."
|
||||
"".format(k, k),
|
||||
old="config['{}']".format(k),
|
||||
help="config['replay_buffer_config']['{}']" "".format(k),
|
||||
error=False,
|
||||
)
|
||||
# Copy values over to new location in config to support new
|
||||
|
@ -192,143 +196,61 @@ def validate_buffer_config(config: dict):
|
|||
if config.get("replay_buffer_config") is not None:
|
||||
config["replay_buffer_config"][k] = config[k]
|
||||
|
||||
# Old Ape-X configs may contain no_local_replay_buffer
|
||||
no_local_replay_buffer = config.get("no_local_replay_buffer", False)
|
||||
if no_local_replay_buffer:
|
||||
replay_mode = config["multiagent"].get("replay_mode", DEPRECATED_VALUE)
|
||||
if replay_mode != DEPRECATED_VALUE:
|
||||
deprecation_warning(
|
||||
old="config['no_local_replay_buffer']",
|
||||
help="no_local_replay_buffer specified at new location config["
|
||||
"'replay_buffer_config']["
|
||||
"'capacity'] will be overwritten.",
|
||||
old="config['multiagent']['replay_mode']",
|
||||
help="config['replay_buffer_config']['replay_mode']",
|
||||
error=False,
|
||||
)
|
||||
config["replay_buffer_config"][
|
||||
"no_local_replay_buffer"
|
||||
] = no_local_replay_buffer
|
||||
config["replay_buffer_config"]["replay_mode"] = replay_mode
|
||||
|
||||
# TODO (Artur):
|
||||
if config["replay_buffer_config"].get("no_local_replay_buffer", False):
|
||||
return
|
||||
# Can't use DEPRECATED_VALUE here because this is also a deliberate
|
||||
# value set for some algorithms
|
||||
# TODO: (Artur): Compare to DEPRECATED_VALUE on deprecation
|
||||
replay_sequence_length = config.get("replay_sequence_length", None)
|
||||
if replay_sequence_length is not None:
|
||||
config["replay_buffer_config"][
|
||||
"replay_sequence_length"
|
||||
] = replay_sequence_length
|
||||
deprecation_warning(
|
||||
old="config['replay_sequence_length']",
|
||||
help="Replay sequence length specified at new "
|
||||
"location config['replay_buffer_config']["
|
||||
"'replay_sequence_length'] will be overwritten.",
|
||||
error=False,
|
||||
)
|
||||
|
||||
replay_buffer_config = config["replay_buffer_config"]
|
||||
assert (
|
||||
"type" in replay_buffer_config
|
||||
), "Can not instantiate ReplayBuffer from config without 'type' key."
|
||||
|
||||
replay_burn_in = config.get("burn_in", DEPRECATED_VALUE)
|
||||
if replay_burn_in != DEPRECATED_VALUE:
|
||||
config["replay_buffer_config"]["replay_burn_in"] = replay_burn_in
|
||||
deprecation_warning(
|
||||
old="config['burn_in']",
|
||||
help="Burn in specified at new location config["
|
||||
"'replay_buffer_config']["
|
||||
"'replay_burn_in'] will be overwritten.",
|
||||
)
|
||||
|
||||
# Check if old replay buffer should be instantiated
|
||||
buffer_type = config["replay_buffer_config"]["type"]
|
||||
if not config["replay_buffer_config"].get("_enable_replay_buffer_api", False):
|
||||
if isinstance(buffer_type, str) and buffer_type.find(".") == -1:
|
||||
# Prepend old-style buffers' path
|
||||
assert buffer_type == "MultiAgentReplayBuffer", (
|
||||
"Without "
|
||||
"ReplayBuffer "
|
||||
"API, only "
|
||||
"MultiAgentReplayBuffer "
|
||||
"is supported!"
|
||||
)
|
||||
# Create valid full [module].[class] string for from_config
|
||||
buffer_type = "ray.rllib.execution.MultiAgentReplayBuffer"
|
||||
else:
|
||||
assert buffer_type in [
|
||||
"ray.rllib.execution.MultiAgentReplayBuffer",
|
||||
Legacy_MultiAgentReplayBuffer,
|
||||
], (
|
||||
"Without ReplayBuffer API, only " "MultiAgentReplayBuffer is supported!"
|
||||
)
|
||||
|
||||
config["replay_buffer_config"]["type"] = buffer_type
|
||||
|
||||
# Remove from config, so it's not passed into the buffer c'tor
|
||||
config["replay_buffer_config"].pop("_enable_replay_buffer_api", None)
|
||||
|
||||
# We need to deprecate the old-style location of the following
|
||||
# buffer arguments and make users put them into the
|
||||
# "replay_buffer_config" field of their config.
|
||||
replay_batch_size = config.get("replay_batch_size", DEPRECATED_VALUE)
|
||||
if replay_batch_size != DEPRECATED_VALUE:
|
||||
config["replay_buffer_config"]["replay_batch_size"] = replay_batch_size
|
||||
deprecation_warning(
|
||||
old="config['replay_batch_size']",
|
||||
help="Replay batch size specified at new "
|
||||
"location config['replay_buffer_config']["
|
||||
"'replay_batch_size'] will be overwritten.",
|
||||
error=False,
|
||||
)
|
||||
|
||||
replay_mode = config.get("replay_mode", DEPRECATED_VALUE)
|
||||
if replay_mode != DEPRECATED_VALUE:
|
||||
config["replay_buffer_config"]["replay_mode"] = replay_mode
|
||||
deprecation_warning(
|
||||
old="config['multiagent']['replay_mode']",
|
||||
help="Replay sequence length specified at new "
|
||||
"location config['replay_buffer_config']["
|
||||
"'replay_mode'] will be overwritten.",
|
||||
error=False,
|
||||
)
|
||||
|
||||
# Can't use DEPRECATED_VALUE here because this is also a deliberate
|
||||
# value set for some algorithms
|
||||
# TODO: (Artur): Compare to DEPRECATED_VALUE on deprecation
|
||||
replay_sequence_length = config.get("replay_sequence_length", None)
|
||||
if replay_sequence_length is not None:
|
||||
config["replay_buffer_config"][
|
||||
"replay_sequence_length"
|
||||
] = replay_sequence_length
|
||||
deprecation_warning(
|
||||
old="config['replay_sequence_length']",
|
||||
help="Replay sequence length specified at new "
|
||||
"location config['replay_buffer_config']["
|
||||
"'replay_sequence_length'] will be overwritten.",
|
||||
error=False,
|
||||
)
|
||||
|
||||
replay_zero_init_states = config.get(
|
||||
"replay_zero_init_states", DEPRECATED_VALUE
|
||||
if isinstance(buffer_type, str) and buffer_type.find(".") == -1:
|
||||
# Create valid full [module].[class] string for from_config
|
||||
config["replay_buffer_config"]["type"] = (
|
||||
"ray.rllib.utils.replay_buffers." + buffer_type
|
||||
)
|
||||
if replay_zero_init_states != DEPRECATED_VALUE:
|
||||
config["replay_buffer_config"][
|
||||
"replay_zero_init_states"
|
||||
] = replay_zero_init_states
|
||||
deprecation_warning(
|
||||
old="config['replay_zero_init_states']",
|
||||
help="Replay zero init states specified at new location "
|
||||
"config["
|
||||
"'replay_buffer_config']["
|
||||
"'replay_zero_init_states'] will be overwritten.",
|
||||
error=False,
|
||||
)
|
||||
|
||||
# TODO (Artur): Move this logic into config objects
|
||||
if config["replay_buffer_config"].get("prioritized_replay", False):
|
||||
is_prioritized_buffer = True
|
||||
else:
|
||||
is_prioritized_buffer = False
|
||||
# This triggers non-prioritization in old-style replay buffer
|
||||
config["replay_buffer_config"]["prioritized_replay_alpha"] = 0.0
|
||||
else:
|
||||
if isinstance(buffer_type, str) and buffer_type.find(".") == -1:
|
||||
# Create valid full [module].[class] string for from_config
|
||||
config["replay_buffer_config"]["type"] = (
|
||||
"ray.rllib.utils.replay_buffers." + buffer_type
|
||||
)
|
||||
test_buffer = from_config(buffer_type, config["replay_buffer_config"])
|
||||
if hasattr(test_buffer, "update_priorities"):
|
||||
is_prioritized_buffer = True
|
||||
else:
|
||||
is_prioritized_buffer = False
|
||||
if config["replay_buffer_config"].get("replay_batch_size", None) is None:
|
||||
# Fall back to train batch size if no replay batch size was provided
|
||||
logger.info(
|
||||
"No value for key `replay_batch_size` in replay_buffer_config. "
|
||||
"config['replay_buffer_config']['replay_batch_size'] will be "
|
||||
"automatically set to config['train_batch_size']"
|
||||
)
|
||||
config["replay_buffer_config"]["replay_batch_size"] = config["train_batch_size"]
|
||||
|
||||
if is_prioritized_buffer:
|
||||
# Instantiate a dummy buffer to fail early on misconfiguration and find out about
|
||||
# inferred buffer class
|
||||
dummy_buffer = from_config(buffer_type, config["replay_buffer_config"])
|
||||
|
||||
config["replay_buffer_config"]["type"] = type(dummy_buffer)
|
||||
|
||||
if hasattr(dummy_buffer, "update_priorities"):
|
||||
if config["multiagent"]["replay_mode"] == "lockstep":
|
||||
raise ValueError(
|
||||
"Prioritized replay is not supported when replay_mode=lockstep."
|
||||
|
@ -339,20 +261,12 @@ def validate_buffer_config(config: dict):
|
|||
"replay_sequence_length > 1."
|
||||
)
|
||||
else:
|
||||
if config.get("worker_side_prioritization"):
|
||||
if config["replay_buffer_config"].get("worker_side_prioritization"):
|
||||
raise ValueError(
|
||||
"Worker side prioritization is not supported when "
|
||||
"prioritized_replay=False."
|
||||
)
|
||||
|
||||
if config["replay_buffer_config"].get("replay_batch_size", None) is None:
|
||||
# Fall back to train batch size if no replay batch size was provided
|
||||
config["replay_buffer_config"]["replay_batch_size"] = config["train_batch_size"]
|
||||
|
||||
# Pop prioritized replay because it's not a valid parameter for older
|
||||
# replay buffers
|
||||
config["replay_buffer_config"].pop("prioritized_replay", None)
|
||||
|
||||
|
||||
def warn_replay_buffer_capacity(*, item: SampleBatchType, capacity: int) -> None:
|
||||
"""Warn if the configured replay buffer capacity is too large for machine's memory.
|
||||
|
|
|
@ -561,22 +561,6 @@ def check_train_results(train_results):
|
|||
for pid, policy_stats in learner_info.items():
|
||||
if pid == "batch_count":
|
||||
continue
|
||||
# Expect td-errors to be per batch-item.
|
||||
if "td_error" in policy_stats:
|
||||
configured_b = train_results["config"]["train_batch_size"]
|
||||
actual_b = policy_stats["td_error"].shape[0]
|
||||
# R2D2 case.
|
||||
if (configured_b - actual_b) / actual_b > 0.1:
|
||||
assert (
|
||||
configured_b
|
||||
/ (
|
||||
train_results["config"]["model"]["max_seq_len"]
|
||||
+ train_results["config"]["replay_buffer_config"][
|
||||
"replay_burn_in"
|
||||
]
|
||||
)
|
||||
== actual_b
|
||||
)
|
||||
|
||||
# Make sure each policy has the LEARNER_STATS_KEY under it.
|
||||
assert LEARNER_STATS_KEY in policy_stats
|
||||
|
|
Loading…
Add table
Reference in a new issue