[RLlib] Multi-GPU for tf-DQN/PG/A2C. (#13393)

This commit is contained in:
Sven Mika 2021-03-08 15:41:27 +01:00 committed by GitHub
parent b0bf44b154
commit 732197e23a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
41 changed files with 385 additions and 134 deletions

View file

@ -8,30 +8,30 @@ RLlib Algorithms
Available Algorithms - Overview
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
=================== ========== ======================= ================== =========== =============================================================
Algorithm Frameworks Discrete Actions Continuous Actions Multi-Agent Model Support
=================== ========== ======================= ================== =========== =============================================================
`A2C, A3C`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+LSTM auto-wrapping`_, `+Attention`_, `+autoreg`_
`ARS`_ tf + torch **Yes** **Yes** No
`BC`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_
`ES`_ tf + torch **Yes** **Yes** No
`DDPG`_, `TD3`_ tf + torch No **Yes** **Yes**
`APEX-DDPG`_ tf + torch No **Yes** **Yes**
`Dreamer`_ torch No **Yes** No `+RNN`_
`DQN`_, `Rainbow`_ tf + torch **Yes** `+parametric`_ No **Yes**
`APEX-DQN`_ tf + torch **Yes** `+parametric`_ No **Yes**
`IMPALA`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+LSTM auto-wrapping`_, `+Attention`_, `+autoreg`_
`MAML`_ tf + torch No **Yes** No
`MARWIL`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_
`MBMPO`_ torch No **Yes** No
`PG`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+LSTM auto-wrapping`_, `+Attention`_, `+autoreg`_
`PPO`_, `APPO`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+LSTM auto-wrapping`_, `+Attention`_, `+autoreg`_
`R2D2`_ tf + torch **Yes** `+parametric`_ No **Yes** `+RNN`_, `+LSTM auto-wrapping`_, `+autoreg`_
`SAC`_ tf + torch **Yes** **Yes** **Yes**
`SlateQ`_ torch **Yes** No No
`LinUCB`_, `LinTS`_ torch **Yes** `+parametric`_ No **Yes**
`AlphaZero`_ torch **Yes** `+parametric`_ No No
=================== ========== ======================= ================== =========== =============================================================
=================== ========== ======================= ================== =========== ============================================================= =========
Algorithm Frameworks Discrete Actions Continuous Actions Multi-Agent Model Support Multi-GPU
=================== ========== ======================= ================== =========== ============================================================= =========
`A2C, A3C`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+LSTM auto-wrapping`_, `+Attention`_, `+autoreg`_ tf (A2C)
`ARS`_ tf + torch **Yes** **Yes** No No
`BC`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_ No
`ES`_ tf + torch **Yes** **Yes** No No
`DDPG`_, `TD3`_ tf + torch No **Yes** **Yes** No
`APEX-DDPG`_ tf + torch No **Yes** **Yes** No
`Dreamer`_ torch No **Yes** No `+RNN`_ No
`DQN`_, `Rainbow`_ tf + torch **Yes** `+parametric`_ No **Yes** tf
`APEX-DQN`_ tf + torch **Yes** `+parametric`_ No **Yes** No
`IMPALA`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+LSTM auto-wrapping`_, `+Attention`_, `+autoreg`_ tf
`MAML`_ tf + torch No **Yes** No No
`MARWIL`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_ No
`MBMPO`_ torch No **Yes** No No
`PG`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+LSTM auto-wrapping`_, `+Attention`_, `+autoreg`_ tf
`PPO`_, `APPO`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+LSTM auto-wrapping`_, `+Attention`_, `+autoreg`_ tf
`R2D2`_ tf + torch **Yes** `+parametric`_ No **Yes** `+RNN`_, `+LSTM auto-wrapping`_, `+autoreg`_ No
`SAC`_ tf + torch **Yes** **Yes** **Yes** No
`SlateQ`_ torch **Yes** No No No
`LinUCB`_, `LinTS`_ torch **Yes** `+parametric`_ No **Yes** No
`AlphaZero`_ torch **Yes** `+parametric`_ No No No
=================== ========== ======================= ================== =========== ============================================================= =========
Multi-Agent only Methods

View file

@ -173,7 +173,6 @@ py_test(
"tuned_examples/dqn/cartpole-dqn.yaml",
"tuned_examples/dqn/cartpole-dqn-softq.yaml",
"tuned_examples/dqn/cartpole-dqn-param-noise.yaml",
"tuned_examples/dqn/cartpole-r2d2.yaml",
],
args = ["--yaml-dir=tuned_examples/dqn"]
)
@ -188,7 +187,6 @@ py_test(
"tuned_examples/dqn/cartpole-dqn.yaml",
"tuned_examples/dqn/cartpole-dqn-softq.yaml",
"tuned_examples/dqn/cartpole-dqn-param-noise.yaml",
"tuned_examples/dqn/cartpole-r2d2.yaml",
],
args = ["--yaml-dir=tuned_examples/dqn", "--framework=torch"]
)
@ -603,7 +601,7 @@ py_test(
py_test(
name = "test_pg",
tags = ["agents_dir"],
size = "small",
size = "medium",
srcs = ["agents/pg/tests/test_pg.py"]
)

View file

@ -7,7 +7,7 @@ from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
from ray.rllib.execution.train_ops import ComputeGradients, AverageGradients, \
ApplyGradients, TrainOneStep
ApplyGradients, TrainTFMultiGPU, TrainOneStep
from ray.rllib.utils import merge_dicts
A2C_DEFAULT_CONFIG = merge_dicts(
@ -47,11 +47,23 @@ def execution_plan(workers, config):
.for_each(ApplyGradients(workers)))
else:
# In normal mode, we execute one SGD step per each train batch.
if config["simple_optimizer"]:
train_step_op = TrainOneStep(workers)
else:
train_step_op = TrainTFMultiGPU(
workers=workers,
sgd_minibatch_size=config["train_batch_size"],
num_sgd_iter=1,
num_gpus=config["num_gpus"],
shuffle_sequences=True,
_fake_gpus=config["_fake_gpus"],
framework=config.get("framework"))
train_op = rollouts.combine(
ConcatBatches(
min_batch_size=config["train_batch_size"],
count_steps_by=config["multiagent"][
"count_steps_by"])).for_each(TrainOneStep(workers))
"count_steps_by"])).for_each(train_step_op)
return StandardMetricsReporting(train_op, workers, config)

View file

