mirror of
https://github.com/vale981/ray
synced 2025-03-04 17:41:43 -05:00
[RLlib] Move learning_starts logic from buffers into training_step()
. (#26032)
This commit is contained in:
parent
c855469845
commit
0dceddb912
82 changed files with 437 additions and 331 deletions
|
@ -45,8 +45,8 @@ run_experiments(
|
|||
"num_gpus": 0,
|
||||
"replay_buffer_config": {
|
||||
"capacity": 10000,
|
||||
"learning_starts": 0,
|
||||
},
|
||||
"num_steps_sampled_before_learning_starts": 0,
|
||||
"rollout_fragment_length": 1,
|
||||
"train_batch_size": 1,
|
||||
"min_iter_time_s": 10,
|
||||
|
|
|
@ -30,8 +30,8 @@ ddpg-hopperbulletenv-v0:
|
|||
prioritized_replay_alpha: 0.6
|
||||
prioritized_replay_beta: 0.4
|
||||
prioritized_replay_eps: 0.000001
|
||||
learning_starts: 500
|
||||
worker_side_prioritization: false
|
||||
num_steps_sampled_before_learning_starts: 500
|
||||
clip_rewards: false
|
||||
actor_lr: 0.001
|
||||
critic_lr: 0.001
|
||||
|
|
|
@ -20,8 +20,8 @@ dqn-breakoutnoframeskip-v4:
|
|||
replay_buffer_config:
|
||||
type: MultiAgentReplayBuffer
|
||||
capacity: 1000000
|
||||
learning_starts: 20000
|
||||
prioritized_replay_alpha: 0.5
|
||||
num_steps_sampled_before_learning_starts: 20000
|
||||
rollout_fragment_length: 4
|
||||
train_batch_size: 32
|
||||
exploration_config:
|
||||
|
|
|
@ -25,8 +25,8 @@ sac-halfcheetahbulletenv-v0:
|
|||
target_network_update_freq: 1
|
||||
min_sample_timesteps_per_iteration: 1000
|
||||
replay_buffer_config:
|
||||
learning_starts: 10000
|
||||
type: MultiAgentPrioritizedReplayBuffer
|
||||
num_steps_sampled_before_learning_starts: 10000
|
||||
optimization:
|
||||
actor_learning_rate: 0.0003
|
||||
critic_learning_rate: 0.0003
|
||||
|
|
|
@ -9,7 +9,6 @@ td3-halfcheetahbulletenv-v0:
|
|||
time_total_s: 3600
|
||||
config:
|
||||
num_gpus: 1
|
||||
replay_buffer_config:
|
||||
learning_starts: 10000
|
||||
num_steps_sampled_before_learning_starts: 10000
|
||||
exploration_config:
|
||||
random_timesteps: 10000
|
||||
|
|
|
@ -1215,7 +1215,7 @@ py_test(
|
|||
"--env", "Pendulum-v1",
|
||||
"--run", "APEX_DDPG",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"framework\": \"tf\", \"num_workers\": 2, \"optimizer\": {\"num_replay_buffer_shards\": 1}, \"replay_buffer_config\": {\"learning_starts\": 100}, \"min_time_s_per_iteration\": 1, \"batch_mode\": \"complete_episodes\"}'",
|
||||
"--config", "'{\"framework\": \"tf\", \"num_workers\": 2, \"optimizer\": {\"num_replay_buffer_shards\": 1}, \"num_steps_sampled_before_learning_starts\": 100, \"min_time_s_per_iteration\": 1, \"batch_mode\": \"complete_episodes\"}'",
|
||||
"--ray-num-cpus", "4",
|
||||
]
|
||||
)
|
||||
|
@ -1272,7 +1272,7 @@ py_test(
|
|||
"--env", "CartPole-v0",
|
||||
"--run", "DQN",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"framework\": \"tf\", \"input\": \"tests/data/cartpole\", \"replay_buffer_config\": {\"learning_starts\": 0}, \"off_policy_estimation_methods\": {\"wis\": {\"type\": \"wis\"}, \"is\": {\"type\": \"is\"}}, \"exploration_config\": {\"type\": \"SoftQ\"}}'"
|
||||
"--config", "'{\"framework\": \"tf\", \"input\": \"tests/data/cartpole\", \"num_steps_sampled_before_learning_starts\": 0, \"off_policy_estimation_methods\": {\"wis\": {\"type\": \"wis\"}, \"is\": {\"type\": \"is\"}}, \"exploration_config\": {\"type\": \"SoftQ\"}}'"
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -1284,7 +1284,7 @@ py_test(
|
|||
"--env", "PongDeterministic-v4",
|
||||
"--run", "DQN",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"framework\": \"tf\", \"lr\": 1e-4, \"exploration_config\": {\"epsilon_timesteps\": 200000, \"final_epsilon\": 0.01}, \"replay_buffer_config\": {\"capacity\": 10000, \"learning_starts\": 10000}, \"rollout_fragment_length\": 4, \"target_network_update_freq\": 1000, \"gamma\": 0.99}'"
|
||||
"--config", "'{\"framework\": \"tf\", \"lr\": 1e-4, \"exploration_config\": {\"epsilon_timesteps\": 200000, \"final_epsilon\": 0.01}, \"replay_buffer_config\": {\"capacity\": 10000}, \"num_steps_sampled_before_learning_starts\": 10000, \"rollout_fragment_length\": 4, \"target_network_update_freq\": 1000, \"gamma\": 0.99}'"
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -3159,7 +3159,7 @@ py_test(
|
|||
tags = ["team:rllib", "exclusive", "examples", "examples_R"],
|
||||
size = "large",
|
||||
srcs = ["examples/recommender_system_with_recsim_and_slateq.py"],
|
||||
args = ["--stop-iters=2", "--learning-starts=100", "--framework=tf2", "--use-tune", "--random-test-episodes=10", "--env-num-candidates=50", "--env-slate-size=2"],
|
||||
args = ["--stop-iters=2", "--num-steps-sampled-before-learning_starts=100", "--framework=tf2", "--use-tune", "--random-test-episodes=10", "--env-num-candidates=50", "--env-slate-size=2"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
|
|
|
@ -105,17 +105,18 @@ class AlphaZeroConfig(AlgorithmConfig):
|
|||
self.sgd_minibatch_size = 128
|
||||
self.shuffle_sequences = True
|
||||
self.num_sgd_iter = 30
|
||||
self.learning_starts = 1000
|
||||
self.replay_buffer_config = {
|
||||
"type": "ReplayBuffer",
|
||||
# Size of the replay buffer in batches (not timesteps!).
|
||||
"capacity": 1000,
|
||||
# When to start returning samples (in batches, not timesteps!).
|
||||
"learning_starts": 500,
|
||||
# Choosing `fragments` here makes it so that the buffer stores entire
|
||||
# batches, instead of sequences, episodes or timesteps.
|
||||
"storage_unit": "fragments",
|
||||
}
|
||||
# Number of timesteps to collect from rollout workers before we start
|
||||
# sampling from replay buffers for learning. Whether we count this in agent
|
||||
# steps or environment steps depends on config["multiagent"]["count_steps_by"].
|
||||
self.num_steps_sampled_before_learning_starts = 1000
|
||||
self.lr_schedule = None
|
||||
self.vf_share_layers = False
|
||||
self.mcts_config = {
|
||||
|
@ -169,6 +170,7 @@ class AlphaZeroConfig(AlgorithmConfig):
|
|||
vf_share_layers: Optional[bool] = None,
|
||||
mcts_config: Optional[dict] = None,
|
||||
ranked_rewards: Optional[dict] = None,
|
||||
num_steps_sampled_before_learning_starts: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> "AlphaZeroConfig":
|
||||
"""Sets the training related configuration.
|
||||
|
@ -221,6 +223,10 @@ class AlphaZeroConfig(AlgorithmConfig):
|
|||
mcts_config: MCTS specific settings.
|
||||
ranked_rewards: Settings for the ranked reward (r2) algorithm
|
||||
from: https://arxiv.org/pdf/1807.01672.pdf
|
||||
num_steps_sampled_before_learning_starts: Number of timesteps to collect
|
||||
from rollout workers before we start sampling from replay buffers for
|
||||
learning. Whether we count this in agent steps or environment steps
|
||||
depends on config["multiagent"]["count_steps_by"].
|
||||
|
||||
Returns:
|
||||
This updated AlgorithmConfig object.
|
||||
|
@ -244,6 +250,10 @@ class AlphaZeroConfig(AlgorithmConfig):
|
|||
self.mcts_config = mcts_config
|
||||
if ranked_rewards is not None:
|
||||
self.ranked_rewards = ranked_rewards
|
||||
if num_steps_sampled_before_learning_starts is not None:
|
||||
self.num_steps_sampled_before_learning_starts = (
|
||||
num_steps_sampled_before_learning_starts
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
@ -344,9 +354,19 @@ class AlphaZero(Algorithm):
|
|||
self.local_replay_buffer.add(batch)
|
||||
|
||||
if self.local_replay_buffer is not None:
|
||||
train_batch = self.local_replay_buffer.sample(
|
||||
self.config["train_batch_size"]
|
||||
)
|
||||
# Update target network every `target_network_update_freq` sample steps.
|
||||
cur_ts = self._counters[
|
||||
NUM_AGENT_STEPS_SAMPLED
|
||||
if self._by_agent_steps
|
||||
else NUM_ENV_STEPS_SAMPLED
|
||||
]
|
||||
|
||||
if cur_ts > self.config["num_steps_sampled_before_learning_starts"]:
|
||||
train_batch = self.local_replay_buffer.sample(
|
||||
self.config["train_batch_size"]
|
||||
)
|
||||
else:
|
||||
train_batch = None
|
||||
else:
|
||||
train_batch = concat_samples(new_sample_batches)
|
||||
|
||||
|
|
|
@ -82,8 +82,6 @@ class ApexDDPGConfig(DDPGConfig):
|
|||
"prioritized_replay_beta": 0.4,
|
||||
# Epsilon to add to the TD errors when updating priorities.
|
||||
"prioritized_replay_eps": 1e-6,
|
||||
# How many steps of the model to sample before learning starts.
|
||||
"learning_starts": 50000,
|
||||
# Whether all shards of the replay buffer must be co-located
|
||||
# with the learner process (running the execution plan).
|
||||
# This is preferred b/c the learner process should have quick
|
||||
|
@ -99,6 +97,10 @@ class ApexDDPGConfig(DDPGConfig):
|
|||
# prioritization, for example: MultiAgentPrioritizedReplayBuffer.
|
||||
"prioritized_replay": DEPRECATED_VALUE,
|
||||
}
|
||||
# Number of timesteps to collect from rollout workers before we start
|
||||
# sampling from replay buffers for learning. Whether we count this in agent
|
||||
# steps or environment steps depends on config["multiagent"]["count_steps_by"].
|
||||
self.num_steps_sampled_before_learning_starts = 50000
|
||||
self.target_network_update_freq = 500000
|
||||
self.training_intensity = 1
|
||||
# __sphinx_doc_end__
|
||||
|
|
|
@ -25,7 +25,7 @@ class TestApexDDPG(unittest.TestCase):
|
|||
.rollouts(num_rollout_workers=2)
|
||||
.reporting(min_sample_timesteps_per_iteration=100)
|
||||
.training(
|
||||
replay_buffer_config={"learning_starts": 0},
|
||||
num_steps_sampled_before_learning_starts=0,
|
||||
optimizer={"num_replay_buffer_shards": 1},
|
||||
)
|
||||
.environment(env="Pendulum-v1")
|
||||
|
|
|
@ -132,6 +132,10 @@ class ApexDQNConfig(DQNConfig):
|
|||
self.train_batch_size = 512
|
||||
self.target_network_update_freq = 500000
|
||||
self.training_intensity = 1
|
||||
# Number of timesteps to collect from rollout workers before we start
|
||||
# sampling from replay buffers for learning. Whether we count this in agent
|
||||
# steps or environment steps depends on config["multiagent"]["count_steps_by"].
|
||||
self.num_steps_sampled_before_learning_starts = 50000
|
||||
|
||||
# max number of inflight requests to each sampling worker
|
||||
# see the AsyncRequestsManager class for more details
|
||||
|
@ -161,7 +165,6 @@ class ApexDQNConfig(DQNConfig):
|
|||
"prioritized_replay_beta": 0.4,
|
||||
# Epsilon to add to the TD errors when updating priorities.
|
||||
"prioritized_replay_eps": 1e-6,
|
||||
"learning_starts": 50000,
|
||||
# Whether all shards of the replay buffer must be co-located
|
||||
# with the learner process (running the execution plan).
|
||||
# This is preferred b/c the learner process should have quick
|
||||
|
@ -241,7 +244,6 @@ class ApexDQNConfig(DQNConfig):
|
|||
{
|
||||
"_enable_replay_buffer_api": True,
|
||||
"type": "MultiAgentReplayBuffer",
|
||||
"learning_starts": 1000,
|
||||
"capacity": 50000,
|
||||
"replay_batch_size": 32,
|
||||
"replay_sequence_length": 1,
|
||||
|
@ -441,12 +443,19 @@ class ApexDQN(DQN):
|
|||
# only do this if there are remote workers (config["num_workers"] > 1)
|
||||
if self.workers.remote_workers():
|
||||
self.update_workers(worker_samples_collected)
|
||||
# trigger a sample from the replay actors and enqueue operation to the
|
||||
# learner thread.
|
||||
self.sample_from_replay_buffer_place_on_learner_queue_non_blocking(
|
||||
worker_samples_collected
|
||||
)
|
||||
self.update_replay_sample_priority()
|
||||
|
||||
# Update target network every `target_network_update_freq` sample steps.
|
||||
cur_ts = self._counters[
|
||||
NUM_AGENT_STEPS_SAMPLED if self._by_agent_steps else NUM_ENV_STEPS_SAMPLED
|
||||
]
|
||||
|
||||
if cur_ts > self.config["num_steps_sampled_before_learning_starts"]:
|
||||
# trigger a sample from the replay actors and enqueue operation to the
|
||||
# learner thread.
|
||||
self.sample_from_replay_buffer_place_on_learner_queue_non_blocking(
|
||||
worker_samples_collected
|
||||
)
|
||||
self.update_replay_sample_priority()
|
||||
|
||||
return copy.deepcopy(self.learner_thread.learner_info)
|
||||
|
||||
|
|
|
@ -26,9 +26,7 @@ class TestApexDQN(unittest.TestCase):
|
|||
.rollouts(num_rollout_workers=0)
|
||||
.resources(num_gpus=0)
|
||||
.training(
|
||||
replay_buffer_config={
|
||||
"learning_starts": 1000,
|
||||
},
|
||||
num_steps_sampled_before_learning_starts=0,
|
||||
optimizer={
|
||||
"num_replay_buffer_shards": 1,
|
||||
},
|
||||
|
@ -53,9 +51,7 @@ class TestApexDQN(unittest.TestCase):
|
|||
.rollouts(num_rollout_workers=3)
|
||||
.resources(num_gpus=0)
|
||||
.training(
|
||||
replay_buffer_config={
|
||||
"learning_starts": 1000,
|
||||
},
|
||||
num_steps_sampled_before_learning_starts=0,
|
||||
optimizer={
|
||||
"num_replay_buffer_shards": 1,
|
||||
},
|
||||
|
@ -110,7 +106,6 @@ class TestApexDQN(unittest.TestCase):
|
|||
replay_buffer_config={
|
||||
"no_local_replay_buffer": True,
|
||||
"type": "MultiAgentPrioritizedReplayBuffer",
|
||||
"learning_starts": 10,
|
||||
"capacity": 100,
|
||||
"prioritized_replay_alpha": 0.6,
|
||||
# Beta parameter for sampling from prioritized replay buffer.
|
||||
|
@ -121,6 +116,9 @@ class TestApexDQN(unittest.TestCase):
|
|||
# Initial lr, doesn't really matter because of the schedule below.
|
||||
lr=0.2,
|
||||
lr_schedule=[[0, 0.2], [100, 0.001]],
|
||||
# Number of timesteps to collect from rollout workers before we start
|
||||
# sampling from replay buffers for learning.
|
||||
num_steps_sampled_before_learning_starts=10,
|
||||
)
|
||||
.reporting(
|
||||
min_sample_timesteps_per_iteration=10,
|
||||
|
|
|
@ -55,7 +55,7 @@ class TestCQL(unittest.TestCase):
|
|||
clip_actions=False,
|
||||
train_batch_size=2000,
|
||||
twin_q=True,
|
||||
replay_buffer_config={"learning_starts": 0},
|
||||
num_steps_sampled_before_learning_starts=0,
|
||||
bc_iters=2,
|
||||
)
|
||||
.evaluation(
|
||||
|
|
|
@ -107,8 +107,6 @@ class DDPGConfig(SimpleQConfig):
|
|||
"prioritized_replay_beta": 0.4,
|
||||
# Epsilon to add to the TD errors when updating priorities.
|
||||
"prioritized_replay_eps": 1e-6,
|
||||
# How many steps of the model to sample before learning starts.
|
||||
"learning_starts": 1500,
|
||||
# Whether to compute priorities on workers.
|
||||
"worker_side_prioritization": False,
|
||||
}
|
||||
|
@ -117,6 +115,10 @@ class DDPGConfig(SimpleQConfig):
|
|||
self.grad_clip = None
|
||||
self.train_batch_size = 256
|
||||
self.target_network_update_freq = 0
|
||||
# Number of timesteps to collect from rollout workers before we start
|
||||
# sampling from replay buffers for learning. Whether we count this in agent
|
||||
# steps or environment steps depends on config["multiagent"]["count_steps_by"].
|
||||
self.num_steps_sampled_before_learning_starts = 1500
|
||||
|
||||
# .rollouts()
|
||||
self.rollout_fragment_length = 1
|
||||
|
|
|
@ -34,10 +34,11 @@ class TestDDPG(unittest.TestCase):
|
|||
|
||||
def test_ddpg_compilation(self):
|
||||
"""Test whether DDPG can be built with both frameworks."""
|
||||
config = ddpg.DDPGConfig()
|
||||
config.num_workers = 0
|
||||
config.num_envs_per_worker = 2
|
||||
config.replay_buffer_config["learning_starts"] = 0
|
||||
config = (
|
||||
ddpg.DDPGConfig()
|
||||
.training(num_steps_sampled_before_learning_starts=0)
|
||||
.rollouts(num_rollout_workers=0, num_envs_per_worker=2)
|
||||
)
|
||||
explore = config.exploration_config.update({"random_timesteps": 100})
|
||||
config.exploration(exploration_config=explore)
|
||||
|
||||
|
@ -63,7 +64,12 @@ class TestDDPG(unittest.TestCase):
|
|||
def test_ddpg_exploration_and_with_random_prerun(self):
|
||||
"""Tests DDPG's Exploration (w/ random actions for n timesteps)."""
|
||||
|
||||
core_config = ddpg.DDPGConfig().rollouts(num_rollout_workers=0)
|
||||
core_config = (
|
||||
ddpg.DDPGConfig()
|
||||
.rollouts(num_rollout_workers=0)
|
||||
.training(num_steps_sampled_before_learning_starts=0)
|
||||
)
|
||||
|
||||
obs = np.array([0.0, 0.1, -0.1])
|
||||
|
||||
# Test against all frameworks.
|
||||
|
@ -125,7 +131,8 @@ class TestDDPG(unittest.TestCase):
|
|||
|
||||
def test_ddpg_loss_function(self):
|
||||
"""Tests DDPG loss function results across all frameworks."""
|
||||
config = ddpg.DDPGConfig()
|
||||
config = ddpg.DDPGConfig().training(num_steps_sampled_before_learning_starts=0)
|
||||
|
||||
# Run locally.
|
||||
config.seed = 42
|
||||
config.num_workers = 0
|
||||
|
@ -138,7 +145,6 @@ class TestDDPG(unittest.TestCase):
|
|||
config.replay_buffer_config = {
|
||||
"type": "MultiAgentReplayBuffer",
|
||||
"capacity": 50000,
|
||||
"learning_starts": 0,
|
||||
}
|
||||
# Use very simple nets.
|
||||
config.actor_hiddens = [10]
|
||||
|
|
|
@ -217,7 +217,6 @@ class DQNConfig(SimpleQConfig):
|
|||
{
|
||||
"_enable_replay_buffer_api": True,
|
||||
"type": "MultiAgentReplayBuffer",
|
||||
"learning_starts": 1000,
|
||||
"capacity": 50000,
|
||||
"replay_sequence_length": 1,
|
||||
}
|
||||
|
@ -370,62 +369,57 @@ class DQN(SimpleQ):
|
|||
"timestep": self._counters[NUM_ENV_STEPS_SAMPLED],
|
||||
}
|
||||
|
||||
for _ in range(sample_and_train_weight):
|
||||
# Sample training batch (MultiAgentBatch) from replay buffer.
|
||||
train_batch = sample_min_n_steps_from_buffer(
|
||||
self.local_replay_buffer,
|
||||
self.config["train_batch_size"],
|
||||
count_by_agent_steps=self._by_agent_steps,
|
||||
)
|
||||
# Update target network every `target_network_update_freq` sample steps.
|
||||
cur_ts = self._counters[
|
||||
NUM_AGENT_STEPS_SAMPLED if self._by_agent_steps else NUM_ENV_STEPS_SAMPLED
|
||||
]
|
||||
|
||||
# Old-style replay buffers return None if learning has not started
|
||||
if train_batch is None or len(train_batch) == 0:
|
||||
self.workers.local_worker().set_global_vars(global_vars)
|
||||
break
|
||||
|
||||
# Postprocess batch before we learn on it
|
||||
post_fn = self.config.get("before_learn_on_batch") or (lambda b, *a: b)
|
||||
train_batch = post_fn(train_batch, self.workers, self.config)
|
||||
|
||||
# for policy_id, sample_batch in train_batch.policy_batches.items():
|
||||
# print(len(sample_batch["obs"]))
|
||||
# print(sample_batch.count)
|
||||
|
||||
# Learn on training batch.
|
||||
# Use simple optimizer (only for multi-agent or tf-eager; all other
|
||||
# cases should use the multi-GPU optimizer, even if only using 1 GPU)
|
||||
if self.config.get("simple_optimizer") is True:
|
||||
train_results = train_one_step(self, train_batch)
|
||||
else:
|
||||
train_results = multi_gpu_train_one_step(self, train_batch)
|
||||
|
||||
# Update replay buffer priorities.
|
||||
update_priorities_in_replay_buffer(
|
||||
self.local_replay_buffer,
|
||||
self.config,
|
||||
train_batch,
|
||||
train_results,
|
||||
)
|
||||
|
||||
# Update target network every `target_network_update_freq` sample steps.
|
||||
cur_ts = self._counters[
|
||||
NUM_AGENT_STEPS_SAMPLED
|
||||
if self._by_agent_steps
|
||||
else NUM_ENV_STEPS_SAMPLED
|
||||
]
|
||||
last_update = self._counters[LAST_TARGET_UPDATE_TS]
|
||||
if cur_ts - last_update >= self.config["target_network_update_freq"]:
|
||||
to_update = self.workers.local_worker().get_policies_to_train()
|
||||
self.workers.local_worker().foreach_policy_to_train(
|
||||
lambda p, pid: pid in to_update and p.update_target()
|
||||
if cur_ts > self.config["num_steps_sampled_before_learning_starts"]:
|
||||
for _ in range(sample_and_train_weight):
|
||||
# Sample training batch (MultiAgentBatch) from replay buffer.
|
||||
train_batch = sample_min_n_steps_from_buffer(
|
||||
self.local_replay_buffer,
|
||||
self.config["train_batch_size"],
|
||||
count_by_agent_steps=self._by_agent_steps,
|
||||
)
|
||||
self._counters[NUM_TARGET_UPDATES] += 1
|
||||
self._counters[LAST_TARGET_UPDATE_TS] = cur_ts
|
||||
|
||||
# Update weights and global_vars - after learning on the local worker -
|
||||
# on all remote workers.
|
||||
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
|
||||
self.workers.sync_weights(global_vars=global_vars)
|
||||
# Postprocess batch before we learn on it
|
||||
post_fn = self.config.get("before_learn_on_batch") or (lambda b, *a: b)
|
||||
train_batch = post_fn(train_batch, self.workers, self.config)
|
||||
|
||||
# for policy_id, sample_batch in train_batch.policy_batches.items():
|
||||
# print(len(sample_batch["obs"]))
|
||||
# print(sample_batch.count)
|
||||
|
||||
# Learn on training batch.
|
||||
# Use simple optimizer (only for multi-agent or tf-eager; all other
|
||||
# cases should use the multi-GPU optimizer, even if only using 1 GPU)
|
||||
if self.config.get("simple_optimizer") is True:
|
||||
train_results = train_one_step(self, train_batch)
|
||||
else:
|
||||
train_results = multi_gpu_train_one_step(self, train_batch)
|
||||
|
||||
# Update replay buffer priorities.
|
||||
update_priorities_in_replay_buffer(
|
||||
self.local_replay_buffer,
|
||||
self.config,
|
||||
train_batch,
|
||||
train_results,
|
||||
)
|
||||
|
||||
last_update = self._counters[LAST_TARGET_UPDATE_TS]
|
||||
if cur_ts - last_update >= self.config["target_network_update_freq"]:
|
||||
to_update = self.workers.local_worker().get_policies_to_train()
|
||||
self.workers.local_worker().foreach_policy_to_train(
|
||||
lambda p, pid: pid in to_update and p.update_target()
|
||||
)
|
||||
self._counters[NUM_TARGET_UPDATES] += 1
|
||||
self._counters[LAST_TARGET_UPDATE_TS] = cur_ts
|
||||
|
||||
# Update weights and global_vars - after learning on the local worker -
|
||||
# on all remote workers.
|
||||
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
|
||||
self.workers.sync_weights(global_vars=global_vars)
|
||||
|
||||
# Return all collected metrics for the iteration.
|
||||
return train_results
|
||||
|
|
|
@ -24,7 +24,11 @@ class TestDQN(unittest.TestCase):
|
|||
def test_dqn_compilation(self):
|
||||
"""Test whether DQN can be built on all frameworks."""
|
||||
num_iterations = 1
|
||||
config = dqn.dqn.DQNConfig().rollouts(num_rollout_workers=2)
|
||||
config = (
|
||||
dqn.dqn.DQNConfig()
|
||||
.rollouts(num_rollout_workers=2)
|
||||
.training(num_steps_sampled_before_learning_starts=0)
|
||||
)
|
||||
|
||||
for _ in framework_iterator(config, with_eager_tracing=True):
|
||||
# Double-dueling DQN.
|
||||
|
@ -60,7 +64,8 @@ class TestDQN(unittest.TestCase):
|
|||
dqn.dqn.DQNConfig()
|
||||
.rollouts(num_rollout_workers=0)
|
||||
.environment(env_config={"is_slippery": False, "map_name": "4x4"})
|
||||
)
|
||||
).training(num_steps_sampled_before_learning_starts=0)
|
||||
|
||||
obs = np.array(0)
|
||||
|
||||
# Test against all frameworks.
|
||||
|
|
|
@ -106,6 +106,10 @@ class DreamerConfig(AlgorithmConfig):
|
|||
|
||||
# .training()
|
||||
self.gamma = 0.99
|
||||
# Number of timesteps to collect from rollout workers before we start
|
||||
# sampling from replay buffers for learning. Whether we count this in agent
|
||||
# steps or environment steps depends on config["multiagent"]["count_steps_by"].
|
||||
self.num_steps_sampled_before_learning_starts = 0
|
||||
|
||||
# .environment()
|
||||
self.env_config = {
|
||||
|
@ -134,6 +138,7 @@ class DreamerConfig(AlgorithmConfig):
|
|||
prefill_timesteps: Optional[int] = None,
|
||||
explore_noise: Optional[float] = None,
|
||||
dreamer_model: Optional[dict] = None,
|
||||
num_steps_sampled_before_learning_starts: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> "DreamerConfig":
|
||||
"""
|
||||
|
@ -153,6 +158,10 @@ class DreamerConfig(AlgorithmConfig):
|
|||
prefill_timesteps: Prefill timesteps.
|
||||
explore_noise: Exploration Gaussian noise.
|
||||
dreamer_model: Custom model config.
|
||||
num_steps_sampled_before_learning_starts: Number of timesteps to collect
|
||||
from rollout workers before we start sampling from replay buffers for
|
||||
learning. Whether we count this in agent steps or environment steps
|
||||
depends on config["multiagent"]["count_steps_by"].
|
||||
|
||||
Returns:
|
||||
|
||||
|
@ -189,6 +198,10 @@ class DreamerConfig(AlgorithmConfig):
|
|||
self.explore_noise = explore_noise
|
||||
if dreamer_model is not None:
|
||||
self.dreamer_model = dreamer_model
|
||||
if num_steps_sampled_before_learning_starts is not None:
|
||||
self.num_steps_sampled_before_learning_starts = (
|
||||
num_steps_sampled_before_learning_starts
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
@ -245,13 +258,20 @@ class DreamerIteration:
|
|||
self.batch_size = batch_size
|
||||
|
||||
def __call__(self, samples):
|
||||
# Dreamer training loop.
|
||||
for n in range(self.dreamer_train_iters):
|
||||
print(f"sub-iteration={n}/{self.dreamer_train_iters}")
|
||||
batch = self.episode_buffer.sample(self.batch_size)
|
||||
# if n == self.dreamer_train_iters - 1:
|
||||
# batch["log_gif"] = True
|
||||
fetches = self.worker.learn_on_batch(batch)
|
||||
|
||||
# Update target network every `target_network_update_freq` sample steps.
|
||||
cur_ts = self._counters[
|
||||
NUM_AGENT_STEPS_SAMPLED if self._by_agent_steps else NUM_ENV_STEPS_SAMPLED
|
||||
]
|
||||
|
||||
if cur_ts > self.config["num_steps_sampled_before_learning_starts"]:
|
||||
# Dreamer training loop.
|
||||
for n in range(self.dreamer_train_iters):
|
||||
print(f"sub-iteration={n}/{self.dreamer_train_iters}")
|
||||
batch = self.episode_buffer.sample(self.batch_size)
|
||||
fetches = self.worker.learn_on_batch(batch)
|
||||
else:
|
||||
fetches = {}
|
||||
|
||||
# Custom Logging
|
||||
policy_fetches = fetches[DEFAULT_POLICY_ID]["learner_stats"]
|
||||
|
@ -378,19 +398,25 @@ class Dreamer(Algorithm):
|
|||
|
||||
fetches = {}
|
||||
|
||||
# Dreamer training loop.
|
||||
# Run multiple sub-iterations for each training iteration.
|
||||
for n in range(dreamer_train_iters):
|
||||
print(f"sub-iteration={n}/{dreamer_train_iters}")
|
||||
batch = self.local_replay_buffer.sample(batch_size)
|
||||
fetches = local_worker.learn_on_batch(batch)
|
||||
# Update target network every `target_network_update_freq` sample steps.
|
||||
cur_ts = self._counters[
|
||||
NUM_AGENT_STEPS_SAMPLED if self._by_agent_steps else NUM_ENV_STEPS_SAMPLED
|
||||
]
|
||||
|
||||
if fetches:
|
||||
# Custom logging.
|
||||
policy_fetches = fetches[DEFAULT_POLICY_ID]["learner_stats"]
|
||||
if "log_gif" in policy_fetches:
|
||||
gif = policy_fetches["log_gif"]
|
||||
policy_fetches["log_gif"] = self._postprocess_gif(gif)
|
||||
if cur_ts > self.config["num_steps_sampled_before_learning_starts"]:
|
||||
# Dreamer training loop.
|
||||
# Run multiple sub-iterations for each training iteration.
|
||||
for n in range(dreamer_train_iters):
|
||||
print(f"sub-iteration={n}/{dreamer_train_iters}")
|
||||
batch = self.local_replay_buffer.sample(batch_size)
|
||||
fetches = local_worker.learn_on_batch(batch)
|
||||
|
||||
if fetches:
|
||||
# Custom logging.
|
||||
policy_fetches = fetches[DEFAULT_POLICY_ID]["learner_stats"]
|
||||
if "log_gif" in policy_fetches:
|
||||
gif = policy_fetches["log_gif"]
|
||||
policy_fetches["log_gif"] = self._postprocess_gif(gif)
|
||||
|
||||
self.local_replay_buffer.add(batch)
|
||||
|
||||
|
|
|
@ -84,12 +84,14 @@ class MADDPGConfig(AlgorithmConfig):
|
|||
# prioritization, for example: MultiAgentPrioritizedReplayBuffer.
|
||||
"prioritized_replay": DEPRECATED_VALUE,
|
||||
"capacity": int(1e6),
|
||||
# How many steps of the model to sample before learning starts.
|
||||
"learning_starts": 1024 * 25,
|
||||
# Force lockstep replay mode for MADDPG.
|
||||
"replay_mode": "lockstep",
|
||||
}
|
||||
self.training_intensity = None
|
||||
# Number of timesteps to collect from rollout workers before we start
|
||||
# sampling from replay buffers for learning. Whether we count this in agent
|
||||
# steps or environment steps depends on config["multiagent"]["count_steps_by"].
|
||||
self.num_steps_sampled_before_learning_starts = 1024 * 25
|
||||
self.critic_lr = 1e-2
|
||||
self.actor_lr = 1e-2
|
||||
self.target_network_update_freq = 0
|
||||
|
@ -157,7 +159,6 @@ class MADDPGConfig(AlgorithmConfig):
|
|||
{
|
||||
"_enable_replay_buffer_api": True,
|
||||
"type": "MultiAgentReplayBuffer",
|
||||
"learning_starts": 1000,
|
||||
"capacity": 50000,
|
||||
"replay_sequence_length": 1,
|
||||
}
|
||||
|
|
|
@ -79,17 +79,20 @@ class QMixConfig(SimpleQConfig):
|
|||
self.lr = 0.0005
|
||||
self.train_batch_size = 32
|
||||
self.target_network_update_freq = 500
|
||||
# Number of timesteps to collect from rollout workers before we start
|
||||
# sampling from replay buffers for learning. Whether we count this in agent
|
||||
# steps or environment steps depends on config["multiagent"]["count_steps_by"].
|
||||
self.num_steps_sampled_before_learning_starts = 1000
|
||||
self.replay_buffer_config = {
|
||||
"type": "ReplayBuffer",
|
||||
# Specify prioritized replay by supplying a buffer type that supports
|
||||
# prioritization, for example: MultiAgentPrioritizedReplayBuffer.
|
||||
"prioritized_replay": DEPRECATED_VALUE,
|
||||
# Size of the replay buffer in batches
|
||||
# Size of the replay buffer in batches (not timesteps!).
|
||||
"capacity": 1000,
|
||||
# Choosing `fragments` here makes it so that the buffer stores entire
|
||||
# batches, instead of sequences, episodes or timesteps.
|
||||
"storage_unit": "fragments",
|
||||
"learning_starts": 1000,
|
||||
# Whether to compute priorities on workers.
|
||||
"worker_side_prioritization": False,
|
||||
}
|
||||
|
@ -255,54 +258,53 @@ class QMix(SimpleQ):
|
|||
# Store new samples in the replay buffer.
|
||||
self.local_replay_buffer.add(batch)
|
||||
|
||||
# Sample n batches from replay buffer until the total number of timesteps
|
||||
# reaches `train_batch_size`.
|
||||
train_batch = sample_min_n_steps_from_buffer(
|
||||
replay_buffer=self.local_replay_buffer,
|
||||
min_steps=self.config["train_batch_size"],
|
||||
count_by_agent_steps=self._by_agent_steps,
|
||||
)
|
||||
if train_batch is None:
|
||||
return {}
|
||||
|
||||
# Learn on the training batch.
|
||||
# Use simple optimizer (only for multi-agent or tf-eager; all other
|
||||
# cases should use the multi-GPU optimizer, even if only using 1 GPU)
|
||||
if self.config.get("simple_optimizer") is True:
|
||||
train_results = train_one_step(self, train_batch)
|
||||
else:
|
||||
train_results = multi_gpu_train_one_step(self, train_batch)
|
||||
|
||||
# TODO: Move training steps counter update outside of `train_one_step()` method.
|
||||
# # Update train step counters.
|
||||
# self._counters[NUM_ENV_STEPS_TRAINED] += train_batch.env_steps()
|
||||
# self._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps()
|
||||
|
||||
# Update target network every `target_network_update_freq` sample steps.
|
||||
cur_ts = self._counters[
|
||||
NUM_AGENT_STEPS_SAMPLED if self._by_agent_steps else NUM_ENV_STEPS_SAMPLED
|
||||
]
|
||||
last_update = self._counters[LAST_TARGET_UPDATE_TS]
|
||||
if cur_ts - last_update >= self.config["target_network_update_freq"]:
|
||||
to_update = self.workers.local_worker().get_policies_to_train()
|
||||
self.workers.local_worker().foreach_policy_to_train(
|
||||
lambda p, pid: pid in to_update and p.update_target()
|
||||
|
||||
train_results = {}
|
||||
|
||||
if cur_ts > self.config["num_steps_sampled_before_learning_starts"]:
|
||||
# Sample n batches from replay buffer until the total number of timesteps
|
||||
# reaches `train_batch_size`.
|
||||
train_batch = sample_min_n_steps_from_buffer(
|
||||
replay_buffer=self.local_replay_buffer,
|
||||
min_steps=self.config["train_batch_size"],
|
||||
count_by_agent_steps=self._by_agent_steps,
|
||||
)
|
||||
self._counters[NUM_TARGET_UPDATES] += 1
|
||||
self._counters[LAST_TARGET_UPDATE_TS] = cur_ts
|
||||
|
||||
update_priorities_in_replay_buffer(
|
||||
self.local_replay_buffer, self.config, train_batch, train_results
|
||||
)
|
||||
# Learn on the training batch.
|
||||
# Use simple optimizer (only for multi-agent or tf-eager; all other
|
||||
# cases should use the multi-GPU optimizer, even if only using 1 GPU)
|
||||
if self.config.get("simple_optimizer") is True:
|
||||
train_results = train_one_step(self, train_batch)
|
||||
else:
|
||||
train_results = multi_gpu_train_one_step(self, train_batch)
|
||||
|
||||
# Update weights and global_vars - after learning on the local worker - on all
|
||||
# remote workers.
|
||||
global_vars = {
|
||||
"timestep": self._counters[NUM_ENV_STEPS_SAMPLED],
|
||||
}
|
||||
# Update remote workers' weights and global vars after learning on local worker.
|
||||
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
|
||||
self.workers.sync_weights(global_vars=global_vars)
|
||||
# Update target network every `target_network_update_freq` sample steps.
|
||||
last_update = self._counters[LAST_TARGET_UPDATE_TS]
|
||||
if cur_ts - last_update >= self.config["target_network_update_freq"]:
|
||||
to_update = self.workers.local_worker().get_policies_to_train()
|
||||
self.workers.local_worker().foreach_policy_to_train(
|
||||
lambda p, pid: pid in to_update and p.update_target()
|
||||
)
|
||||
self._counters[NUM_TARGET_UPDATES] += 1
|
||||
self._counters[LAST_TARGET_UPDATE_TS] = cur_ts
|
||||
|
||||
update_priorities_in_replay_buffer(
|
||||
self.local_replay_buffer, self.config, train_batch, train_results
|
||||
)
|
||||
|
||||
# Update weights and global_vars - after learning on the local worker -
|
||||
# on all remote workers.
|
||||
global_vars = {
|
||||
"timestep": self._counters[NUM_ENV_STEPS_SAMPLED],
|
||||
}
|
||||
# Update remote workers' weights and global vars after learning on local
|
||||
# worker.
|
||||
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
|
||||
self.workers.sync_weights(global_vars=global_vars)
|
||||
|
||||
# Return all collected metrics for the iteration.
|
||||
return train_results
|
||||
|
|
|
@ -70,6 +70,7 @@ class TestR2D2(unittest.TestCase):
|
|||
lr=5e-4,
|
||||
zero_init_states=True,
|
||||
replay_buffer_config={"replay_burn_in": 20},
|
||||
num_steps_sampled_before_learning_starts=0,
|
||||
)
|
||||
.exploration(exploration_config={"epsilon_timesteps": 100000})
|
||||
)
|
||||
|
|
|
@ -64,8 +64,6 @@ class SACConfig(AlgorithmConfig):
|
|||
"_enable_replay_buffer_api": True,
|
||||
"type": "MultiAgentPrioritizedReplayBuffer",
|
||||
"capacity": int(1e6),
|
||||
# How many steps of the model to sample before learning starts.
|
||||
"learning_starts": 1500,
|
||||
# If True prioritized replay buffer will be used.
|
||||
"prioritized_replay": False,
|
||||
"prioritized_replay_alpha": 0.6,
|
||||
|
@ -90,6 +88,10 @@ class SACConfig(AlgorithmConfig):
|
|||
|
||||
# .training()
|
||||
self.train_batch_size = 256
|
||||
# Number of timesteps to collect from rollout workers before we start
|
||||
# sampling from replay buffers for learning. Whether we count this in agent
|
||||
# steps or environment steps depends on config["multiagent"]["count_steps_by"].
|
||||
self.num_steps_sampled_before_learning_starts = 1500
|
||||
|
||||
# .reporting()
|
||||
self.min_time_s_per_iteration = 1
|
||||
|
@ -123,6 +125,7 @@ class SACConfig(AlgorithmConfig):
|
|||
target_network_update_freq: Optional[int] = None,
|
||||
_deterministic_loss: Optional[bool] = None,
|
||||
_use_beta_distribution: Optional[bool] = None,
|
||||
num_steps_sampled_before_learning_starts: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> "SACConfig":
|
||||
"""Sets the training related configuration.
|
||||
|
@ -166,7 +169,6 @@ class SACConfig(AlgorithmConfig):
|
|||
{
|
||||
"_enable_replay_buffer_api": True,
|
||||
"type": "MultiAgentReplayBuffer",
|
||||
"learning_starts": 1000,
|
||||
"capacity": 50000,
|
||||
"replay_batch_size": 32,
|
||||
"replay_sequence_length": 1,
|
||||
|
@ -267,6 +269,10 @@ class SACConfig(AlgorithmConfig):
|
|||
self._deterministic_loss = _deterministic_loss
|
||||
if _use_beta_distribution is not None:
|
||||
self._use_beta_distribution = _use_beta_distribution
|
||||
if num_steps_sampled_before_learning_starts is not None:
|
||||
self.num_steps_sampled_before_learning_starts = (
|
||||
num_steps_sampled_before_learning_starts
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
|
|
@ -46,6 +46,7 @@ class TestRNNSAC(unittest.TestCase):
|
|||
"zero_init_states": True,
|
||||
},
|
||||
lr=5e-4,
|
||||
num_steps_sampled_before_learning_starts=0,
|
||||
)
|
||||
)
|
||||
num_iterations = 1
|
||||
|
|
|
@ -79,7 +79,10 @@ class TestSAC(unittest.TestCase):
|
|||
.training(
|
||||
n_step=3,
|
||||
twin_q=True,
|
||||
replay_buffer_config={"learning_starts": 0, "capacity": 40000},
|
||||
replay_buffer_config={
|
||||
"capacity": 40000,
|
||||
},
|
||||
num_steps_sampled_before_learning_starts=0,
|
||||
store_buffer_in_checkpoints=True,
|
||||
train_batch_size=10,
|
||||
)
|
||||
|
@ -172,7 +175,7 @@ class TestSAC(unittest.TestCase):
|
|||
_deterministic_loss=True,
|
||||
q_model_config={"fcnet_hiddens": [10]},
|
||||
policy_model_config={"fcnet_hiddens": [10]},
|
||||
replay_buffer_config={"learning_starts": 0},
|
||||
num_steps_sampled_before_learning_starts=0,
|
||||
)
|
||||
.rollouts(num_rollout_workers=0)
|
||||
.reporting(
|
||||
|
@ -523,7 +526,10 @@ class TestSAC(unittest.TestCase):
|
|||
config = (
|
||||
sac.SACConfig()
|
||||
.training(
|
||||
replay_buffer_config={"learning_starts": 0, "capacity": 10},
|
||||
replay_buffer_config={
|
||||
"capacity": 10,
|
||||
},
|
||||
num_steps_sampled_before_learning_starts=0,
|
||||
train_batch_size=5,
|
||||
)
|
||||
.rollouts(
|
||||
|
|
|
@ -82,7 +82,6 @@ class SimpleQConfig(AlgorithmConfig):
|
|||
>>> })
|
||||
>>> config = SimpleQConfig().rollouts(rollout_fragment_length=32)\
|
||||
>>> .exploration(exploration_config=explore_config)\
|
||||
>>> .training(learning_starts=200)
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.algorithms.simple_q import SimpleQConfig
|
||||
|
@ -106,14 +105,16 @@ class SimpleQConfig(AlgorithmConfig):
|
|||
# __sphinx_doc_begin__
|
||||
self.target_network_update_freq = 500
|
||||
self.replay_buffer_config = {
|
||||
# How many steps of the model to sample before learning starts.
|
||||
"learning_starts": 1000,
|
||||
"type": "MultiAgentReplayBuffer",
|
||||
"capacity": 50000,
|
||||
# The number of contiguous environment steps to replay at once. This
|
||||
# may be set to greater than 1 to support recurrent models.
|
||||
"replay_sequence_length": 1,
|
||||
}
|
||||
# Number of timesteps to collect from rollout workers before we start
|
||||
# sampling from replay buffers for learning. Whether we count this in agent
|
||||
# steps or environment steps depends on config["multiagent"]["count_steps_by"].
|
||||
self.num_steps_sampled_before_learning_starts = 1000
|
||||
self.store_buffer_in_checkpoints = False
|
||||
self.lr_schedule = None
|
||||
self.adam_epsilon = 1e-8
|
||||
|
@ -166,6 +167,7 @@ class SimpleQConfig(AlgorithmConfig):
|
|||
lr_schedule: Optional[List[List[Union[int, float]]]] = None,
|
||||
adam_epsilon: Optional[float] = None,
|
||||
grad_clip: Optional[int] = None,
|
||||
num_steps_sampled_before_learning_starts: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> "SimpleQConfig":
|
||||
"""Sets the training related configuration.
|
||||
|
@ -180,7 +182,6 @@ class SimpleQConfig(AlgorithmConfig):
|
|||
{
|
||||
"_enable_replay_buffer_api": True,
|
||||
"type": "MultiAgentReplayBuffer",
|
||||
"learning_starts": 1000,
|
||||
"capacity": 50000,
|
||||
"replay_sequence_length": 1,
|
||||
}
|
||||
|
@ -221,6 +222,10 @@ class SimpleQConfig(AlgorithmConfig):
|
|||
timestep 0.
|
||||
adam_epsilon: Adam optimizer's epsilon hyper parameter.
|
||||
grad_clip: If not None, clip gradients during optimization at this value.
|
||||
num_steps_sampled_before_learning_starts: Number of timesteps to collect
|
||||
from rollout workers before we start sampling from replay buffers for
|
||||
learning. Whether we count this in agent steps or environment steps
|
||||
depends on config["multiagent"]["count_steps_by"].
|
||||
|
||||
Returns:
|
||||
This updated AlgorithmConfig object.
|
||||
|
@ -249,6 +254,10 @@ class SimpleQConfig(AlgorithmConfig):
|
|||
self.adam_epsilon = adam_epsilon
|
||||
if grad_clip is not None:
|
||||
self.grad_clip = grad_clip
|
||||
if num_steps_sampled_before_learning_starts is not None:
|
||||
self.num_steps_sampled_before_learning_starts = (
|
||||
num_steps_sampled_before_learning_starts
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
@ -341,54 +350,47 @@ class SimpleQ(Algorithm):
|
|||
global_vars = {
|
||||
"timestep": self._counters[NUM_ENV_STEPS_SAMPLED],
|
||||
}
|
||||
|
||||
# Use deprecated replay() to support old replay buffers for now
|
||||
train_batch = self.local_replay_buffer.sample(batch_size)
|
||||
# If not yet learning, early-out here and do not perform learning, weight-
|
||||
# synching, or target net updating.
|
||||
if train_batch is None or len(train_batch) == 0:
|
||||
self.workers.local_worker().set_global_vars(global_vars)
|
||||
return {}
|
||||
|
||||
# Learn on the training batch.
|
||||
# Use simple optimizer (only for multi-agent or tf-eager; all other
|
||||
# cases should use the multi-GPU optimizer, even if only using 1 GPU)
|
||||
if self.config.get("simple_optimizer") is True:
|
||||
train_results = train_one_step(self, train_batch)
|
||||
else:
|
||||
train_results = multi_gpu_train_one_step(self, train_batch)
|
||||
|
||||
# Update replay buffer priorities.
|
||||
update_priorities_in_replay_buffer(
|
||||
self.local_replay_buffer,
|
||||
self.config,
|
||||
train_batch,
|
||||
train_results,
|
||||
)
|
||||
|
||||
# TODO: Move training steps counter update outside of `train_one_step()` method.
|
||||
# # Update train step counters.
|
||||
# self._counters[NUM_ENV_STEPS_TRAINED] += train_batch.env_steps()
|
||||
# self._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps()
|
||||
|
||||
# Update target network every `target_network_update_freq` sample steps.
|
||||
cur_ts = self._counters[
|
||||
NUM_AGENT_STEPS_SAMPLED if self._by_agent_steps else NUM_ENV_STEPS_SAMPLED
|
||||
]
|
||||
last_update = self._counters[LAST_TARGET_UPDATE_TS]
|
||||
if cur_ts - last_update >= self.config["target_network_update_freq"]:
|
||||
with self._timers[TARGET_NET_UPDATE_TIMER]:
|
||||
to_update = local_worker.get_policies_to_train()
|
||||
local_worker.foreach_policy_to_train(
|
||||
lambda p, pid: pid in to_update and p.update_target()
|
||||
)
|
||||
self._counters[NUM_TARGET_UPDATES] += 1
|
||||
self._counters[LAST_TARGET_UPDATE_TS] = cur_ts
|
||||
|
||||
# Update weights and global_vars - after learning on the local worker - on all
|
||||
# remote workers.
|
||||
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
|
||||
self.workers.sync_weights(global_vars=global_vars)
|
||||
if cur_ts > self.config["num_steps_sampled_before_learning_starts"]:
|
||||
# Use deprecated replay() to support old replay buffers for now
|
||||
train_batch = self.local_replay_buffer.sample(batch_size)
|
||||
|
||||
# Learn on the training batch.
|
||||
# Use simple optimizer (only for multi-agent or tf-eager; all other
|
||||
# cases should use the multi-GPU optimizer, even if only using 1 GPU)
|
||||
if self.config.get("simple_optimizer") is True:
|
||||
train_results = train_one_step(self, train_batch)
|
||||
else:
|
||||
train_results = multi_gpu_train_one_step(self, train_batch)
|
||||
|
||||
# Update replay buffer priorities.
|
||||
update_priorities_in_replay_buffer(
|
||||
self.local_replay_buffer,
|
||||
self.config,
|
||||
train_batch,
|
||||
train_results,
|
||||
)
|
||||
|
||||
last_update = self._counters[LAST_TARGET_UPDATE_TS]
|
||||
if cur_ts - last_update >= self.config["target_network_update_freq"]:
|
||||
with self._timers[TARGET_NET_UPDATE_TIMER]:
|
||||
to_update = local_worker.get_policies_to_train()
|
||||
local_worker.foreach_policy_to_train(
|
||||
lambda p, pid: pid in to_update and p.update_target()
|
||||
)
|
||||
self._counters[NUM_TARGET_UPDATES] += 1
|
||||
self._counters[LAST_TARGET_UPDATE_TS] = cur_ts
|
||||
|
||||
# Update weights and global_vars - after learning on the local worker -
|
||||
# on all remote workers.
|
||||
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
|
||||
self.workers.sync_weights(global_vars=global_vars)
|
||||
else:
|
||||
train_results = {}
|
||||
|
||||
# Return all collected metrics for the iteration.
|
||||
return train_results
|
||||
|
|
|
@ -31,8 +31,10 @@ class TestSimpleQ(unittest.TestCase):
|
|||
def test_simple_q_compilation(self):
|
||||
"""Test whether SimpleQ can be built on all frameworks."""
|
||||
# Run locally and with compression
|
||||
config = simple_q.SimpleQConfig().rollouts(
|
||||
num_rollout_workers=0, compress_observations=True
|
||||
config = (
|
||||
simple_q.SimpleQConfig()
|
||||
.rollouts(num_rollout_workers=0, compress_observations=True)
|
||||
.training(num_steps_sampled_before_learning_starts=0)
|
||||
)
|
||||
|
||||
num_iterations = 2
|
||||
|
@ -57,7 +59,8 @@ class TestSimpleQ(unittest.TestCase):
|
|||
model={
|
||||
"fcnet_hiddens": [10],
|
||||
"fcnet_activation": "linear",
|
||||
}
|
||||
},
|
||||
num_steps_sampled_before_learning_starts=0,
|
||||
)
|
||||
|
||||
for fw in framework_iterator(config):
|
||||
|
|
|
@ -90,9 +90,11 @@ class SlateQConfig(AlgorithmConfig):
|
|||
"replay_sequence_length": 1,
|
||||
# Whether to compute priorities on workers.
|
||||
"worker_side_prioritization": False,
|
||||
# How many steps of the model to sample before learning starts.
|
||||
"learning_starts": 20000,
|
||||
}
|
||||
# Number of timesteps to collect from rollout workers before we start
|
||||
# sampling from replay buffers for learning. Whether we count this in agent
|
||||
# steps or environment steps depends on config["multiagent"]["count_steps_by"].
|
||||
self.num_steps_sampled_before_learning_starts = 20000
|
||||
|
||||
# Override some of AlgorithmConfig's default values with SlateQ-specific values.
|
||||
self.exploration_config = {
|
||||
|
@ -139,6 +141,7 @@ class SlateQConfig(AlgorithmConfig):
|
|||
rmsprop_epsilon: Optional[float] = None,
|
||||
grad_clip: Optional[float] = None,
|
||||
n_step: Optional[int] = None,
|
||||
num_steps_sampled_before_learning_starts: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> "SlateQConfig":
|
||||
"""Sets the training related configuration.
|
||||
|
@ -202,6 +205,10 @@ class SlateQConfig(AlgorithmConfig):
|
|||
self.grad_clip = grad_clip
|
||||
if n_step is not None:
|
||||
self.n_step = n_step
|
||||
if num_steps_sampled_before_learning_starts is not None:
|
||||
self.num_steps_sampled_before_learning_starts = (
|
||||
num_steps_sampled_before_learning_starts
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ class TestSlateQ(unittest.TestCase):
|
|||
config = (
|
||||
slateq.SlateQConfig()
|
||||
.environment(env=InterestEvolutionRecSimEnv)
|
||||
.training(replay_buffer_config={"learning_starts": 1000})
|
||||
.training(num_steps_sampled_before_learning_starts=1000)
|
||||
)
|
||||
|
||||
num_iterations = 1
|
||||
|
|
|
@ -71,9 +71,12 @@ class TD3Config(DDPGConfig):
|
|||
# prioritization, for example: MultiAgentPrioritizedReplayBuffer.
|
||||
"prioritized_replay": DEPRECATED_VALUE,
|
||||
"capacity": 1000000,
|
||||
"learning_starts": 10000,
|
||||
"worker_side_prioritization": False,
|
||||
}
|
||||
# Number of timesteps to collect from rollout workers before we start
|
||||
# sampling from replay buffers for learning. Whether we count this in agent
|
||||
# steps or environment steps depends on config["multiagent"]["count_steps_by"].
|
||||
self.num_steps_sampled_before_learning_starts = 10000
|
||||
|
||||
# .exploration()
|
||||
# TD3 uses Gaussian Noise by default.
|
||||
|
|
|
@ -231,7 +231,7 @@ class TestWorkerFailure(unittest.TestCase):
|
|||
"min_sample_timesteps_per_iteration": 1000,
|
||||
"min_time_s_per_iteration": 1,
|
||||
"explore": False,
|
||||
"learning_starts": 1000,
|
||||
"num_steps_sampled_before_learning_starts": 1000,
|
||||
"target_network_update_freq": 100,
|
||||
"optimizer": {
|
||||
"num_replay_buffer_shards": 1,
|
||||
|
|
|
@ -95,7 +95,7 @@ if __name__ == "__main__":
|
|||
"clip_actions": True,
|
||||
"twin_q": True,
|
||||
"train_batch_size": 2000,
|
||||
"learning_starts": 0,
|
||||
"num_steps_sampled_before_learning_starts": 0,
|
||||
"bc_iters": 100,
|
||||
"metrics_num_episodes_for_smoothing": 5,
|
||||
"evaluation_interval": 1,
|
||||
|
|
|
@ -116,7 +116,7 @@ if __name__ == "__main__":
|
|||
assert r["model"]["foo"] == 42, result
|
||||
|
||||
if args.run == "DQN":
|
||||
extra_config = {"replay_buffer_config": {"learning_starts": 0}}
|
||||
extra_config = {"num_steps_sampled_before_learning_starts": 0}
|
||||
else:
|
||||
extra_config = {}
|
||||
|
||||
|
|
|
@ -27,10 +27,10 @@ if __name__ == "__main__":
|
|||
"num_workers": 2,
|
||||
"num_envs_per_worker": 8,
|
||||
"replay_buffer_config": {
|
||||
"learning_starts": 1000,
|
||||
"capacity": int(1e5),
|
||||
"prioritized_replay_alpha": 0.5,
|
||||
},
|
||||
"num_steps_sampled_before_learning_starts": 1000,
|
||||
"compress_observations": True,
|
||||
"rollout_fragment_length": 20,
|
||||
"train_batch_size": 512,
|
||||
|
|
|
@ -38,7 +38,7 @@ if __name__ == "__main__":
|
|||
config["bc_iters"] = 0
|
||||
config["clip_actions"] = False
|
||||
config["normalize_actions"] = True
|
||||
config["replay_buffer_config"]["learning_starts"] = 256
|
||||
config["num_steps_sampled_before_learning_starts"] = 256
|
||||
config["rollout_fragment_length"] = 1
|
||||
# Test without prioritized replay
|
||||
config["replay_buffer_config"]["type"] = "MultiAgentReplayBuffer"
|
||||
|
|
|
@ -45,7 +45,7 @@ parser.add_argument(
|
|||
choices=["interest-evolution", "interest-exploration", "long-term-satisfaction"],
|
||||
help=("Select the RecSim env to use."),
|
||||
)
|
||||
parser.add_argument("--learning-starts", type=int, default=20000)
|
||||
|
||||
parser.add_argument(
|
||||
"--random-test-episodes",
|
||||
type=int,
|
||||
|
@ -71,6 +71,15 @@ parser.add_argument(
|
|||
"`--env-slate-size` from each timestep. These candidates will be "
|
||||
"sampled by the environment's built-in document sampler model.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-steps-sampled-before-learning_starts",
|
||||
type=int,
|
||||
default=20000,
|
||||
help="Number of timesteps to collect from rollout workers before we start "
|
||||
"sampling from replay buffers for learning..",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--env-slate-size",
|
||||
type=int,
|
||||
|
@ -126,9 +135,7 @@ def main():
|
|||
"num_gpus": args.num_gpus,
|
||||
"num_workers": args.num_workers,
|
||||
"env_config": env_config,
|
||||
"replay_buffer_config": {
|
||||
"learning_starts": args.learning_starts,
|
||||
},
|
||||
"num_steps_sampled_before_learning_starts": args.num_steps_sampled_before_learning_starts, # noqa E501
|
||||
}
|
||||
|
||||
# Perform a test run on the env with a random agent to see, what
|
||||
|
|
|
@ -24,9 +24,9 @@ param_space = {
|
|||
"type": "MultiAgentReplayBuffer",
|
||||
"storage_unit": "sequences",
|
||||
"capacity": 100000,
|
||||
"learning_starts": 1000,
|
||||
"replay_burn_in": 4,
|
||||
},
|
||||
"num_steps_sampled_before_learning_starts": 1000,
|
||||
"train_batch_size": 480,
|
||||
"target_network_update_freq": 480,
|
||||
"tau": 0.3,
|
||||
|
|
|
@ -180,7 +180,7 @@ if __name__ == "__main__":
|
|||
# Example of using DQN (supports off-policy actions).
|
||||
config.update(
|
||||
{
|
||||
"replay_buffer_config": {"learning_starts": 100},
|
||||
"num_steps_sampled_before_learning_starts": 100,
|
||||
"min_sample_timesteps_per_iteration": 200,
|
||||
"n_step": 3,
|
||||
"rollout_fragment_length": 4,
|
||||
|
|
|
@ -115,7 +115,7 @@ if __name__ == "__main__":
|
|||
"env_config": {
|
||||
"actions_are_logits": True,
|
||||
},
|
||||
"replay_buffer_config": {"learning_starts": 100},
|
||||
"num_steps_sampled_before_learning_starts": 100,
|
||||
"multiagent": {
|
||||
"policies": {
|
||||
"pol1": PolicySpec(
|
||||
|
|
|
@ -83,9 +83,7 @@ class MyAlgo(Algorithm):
|
|||
# Call super's `setup` to create rollout workers.
|
||||
super().setup(config)
|
||||
# Create local replay buffer.
|
||||
self.local_replay_buffer = MultiAgentReplayBuffer(
|
||||
num_shards=1, learning_starts=1000, capacity=50000
|
||||
)
|
||||
self.local_replay_buffer = MultiAgentReplayBuffer(num_shards=1, capacity=50000)
|
||||
|
||||
@override(Algorithm)
|
||||
def training_step(self) -> ResultDict:
|
||||
|
@ -93,6 +91,7 @@ class MyAlgo(Algorithm):
|
|||
# into replay buffer.
|
||||
ppo_batches = []
|
||||
num_env_steps = 0
|
||||
|
||||
# PPO batch size fixed at 200.
|
||||
while num_env_steps < 200:
|
||||
ma_batches = synchronous_parallel_sample(
|
||||
|
@ -112,8 +111,9 @@ class MyAlgo(Algorithm):
|
|||
|
||||
# DQN sub-flow.
|
||||
dqn_train_results = {}
|
||||
dqn_train_batch = self.local_replay_buffer.sample(num_items=64)
|
||||
if dqn_train_batch is not None:
|
||||
|
||||
if self._counters[NUM_ENV_STEPS_SAMPLED] > 1000:
|
||||
dqn_train_batch = self.local_replay_buffer.sample(num_items=64)
|
||||
dqn_train_results = train_one_step(self, dqn_train_batch, ["dqn_policy"])
|
||||
self._counters["agent_steps_trained_DQN"] += dqn_train_batch.agent_steps()
|
||||
print(
|
||||
|
|
|
@ -28,6 +28,7 @@ CONFIGS = {
|
|||
"optimizer": {
|
||||
"num_replay_buffer_shards": 1,
|
||||
},
|
||||
"num_steps_sampled_before_learning_starts": 0,
|
||||
},
|
||||
"ARS": {
|
||||
"explore": False,
|
||||
|
@ -39,9 +40,11 @@ CONFIGS = {
|
|||
"DDPG": {
|
||||
"explore": False,
|
||||
"min_sample_timesteps_per_iteration": 100,
|
||||
"num_steps_sampled_before_learning_starts": 0,
|
||||
},
|
||||
"DQN": {
|
||||
"explore": False,
|
||||
"num_steps_sampled_before_learning_starts": 0,
|
||||
},
|
||||
"ES": {
|
||||
"explore": False,
|
||||
|
@ -59,9 +62,11 @@ CONFIGS = {
|
|||
},
|
||||
"SimpleQ": {
|
||||
"explore": False,
|
||||
"num_steps_sampled_before_learning_starts": 0,
|
||||
},
|
||||
"SAC": {
|
||||
"explore": False,
|
||||
"num_steps_sampled_before_learning_starts": 0,
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
@ -22,6 +22,9 @@ def test_custom_resource(algorithm):
|
|||
"custom_resources_per_worker": {"custom_resource": 0.01},
|
||||
}
|
||||
|
||||
if algorithm == "APEX":
|
||||
config["num_steps_sampled_before_learning_starts"] = 0
|
||||
|
||||
stop = {"training_iteration": 1}
|
||||
|
||||
tune.run(
|
||||
|
|
|
@ -44,12 +44,19 @@ class TestEagerSupportPG(unittest.TestCase):
|
|||
def test_simple_q(self):
|
||||
check_support(
|
||||
"SimpleQ",
|
||||
{"num_workers": 0, "replay_buffer_config": {"learning_starts": 0}},
|
||||
{
|
||||
"num_workers": 0,
|
||||
"num_steps_sampled_before_learning_starts": 0,
|
||||
},
|
||||
)
|
||||
|
||||
def test_dqn(self):
|
||||
check_support(
|
||||
"DQN", {"num_workers": 0, "replay_buffer_config": {"learning_starts": 0}}
|
||||
"DQN",
|
||||
{
|
||||
"num_workers": 0,
|
||||
"num_steps_sampled_before_learning_starts": 0,
|
||||
},
|
||||
)
|
||||
|
||||
def test_ddpg(self):
|
||||
|
@ -91,12 +98,19 @@ class TestEagerSupportOffPolicy(unittest.TestCase):
|
|||
def test_simple_q(self):
|
||||
check_support(
|
||||
"SimpleQ",
|
||||
{"num_workers": 0, "replay_buffer_config": {"learning_starts": 0}},
|
||||
{
|
||||
"num_workers": 0,
|
||||
"replay_buffer_config": {"num_steps_sampled_before_learning_starts": 0},
|
||||
},
|
||||
)
|
||||
|
||||
def test_dqn(self):
|
||||
check_support(
|
||||
"DQN", {"num_workers": 0, "replay_buffer_config": {"learning_starts": 0}}
|
||||
"DQN",
|
||||
{
|
||||
"num_workers": 0,
|
||||
"num_steps_sampled_before_learning_starts": 0,
|
||||
},
|
||||
)
|
||||
|
||||
def test_ddpg(self):
|
||||
|
@ -113,7 +127,7 @@ class TestEagerSupportOffPolicy(unittest.TestCase):
|
|||
"APEX",
|
||||
{
|
||||
"num_workers": 2,
|
||||
"replay_buffer_config": {"learning_starts": 0},
|
||||
"replay_buffer_config": {"num_steps_sampled_before_learning_starts": 0},
|
||||
"num_gpus": 0,
|
||||
"min_time_s_per_iteration": 1,
|
||||
"min_sample_timesteps_per_iteration": 100,
|
||||
|
@ -125,7 +139,11 @@ class TestEagerSupportOffPolicy(unittest.TestCase):
|
|||
|
||||
def test_sac(self):
|
||||
check_support(
|
||||
"SAC", {"num_workers": 0, "replay_buffer_config": {"learning_starts": 0}}
|
||||
"SAC",
|
||||
{
|
||||
"num_workers": 0,
|
||||
"num_steps_sampled_before_learning_starts": 0,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -208,20 +208,16 @@ class TestExecution(unittest.TestCase):
|
|||
def test_store_to_replay_local(self):
|
||||
buf = MultiAgentReplayBuffer(
|
||||
num_shards=1,
|
||||
learning_starts=200,
|
||||
capacity=1000,
|
||||
prioritized_replay_alpha=0.6,
|
||||
prioritized_replay_beta=0.4,
|
||||
prioritized_replay_eps=0.0001,
|
||||
)
|
||||
assert len(buf.sample(100)) == 0
|
||||
|
||||
workers = make_workers(0)
|
||||
a = ParallelRollouts(workers, mode="bulk_sync")
|
||||
b = a.for_each(StoreToReplayBuffer(local_buffer=buf))
|
||||
|
||||
next(b)
|
||||
assert len(buf.sample(100)) == 0 # learning hasn't started yet
|
||||
next(b)
|
||||
assert buf.sample(100).count == 100
|
||||
|
||||
|
@ -232,7 +228,6 @@ class TestExecution(unittest.TestCase):
|
|||
ReplayActor = ray.remote(num_cpus=0)(MultiAgentReplayBuffer)
|
||||
actor = ReplayActor.remote(
|
||||
num_shards=1,
|
||||
learning_starts=200,
|
||||
capacity=1000,
|
||||
prioritized_replay_alpha=0.6,
|
||||
prioritized_replay_beta=0.4,
|
||||
|
@ -244,8 +239,6 @@ class TestExecution(unittest.TestCase):
|
|||
a = ParallelRollouts(workers, mode="bulk_sync")
|
||||
b = a.for_each(StoreToReplayBuffer(actors=[actor]))
|
||||
|
||||
next(b)
|
||||
assert len(ray.get(actor.sample.remote(100))) == 0 # learning hasn't started
|
||||
next(b)
|
||||
assert ray.get(actor.sample.remote(100)).count == 100
|
||||
|
||||
|
|
|
@ -97,8 +97,8 @@ class TestSupportedMultiAgentOffPolicy(unittest.TestCase):
|
|||
"num_gpus": 0,
|
||||
"replay_buffer_config": {
|
||||
"capacity": 1000,
|
||||
"learning_starts": 10,
|
||||
},
|
||||
"num_steps_sampled_before_learning_starts": 10,
|
||||
"min_time_s_per_iteration": 1,
|
||||
"target_network_update_freq": 100,
|
||||
"optimizer": {
|
||||
|
@ -115,8 +115,8 @@ class TestSupportedMultiAgentOffPolicy(unittest.TestCase):
|
|||
"min_sample_timesteps_per_iteration": 100,
|
||||
"replay_buffer_config": {
|
||||
"capacity": 1000,
|
||||
"learning_starts": 10,
|
||||
},
|
||||
"num_steps_sampled_before_learning_starts": 10,
|
||||
"num_gpus": 0,
|
||||
"min_time_s_per_iteration": 1,
|
||||
"target_network_update_freq": 100,
|
||||
|
@ -131,8 +131,8 @@ class TestSupportedMultiAgentOffPolicy(unittest.TestCase):
|
|||
"min_sample_timesteps_per_iteration": 1,
|
||||
"replay_buffer_config": {
|
||||
"capacity": 1000,
|
||||
"learning_starts": 500,
|
||||
},
|
||||
"num_steps_sampled_before_learning_starts": 10,
|
||||
"use_state_preprocessor": True,
|
||||
},
|
||||
)
|
||||
|
|
|
@ -21,7 +21,7 @@ cartpole-apex-dqn-training-itr:
|
|||
replay_buffer_config:
|
||||
type: MultiAgentPrioritizedReplayBuffer
|
||||
capacity: 20000
|
||||
learning_starts: 1000
|
||||
num_steps_sampled_before_learning_starts: 1000
|
||||
|
||||
num_gpus: 0
|
||||
|
||||
|
|
|
@ -129,8 +129,8 @@ atari-basic-dqn:
|
|||
noisy: false
|
||||
replay_buffer_config:
|
||||
type: MultiAgentReplayBuffer
|
||||
learning_starts: 20000
|
||||
capacity: 1000000
|
||||
num_steps_sampled_before_learning_starts: 20000
|
||||
n_step: 1
|
||||
target_network_update_freq: 8000
|
||||
lr: .0000625
|
||||
|
|
|
@ -28,7 +28,7 @@ halfcheetah_bc:
|
|||
rollout_fragment_length: 1
|
||||
replay_buffer_config:
|
||||
type: MultiAgentReplayBuffer
|
||||
learning_starts: 10
|
||||
num_steps_sampled_before_learning_starts: 10
|
||||
train_batch_size: 256
|
||||
target_network_update_freq: 0
|
||||
min_train_timesteps_per_iteration: 1000
|
||||
|
|
|
@ -30,7 +30,7 @@ halfcheetah_cql:
|
|||
rollout_fragment_length: 1
|
||||
replay_buffer_config:
|
||||
type: MultiAgentReplayBuffer
|
||||
learning_starts: 256
|
||||
num_steps_sampled_before_learning_starts: 256
|
||||
train_batch_size: 256
|
||||
target_network_update_freq: 0
|
||||
min_train_timesteps_per_iteration: 1000
|
||||
|
|
|
@ -28,7 +28,7 @@ hopper_bc:
|
|||
rollout_fragment_length: 1
|
||||
replay_buffer_config:
|
||||
type: MultiAgentReplayBuffer
|
||||
learning_starts: 10
|
||||
num_steps_sampled_before_learning_starts: 10
|
||||
train_batch_size: 256
|
||||
target_network_update_freq: 0
|
||||
min_train_timesteps_per_iteration: 1000
|
||||
|
|
|
@ -28,7 +28,7 @@ hopper_cql:
|
|||
rollout_fragment_length: 1
|
||||
replay_buffer_config:
|
||||
type: MultiAgentReplayBuffer
|
||||
learning_starts: 10
|
||||
num_steps_sampled_before_learning_starts: 10
|
||||
train_batch_size: 256
|
||||
target_network_update_freq: 0
|
||||
min_train_timesteps_per_iteration: 1000
|
||||
|
|
|
@ -37,6 +37,8 @@ halfcheetah-ddpg:
|
|||
prioritized_replay_beta: 0.4
|
||||
prioritized_replay_eps: 0.000001
|
||||
worker_side_prioritization: false
|
||||
|
||||
num_steps_sampled_before_learning_starts: 500
|
||||
clip_rewards: False
|
||||
|
||||
# === Optimization ===
|
||||
|
@ -45,7 +47,6 @@ halfcheetah-ddpg:
|
|||
use_huber: false
|
||||
huber_threshold: 1.0
|
||||
l2_reg: 0.000001
|
||||
learning_starts: 500
|
||||
rollout_fragment_length: 1
|
||||
train_batch_size: 64
|
||||
|
||||
|
|
|
@ -29,13 +29,13 @@ ddpg-halfcheetahbulletenv-v0:
|
|||
prioritized_replay_beta: 0.4
|
||||
prioritized_replay_eps: 0.000001
|
||||
worker_side_prioritization: false
|
||||
num_steps_sampled_before_learning_starts: 500
|
||||
clip_rewards: false
|
||||
actor_lr: 0.001
|
||||
critic_lr: 0.001
|
||||
use_huber: true
|
||||
huber_threshold: 1.0
|
||||
l2_reg: 0.000001
|
||||
learning_starts: 500
|
||||
rollout_fragment_length: 1
|
||||
train_batch_size: 48
|
||||
num_workers: 0
|
||||
|
|
|
@ -32,7 +32,7 @@ ddpg-hopperbulletenv-v0:
|
|||
prioritized_replay_beta: 0.4
|
||||
prioritized_replay_eps: 0.000001
|
||||
worker_side_prioritization: false
|
||||
learning_starts: 500
|
||||
num_steps_sampled_before_learning_starts: 500
|
||||
clip_rewards: False
|
||||
actor_lr: 0.001
|
||||
critic_lr: 0.001
|
||||
|
|
|
@ -38,6 +38,7 @@ mountaincarcontinuous-ddpg:
|
|||
prioritized_replay_beta: 0.4
|
||||
prioritized_replay_eps: 0.000001
|
||||
worker_side_prioritization: false
|
||||
num_steps_sampled_before_learning_starts: 1000
|
||||
clip_rewards: False
|
||||
|
||||
# === Optimization ===
|
||||
|
@ -46,7 +47,6 @@ mountaincarcontinuous-ddpg:
|
|||
use_huber: False
|
||||
huber_threshold: 1.0
|
||||
l2_reg: 0.00001
|
||||
learning_starts: 1000
|
||||
rollout_fragment_length: 1
|
||||
train_batch_size: 64
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ pendulum-ddpg-fake-gpus:
|
|||
type: MultiAgentPrioritizedReplayBuffer
|
||||
capacity: 10000
|
||||
worker_side_prioritization: false
|
||||
learning_starts: 500
|
||||
num_steps_sampled_before_learning_starts: 500
|
||||
clip_rewards: false
|
||||
use_huber: true
|
||||
train_batch_size: 64
|
||||
|
|
|
@ -37,6 +37,7 @@ pendulum-ddpg:
|
|||
type: MultiAgentPrioritizedReplayBuffer
|
||||
capacity: 10000
|
||||
worker_side_prioritization: false
|
||||
num_steps_sampled_before_learning_starts: 500
|
||||
clip_rewards: False
|
||||
|
||||
# === Optimization ===
|
||||
|
@ -45,7 +46,6 @@ pendulum-ddpg:
|
|||
use_huber: True
|
||||
huber_threshold: 1.0
|
||||
l2_reg: 0.000001
|
||||
learning_starts: 500
|
||||
rollout_fragment_length: 1
|
||||
train_batch_size: 64
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@ atari-dist-dqn:
|
|||
replay_buffer_config:
|
||||
type: MultiAgentReplayBuffer
|
||||
capacity: 1000000
|
||||
learning_starts: 20000
|
||||
num_steps_sampled_before_learning_starts: 20000
|
||||
n_step: 1
|
||||
target_network_update_freq: 8000
|
||||
lr: .0000625
|
||||
|
|
|
@ -17,8 +17,8 @@ atari-basic-dqn:
|
|||
noisy: false
|
||||
replay_buffer_config:
|
||||
type: MultiAgentReplayBuffer
|
||||
learning_starts: 20000
|
||||
capacity: 1000000
|
||||
num_steps_sampled_before_learning_starts: 20000
|
||||
n_step: 1
|
||||
target_network_update_freq: 8000
|
||||
lr: .0000625
|
||||
|
|
|
@ -17,8 +17,8 @@ dueling-ddqn:
|
|||
noisy: false
|
||||
replay_buffer_config:
|
||||
type: MultiAgentReplayBuffer
|
||||
learning_starts: 20000
|
||||
capacity: 1000000
|
||||
num_steps_sampled_before_learning_starts: 20000
|
||||
n_step: 1
|
||||
target_network_update_freq: 8000
|
||||
lr: .0000625
|
||||
|
|
|
@ -14,7 +14,7 @@ pong-deterministic-dqn:
|
|||
replay_buffer_config:
|
||||
type: MultiAgentPrioritizedReplayBuffer
|
||||
capacity: 50000
|
||||
learning_starts: 10000
|
||||
num_steps_sampled_before_learning_starts: 10000
|
||||
rollout_fragment_length: 4
|
||||
train_batch_size: 32
|
||||
exploration_config:
|
||||
|
|
|
@ -18,8 +18,8 @@ pong-deterministic-rainbow:
|
|||
replay_buffer_config:
|
||||
type: MultiAgentPrioritizedReplayBuffer
|
||||
prioritized_replay_alpha: 0.5
|
||||
learning_starts: 10000
|
||||
capacity: 50000
|
||||
num_steps_sampled_before_learning_starts: 10000
|
||||
n_step: 3
|
||||
gpu: True
|
||||
model:
|
||||
|
|
|
@ -12,7 +12,7 @@ two-step-game-maddpg:
|
|||
env_config:
|
||||
actions_are_logits: true
|
||||
|
||||
learning_starts: 200
|
||||
num_steps_sampled_before_learning_starts: 200
|
||||
|
||||
multiagent:
|
||||
policies:
|
||||
|
|
|
@ -34,11 +34,11 @@ atari-sac-tf-and-torch:
|
|||
type: MultiAgentPrioritizedReplayBuffer
|
||||
capacity: 1000000
|
||||
# How many steps of the model to sample before learning starts.
|
||||
learning_starts: 100000
|
||||
# If True prioritized replay buffer will be used.
|
||||
prioritized_replay_alpha: 0.6
|
||||
prioritized_replay_beta: 0.4
|
||||
prioritized_replay_eps: 1e-6
|
||||
num_steps_sampled_before_learning_starts: 100000
|
||||
train_batch_size: 64
|
||||
min_sample_timesteps_per_iteration: 4
|
||||
# Paper uses 20k random timesteps, which is not exactly the same, but
|
||||
|
|
|
@ -14,7 +14,7 @@ cartpole-sac:
|
|||
n_step: 3
|
||||
replay_buffer_config:
|
||||
type: MultiAgentPrioritizedReplayBuffer
|
||||
learning_starts: 256
|
||||
num_steps_sampled_before_learning_starts: 256
|
||||
initial_alpha: 0.2
|
||||
clip_actions: false
|
||||
min_sample_timesteps_per_iteration: 1000
|
||||
|
|
|
@ -24,7 +24,7 @@ halfcheetah-pybullet-sac:
|
|||
min_sample_timesteps_per_iteration: 1000
|
||||
replay_buffer_config:
|
||||
type: MultiAgentPrioritizedReplayBuffer
|
||||
learning_starts: 10000
|
||||
num_steps_sampled_before_learning_starts: 10000
|
||||
optimization:
|
||||
actor_learning_rate: 0.0003
|
||||
critic_learning_rate: 0.0003
|
||||
|
|
|
@ -25,7 +25,7 @@ halfcheetah_sac:
|
|||
min_sample_timesteps_per_iteration: 1000
|
||||
replay_buffer_config:
|
||||
type: MultiAgentPrioritizedReplayBuffer
|
||||
learning_starts: 10000
|
||||
num_steps_sampled_before_learning_starts: 10000
|
||||
optimization:
|
||||
actor_learning_rate: 0.0003
|
||||
critic_learning_rate: 0.0003
|
||||
|
|
|
@ -33,7 +33,7 @@ mspacman-sac-tf:
|
|||
# seems to work nevertheless.
|
||||
replay_buffer_config:
|
||||
type: MultiAgentPrioritizedReplayBuffer
|
||||
learning_starts: 20000
|
||||
num_steps_sampled_before_learning_starts: 20000
|
||||
optimization:
|
||||
actor_learning_rate: 0.0003
|
||||
critic_learning_rate: 0.0003
|
||||
|
|
|
@ -26,7 +26,7 @@ pendulum-sac-fake-gpus:
|
|||
min_sample_timesteps_per_iteration: 1000
|
||||
replay_buffer_config:
|
||||
type: MultiAgentPrioritizedReplayBuffer
|
||||
learning_starts: 256
|
||||
num_steps_sampled_before_learning_starts: 256
|
||||
num_workers: 0
|
||||
metrics_smoothing_episodes: 5
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ pendulum-sac:
|
|||
min_sample_timesteps_per_iteration: 1000
|
||||
replay_buffer_config:
|
||||
type: MultiAgentPrioritizedReplayBuffer
|
||||
learning_starts: 256
|
||||
num_steps_sampled_before_learning_starts: 256
|
||||
optimization:
|
||||
actor_learning_rate: 0.0003
|
||||
critic_learning_rate: 0.0003
|
||||
|
|
|
@ -35,7 +35,7 @@ transformed-actions-pendulum-sac-dummy-torch:
|
|||
min_sample_timesteps_per_iteration: 1000
|
||||
replay_buffer_config:
|
||||
type: MultiAgentPrioritizedReplayBuffer
|
||||
learning_starts: 256
|
||||
num_steps_sampled_before_learning_starts: 256
|
||||
optimization:
|
||||
actor_learning_rate: 0.0003
|
||||
critic_learning_rate: 0.0003
|
||||
|
|
|
@ -35,12 +35,12 @@ interest-evolution-recsim-env-slateq:
|
|||
|
||||
replay_buffer_config:
|
||||
capacity: 100000
|
||||
num_steps_sampled_before_learning_starts: 10000
|
||||
|
||||
# Double learning rate and batch size.
|
||||
lr: 0.002
|
||||
train_batch_size: 64
|
||||
|
||||
learning_starts: 10000
|
||||
target_network_update_freq: 3200
|
||||
|
||||
metrics_num_episodes_for_smoothing: 200
|
||||
|
|
|
@ -31,10 +31,10 @@ interest-evolution-recsim-env-slateq:
|
|||
|
||||
replay_buffer_config:
|
||||
capacity: 100000
|
||||
num_steps_sampled_before_learning_starts: 10000
|
||||
|
||||
lr: 0.001
|
||||
|
||||
learning_starts: 10000
|
||||
target_network_update_freq: 3200
|
||||
|
||||
metrics_num_episodes_for_smoothing: 200
|
||||
|
|
|
@ -16,8 +16,7 @@ invertedpendulum-td3:
|
|||
critic_hiddens: [32, 32]
|
||||
|
||||
# === Exploration ===
|
||||
replay_buffer_config:
|
||||
learning_starts: 1000
|
||||
num_steps_sampled_before_learning_starts: 1000
|
||||
exploration_config:
|
||||
random_timesteps: 1000
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ mujoco-td3:
|
|||
random_timesteps: 10000
|
||||
replay_buffer_config:
|
||||
type: MultiAgentReplayBuffer
|
||||
learning_starts: 10000
|
||||
num_steps_sampled_before_learning_starts: 10000
|
||||
# === Evaluation ===
|
||||
evaluation_interval: 10
|
||||
evaluation_duration: 10
|
||||
|
|
|
@ -12,7 +12,7 @@ pendulum-td3-fake-gpus:
|
|||
|
||||
replay_buffer_config:
|
||||
type: MultiAgentReplayBuffer
|
||||
learning_starts: 5000
|
||||
num_steps_sampled_before_learning_starts: 5000
|
||||
exploration_config:
|
||||
random_timesteps: 5000
|
||||
evaluation_interval: 10
|
||||
|
|
|
@ -14,6 +14,6 @@ pendulum-td3:
|
|||
# === Exploration ===
|
||||
replay_buffer_config:
|
||||
type: MultiAgentReplayBuffer
|
||||
learning_starts: 5000
|
||||
num_steps_sampled_before_learning_starts: 5000
|
||||
exploration_config:
|
||||
random_timesteps: 5000
|
||||
|
|
|
@ -77,7 +77,6 @@ class MultiAgentMixInReplayBuffer(MultiAgentPrioritizedReplayBuffer):
|
|||
capacity: int = 10000,
|
||||
storage_unit: str = "timesteps",
|
||||
num_shards: int = 1,
|
||||
learning_starts: int = 1000,
|
||||
replay_mode: str = "independent",
|
||||
replay_sequence_override: bool = True,
|
||||
replay_sequence_length: int = 1,
|
||||
|
@ -99,9 +98,6 @@ class MultiAgentMixInReplayBuffer(MultiAgentPrioritizedReplayBuffer):
|
|||
are stored in episodes, replay_sequence_length is ignored.
|
||||
num_shards: The number of buffer shards that exist in total
|
||||
(including this one).
|
||||
learning_starts: Number of timesteps after which a call to
|
||||
`replay()` will yield samples (before that, `replay()` will
|
||||
return None).
|
||||
replay_mode: One of "independent" or "lockstep". Determines,
|
||||
whether batches are sampled independently or to an equal
|
||||
amount.
|
||||
|
@ -152,7 +148,6 @@ class MultiAgentMixInReplayBuffer(MultiAgentPrioritizedReplayBuffer):
|
|||
capacity=capacity,
|
||||
storage_unit=storage_unit,
|
||||
num_shards=num_shards,
|
||||
learning_starts=learning_starts,
|
||||
replay_mode=replay_mode,
|
||||
replay_sequence_override=replay_sequence_override,
|
||||
replay_sequence_length=replay_sequence_length,
|
||||
|
@ -266,9 +261,6 @@ class MultiAgentMixInReplayBuffer(MultiAgentPrioritizedReplayBuffer):
|
|||
# Merge kwargs, overwriting standard call arguments
|
||||
kwargs = merge_dicts_with_warning(self.underlying_buffer_call_args, kwargs)
|
||||
|
||||
if self._num_added < self.replay_starts:
|
||||
return MultiAgentBatch({}, 0)
|
||||
|
||||
def mix_batches(_policy_id):
|
||||
"""Mixes old with new samples.
|
||||
|
||||
|
|
|
@ -40,7 +40,6 @@ class MultiAgentPrioritizedReplayBuffer(
|
|||
capacity: int = 10000,
|
||||
storage_unit: str = "timesteps",
|
||||
num_shards: int = 1,
|
||||
learning_starts: int = 1000,
|
||||
replay_mode: str = "independent",
|
||||
replay_sequence_override: bool = True,
|
||||
replay_sequence_length: int = 1,
|
||||
|
@ -63,9 +62,6 @@ class MultiAgentPrioritizedReplayBuffer(
|
|||
ignored.
|
||||
num_shards: The number of buffer shards that exist in total
|
||||
(including this one).
|
||||
learning_starts: Number of timesteps after which a call to
|
||||
`replay()` will yield samples (before that, `replay()` will
|
||||
return None).
|
||||
replay_mode: One of "independent" or "lockstep". Determines,
|
||||
whether batches are sampled independently or to an equal
|
||||
amount.
|
||||
|
@ -136,7 +132,6 @@ class MultiAgentPrioritizedReplayBuffer(
|
|||
capacity=shard_capacity,
|
||||
storage_unit=storage_unit,
|
||||
replay_sequence_override=replay_sequence_override,
|
||||
learning_starts=learning_starts,
|
||||
replay_mode=replay_mode,
|
||||
replay_sequence_length=replay_sequence_length,
|
||||
replay_burn_in=replay_burn_in,
|
||||
|
|
|
@ -66,7 +66,6 @@ class MultiAgentReplayBuffer(ReplayBuffer):
|
|||
capacity: int = 10000,
|
||||
storage_unit: str = "timesteps",
|
||||
num_shards: int = 1,
|
||||
learning_starts: int = 1000,
|
||||
replay_mode: str = "independent",
|
||||
replay_sequence_override: bool = True,
|
||||
replay_sequence_length: int = 1,
|
||||
|
@ -84,9 +83,6 @@ class MultiAgentReplayBuffer(ReplayBuffer):
|
|||
are stored in episodes, replay_sequence_length is ignored.
|
||||
num_shards: The number of buffer shards that exist in total
|
||||
(including this one).
|
||||
learning_starts: Number of timesteps after which a call to
|
||||
`sample()` will yield samples (before that, `sample()` will
|
||||
return None).
|
||||
replay_mode: One of "independent" or "lockstep". Determines,
|
||||
whether batches are sampled independently or to an equal
|
||||
amount.
|
||||
|
@ -121,7 +117,6 @@ class MultiAgentReplayBuffer(ReplayBuffer):
|
|||
else:
|
||||
self.underlying_buffer_call_args = {}
|
||||
self.replay_sequence_override = replay_sequence_override
|
||||
self.replay_starts = learning_starts // num_shards
|
||||
self.replay_mode = replay_mode
|
||||
self.replay_sequence_length = replay_sequence_length
|
||||
self.replay_burn_in = replay_burn_in
|
||||
|
@ -318,8 +313,6 @@ class MultiAgentReplayBuffer(ReplayBuffer):
|
|||
# Merge kwargs, overwriting standard call arguments
|
||||
kwargs = merge_dicts_with_warning(self.underlying_buffer_call_args, kwargs)
|
||||
|
||||
if self._num_added < self.replay_starts:
|
||||
return MultiAgentBatch({}, 0)
|
||||
with self.replay_timer:
|
||||
# Lockstep mode: Sample from all policies at the same time an
|
||||
# equal amount of steps.
|
||||
|
|
|
@ -51,7 +51,6 @@ class TestMixInMultiAgentReplayBuffer(unittest.TestCase):
|
|||
capacity=self.capacity,
|
||||
storage_unit="episodes",
|
||||
replay_ratio=0.5,
|
||||
learning_starts=0,
|
||||
)
|
||||
|
||||
# If we insert and replay n times, expect roughly return batches of
|
||||
|
@ -75,7 +74,6 @@ class TestMixInMultiAgentReplayBuffer(unittest.TestCase):
|
|||
capacity=100,
|
||||
storage_unit="sequences",
|
||||
replay_ratio=0.5,
|
||||
learning_starts=0,
|
||||
replay_sequence_length=2,
|
||||
replay_sequence_override=True,
|
||||
)
|
||||
|
@ -99,7 +97,6 @@ class TestMixInMultiAgentReplayBuffer(unittest.TestCase):
|
|||
capacity=self.capacity,
|
||||
storage_unit="timesteps",
|
||||
replay_ratio=0.333,
|
||||
learning_starts=0,
|
||||
)
|
||||
# Expect exactly 0 samples to be returned (buffer empty).
|
||||
sample = buffer.sample(10)
|
||||
|
@ -132,7 +129,8 @@ class TestMixInMultiAgentReplayBuffer(unittest.TestCase):
|
|||
|
||||
# 90% replay ratio.
|
||||
buffer = MultiAgentMixInReplayBuffer(
|
||||
capacity=self.capacity, replay_ratio=0.9, learning_starts=0
|
||||
capacity=self.capacity,
|
||||
replay_ratio=0.9,
|
||||
)
|
||||
|
||||
# If we insert and replay n times, expect roughly return batches of
|
||||
|
@ -148,7 +146,8 @@ class TestMixInMultiAgentReplayBuffer(unittest.TestCase):
|
|||
|
||||
# 0% replay ratio -> Only new samples.
|
||||
buffer = MultiAgentMixInReplayBuffer(
|
||||
capacity=self.capacity, replay_ratio=0.0, learning_starts=0
|
||||
capacity=self.capacity,
|
||||
replay_ratio=0.0,
|
||||
)
|
||||
# Add a new batch.
|
||||
batch = self._generate_single_timesteps()
|
||||
|
@ -175,7 +174,8 @@ class TestMixInMultiAgentReplayBuffer(unittest.TestCase):
|
|||
|
||||
# 100% replay ratio -> Only new samples.
|
||||
buffer = MultiAgentMixInReplayBuffer(
|
||||
capacity=self.capacity, replay_ratio=1.0, learning_starts=0
|
||||
capacity=self.capacity,
|
||||
replay_ratio=1.0,
|
||||
)
|
||||
# Expect exactly 0 samples to be returned (buffer empty).
|
||||
sample = buffer.sample(1)
|
||||
|
|
|
@ -79,7 +79,9 @@ class TestMultiAgentPrioritizedReplayBuffer(unittest.TestCase):
|
|||
|
||||
# Test lockstep mode with different policy ids using MultiAgentBatches
|
||||
buffer = MultiAgentPrioritizedReplayBuffer(
|
||||
capacity=10, replay_mode="independent", learning_starts=0, num_shards=1
|
||||
capacity=10,
|
||||
replay_mode="independent",
|
||||
num_shards=1,
|
||||
)
|
||||
|
||||
self._add_multi_agent_batch_to_buffer(buffer, num_policies=1, num_batches=1)
|
||||
|
@ -99,7 +101,6 @@ class TestMultiAgentPrioritizedReplayBuffer(unittest.TestCase):
|
|||
buffer = MultiAgentPrioritizedReplayBuffer(
|
||||
capacity=buffer_size,
|
||||
replay_mode="lockstep",
|
||||
learning_starts=0,
|
||||
num_shards=1,
|
||||
)
|
||||
|
||||
|
@ -131,7 +132,6 @@ class TestMultiAgentPrioritizedReplayBuffer(unittest.TestCase):
|
|||
buffer = MultiAgentPrioritizedReplayBuffer(
|
||||
capacity=buffer_size,
|
||||
replay_mode="independent",
|
||||
learning_starts=0,
|
||||
num_shards=1,
|
||||
)
|
||||
|
||||
|
@ -173,7 +173,6 @@ class TestMultiAgentPrioritizedReplayBuffer(unittest.TestCase):
|
|||
prioritized_replay_beta=self.beta,
|
||||
replay_mode="independent",
|
||||
replay_sequence_length=2,
|
||||
learning_starts=0,
|
||||
num_shards=1,
|
||||
)
|
||||
|
||||
|
@ -222,7 +221,6 @@ class TestMultiAgentPrioritizedReplayBuffer(unittest.TestCase):
|
|||
prioritized_replay_alpha=self.alpha,
|
||||
prioritized_replay_beta=self.beta,
|
||||
replay_mode="independent",
|
||||
learning_starts=0,
|
||||
num_shards=1,
|
||||
)
|
||||
new_buffer.set_state(state)
|
||||
|
|
|
@ -85,7 +85,9 @@ class TestMultiAgentReplayBuffer(unittest.TestCase):
|
|||
|
||||
# Test lockstep mode with different policy ids using MultiAgentBatches
|
||||
buffer = MultiAgentReplayBuffer(
|
||||
capacity=10, replay_mode="independent", learning_starts=0, num_shards=1
|
||||
capacity=10,
|
||||
replay_mode="independent",
|
||||
num_shards=1,
|
||||
)
|
||||
|
||||
self._add_multi_agent_batch_to_buffer(buffer, num_policies=1, num_batches=1)
|
||||
|
@ -105,7 +107,6 @@ class TestMultiAgentReplayBuffer(unittest.TestCase):
|
|||
buffer = MultiAgentReplayBuffer(
|
||||
capacity=buffer_size,
|
||||
replay_mode="lockstep",
|
||||
learning_starts=0,
|
||||
num_shards=1,
|
||||
)
|
||||
|
||||
|
@ -142,7 +143,6 @@ class TestMultiAgentReplayBuffer(unittest.TestCase):
|
|||
replay_mode="independent",
|
||||
storage_unit="sequences",
|
||||
replay_sequence_length=2,
|
||||
learning_starts=0,
|
||||
num_shards=1,
|
||||
)
|
||||
|
||||
|
@ -188,7 +188,7 @@ class TestMultiAgentReplayBuffer(unittest.TestCase):
|
|||
buffer = MultiAgentReplayBuffer(
|
||||
capacity=buffer_size,
|
||||
replay_mode="independent",
|
||||
learning_starts=0,
|
||||
num_steps_sampled_before_learning_starts=0,
|
||||
num_shards=1,
|
||||
)
|
||||
|
||||
|
@ -234,7 +234,6 @@ class TestMultiAgentReplayBuffer(unittest.TestCase):
|
|||
buffer = MultiAgentReplayBuffer(
|
||||
capacity=buffer_size,
|
||||
replay_mode="lockstep",
|
||||
learning_starts=0,
|
||||
num_shards=1,
|
||||
underlying_buffer_config=replay_buffer_config,
|
||||
)
|
||||
|
@ -280,7 +279,6 @@ class TestMultiAgentReplayBuffer(unittest.TestCase):
|
|||
buffer = MultiAgentReplayBuffer(
|
||||
capacity=buffer_size,
|
||||
replay_mode="independent",
|
||||
learning_starts=0,
|
||||
num_shards=1,
|
||||
underlying_buffer_config=prioritized_replay_buffer_config,
|
||||
)
|
||||
|
@ -302,7 +300,6 @@ class TestMultiAgentReplayBuffer(unittest.TestCase):
|
|||
buffer = MultiAgentReplayBuffer(
|
||||
capacity=buffer_size,
|
||||
replay_mode="independent",
|
||||
learning_starts=0,
|
||||
num_shards=1,
|
||||
)
|
||||
|
||||
|
@ -315,7 +312,7 @@ class TestMultiAgentReplayBuffer(unittest.TestCase):
|
|||
another_buffer = MultiAgentReplayBuffer(
|
||||
capacity=buffer_size,
|
||||
replay_mode="independent",
|
||||
learning_starts=0,
|
||||
num_steps_sampled_before_learning_starts=0,
|
||||
num_shards=1,
|
||||
)
|
||||
|
||||
|
|
|
@ -238,7 +238,6 @@ def validate_buffer_config(config: dict) -> None:
|
|||
"prioritized_replay_eps",
|
||||
"no_local_replay_buffer",
|
||||
"replay_zero_init_states",
|
||||
"learning_starts",
|
||||
"replay_buffer_shards_colocated_with_driver",
|
||||
]
|
||||
for k in keys_with_deprecated_positions:
|
||||
|
@ -262,6 +261,19 @@ def validate_buffer_config(config: dict) -> None:
|
|||
)
|
||||
config["replay_buffer_config"]["replay_mode"] = replay_mode
|
||||
|
||||
learning_starts = config.get(
|
||||
"learning_starts",
|
||||
config.get("replay_buffer_config", {}).get("learning_starts", DEPRECATED_VALUE),
|
||||
)
|
||||
if learning_starts != DEPRECATED_VALUE:
|
||||
deprecation_warning(
|
||||
old="config['learning_starts'] or"
|
||||
"config['replay_buffer_config']['learning_starts']",
|
||||
help="config['num_steps_sampled_before_learning_starts']",
|
||||
error=False,
|
||||
)
|
||||
config["num_steps_sampled_before_learning_starts"] = learning_starts
|
||||
|
||||
# Can't use DEPRECATED_VALUE here because this is also a deliberate
|
||||
# value set for some algorithms
|
||||
# TODO: (Artur): Compare to DEPRECATED_VALUE on deprecation
|
||||
|
|
Loading…
Add table
Reference in a new issue