@ -1,3 +1,4 @@
import copy
import unittest
import ray
@ -33,6 +34,30 @@ class TestA2C(unittest.TestCase):
check_compute_single_action(trainer)
trainer.stop()
def test_a2c_fake_multi_gpu_learning(self):
"""Test whether A2CTrainer can learn CartPole w/ faked multi-GPU."""
config = copy.deepcopy(a3c.a2c.A2C_DEFAULT_CONFIG)
# Fake GPU setup.
config["num_gpus"] = 2
config["_fake_gpus"] = True
config["framework"] = "tf"
# Mimic tuned_example for A2C CartPole.
config["lr"] = 0.001
trainer = a3c.A2CTrainer(config=config, env="CartPole-v0")
num_iterations = 100
learnt = False
for i in range(num_iterations):
results = trainer.train()
print("reward={}".format(results["episode_reward_mean"]))
if results["episode_reward_mean"] > 100.0:
learnt = True
break
assert learnt, "A2C multi-GPU (with fake-GPUs) did not learn CartPole!"
trainer.stop()
def test_a2c_exec_impl(ray_start_regular):
config = {"min_iter_time_s": 0}
for _ in framework_iterator(config):

View file

@ -46,6 +46,8 @@ CQL_DEFAULT_CONFIG = merge_dicts(
def validate_config(config: TrainerConfigDict):
if config["num_gpus"] > 1:
raise ValueError("`num_gpus` > 1 not yet supported for CQL!")
if config["framework"] == "tf":
raise ValueError("Tensorflow CQL not implemented yet!")

View file

@ -151,6 +151,8 @@ DEFAULT_CONFIG = with_common_config({
def validate_config(config):
if config["num_gpus"] > 1:
raise ValueError("`num_gpus` > 1 not yet supported for DDPG!")
if config["model"]["custom_model"]:
logger.warning(
"Setting use_state_preprocessor=True since a custom model "

View file

@ -123,6 +123,8 @@ def ddpg_actor_critic_loss(policy, model, _, train_batch):
model_out_tp1, _ = model(input_dict_next, [], None)
target_model_out_tp1, _ = policy.target_model(input_dict_next, [], None)
policy.target_q_func_vars = policy.target_model.variables()
# Policy network evaluation.
policy_t = model.get_policy_output(model_out_t)
policy_tp1 = \

View file

@ -17,8 +17,8 @@ import copy
from typing import Tuple
import ray
from ray.rllib.agents.dqn.dqn import DEFAULT_CONFIG as DQN_CONFIG
from ray.rllib.agents.dqn.dqn import DQNTrainer, calculate_rr_weights
from ray.rllib.agents.dqn.dqn import calculate_rr_weights, \
DEFAULT_CONFIG as DQN_CONFIG, DQNTrainer, validate_config
from ray.rllib.agents.dqn.learner_thread import LearnerThread
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.common import (STEPS_TRAINED_COUNTER,
@ -195,7 +195,14 @@ def apex_execution_plan(workers: WorkerSet,
selected_workers=selected_workers).for_each(add_apex_metrics)
def apex_validate_config(config):
if config["num_gpus"] > 1:
raise ValueError("`num_gpus` > 1 not yet supported for APEX-DQN!")
validate_config(config)
ApexTrainer = DQNTrainer.with_updates(
name="APEX",
default_config=APEX_DEFAULT_CONFIG,
validate_config=apex_validate_config,
execution_plan=apex_execution_plan)

View file

@ -22,7 +22,8 @@ from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.replay_buffer import LocalReplayBuffer
from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
from ray.rllib.execution.rollout_ops import ParallelRollouts
from ray.rllib.execution.train_ops import TrainOneStep, UpdateTargetNetwork
from ray.rllib.execution.train_ops import TrainOneStep, UpdateTargetNetwork, \
TrainTFMultiGPU
from ray.rllib.policy.policy import LEARNER_STATS_KEY, Policy
from ray.rllib.utils.typing import TrainerConfigDict
from ray.util.iter import LocalIterator
@ -170,6 +171,13 @@ def validate_config(config: TrainerConfigDict) -> None:
raise ValueError("Prioritized replay is not supported when "
"replay_sequence_length > 1.")
# Multi-agent mode and multi-GPU optimizer.
if config["multiagent"]["policies"] and not config["simple_optimizer"]:
logger.info(
"In multi-agent mode, policies will be optimized sequentially "
"by the multi-GPU optimizer. Consider setting "
"simple_optimizer=True if this doesn't work for you.")
def execution_plan(workers: WorkerSet,
config: TrainerConfigDict) -> LocalIterator[dict]:
@ -231,9 +239,22 @@ def execution_plan(workers: WorkerSet,
# returned from the LocalReplay() iterator is passed to TrainOneStep to
# take a SGD step, and then we decide whether to update the target network.
post_fn = config.get("before_learn_on_batch") or (lambda b, *a: b)
if config["simple_optimizer"]:
train_step_op = TrainOneStep(workers)
else:
train_step_op = TrainTFMultiGPU(
workers=workers,
sgd_minibatch_size=config["train_batch_size"],
num_sgd_iter=1,
num_gpus=config["num_gpus"],
shuffle_sequences=True,
_fake_gpus=config["_fake_gpus"],
framework=config.get("framework"))
replay_op = Replay(local_buffer=local_replay_buffer) \
.for_each(lambda x: post_fn(x, workers, config)) \
.for_each(TrainOneStep(workers)) \
.for_each(train_step_op) \
.for_each(update_prio) \
.for_each(UpdateTargetNetwork(
workers, config["target_network_update_freq"]))

View file

@ -164,7 +164,7 @@ def build_q_model(policy: Policy, obs_space: gym.spaces.Space,
else:
num_outputs = action_space.n
policy.q_model = ModelCatalog.get_model_v2(
q_model = ModelCatalog.get_model_v2(
obs_space=obs_space,
action_space=action_space,
num_outputs=num_outputs,
@ -206,7 +206,7 @@ def build_q_model(policy: Policy, obs_space: gym.spaces.Space,
getattr(policy, "exploration", None), ParameterNoise)
or config["exploration_config"]["type"] == "ParameterNoise")
return policy.q_model
return q_model
def get_distribution_inputs_and_class(policy: Policy,
@ -240,7 +240,7 @@ def build_q_losses(policy: Policy, model, _,
# q network evaluation
q_t, q_logits_t, q_dist_t, _ = compute_q_values(
policy,
policy.q_model, {"obs": train_batch[SampleBatch.CUR_OBS]},
model, {"obs": train_batch[SampleBatch.CUR_OBS]},
state_batches=None,
explore=False)
@ -265,7 +265,7 @@ def build_q_losses(policy: Policy, model, _,
if config["double_q"]:
q_tp1_using_online_net, q_logits_tp1_using_online_net, \
q_dist_tp1_using_online_net, _ = compute_q_values(
policy, policy.q_model,
policy, model,
{"obs": train_batch[SampleBatch.NEXT_OBS]},
state_batches=None,
explore=False)

View file

@ -161,7 +161,7 @@ def build_q_model_and_distribution(
isinstance(getattr(policy, "exploration", None), ParameterNoise)
or config["exploration_config"]["type"] == "ParameterNoise")
policy.q_model = ModelCatalog.get_model_v2(
model = ModelCatalog.get_model_v2(
obs_space=obs_space,
action_space=action_space,
num_outputs=num_outputs,
@ -180,7 +180,7 @@ def build_q_model_and_distribution(
# generically into ModelCatalog.
add_layer_norm=add_layer_norm)
policy.q_func_vars = policy.q_model.variables()
policy.q_func_vars = model.variables()
policy.target_q_model = ModelCatalog.get_model_v2(
obs_space=obs_space,
@ -203,7 +203,7 @@ def build_q_model_and_distribution(
policy.target_q_func_vars = policy.target_q_model.variables()
return policy.q_model, TorchCategorical
return model, TorchCategorical
def get_distribution_inputs_and_class(
@ -241,7 +241,7 @@ def build_q_losses(policy: Policy, model, _,
# Q-network evaluation.
q_t, q_logits_t, q_probs_t, _ = compute_q_values(
policy,
policy.q_model, {"obs": train_batch[SampleBatch.CUR_OBS]},
model, {"obs": train_batch[SampleBatch.CUR_OBS]},
explore=False,
is_training=True)
@ -267,7 +267,7 @@ def build_q_losses(policy: Policy, model, _,
q_tp1_using_online_net, q_logits_tp1_using_online_net, \
q_dist_tp1_using_online_net, _ = compute_q_values(
policy,
policy.q_model,
model,
{"obs": train_batch[SampleBatch.NEXT_OBS]},
explore=False,
is_training=True)

View file

@ -78,7 +78,7 @@ def r2d2_loss(policy: Policy, model, _,
# Q-network evaluation (at t).
q, _, _, _ = compute_q_values(
policy,
policy.q_model,
model,
train_batch,
state_batches=state_batches,
seq_lens=train_batch.get("seq_lens"),

View file

@ -86,7 +86,7 @@ def r2d2_loss(policy: Policy, model, _,
# Q-network evaluation (at t).
q, _, _, _ = compute_q_values(
policy,
policy.q_model,
model,
train_batch,
state_batches=state_batches,
seq_lens=train_batch.get("seq_lens"),

View file

@ -22,7 +22,8 @@ from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.replay_buffer import LocalReplayBuffer
from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
from ray.rllib.execution.rollout_ops import ParallelRollouts
from ray.rllib.execution.train_ops import TrainOneStep, UpdateTargetNetwork
from ray.rllib.execution.train_ops import TrainTFMultiGPU, TrainOneStep, \
UpdateTargetNetwork
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.typing import TrainerConfigDict
from ray.util.iter import LocalIterator
@ -139,9 +140,21 @@ def execution_plan(workers: WorkerSet,
store_op = rollouts.for_each(
StoreToReplayBuffer(local_buffer=local_replay_buffer))
if config["simple_optimizer"]:
train_step_op = TrainOneStep(workers)
else:
train_step_op = TrainTFMultiGPU(
workers=workers,
sgd_minibatch_size=config["train_batch_size"],
num_sgd_iter=1,
num_gpus=config["num_gpus"],
shuffle_sequences=True,
_fake_gpus=config["_fake_gpus"],
framework=config.get("framework"))
# (2) Read and train on experiences from the replay buffer.
replay_op = Replay(local_buffer=local_replay_buffer) \
.for_each(TrainOneStep(workers)) \
.for_each(train_step_op) \
.for_each(UpdateTargetNetwork(
workers, config["target_network_update_freq"]))

View file

@ -79,7 +79,7 @@ def build_q_models(policy: Policy, obs_space: gym.spaces.Space,
raise UnsupportedSpaceException(
"Action space {} is not supported for DQN.".format(action_space))
policy.q_model = ModelCatalog.get_model_v2(
model = ModelCatalog.get_model_v2(
obs_space=obs_space,
action_space=action_space,
num_outputs=action_space.n,
@ -95,10 +95,10 @@ def build_q_models(policy: Policy, obs_space: gym.spaces.Space,
framework=config["framework"],
name=Q_TARGET_SCOPE)
policy.q_func_vars = policy.q_model.variables()
policy.q_func_vars = model.variables()
policy.target_q_func_vars = policy.target_q_model.variables()
return policy.q_model
return model
def get_distribution_inputs_and_class(
@ -114,6 +114,7 @@ def get_distribution_inputs_and_class(
q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals
policy.q_values = q_vals
policy.q_func_vars = q_model.variables()
return policy.q_values, (TorchCategorical
if policy.config["framework"] == "torch" else
Categorical), [] # state-outs
@ -135,10 +136,7 @@ def build_q_losses(policy: Policy, model: ModelV2,
"""
# q network evaluation
q_t = compute_q_values(
policy,
policy.q_model,
train_batch[SampleBatch.CUR_OBS],
explore=False)
policy, policy.model, train_batch[SampleBatch.CUR_OBS], explore=False)
# target q network evalution
q_tp1 = compute_q_values(

View file

@ -38,7 +38,7 @@ class TargetNetworkMixin:
# target Q network.
assert len(self.q_func_vars) == len(self.target_q_func_vars), \
(self.q_func_vars, self.target_q_func_vars)
self.target_q_model.load_state_dict(self.q_model.state_dict())
self.target_q_model.load_state_dict(self.model.state_dict())
self.update_target = do_update
@ -67,7 +67,7 @@ def build_q_losses(policy: Policy, model, dist_class,
# q network evaluation
q_t = compute_q_values(
policy,
policy.q_model,
policy.model,
train_batch[SampleBatch.CUR_OBS],
explore=False,
is_training=True)

View file

@ -1,3 +1,4 @@
import copy
import numpy as np
import unittest
@ -50,6 +51,34 @@ class TestDQN(unittest.TestCase):
check_compute_single_action(trainer)
trainer.stop()
def test_dqn_fake_multi_gpu_learning(self):
"""Test whether DQNTrainer can learn CartPole w/ faked multi-GPU."""
config = copy.deepcopy(dqn.DEFAULT_CONFIG)
# Fake GPU setup.
config["num_gpus"] = 2
config["_fake_gpus"] = True
config["framework"] = "tf"
# Double batch size (2 GPUs).
config["train_batch_size"] = 64
# Mimic tuned_example for DQN CartPole.
config["n_step"] = 3
config["model"]["fcnet_hiddens"] = [64]
config["model"]["fcnet_activation"] = "linear"
trainer = dqn.DQNTrainer(config=config, env="CartPole-v0")
num_iterations = 200
learnt = False
for i in range(num_iterations):
results = trainer.train()
print("reward={}".format(results["episode_reward_mean"]))
if results["episode_reward_mean"] > 100.0:
learnt = True
break
assert learnt, "DQN multi-GPU (with fake-GPUs) did not learn CartPole!"
trainer.stop()
def test_dqn_exploration_and_soft_q_config(self):
"""Tests, whether a DQN Agent outputs exploration/softmaxed actions."""
config = dqn.DEFAULT_CONFIG.copy()

View file

@ -1,3 +1,4 @@
import copy
import numpy as np
import unittest
@ -44,6 +45,28 @@ class TestSimpleQ(unittest.TestCase):
check_compute_single_action(trainer)
def test_simple_q_fake_multi_gpu_learning(self):
"""Test whether SimpleQTrainer learns CartPole w/ fake GPUs."""
config = copy.deepcopy(dqn.SIMPLE_Q_DEFAULT_CONFIG)
# Fake GPU setup.
config["num_gpus"] = 2
config["_fake_gpus"] = True
config["framework"] = "tf"
trainer = dqn.SimpleQTrainer(config=config, env="CartPole-v0")
num_iterations = 200
learnt = False
for i in range(num_iterations):
results = trainer.train()
print("reward={}".format(results["episode_reward_mean"]))
if results["episode_reward_mean"] > 75.0:
learnt = True
break
assert learnt, "SimpleQ multi-GPU (with fake-GPUs) did not " \
"learn CartPole!"
trainer.stop()
def test_simple_q_loss_function(self):
"""Tests the Simple-Q loss function results on all frameworks."""
config = dqn.SIMPLE_Q_DEFAULT_CONFIG.copy()

View file

@ -247,6 +247,8 @@ def get_policy_class(config):
def validate_config(config):
config["action_repeat"] = config["env_config"]["frame_skip"]
if config["num_gpus"] > 1:
raise ValueError("`num_gpus` > 1 not yet supported for Dreamer!")
if config["framework"] != "torch":
raise ValueError("Dreamer not supported in Tensorflow yet!")
if config["batch_mode"] != "complete_episodes":

View file

@ -178,6 +178,8 @@ def get_policy_class(config):
def validate_config(config):
if config["num_gpus"] > 1:
raise ValueError("`num_gpus` > 1 not yet supported for ES/ARS!")
if config["num_workers"] <= 0:
raise ValueError("`num_workers` must be > 0 for ES!")
if config["evaluation_config"]["num_envs_per_worker"] != 1:

View file

@ -219,6 +219,8 @@ def get_policy_class(config):
def validate_config(config):
if config["num_gpus"] > 1:
raise ValueError("`num_gpus` > 1 not yet supported for MAML!")
if config["inner_adaptation_steps"] <= 0:
raise ValueError("Inner Adaptation Steps must be >=1!")
if config["maml_optimizer_steps"] <= 0:

View file

@ -70,9 +70,15 @@ def execution_plan(workers, config):
return StandardMetricsReporting(train_op, workers, config)
def validate_config(config):
if config["num_gpus"] > 1:
raise ValueError("`num_gpus` > 1 not yet supported for MARWIL!")
MARWILTrainer = build_trainer(
name="MARWIL",
default_config=DEFAULT_CONFIG,
default_policy=MARWILTFPolicy,
get_policy_class=get_policy_class,
validate_config=validate_config,
execution_plan=execution_plan)

View file

@ -416,6 +416,8 @@ def validate_config(config):
Raises:
ValueError: In case something is wrong with the config.
"""
if config["num_gpus"] > 1:
raise ValueError("`num_gpus` > 1 not yet supported for MB-MPO!")
if config["framework"] != "torch":
logger.warning("MB-MPO only supported in PyTorch so far! Switching to "
"`framework=torch`.")

View file

@ -1,3 +1,4 @@
import copy
import numpy as np
import unittest
@ -24,13 +25,42 @@ class TestPG(unittest.TestCase):
config["num_workers"] = 0
num_iterations = 2
for _ in framework_iterator(config):
for fw in framework_iterator(config):
# For tf, build with fake-GPUs.
config["_fake_gpus"] = fw == "tf"
config["num_gpus"] = 2 if fw == "tf" else 0
trainer = pg.PGTrainer(config=config, env="CartPole-v0")
for i in range(num_iterations):
print(trainer.train())
check_compute_single_action(
trainer, include_prev_action_reward=True)
def test_pg_fake_multi_gpu_learning(self):
"""Test whether PGTrainer can learn CartPole w/ faked multi-GPU."""
config = copy.deepcopy(pg.DEFAULT_CONFIG)
# Fake GPU setup.
config["num_gpus"] = 2
config["_fake_gpus"] = True
config["framework"] = "tf"
# Mimic tuned_example for PG CartPole.
config["model"]["fcnet_hiddens"] = [64]
config["model"]["fcnet_activation"] = "linear"
trainer = pg.PGTrainer(config=config, env="CartPole-v0")
num_iterations = 200
learnt = False
for i in range(num_iterations):
results = trainer.train()
print("reward={}".format(results["episode_reward_mean"]))
# Make this test quite short (75.0).
if results["episode_reward_mean"] > 75.0:
learnt = True
break
assert learnt, "PG multi-GPU (with fake-GPUs) did not learn CartPole!"
trainer.stop()
def test_pg_loss_functions(self):
"""Tests the PG loss function math."""
config = pg.DEFAULT_CONFIG.copy()

View file

@ -20,7 +20,7 @@ from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches, \
StandardizeFields, SelectExperiences
from ray.rllib.execution.train_ops import TrainOneStep, TrainTFMultiGPU
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy import LEARNER_STATS_KEY, Policy
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.typing import TrainerConfigDict
@ -86,13 +86,6 @@ DEFAULT_CONFIG = with_common_config({
"batch_mode": "truncate_episodes",
# Which observation filter to apply to the observation.
"observation_filter": "NoFilter",
# Uses the sync samples optimizer instead of the multi-gpu one. This is
# usually slower, but you might want to try it if you run into issues with
# the default optimizer.
"simple_optimizer": False,
# Whether to fake GPUs (using CPUs).
# Set this to True for debugging on non-GPU machines (set `num_gpus` > 0).
"_fake_gpus": False,
# Deprecated keys:
# Share layers for value function. If you set this to True, it's important
@ -139,16 +132,8 @@ def validate_config(config: TrainerConfigDict) -> None:
"function (to estimate the return at the end of the truncated "
"trajectory). Consider setting batch_mode=complete_episodes.")
# Multi-gpu not supported for PyTorch and tf-eager.
if config["framework"] in ["tf2", "tfe", "torch"]:
config["simple_optimizer"] = True
# Performance warning, if "simple" optimizer used with (static-graph) tf.
elif config["simple_optimizer"]:
logger.warning(
"Using the simple minibatch optimizer. This will significantly "
"reduce performance, consider simple_optimizer=False.")
# Multi-agent mode and multi-GPU optimizer.
elif config["multiagent"]["policies"] and not config["simple_optimizer"]:
if config["multiagent"]["policies"] and not config["simple_optimizer"]:
logger.info(
"In multi-agent mode, policies will be optimized sequentially "
"by the multi-GPU optimizer. Consider setting "
@ -184,12 +169,14 @@ class UpdateKL:
def __call__(self, fetches):
def update(pi, pi_id):
assert "kl" not in fetches, (
"kl should be nested under policy id key", fetches)
assert LEARNER_STATS_KEY not in fetches, \
("{} should be nested under policy id key".format(
LEARNER_STATS_KEY), fetches)
if pi_id in fetches:
assert "kl" in fetches[pi_id], (fetches, pi_id)
kl = fetches[pi_id][LEARNER_STATS_KEY].get("kl")
assert kl is not None, (fetches, pi_id)
# Make the actual `Policy.update_kl()` call.
pi.update_kl(fetches[pi_id]["kl"])
pi.update_kl(kl)
else:
logger.warning("No data for {}, not updating kl".format(pi_id))
@ -205,9 +192,11 @@ def warn_about_bad_reward_scales(config, result):
# Warn about excessively high VF loss.
learner_stats = result["info"]["learner"]
if DEFAULT_POLICY_ID in learner_stats:
scaled_vf_loss = (config["vf_loss_coeff"] *
learner_stats[DEFAULT_POLICY_ID]["vf_loss"])
policy_loss = learner_stats[DEFAULT_POLICY_ID]["policy_loss"]
scaled_vf_loss = config["vf_loss_coeff"] * \
learner_stats[DEFAULT_POLICY_ID][LEARNER_STATS_KEY]["vf_loss"]
policy_loss = learner_stats[DEFAULT_POLICY_ID][LEARNER_STATS_KEY][
"policy_loss"]
if config.get("model", {}).get("vf_share_layers") and \
scaled_vf_loss > 100:
logger.warning(
@ -272,13 +261,10 @@ def execution_plan(workers: WorkerSet,
else:
train_op = rollouts.for_each(
TrainTFMultiGPU(
workers,
workers=workers,
sgd_minibatch_size=config["sgd_minibatch_size"],
num_sgd_iter=config["num_sgd_iter"],
num_gpus=config["num_gpus"],
rollout_fragment_length=config["rollout_fragment_length"],
num_envs_per_worker=config["num_envs_per_worker"],
train_batch_size=config["train_batch_size"],
shuffle_sequences=config["shuffle_sequences"],
_fake_gpus=config["_fake_gpus"],
framework=config.get("framework")))

View file

@ -2,6 +2,7 @@ import unittest
import ray
import ray.rllib.agents.ppo as ppo
from ray.rllib.policy.policy import LEARNER_STATS_KEY
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.test_utils import check_compute_single_action, \
framework_iterator
@ -40,7 +41,8 @@ class TestDDPPO(unittest.TestCase):
trainer = ppo.ddppo.DDPPOTrainer(config=config, env="CartPole-v0")
for _ in range(num_iterations):
result = trainer.train()
lr = result["info"]["learner"][DEFAULT_POLICY_ID]["cur_lr"]
lr = result["info"]["learner"][DEFAULT_POLICY_ID][
LEARNER_STATS_KEY]["cur_lr"]
trainer.stop()
assert lr == 0.0, "lr should anneal to 0.0"

View file

@ -90,7 +90,7 @@ class TestPPO(unittest.TestCase):
for _ in framework_iterator(config):
for env in ["CartPole-v0", "MsPacmanNoFrameskip-v4"]:
print("Env={}".format(env))
for lstm in [True, False]:
for lstm in [False, True]:
print("LSTM={}".format(lstm))
config["model"]["use_lstm"] = lstm
config["model"]["lstm_use_prev_action"] = lstm
@ -111,7 +111,7 @@ class TestPPO(unittest.TestCase):
config["num_gpus"] = 2
config["_fake_gpus"] = True
config["framework"] = "tf"
# Mimick tuned_example for PPO CartPole.
# Mimic tuned_example for PPO CartPole.
config["num_workers"] = 1
config["lr"] = 0.0003
config["observation_filter"] = "MeanStdFilter"
@ -127,7 +127,7 @@ class TestPPO(unittest.TestCase):
for i in range(num_iterations):
results = trainer.train()
print(results)
if results["episode_reward_mean"] > 150:
if results["episode_reward_mean"] > 75.0:
learnt = True
break
assert learnt, "PPO multi-GPU (with fake-GPUs) did not learn CartPole!"

View file

@ -167,6 +167,9 @@ def validate_config(config: TrainerConfigDict) -> None:
Raises:
ValueError: In case something is wrong with the config.
"""
if config["num_gpus"] > 1:
raise ValueError("`num_gpus` > 1 not yet supported for SAC!")
if config["use_state_preprocessor"] != DEPRECATED_VALUE:
deprecation_warning(
old="config['use_state_preprocessor']", error=False)

View file

@ -136,6 +136,9 @@ DEFAULT_CONFIG = with_common_config({
def validate_config(config: TrainerConfigDict) -> None:
"""Checks the config based on settings"""
if config["num_gpus"] > 1:
raise ValueError("`num_gpus` > 1 not yet supported for SlateQ!")
if config["framework"] != "torch":
raise ValueError("SlateQ only runs on PyTorch")

View file

@ -303,9 +303,14 @@ COMMON_CONFIG: TrainerConfigDict = {
# === Resource Settings ===
# Number of GPUs to allocate to the trainer process. Note that not all
# algorithms can take advantage of trainer GPUs. This can be fractional
# (e.g., 0.3 GPUs).
# algorithms can take advantage of trainer GPUs. Support for multi-GPU
# is currently only available for tf-[PPO/IMPALA/DQN/PG].
# This can be fractional (e.g., 0.3 GPUs).
"num_gpus": 0,
# Set to True for debugging (multi-)?GPU funcitonality on a CPU machine.
# GPU towers will be simulated by graphs located on CPUs in this case.
# Use `num_gpus` to test for different numbers of fake GPUs.
"_fake_gpus": False,
# Number of CPUs to allocate per worker.
"num_cpus_per_worker": 1,
# Number of GPUs to allocate per worker. This can be fractional. This is
@ -404,6 +409,13 @@ COMMON_CONFIG: TrainerConfigDict = {
# Define logger-specific configuration to be used inside Logger
# Default value None allows overwriting with nested dicts
"logger_config": None,
# Deprecated values.
# Uses the sync samples optimizer instead of the multi-gpu one. This is
# usually slower, but you might want to try it if you run into issues with
# the default optimizer.
# This will be set automatically from now on.
"simple_optimizer": DEPRECATED_VALUE,
}
# __sphinx_doc_end__
# yapf: enable
@ -1120,6 +1132,25 @@ class Trainer(Trainable):
if model_config is None:
config["model"] = model_config = {}
# Multi-GPU settings.
simple_optim_setting = config.get("simple_optimizer", DEPRECATED_VALUE)
if simple_optim_setting != DEPRECATED_VALUE:
deprecation_warning("simple_optimizer", error=False)
if config.get("num_gpus", 0) > 1:
if config.get("framework") in ["tfe", "tf2", "torch"]:
raise ValueError("`num_gpus` > 1 not supported yet for "
"framework={}!".format(
config.get("framework")))
elif simple_optim_setting is True:
raise ValueError(
"Cannot use `simple_optimizer` if `num_gpus` > 1! "
"Consider `simple_optimizer=False`.")
config["simple_optimizer"] = False
elif simple_optim_setting == DEPRECATED_VALUE:
config["simple_optimizer"] = True
# Trajectory View API settings.
if not config.get("_use_trajectory_view_api"):
traj_view_framestacks = model_config.get("num_framestacks", "auto")
if model_config.get("_time_major"):
@ -1130,6 +1161,7 @@ class Trainer(Trainable):
"iff `_use_trajectory_view_api` is True!")
model_config["num_framestacks"] = 0
# Offline RL settings.
if isinstance(config["input_evaluation"], tuple):
config["input_evaluation"] = list(config["input_evaluation"])
elif not isinstance(config["input_evaluation"], list):
@ -1156,6 +1188,7 @@ class Trainer(Trainable):
"complete_episodes]! Got {}".format(
config["batch_mode"]))
# Check multi-agent batch count mode.
if config["multiagent"].get("count_steps_by", "env_steps") not in \
["env_steps", "agent_steps"]:
raise ValueError(

View file

@ -5,7 +5,7 @@ from ray.rllib.agents.trainer import Trainer, COMMON_CONFIG
from ray.rllib.env.env_context import EnvContext
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
from ray.rllib.execution.train_ops import TrainOneStep
from ray.rllib.execution.train_ops import TrainOneStep, TrainTFMultiGPU
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.policy import Policy
from ray.rllib.utils import add_mixins
@ -26,7 +26,21 @@ def default_execution_plan(workers: WorkerSet, config: TrainerConfigDict):
ConcatBatches(
min_batch_size=config["train_batch_size"],
count_steps_by=config["multiagent"]["count_steps_by"],
)).for_each(TrainOneStep(workers))
))
if config.get("simple_optimizer") is True:
train_op = train_op.for_each(TrainOneStep(workers))
else:
train_op = train_op.for_each(
TrainTFMultiGPU(
workers=workers,
sgd_minibatch_size=config.get("sgd_minibatch_size",
config["train_batch_size"]),
num_sgd_iter=config.get("num_sgd_iter", 1),
num_gpus=config["num_gpus"],
shuffle_sequences=config.get("shuffle_sequences", False),
_fake_gpus=config["_fake_gpus"],
framework=config["framework"]))
# Add on the standard episode reward, etc. metrics reporting. This returns
# a LocalIterator[metrics_dict] representing metrics for each train step.

View file

@ -63,13 +63,15 @@ if __name__ == "__main__":
# Checkpoint with the lowest policy loss value:
ckpt = results.get_best_checkpoint(
best_trial,
metric="info/learner/default_policy/policy_loss",
metric="info/learner/default_policy/learner_stats/policy_loss",
mode="min")
print("Lowest pol-loss: {}".format(ckpt))
# Checkpoint with the highest value-function loss:
ckpt = results.get_best_checkpoint(
best_trial, metric="info/learner/default_policy/vf_loss", mode="max")
best_trial,
metric="info/learner/default_policy/learner_stats/vf_loss",
mode="max")
print("Highest vf-loss: {}".format(ckpt))
ray.shutdown()

View file

@ -11,6 +11,7 @@ from ray.rllib.models import ModelCatalog
from ray.rllib.models.tf.misc import normc_initializer
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.tf.visionnet import VisionNetwork as MyVisionNetwork
from ray.rllib.policy.policy import LEARNER_STATS_KEY
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.framework import try_import_tf
@ -84,7 +85,7 @@ class MyKerasQModel(DistributionalQTFModel):
kernel_initializer=normc_initializer(1.0))(layer_1)
self.base_model = tf.keras.Model(self.inputs, layer_out)
# Implement the core forward method
# Implement the core forward method.
def forward(self, input_dict, state, seq_lens):
model_out = self.base_model(input_dict["obs"])
return model_out, state
@ -95,7 +96,7 @@ class MyKerasQModel(DistributionalQTFModel):
if __name__ == "__main__":
args = parser.parse_args()
ray.init(num_cpus=args.num_cpus or None)
ray.init(num_cpus=args.num_cpus or None, local_mode=True)
ModelCatalog.register_custom_model(
"keras_model", MyVisionNetwork
if args.use_vision_network else MyKerasModel)
@ -107,7 +108,8 @@ if __name__ == "__main__":
def check_has_custom_metric(result):
r = result["result"]["info"]["learner"]
if DEFAULT_POLICY_ID in r:
r = r[DEFAULT_POLICY_ID]
r = r[DEFAULT_POLICY_ID].get(LEARNER_STATS_KEY,
r[DEFAULT_POLICY_ID])
assert r["model"]["foo"] == 42, result
if args.run == "DQN":

View file

@ -33,7 +33,6 @@ parser.add_argument("--num-policies", type=int, default=2)
parser.add_argument("--stop-iters", type=int, default=200)
parser.add_argument("--stop-reward", type=float, default=150)
parser.add_argument("--stop-timesteps", type=int, default=100000)
parser.add_argument("--simple", action="store_true")
parser.add_argument("--num-cpus", type=int, default=0)
parser.add_argument("--as-test", action="store_true")
parser.add_argument(
@ -82,7 +81,6 @@ if __name__ == "__main__":
"env_config": {
"num_agents": args.num_agents,
},
"simple_optimizer": args.simple,
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"num_sgd_iter": 10,

View file

@ -73,7 +73,6 @@ if __name__ == "__main__":
config["num_workers"] = args.num_workers
config["rollout_fragment_length"] = 200
config["sgd_minibatch_size"] = 256
config["simple_optimizer"] = True
config["train_batch_size"] = 4000
config["batch_mode"] = "complete_episodes"

View file

@ -251,8 +251,9 @@ class LocalSyncParallelOptimizer:
feed_dict.update(tower.loss_graph.extra_compute_grad_feed_dict())
fetches = {"train": self._train_op}
for tower in self._towers:
fetches.update(tower.loss_graph._get_grad_and_stats_fetches())
for tower_num, tower in enumerate(self._towers):
tower_fetch = tower.loss_graph._get_grad_and_stats_fetches()
fetches["tower_{}".format(tower_num)] = tower_fetch
return sess.run(fetches, feed_dict=feed_dict)

View file

@ -1,7 +1,7 @@
from collections import defaultdict
import logging
import numpy as np
import math
import tree
from typing import List, Tuple, Any
import ray
@ -18,7 +18,7 @@ from ray.rllib.execution.multi_gpu_impl import LocalSyncParallelOptimizer
from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
MultiAgentBatch
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.sgd import do_minibatch_sgd, averaged
from ray.rllib.utils.sgd import do_minibatch_sgd
from ray.rllib.utils.typing import PolicyID, SampleBatchType, ModelGradients
tf1, tf, tfv = try_import_tf()
@ -106,13 +106,11 @@ class TrainTFMultiGPU:
"""
def __init__(self,
*,
workers: WorkerSet,
sgd_minibatch_size: int,
num_sgd_iter: int,
num_gpus: int,
rollout_fragment_length: int,
num_envs_per_worker: int,
train_batch_size: int,
shuffle_sequences: bool,
policies: List[PolicyID] = frozenset([]),
_fake_gpus: bool = False,
@ -124,7 +122,7 @@ class TrainTFMultiGPU:
self.shuffle_sequences = shuffle_sequences
self.framework = framework
# Collect actual devices to use.
# Collect actual GPU devices to use.
if not num_gpus:
_fake_gpus = True
num_gpus = 1
@ -133,10 +131,13 @@ class TrainTFMultiGPU:
"/{}:{}".format(type_, i) for i in range(int(math.ceil(num_gpus)))
]
# Total batch size (all towers). Make sure it is dividable by
# num towers.
self.batch_size = int(sgd_minibatch_size / len(self.devices)) * len(
self.devices)
assert self.batch_size % len(self.devices) == 0
assert self.batch_size >= len(self.devices), "batch size too small"
# Batch size per tower.
self.per_device_batch_size = int(self.batch_size / len(self.devices))
# per-GPU graph copies created below must share vars with the policy
@ -177,8 +178,8 @@ class TrainTFMultiGPU:
metrics = _get_shared_metrics()
load_timer = metrics.timers[LOAD_BATCH_TIMER]
learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER]
# Load data into GPUs.
with load_timer:
# (1) Load data into GPUs.
num_loaded_tuples = {}
for policy_id, batch in samples.policy_batches.items():
# Not a policy-to-train.
@ -202,8 +203,8 @@ class TrainTFMultiGPU:
self.sess, [tuples[k] for k in data_keys],
[tuples[k] for k in state_keys]))
# Execute minibatch SGD on loaded data.
with learn_timer:
# (2) Execute minibatch SGD on loaded data.
fetches = {}
for policy_id, tuples_per_device in num_loaded_tuples.items():
optimizer = self.optimizers[policy_id]
@ -211,19 +212,25 @@ class TrainTFMultiGPU:
1,
int(tuples_per_device) // int(self.per_device_batch_size))
logger.debug("== sgd epochs for {} ==".format(policy_id))
for i in range(self.num_sgd_iter):
iter_extra_fetches = defaultdict(list)
for _ in range(self.num_sgd_iter):
permutation = np.random.permutation(num_batches)
batch_fetches_all_towers = []
for batch_index in range(num_batches):
batch_fetches = optimizer.optimize(
self.sess, permutation[batch_index] *
self.per_device_batch_size)
for k, v in batch_fetches[LEARNER_STATS_KEY].items():
iter_extra_fetches[k].append(v)
if logger.getEffectiveLevel() <= logging.DEBUG:
avg = averaged(iter_extra_fetches)
logger.debug("{} {}".format(i, avg))
fetches[policy_id] = averaged(iter_extra_fetches, axis=0)
batch_fetches_all_towers.append(
tree.map_structure_with_path(
lambda p, *s: self._all_tower_reduce(p, *s),
*(batch_fetches["tower_{}".format(tower_num)]
for tower_num in range(len(self.devices)))))
# Reduce mean across all minibatch SGD steps (axis=0 to keep
# all shapes as-is).
fetches[policy_id] = tree.map_structure(
lambda *s: np.nanmean(s, axis=0),
*batch_fetches_all_towers)
load_timer.push_units_processed(samples.count)
learn_timer.push_units_processed(samples.count)
@ -240,6 +247,16 @@ class TrainTFMultiGPU:
self.workers.local_worker().set_global_vars(_get_global_vars())
return samples, fetches
def _all_tower_reduce(self, path, *tower_data):
"""Reduces stats across towers based on their stats-dict paths."""
if len(path) == 1 and path[0] == "td_error":
return np.concatenate(tower_data, axis=0)
elif path[-1].startswith("min_"):
return np.nanmin(tower_data)
elif path[-1].startswith("max_"):
return np.nanmax(tower_data)
return np.nanmean(tower_data)
class ComputeGradients:
"""Callable that computes gradients with respect to the policy loss.

View file

@ -160,7 +160,11 @@ class DynamicTFPolicy(TFPolicy):
# Setup self.model.
if existing_model:
self.model = existing_model
if isinstance(existing_model, list):
self.model = existing_model[0]
# TODO: (sven) hack, but works for `target_[q_]?model`.
for i in range(1, len(existing_model)):
setattr(self, existing_model[i][0], existing_model[i][1])
elif make_model:
self.model = make_model(self, obs_space, action_space, config)
else:
@ -389,7 +393,11 @@ class DynamicTFPolicy(TFPolicy):
self.action_space,
self.config,
existing_inputs=input_dict,
existing_model=self.model)
existing_model=[
self.model,
("target_q_model", getattr(self, "target_q_model", None)),
("target_model", getattr(self, "target_model", None)),
])
instance._loss_input_dict = input_dict
loss = instance._do_loss_init(input_dict)

View file

@ -292,6 +292,12 @@ def build_tf_policy(
@override(TFPolicy)
def extra_compute_grad_fetches(self):
if extra_learn_fetches_fn:
# TODO: (sven) in torch, extra_learn_fetches do not exist.
# Hence, things like td_error are returned by the stats_fn
# and end up under the LEARNER_STATS_KEY. We should
# change tf to do this as well. However, this will confilct
# the handling of LEARNER_STATS_KEY inside the multi-GPU
# train op.
# Auto-add empty learner stats dict if needed.
return dict({
LEARNER_STATS_KEY: {}

View file

@ -120,7 +120,6 @@ class TestRNNSequencing(unittest.TestCase):
"rollout_fragment_length": 10,
"train_batch_size": 10,
"sgd_minibatch_size": 10,
"simple_optimizer": True,
"num_sgd_iter": 1,
"model": {
"custom_model": "rnn",
@ -178,7 +177,6 @@ class TestRNNSequencing(unittest.TestCase):
"rollout_fragment_length": 20,
"train_batch_size": 20,
"sgd_minibatch_size": 10,
"simple_optimizer": False,
"num_sgd_iter": 1,
"model": {
"custom_model": "rnn",
@ -198,15 +196,17 @@ class TestRNNSequencing(unittest.TestCase):
ray.experimental.internal_kv._internal_kv_get("rnn_spy_in_1"))
if batch0["sequences"][0][0][0] > batch1["sequences"][0][0][0]:
batch0, batch1 = batch1, batch0 # sort minibatches
self.assertEqual(batch0["seq_lens"].tolist(), [4, 4])
self.assertEqual(batch1["seq_lens"].tolist(), [4, 3])
self.assertEqual(batch0["seq_lens"].tolist(), [4, 4, 2])
self.assertEqual(batch1["seq_lens"].tolist(), [4, 3, 3])
self.assertEqual(batch0["sequences"].tolist(), [
[[0], [1], [2], [3]],
[[4], [5], [6], [7]],
[[8], [9], [0], [0]],
])
self.assertEqual(batch1["sequences"].tolist(), [
[[8], [9], [10], [11]],
[[12], [13], [14], [0]],
[[0], [1], [2], [0]],
])
# second epoch: 20 observations get split into 2 minibatches of 8
@ -217,15 +217,17 @@ class TestRNNSequencing(unittest.TestCase):
ray.experimental.internal_kv._internal_kv_get("rnn_spy_in_3"))
if batch2["sequences"][0][0][0] > batch3["sequences"][0][0][0]:
batch2, batch3 = batch3, batch2
self.assertEqual(batch2["seq_lens"].tolist(), [4, 4])
self.assertEqual(batch3["seq_lens"].tolist(), [2, 4])
self.assertEqual(batch2["seq_lens"].tolist(), [4, 4, 2])
self.assertEqual(batch3["seq_lens"].tolist(), [4, 4, 2])
self.assertEqual(batch2["sequences"].tolist(), [
[[5], [6], [7], [8]],
[[9], [10], [11], [12]],
[[0], [1], [2], [3]],
[[4], [5], [6], [7]],
[[8], [9], [0], [0]],
])
self.assertEqual(batch3["sequences"].tolist(), [
[[5], [6], [7], [8]],
[[9], [10], [11], [12]],
[[13], [14], [0], [0]],
[[0], [1], [2], [3]],
])

View file

@ -96,7 +96,7 @@ def do_minibatch_sgd(samples, policies, local_worker, num_sgd_iter,
if isinstance(samples, SampleBatch):
samples = MultiAgentBatch({DEFAULT_POLICY_ID: samples}, samples.count)
fetches = {}
fetches = defaultdict(dict)
for policy_id in policies.keys():
if policy_id not in samples.policy_batches:
continue
@ -106,14 +106,13 @@ def do_minibatch_sgd(samples, policies, local_worker, num_sgd_iter,
batch[field] = standardized(batch[field])
for i in range(num_sgd_iter):
iter_extra_fetches = defaultdict(list)
learner_stats = defaultdict(list)
for minibatch in minibatches(batch, sgd_minibatch_size):
batch_fetches = (local_worker.learn_on_batch(
MultiAgentBatch({
policy_id: minibatch
}, minibatch.count)))[policy_id]
for k, v in batch_fetches.get(LEARNER_STATS_KEY, {}).items():
iter_extra_fetches[k].append(v)
logger.debug("{} {}".format(i, averaged(iter_extra_fetches)))
fetches[policy_id] = averaged(iter_extra_fetches)
learner_stats[k].append(v)
fetches[policy_id][LEARNER_STATS_KEY] = averaged(learner_stats)
return fetches