mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Multi-GPU for tf-DQN/PG/A2C. (#13393)
This commit is contained in:
parent
b0bf44b154
commit
732197e23a
41 changed files with 385 additions and 134 deletions
|
@ -8,30 +8,30 @@ RLlib Algorithms
|
|||
Available Algorithms - Overview
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
=================== ========== ======================= ================== =========== =============================================================
|
||||
Algorithm Frameworks Discrete Actions Continuous Actions Multi-Agent Model Support
|
||||
=================== ========== ======================= ================== =========== =============================================================
|
||||
`A2C, A3C`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+LSTM auto-wrapping`_, `+Attention`_, `+autoreg`_
|
||||
`ARS`_ tf + torch **Yes** **Yes** No
|
||||
`BC`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_
|
||||
`ES`_ tf + torch **Yes** **Yes** No
|
||||
`DDPG`_, `TD3`_ tf + torch No **Yes** **Yes**
|
||||
`APEX-DDPG`_ tf + torch No **Yes** **Yes**
|
||||
`Dreamer`_ torch No **Yes** No `+RNN`_
|
||||
`DQN`_, `Rainbow`_ tf + torch **Yes** `+parametric`_ No **Yes**
|
||||
`APEX-DQN`_ tf + torch **Yes** `+parametric`_ No **Yes**
|
||||
`IMPALA`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+LSTM auto-wrapping`_, `+Attention`_, `+autoreg`_
|
||||
`MAML`_ tf + torch No **Yes** No
|
||||
`MARWIL`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_
|
||||
`MBMPO`_ torch No **Yes** No
|
||||
`PG`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+LSTM auto-wrapping`_, `+Attention`_, `+autoreg`_
|
||||
`PPO`_, `APPO`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+LSTM auto-wrapping`_, `+Attention`_, `+autoreg`_
|
||||
`R2D2`_ tf + torch **Yes** `+parametric`_ No **Yes** `+RNN`_, `+LSTM auto-wrapping`_, `+autoreg`_
|
||||
`SAC`_ tf + torch **Yes** **Yes** **Yes**
|
||||
`SlateQ`_ torch **Yes** No No
|
||||
`LinUCB`_, `LinTS`_ torch **Yes** `+parametric`_ No **Yes**
|
||||
`AlphaZero`_ torch **Yes** `+parametric`_ No No
|
||||
=================== ========== ======================= ================== =========== =============================================================
|
||||
=================== ========== ======================= ================== =========== ============================================================= =========
|
||||
Algorithm Frameworks Discrete Actions Continuous Actions Multi-Agent Model Support Multi-GPU
|
||||
=================== ========== ======================= ================== =========== ============================================================= =========
|
||||
`A2C, A3C`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+LSTM auto-wrapping`_, `+Attention`_, `+autoreg`_ tf (A2C)
|
||||
`ARS`_ tf + torch **Yes** **Yes** No No
|
||||
`BC`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_ No
|
||||
`ES`_ tf + torch **Yes** **Yes** No No
|
||||
`DDPG`_, `TD3`_ tf + torch No **Yes** **Yes** No
|
||||
`APEX-DDPG`_ tf + torch No **Yes** **Yes** No
|
||||
`Dreamer`_ torch No **Yes** No `+RNN`_ No
|
||||
`DQN`_, `Rainbow`_ tf + torch **Yes** `+parametric`_ No **Yes** tf
|
||||
`APEX-DQN`_ tf + torch **Yes** `+parametric`_ No **Yes** No
|
||||
`IMPALA`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+LSTM auto-wrapping`_, `+Attention`_, `+autoreg`_ tf
|
||||
`MAML`_ tf + torch No **Yes** No No
|
||||
`MARWIL`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_ No
|
||||
`MBMPO`_ torch No **Yes** No No
|
||||
`PG`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+LSTM auto-wrapping`_, `+Attention`_, `+autoreg`_ tf
|
||||
`PPO`_, `APPO`_ tf + torch **Yes** `+parametric`_ **Yes** **Yes** `+RNN`_, `+LSTM auto-wrapping`_, `+Attention`_, `+autoreg`_ tf
|
||||
`R2D2`_ tf + torch **Yes** `+parametric`_ No **Yes** `+RNN`_, `+LSTM auto-wrapping`_, `+autoreg`_ No
|
||||
`SAC`_ tf + torch **Yes** **Yes** **Yes** No
|
||||
`SlateQ`_ torch **Yes** No No No
|
||||
`LinUCB`_, `LinTS`_ torch **Yes** `+parametric`_ No **Yes** No
|
||||
`AlphaZero`_ torch **Yes** `+parametric`_ No No No
|
||||
=================== ========== ======================= ================== =========== ============================================================= =========
|
||||
|
||||
Multi-Agent only Methods
|
||||
|
||||
|
|
|
@ -173,7 +173,6 @@ py_test(
|
|||
"tuned_examples/dqn/cartpole-dqn.yaml",
|
||||
"tuned_examples/dqn/cartpole-dqn-softq.yaml",
|
||||
"tuned_examples/dqn/cartpole-dqn-param-noise.yaml",
|
||||
"tuned_examples/dqn/cartpole-r2d2.yaml",
|
||||
],
|
||||
args = ["--yaml-dir=tuned_examples/dqn"]
|
||||
)
|
||||
|
@ -188,7 +187,6 @@ py_test(
|
|||
"tuned_examples/dqn/cartpole-dqn.yaml",
|
||||
"tuned_examples/dqn/cartpole-dqn-softq.yaml",
|
||||
"tuned_examples/dqn/cartpole-dqn-param-noise.yaml",
|
||||
"tuned_examples/dqn/cartpole-r2d2.yaml",
|
||||
],
|
||||
args = ["--yaml-dir=tuned_examples/dqn", "--framework=torch"]
|
||||
)
|
||||
|
@ -603,7 +601,7 @@ py_test(
|
|||
py_test(
|
||||
name = "test_pg",
|
||||
tags = ["agents_dir"],
|
||||
size = "small",
|
||||
size = "medium",
|
||||
srcs = ["agents/pg/tests/test_pg.py"]
|
||||
)
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ from ray.rllib.agents.trainer_template import build_trainer
|
|||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
|
||||
from ray.rllib.execution.train_ops import ComputeGradients, AverageGradients, \
|
||||
ApplyGradients, TrainOneStep
|
||||
ApplyGradients, TrainTFMultiGPU, TrainOneStep
|
||||
from ray.rllib.utils import merge_dicts
|
||||
|
||||
A2C_DEFAULT_CONFIG = merge_dicts(
|
||||
|
@ -47,11 +47,23 @@ def execution_plan(workers, config):
|
|||
.for_each(ApplyGradients(workers)))
|
||||
else:
|
||||
# In normal mode, we execute one SGD step per each train batch.
|
||||
if config["simple_optimizer"]:
|
||||
train_step_op = TrainOneStep(workers)
|
||||
else:
|
||||
train_step_op = TrainTFMultiGPU(
|
||||
workers=workers,
|
||||
sgd_minibatch_size=config["train_batch_size"],
|
||||
num_sgd_iter=1,
|
||||
num_gpus=config["num_gpus"],
|
||||
shuffle_sequences=True,
|
||||
_fake_gpus=config["_fake_gpus"],
|
||||
framework=config.get("framework"))
|
||||
|
||||
train_op = rollouts.combine(
|
||||
ConcatBatches(
|
||||
min_batch_size=config["train_batch_size"],
|
||||
count_steps_by=config["multiagent"][
|
||||
"count_steps_by"])).for_each(TrainOneStep(workers))
|
||||
"count_steps_by"])).for_each(train_step_op)
|
||||
|
||||
return StandardMetricsReporting(train_op, workers, config)
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import copy
|
||||
import unittest
|
||||
|
||||
import ray
|
||||
|
@ -33,6 +34,30 @@ class TestA2C(unittest.TestCase):
|
|||
check_compute_single_action(trainer)
|
||||
trainer.stop()
|
||||
|
||||
def test_a2c_fake_multi_gpu_learning(self):
|
||||
"""Test whether A2CTrainer can learn CartPole w/ faked multi-GPU."""
|
||||
config = copy.deepcopy(a3c.a2c.A2C_DEFAULT_CONFIG)
|
||||
|
||||
# Fake GPU setup.
|
||||
config["num_gpus"] = 2
|
||||
config["_fake_gpus"] = True
|
||||
|
||||
config["framework"] = "tf"
|
||||
# Mimic tuned_example for A2C CartPole.
|
||||
config["lr"] = 0.001
|
||||
|
||||
trainer = a3c.A2CTrainer(config=config, env="CartPole-v0")
|
||||
num_iterations = 100
|
||||
learnt = False
|
||||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
print("reward={}".format(results["episode_reward_mean"]))
|
||||
if results["episode_reward_mean"] > 100.0:
|
||||
learnt = True
|
||||
break
|
||||
assert learnt, "A2C multi-GPU (with fake-GPUs) did not learn CartPole!"
|
||||
trainer.stop()
|
||||
|
||||
def test_a2c_exec_impl(ray_start_regular):
|
||||
config = {"min_iter_time_s": 0}
|
||||
for _ in framework_iterator(config):
|
||||
|
|
|
@ -46,6 +46,8 @@ CQL_DEFAULT_CONFIG = merge_dicts(
|
|||
|
||||
|
||||
def validate_config(config: TrainerConfigDict):
|
||||
if config["num_gpus"] > 1:
|
||||
raise ValueError("`num_gpus` > 1 not yet supported for CQL!")
|
||||
if config["framework"] == "tf":
|
||||
raise ValueError("Tensorflow CQL not implemented yet!")
|
||||
|
||||
|
|
|
@ -151,6 +151,8 @@ DEFAULT_CONFIG = with_common_config({
|
|||
|
||||
|
||||
def validate_config(config):
|
||||
if config["num_gpus"] > 1:
|
||||
raise ValueError("`num_gpus` > 1 not yet supported for DDPG!")
|
||||
if config["model"]["custom_model"]:
|
||||
logger.warning(
|
||||
"Setting use_state_preprocessor=True since a custom model "
|
||||
|
|
|
@ -123,6 +123,8 @@ def ddpg_actor_critic_loss(policy, model, _, train_batch):
|
|||
model_out_tp1, _ = model(input_dict_next, [], None)
|
||||
target_model_out_tp1, _ = policy.target_model(input_dict_next, [], None)
|
||||
|
||||
policy.target_q_func_vars = policy.target_model.variables()
|
||||
|
||||
# Policy network evaluation.
|
||||
policy_t = model.get_policy_output(model_out_t)
|
||||
policy_tp1 = \
|
||||
|
|
|
@ -17,8 +17,8 @@ import copy
|
|||
from typing import Tuple
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.dqn.dqn import DEFAULT_CONFIG as DQN_CONFIG
|
||||
from ray.rllib.agents.dqn.dqn import DQNTrainer, calculate_rr_weights
|
||||
from ray.rllib.agents.dqn.dqn import calculate_rr_weights, \
|
||||
DEFAULT_CONFIG as DQN_CONFIG, DQNTrainer, validate_config
|
||||
from ray.rllib.agents.dqn.learner_thread import LearnerThread
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.execution.common import (STEPS_TRAINED_COUNTER,
|
||||
|
@ -195,7 +195,14 @@ def apex_execution_plan(workers: WorkerSet,
|
|||
selected_workers=selected_workers).for_each(add_apex_metrics)
|
||||
|
||||
|
||||
def apex_validate_config(config):
|
||||
if config["num_gpus"] > 1:
|
||||
raise ValueError("`num_gpus` > 1 not yet supported for APEX-DQN!")
|
||||
validate_config(config)
|
||||
|
||||
|
||||
ApexTrainer = DQNTrainer.with_updates(
|
||||
name="APEX",
|
||||
default_config=APEX_DEFAULT_CONFIG,
|
||||
validate_config=apex_validate_config,
|
||||
execution_plan=apex_execution_plan)
|
||||
|
|
|
@ -22,7 +22,8 @@ from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
|||
from ray.rllib.execution.replay_buffer import LocalReplayBuffer
|
||||
from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
|
||||
from ray.rllib.execution.rollout_ops import ParallelRollouts
|
||||
from ray.rllib.execution.train_ops import TrainOneStep, UpdateTargetNetwork
|
||||
from ray.rllib.execution.train_ops import TrainOneStep, UpdateTargetNetwork, \
|
||||
TrainTFMultiGPU
|
||||
from ray.rllib.policy.policy import LEARNER_STATS_KEY, Policy
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
from ray.util.iter import LocalIterator
|
||||
|
@ -170,6 +171,13 @@ def validate_config(config: TrainerConfigDict) -> None:
|
|||
raise ValueError("Prioritized replay is not supported when "
|
||||
"replay_sequence_length > 1.")
|
||||
|
||||
# Multi-agent mode and multi-GPU optimizer.
|
||||
if config["multiagent"]["policies"] and not config["simple_optimizer"]:
|
||||
logger.info(
|
||||
"In multi-agent mode, policies will be optimized sequentially "
|
||||
"by the multi-GPU optimizer. Consider setting "
|
||||
"simple_optimizer=True if this doesn't work for you.")
|
||||
|
||||
|
||||
def execution_plan(workers: WorkerSet,
|
||||
config: TrainerConfigDict) -> LocalIterator[dict]:
|
||||
|
@ -231,9 +239,22 @@ def execution_plan(workers: WorkerSet,
|
|||
# returned from the LocalReplay() iterator is passed to TrainOneStep to
|
||||
# take a SGD step, and then we decide whether to update the target network.
|
||||
post_fn = config.get("before_learn_on_batch") or (lambda b, *a: b)
|
||||
|
||||
if config["simple_optimizer"]:
|
||||
train_step_op = TrainOneStep(workers)
|
||||
else:
|
||||
train_step_op = TrainTFMultiGPU(
|
||||
workers=workers,
|
||||
sgd_minibatch_size=config["train_batch_size"],
|
||||
num_sgd_iter=1,
|
||||
num_gpus=config["num_gpus"],
|
||||
shuffle_sequences=True,
|
||||
_fake_gpus=config["_fake_gpus"],
|
||||
framework=config.get("framework"))
|
||||
|
||||
replay_op = Replay(local_buffer=local_replay_buffer) \
|
||||
.for_each(lambda x: post_fn(x, workers, config)) \
|
||||
.for_each(TrainOneStep(workers)) \
|
||||
.for_each(train_step_op) \
|
||||
.for_each(update_prio) \
|
||||
.for_each(UpdateTargetNetwork(
|
||||
workers, config["target_network_update_freq"]))
|
||||
|
|
|
@ -164,7 +164,7 @@ def build_q_model(policy: Policy, obs_space: gym.spaces.Space,
|
|||
else:
|
||||
num_outputs = action_space.n
|
||||
|
||||
policy.q_model = ModelCatalog.get_model_v2(
|
||||
q_model = ModelCatalog.get_model_v2(
|
||||
obs_space=obs_space,
|
||||
action_space=action_space,
|
||||
num_outputs=num_outputs,
|
||||
|
@ -206,7 +206,7 @@ def build_q_model(policy: Policy, obs_space: gym.spaces.Space,
|
|||
getattr(policy, "exploration", None), ParameterNoise)
|
||||
or config["exploration_config"]["type"] == "ParameterNoise")
|
||||
|
||||
return policy.q_model
|
||||
return q_model
|
||||
|
||||
|
||||
def get_distribution_inputs_and_class(policy: Policy,
|
||||
|
@ -240,7 +240,7 @@ def build_q_losses(policy: Policy, model, _,
|
|||
# q network evaluation
|
||||
q_t, q_logits_t, q_dist_t, _ = compute_q_values(
|
||||
policy,
|
||||
policy.q_model, {"obs": train_batch[SampleBatch.CUR_OBS]},
|
||||
model, {"obs": train_batch[SampleBatch.CUR_OBS]},
|
||||
state_batches=None,
|
||||
explore=False)
|
||||
|
||||
|
@ -265,7 +265,7 @@ def build_q_losses(policy: Policy, model, _,
|
|||
if config["double_q"]:
|
||||
q_tp1_using_online_net, q_logits_tp1_using_online_net, \
|
||||
q_dist_tp1_using_online_net, _ = compute_q_values(
|
||||
policy, policy.q_model,
|
||||
policy, model,
|
||||
{"obs": train_batch[SampleBatch.NEXT_OBS]},
|
||||
state_batches=None,
|
||||
explore=False)
|
||||
|
|
|
@ -161,7 +161,7 @@ def build_q_model_and_distribution(
|
|||
isinstance(getattr(policy, "exploration", None), ParameterNoise)
|
||||
or config["exploration_config"]["type"] == "ParameterNoise")
|
||||
|
||||
policy.q_model = ModelCatalog.get_model_v2(
|
||||
model = ModelCatalog.get_model_v2(
|
||||
obs_space=obs_space,
|
||||
action_space=action_space,
|
||||
num_outputs=num_outputs,
|
||||
|
@ -180,7 +180,7 @@ def build_q_model_and_distribution(
|
|||
# generically into ModelCatalog.
|
||||
add_layer_norm=add_layer_norm)
|
||||
|
||||
policy.q_func_vars = policy.q_model.variables()
|
||||
policy.q_func_vars = model.variables()
|
||||
|
||||
policy.target_q_model = ModelCatalog.get_model_v2(
|
||||
obs_space=obs_space,
|
||||
|
@ -203,7 +203,7 @@ def build_q_model_and_distribution(
|
|||
|
||||
policy.target_q_func_vars = policy.target_q_model.variables()
|
||||
|
||||
return policy.q_model, TorchCategorical
|
||||
return model, TorchCategorical
|
||||
|
||||
|
||||
def get_distribution_inputs_and_class(
|
||||
|
@ -241,7 +241,7 @@ def build_q_losses(policy: Policy, model, _,
|
|||
# Q-network evaluation.
|
||||
q_t, q_logits_t, q_probs_t, _ = compute_q_values(
|
||||
policy,
|
||||
policy.q_model, {"obs": train_batch[SampleBatch.CUR_OBS]},
|
||||
model, {"obs": train_batch[SampleBatch.CUR_OBS]},
|
||||
explore=False,
|
||||
is_training=True)
|
||||
|
||||
|
@ -267,7 +267,7 @@ def build_q_losses(policy: Policy, model, _,
|
|||
q_tp1_using_online_net, q_logits_tp1_using_online_net, \
|
||||
q_dist_tp1_using_online_net, _ = compute_q_values(
|
||||
policy,
|
||||
policy.q_model,
|
||||
model,
|
||||
{"obs": train_batch[SampleBatch.NEXT_OBS]},
|
||||
explore=False,
|
||||
is_training=True)
|
||||
|
|
|
@ -78,7 +78,7 @@ def r2d2_loss(policy: Policy, model, _,
|
|||
# Q-network evaluation (at t).
|
||||
q, _, _, _ = compute_q_values(
|
||||
policy,
|
||||
policy.q_model,
|
||||
model,
|
||||
train_batch,
|
||||
state_batches=state_batches,
|
||||
seq_lens=train_batch.get("seq_lens"),
|
||||
|
|
|
@ -86,7 +86,7 @@ def r2d2_loss(policy: Policy, model, _,
|
|||
# Q-network evaluation (at t).
|
||||
q, _, _, _ = compute_q_values(
|
||||
policy,
|
||||
policy.q_model,
|
||||
model,
|
||||
train_batch,
|
||||
state_batches=state_batches,
|
||||
seq_lens=train_batch.get("seq_lens"),
|
||||
|
|
|
@ -22,7 +22,8 @@ from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
|||
from ray.rllib.execution.replay_buffer import LocalReplayBuffer
|
||||
from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
|
||||
from ray.rllib.execution.rollout_ops import ParallelRollouts
|
||||
from ray.rllib.execution.train_ops import TrainOneStep, UpdateTargetNetwork
|
||||
from ray.rllib.execution.train_ops import TrainTFMultiGPU, TrainOneStep, \
|
||||
UpdateTargetNetwork
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
from ray.util.iter import LocalIterator
|
||||
|
@ -139,9 +140,21 @@ def execution_plan(workers: WorkerSet,
|
|||
store_op = rollouts.for_each(
|
||||
StoreToReplayBuffer(local_buffer=local_replay_buffer))
|
||||
|
||||
if config["simple_optimizer"]:
|
||||
train_step_op = TrainOneStep(workers)
|
||||
else:
|
||||
train_step_op = TrainTFMultiGPU(
|
||||
workers=workers,
|
||||
sgd_minibatch_size=config["train_batch_size"],
|
||||
num_sgd_iter=1,
|
||||
num_gpus=config["num_gpus"],
|
||||
shuffle_sequences=True,
|
||||
_fake_gpus=config["_fake_gpus"],
|
||||
framework=config.get("framework"))
|
||||
|
||||
# (2) Read and train on experiences from the replay buffer.
|
||||
replay_op = Replay(local_buffer=local_replay_buffer) \
|
||||
.for_each(TrainOneStep(workers)) \
|
||||
.for_each(train_step_op) \
|
||||
.for_each(UpdateTargetNetwork(
|
||||
workers, config["target_network_update_freq"]))
|
||||
|
||||
|
|
|
@ -79,7 +79,7 @@ def build_q_models(policy: Policy, obs_space: gym.spaces.Space,
|
|||
raise UnsupportedSpaceException(
|
||||
"Action space {} is not supported for DQN.".format(action_space))
|
||||
|
||||
policy.q_model = ModelCatalog.get_model_v2(
|
||||
model = ModelCatalog.get_model_v2(
|
||||
obs_space=obs_space,
|
||||
action_space=action_space,
|
||||
num_outputs=action_space.n,
|
||||
|
@ -95,10 +95,10 @@ def build_q_models(policy: Policy, obs_space: gym.spaces.Space,
|
|||
framework=config["framework"],
|
||||
name=Q_TARGET_SCOPE)
|
||||
|
||||
policy.q_func_vars = policy.q_model.variables()
|
||||
policy.q_func_vars = model.variables()
|
||||
policy.target_q_func_vars = policy.target_q_model.variables()
|
||||
|
||||
return policy.q_model
|
||||
return model
|
||||
|
||||
|
||||
def get_distribution_inputs_and_class(
|
||||
|
@ -114,6 +114,7 @@ def get_distribution_inputs_and_class(
|
|||
q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals
|
||||
|
||||
policy.q_values = q_vals
|
||||
policy.q_func_vars = q_model.variables()
|
||||
return policy.q_values, (TorchCategorical
|
||||
if policy.config["framework"] == "torch" else
|
||||
Categorical), [] # state-outs
|
||||
|
@ -135,10 +136,7 @@ def build_q_losses(policy: Policy, model: ModelV2,
|
|||
"""
|
||||
# q network evaluation
|
||||
q_t = compute_q_values(
|
||||
policy,
|
||||
policy.q_model,
|
||||
train_batch[SampleBatch.CUR_OBS],
|
||||
explore=False)
|
||||
policy, policy.model, train_batch[SampleBatch.CUR_OBS], explore=False)
|
||||
|
||||
# target q network evalution
|
||||
q_tp1 = compute_q_values(
|
||||
|
|
|
@ -38,7 +38,7 @@ class TargetNetworkMixin:
|
|||
# target Q network.
|
||||
assert len(self.q_func_vars) == len(self.target_q_func_vars), \
|
||||
(self.q_func_vars, self.target_q_func_vars)
|
||||
self.target_q_model.load_state_dict(self.q_model.state_dict())
|
||||
self.target_q_model.load_state_dict(self.model.state_dict())
|
||||
|
||||
self.update_target = do_update
|
||||
|
||||
|
@ -67,7 +67,7 @@ def build_q_losses(policy: Policy, model, dist_class,
|
|||
# q network evaluation
|
||||
q_t = compute_q_values(
|
||||
policy,
|
||||
policy.q_model,
|
||||
policy.model,
|
||||
train_batch[SampleBatch.CUR_OBS],
|
||||
explore=False,
|
||||
is_training=True)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import copy
|
||||
import numpy as np
|
||||
import unittest
|
||||
|
||||
|
@ -50,6 +51,34 @@ class TestDQN(unittest.TestCase):
|
|||
check_compute_single_action(trainer)
|
||||
trainer.stop()
|
||||
|
||||
def test_dqn_fake_multi_gpu_learning(self):
|
||||
"""Test whether DQNTrainer can learn CartPole w/ faked multi-GPU."""
|
||||
config = copy.deepcopy(dqn.DEFAULT_CONFIG)
|
||||
|
||||
# Fake GPU setup.
|
||||
config["num_gpus"] = 2
|
||||
config["_fake_gpus"] = True
|
||||
|
||||
config["framework"] = "tf"
|
||||
# Double batch size (2 GPUs).
|
||||
config["train_batch_size"] = 64
|
||||
# Mimic tuned_example for DQN CartPole.
|
||||
config["n_step"] = 3
|
||||
config["model"]["fcnet_hiddens"] = [64]
|
||||
config["model"]["fcnet_activation"] = "linear"
|
||||
|
||||
trainer = dqn.DQNTrainer(config=config, env="CartPole-v0")
|
||||
num_iterations = 200
|
||||
learnt = False
|
||||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
print("reward={}".format(results["episode_reward_mean"]))
|
||||
if results["episode_reward_mean"] > 100.0:
|
||||
learnt = True
|
||||
break
|
||||
assert learnt, "DQN multi-GPU (with fake-GPUs) did not learn CartPole!"
|
||||
trainer.stop()
|
||||
|
||||
def test_dqn_exploration_and_soft_q_config(self):
|
||||
"""Tests, whether a DQN Agent outputs exploration/softmaxed actions."""
|
||||
config = dqn.DEFAULT_CONFIG.copy()
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import copy
|
||||
import numpy as np
|
||||
import unittest
|
||||
|
||||
|
@ -44,6 +45,28 @@ class TestSimpleQ(unittest.TestCase):
|
|||
|
||||
check_compute_single_action(trainer)
|
||||
|
||||
def test_simple_q_fake_multi_gpu_learning(self):
|
||||
"""Test whether SimpleQTrainer learns CartPole w/ fake GPUs."""
|
||||
config = copy.deepcopy(dqn.SIMPLE_Q_DEFAULT_CONFIG)
|
||||
|
||||
# Fake GPU setup.
|
||||
config["num_gpus"] = 2
|
||||
config["_fake_gpus"] = True
|
||||
config["framework"] = "tf"
|
||||
|
||||
trainer = dqn.SimpleQTrainer(config=config, env="CartPole-v0")
|
||||
num_iterations = 200
|
||||
learnt = False
|
||||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
print("reward={}".format(results["episode_reward_mean"]))
|
||||
if results["episode_reward_mean"] > 75.0:
|
||||
learnt = True
|
||||
break
|
||||
assert learnt, "SimpleQ multi-GPU (with fake-GPUs) did not " \
|
||||
"learn CartPole!"
|
||||
trainer.stop()
|
||||
|
||||
def test_simple_q_loss_function(self):
|
||||
"""Tests the Simple-Q loss function results on all frameworks."""
|
||||
config = dqn.SIMPLE_Q_DEFAULT_CONFIG.copy()
|
||||
|
|
|
@ -247,6 +247,8 @@ def get_policy_class(config):
|
|||
|
||||
def validate_config(config):
|
||||
config["action_repeat"] = config["env_config"]["frame_skip"]
|
||||
if config["num_gpus"] > 1:
|
||||
raise ValueError("`num_gpus` > 1 not yet supported for Dreamer!")
|
||||
if config["framework"] != "torch":
|
||||
raise ValueError("Dreamer not supported in Tensorflow yet!")
|
||||
if config["batch_mode"] != "complete_episodes":
|
||||
|
|
|
@ -178,6 +178,8 @@ def get_policy_class(config):
|
|||
|
||||
|
||||
def validate_config(config):
|
||||
if config["num_gpus"] > 1:
|
||||
raise ValueError("`num_gpus` > 1 not yet supported for ES/ARS!")
|
||||
if config["num_workers"] <= 0:
|
||||
raise ValueError("`num_workers` must be > 0 for ES!")
|
||||
if config["evaluation_config"]["num_envs_per_worker"] != 1:
|
||||
|
|
|
@ -219,6 +219,8 @@ def get_policy_class(config):
|
|||
|
||||
|
||||
def validate_config(config):
|
||||
if config["num_gpus"] > 1:
|
||||
raise ValueError("`num_gpus` > 1 not yet supported for MAML!")
|
||||
if config["inner_adaptation_steps"] <= 0:
|
||||
raise ValueError("Inner Adaptation Steps must be >=1!")
|
||||
if config["maml_optimizer_steps"] <= 0:
|
||||
|
|
|
@ -70,9 +70,15 @@ def execution_plan(workers, config):
|
|||
return StandardMetricsReporting(train_op, workers, config)
|
||||
|
||||
|
||||
def validate_config(config):
|
||||
if config["num_gpus"] > 1:
|
||||
raise ValueError("`num_gpus` > 1 not yet supported for MARWIL!")
|
||||
|
||||
|
||||
MARWILTrainer = build_trainer(
|
||||
name="MARWIL",
|
||||
default_config=DEFAULT_CONFIG,
|
||||
default_policy=MARWILTFPolicy,
|
||||
get_policy_class=get_policy_class,
|
||||
validate_config=validate_config,
|
||||
execution_plan=execution_plan)
|
||||
|
|
|
@ -416,6 +416,8 @@ def validate_config(config):
|
|||
Raises:
|
||||
ValueError: In case something is wrong with the config.
|
||||
"""
|
||||
if config["num_gpus"] > 1:
|
||||
raise ValueError("`num_gpus` > 1 not yet supported for MB-MPO!")
|
||||
if config["framework"] != "torch":
|
||||
logger.warning("MB-MPO only supported in PyTorch so far! Switching to "
|
||||
"`framework=torch`.")
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import copy
|
||||
import numpy as np
|
||||
import unittest
|
||||
|
||||
|
@ -24,13 +25,42 @@ class TestPG(unittest.TestCase):
|
|||
config["num_workers"] = 0
|
||||
num_iterations = 2
|
||||
|
||||
for _ in framework_iterator(config):
|
||||
for fw in framework_iterator(config):
|
||||
# For tf, build with fake-GPUs.
|
||||
config["_fake_gpus"] = fw == "tf"
|
||||
config["num_gpus"] = 2 if fw == "tf" else 0
|
||||
trainer = pg.PGTrainer(config=config, env="CartPole-v0")
|
||||
for i in range(num_iterations):
|
||||
print(trainer.train())
|
||||
check_compute_single_action(
|
||||
trainer, include_prev_action_reward=True)
|
||||
|
||||
def test_pg_fake_multi_gpu_learning(self):
|
||||
"""Test whether PGTrainer can learn CartPole w/ faked multi-GPU."""
|
||||
config = copy.deepcopy(pg.DEFAULT_CONFIG)
|
||||
|
||||
# Fake GPU setup.
|
||||
config["num_gpus"] = 2
|
||||
config["_fake_gpus"] = True
|
||||
|
||||
config["framework"] = "tf"
|
||||
# Mimic tuned_example for PG CartPole.
|
||||
config["model"]["fcnet_hiddens"] = [64]
|
||||
config["model"]["fcnet_activation"] = "linear"
|
||||
|
||||
trainer = pg.PGTrainer(config=config, env="CartPole-v0")
|
||||
num_iterations = 200
|
||||
learnt = False
|
||||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
print("reward={}".format(results["episode_reward_mean"]))
|
||||
# Make this test quite short (75.0).
|
||||
if results["episode_reward_mean"] > 75.0:
|
||||
learnt = True
|
||||
break
|
||||
assert learnt, "PG multi-GPU (with fake-GPUs) did not learn CartPole!"
|
||||
trainer.stop()
|
||||
|
||||
def test_pg_loss_functions(self):
|
||||
"""Tests the PG loss function math."""
|
||||
config = pg.DEFAULT_CONFIG.copy()
|
||||
|
|
|
@ -20,7 +20,7 @@ from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches, \
|
|||
StandardizeFields, SelectExperiences
|
||||
from ray.rllib.execution.train_ops import TrainOneStep, TrainTFMultiGPU
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.policy import LEARNER_STATS_KEY, Policy
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
|
@ -86,13 +86,6 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"batch_mode": "truncate_episodes",
|
||||
# Which observation filter to apply to the observation.
|
||||
"observation_filter": "NoFilter",
|
||||
# Uses the sync samples optimizer instead of the multi-gpu one. This is
|
||||
# usually slower, but you might want to try it if you run into issues with
|
||||
# the default optimizer.
|
||||
"simple_optimizer": False,
|
||||
# Whether to fake GPUs (using CPUs).
|
||||
# Set this to True for debugging on non-GPU machines (set `num_gpus` > 0).
|
||||
"_fake_gpus": False,
|
||||
|
||||
# Deprecated keys:
|
||||
# Share layers for value function. If you set this to True, it's important
|
||||
|
@ -139,16 +132,8 @@ def validate_config(config: TrainerConfigDict) -> None:
|
|||
"function (to estimate the return at the end of the truncated "
|
||||
"trajectory). Consider setting batch_mode=complete_episodes.")
|
||||
|
||||
# Multi-gpu not supported for PyTorch and tf-eager.
|
||||
if config["framework"] in ["tf2", "tfe", "torch"]:
|
||||
config["simple_optimizer"] = True
|
||||
# Performance warning, if "simple" optimizer used with (static-graph) tf.
|
||||
elif config["simple_optimizer"]:
|
||||
logger.warning(
|
||||
"Using the simple minibatch optimizer. This will significantly "
|
||||
"reduce performance, consider simple_optimizer=False.")
|
||||
# Multi-agent mode and multi-GPU optimizer.
|
||||
elif config["multiagent"]["policies"] and not config["simple_optimizer"]:
|
||||
if config["multiagent"]["policies"] and not config["simple_optimizer"]:
|
||||
logger.info(
|
||||
"In multi-agent mode, policies will be optimized sequentially "
|
||||
"by the multi-GPU optimizer. Consider setting "
|
||||
|
@ -184,12 +169,14 @@ class UpdateKL:
|
|||
|
||||
def __call__(self, fetches):
|
||||
def update(pi, pi_id):
|
||||
assert "kl" not in fetches, (
|
||||
"kl should be nested under policy id key", fetches)
|
||||
assert LEARNER_STATS_KEY not in fetches, \
|
||||
("{} should be nested under policy id key".format(
|
||||
LEARNER_STATS_KEY), fetches)
|
||||
if pi_id in fetches:
|
||||
assert "kl" in fetches[pi_id], (fetches, pi_id)
|
||||
kl = fetches[pi_id][LEARNER_STATS_KEY].get("kl")
|
||||
assert kl is not None, (fetches, pi_id)
|
||||
# Make the actual `Policy.update_kl()` call.
|
||||
pi.update_kl(fetches[pi_id]["kl"])
|
||||
pi.update_kl(kl)
|
||||
else:
|
||||
logger.warning("No data for {}, not updating kl".format(pi_id))
|
||||
|
||||
|
@ -205,9 +192,11 @@ def warn_about_bad_reward_scales(config, result):
|
|||
# Warn about excessively high VF loss.
|
||||
learner_stats = result["info"]["learner"]
|
||||
if DEFAULT_POLICY_ID in learner_stats:
|
||||
scaled_vf_loss = (config["vf_loss_coeff"] *
|
||||
learner_stats[DEFAULT_POLICY_ID]["vf_loss"])
|
||||
policy_loss = learner_stats[DEFAULT_POLICY_ID]["policy_loss"]
|
||||
scaled_vf_loss = config["vf_loss_coeff"] * \
|
||||
learner_stats[DEFAULT_POLICY_ID][LEARNER_STATS_KEY]["vf_loss"]
|
||||
|
||||
policy_loss = learner_stats[DEFAULT_POLICY_ID][LEARNER_STATS_KEY][
|
||||
"policy_loss"]
|
||||
if config.get("model", {}).get("vf_share_layers") and \
|
||||
scaled_vf_loss > 100:
|
||||
logger.warning(
|
||||
|
@ -272,13 +261,10 @@ def execution_plan(workers: WorkerSet,
|
|||
else:
|
||||
train_op = rollouts.for_each(
|
||||
TrainTFMultiGPU(
|
||||
workers,
|
||||
workers=workers,
|
||||
sgd_minibatch_size=config["sgd_minibatch_size"],
|
||||
num_sgd_iter=config["num_sgd_iter"],
|
||||
num_gpus=config["num_gpus"],
|
||||
rollout_fragment_length=config["rollout_fragment_length"],
|
||||
num_envs_per_worker=config["num_envs_per_worker"],
|
||||
train_batch_size=config["train_batch_size"],
|
||||
shuffle_sequences=config["shuffle_sequences"],
|
||||
_fake_gpus=config["_fake_gpus"],
|
||||
framework=config.get("framework")))
|
||||
|
|
|
@ -2,6 +2,7 @@ import unittest
|
|||
|
||||
import ray
|
||||
import ray.rllib.agents.ppo as ppo
|
||||
from ray.rllib.policy.policy import LEARNER_STATS_KEY
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.utils.test_utils import check_compute_single_action, \
|
||||
framework_iterator
|
||||
|
@ -40,7 +41,8 @@ class TestDDPPO(unittest.TestCase):
|
|||
trainer = ppo.ddppo.DDPPOTrainer(config=config, env="CartPole-v0")
|
||||
for _ in range(num_iterations):
|
||||
result = trainer.train()
|
||||
lr = result["info"]["learner"][DEFAULT_POLICY_ID]["cur_lr"]
|
||||
lr = result["info"]["learner"][DEFAULT_POLICY_ID][
|
||||
LEARNER_STATS_KEY]["cur_lr"]
|
||||
trainer.stop()
|
||||
assert lr == 0.0, "lr should anneal to 0.0"
|
||||
|
||||
|
|
|
@ -90,7 +90,7 @@ class TestPPO(unittest.TestCase):
|
|||
for _ in framework_iterator(config):
|
||||
for env in ["CartPole-v0", "MsPacmanNoFrameskip-v4"]:
|
||||
print("Env={}".format(env))
|
||||
for lstm in [True, False]:
|
||||
for lstm in [False, True]:
|
||||
print("LSTM={}".format(lstm))
|
||||
config["model"]["use_lstm"] = lstm
|
||||
config["model"]["lstm_use_prev_action"] = lstm
|
||||
|
@ -111,7 +111,7 @@ class TestPPO(unittest.TestCase):
|
|||
config["num_gpus"] = 2
|
||||
config["_fake_gpus"] = True
|
||||
config["framework"] = "tf"
|
||||
# Mimick tuned_example for PPO CartPole.
|
||||
# Mimic tuned_example for PPO CartPole.
|
||||
config["num_workers"] = 1
|
||||
config["lr"] = 0.0003
|
||||
config["observation_filter"] = "MeanStdFilter"
|
||||
|
@ -127,7 +127,7 @@ class TestPPO(unittest.TestCase):
|
|||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
print(results)
|
||||
if results["episode_reward_mean"] > 150:
|
||||
if results["episode_reward_mean"] > 75.0:
|
||||
learnt = True
|
||||
break
|
||||
assert learnt, "PPO multi-GPU (with fake-GPUs) did not learn CartPole!"
|
||||
|
|
|
@ -167,6 +167,9 @@ def validate_config(config: TrainerConfigDict) -> None:
|
|||
Raises:
|
||||
ValueError: In case something is wrong with the config.
|
||||
"""
|
||||
if config["num_gpus"] > 1:
|
||||
raise ValueError("`num_gpus` > 1 not yet supported for SAC!")
|
||||
|
||||
if config["use_state_preprocessor"] != DEPRECATED_VALUE:
|
||||
deprecation_warning(
|
||||
old="config['use_state_preprocessor']", error=False)
|
||||
|
|
|
@ -136,6 +136,9 @@ DEFAULT_CONFIG = with_common_config({
|
|||
|
||||
def validate_config(config: TrainerConfigDict) -> None:
|
||||
"""Checks the config based on settings"""
|
||||
if config["num_gpus"] > 1:
|
||||
raise ValueError("`num_gpus` > 1 not yet supported for SlateQ!")
|
||||
|
||||
if config["framework"] != "torch":
|
||||
raise ValueError("SlateQ only runs on PyTorch")
|
||||
|
||||
|
|
|
@ -303,9 +303,14 @@ COMMON_CONFIG: TrainerConfigDict = {
|
|||
|
||||
# === Resource Settings ===
|
||||
# Number of GPUs to allocate to the trainer process. Note that not all
|
||||
# algorithms can take advantage of trainer GPUs. This can be fractional
|
||||
# (e.g., 0.3 GPUs).
|
||||
# algorithms can take advantage of trainer GPUs. Support for multi-GPU
|
||||
# is currently only available for tf-[PPO/IMPALA/DQN/PG].
|
||||
# This can be fractional (e.g., 0.3 GPUs).
|
||||
"num_gpus": 0,
|
||||
# Set to True for debugging (multi-)?GPU funcitonality on a CPU machine.
|
||||
# GPU towers will be simulated by graphs located on CPUs in this case.
|
||||
# Use `num_gpus` to test for different numbers of fake GPUs.
|
||||
"_fake_gpus": False,
|
||||
# Number of CPUs to allocate per worker.
|
||||
"num_cpus_per_worker": 1,
|
||||
# Number of GPUs to allocate per worker. This can be fractional. This is
|
||||
|
@ -404,6 +409,13 @@ COMMON_CONFIG: TrainerConfigDict = {
|
|||
# Define logger-specific configuration to be used inside Logger
|
||||
# Default value None allows overwriting with nested dicts
|
||||
"logger_config": None,
|
||||
|
||||
# Deprecated values.
|
||||
# Uses the sync samples optimizer instead of the multi-gpu one. This is
|
||||
# usually slower, but you might want to try it if you run into issues with
|
||||
# the default optimizer.
|
||||
# This will be set automatically from now on.
|
||||
"simple_optimizer": DEPRECATED_VALUE,
|
||||
}
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
@ -1120,6 +1132,25 @@ class Trainer(Trainable):
|
|||
if model_config is None:
|
||||
config["model"] = model_config = {}
|
||||
|
||||
# Multi-GPU settings.
|
||||
simple_optim_setting = config.get("simple_optimizer", DEPRECATED_VALUE)
|
||||
if simple_optim_setting != DEPRECATED_VALUE:
|
||||
deprecation_warning("simple_optimizer", error=False)
|
||||
|
||||
if config.get("num_gpus", 0) > 1:
|
||||
if config.get("framework") in ["tfe", "tf2", "torch"]:
|
||||
raise ValueError("`num_gpus` > 1 not supported yet for "
|
||||
"framework={}!".format(
|
||||
config.get("framework")))
|
||||
elif simple_optim_setting is True:
|
||||
raise ValueError(
|
||||
"Cannot use `simple_optimizer` if `num_gpus` > 1! "
|
||||
"Consider `simple_optimizer=False`.")
|
||||
config["simple_optimizer"] = False
|
||||
elif simple_optim_setting == DEPRECATED_VALUE:
|
||||
config["simple_optimizer"] = True
|
||||
|
||||
# Trajectory View API settings.
|
||||
if not config.get("_use_trajectory_view_api"):
|
||||
traj_view_framestacks = model_config.get("num_framestacks", "auto")
|
||||
if model_config.get("_time_major"):
|
||||
|
@ -1130,6 +1161,7 @@ class Trainer(Trainable):
|
|||
"iff `_use_trajectory_view_api` is True!")
|
||||
model_config["num_framestacks"] = 0
|
||||
|
||||
# Offline RL settings.
|
||||
if isinstance(config["input_evaluation"], tuple):
|
||||
config["input_evaluation"] = list(config["input_evaluation"])
|
||||
elif not isinstance(config["input_evaluation"], list):
|
||||
|
@ -1156,6 +1188,7 @@ class Trainer(Trainable):
|
|||
"complete_episodes]! Got {}".format(
|
||||
config["batch_mode"]))
|
||||
|
||||
# Check multi-agent batch count mode.
|
||||
if config["multiagent"].get("count_steps_by", "env_steps") not in \
|
||||
["env_steps", "agent_steps"]:
|
||||
raise ValueError(
|
||||
|
|
|
@ -5,7 +5,7 @@ from ray.rllib.agents.trainer import Trainer, COMMON_CONFIG
|
|||
from ray.rllib.env.env_context import EnvContext
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
|
||||
from ray.rllib.execution.train_ops import TrainOneStep
|
||||
from ray.rllib.execution.train_ops import TrainOneStep, TrainTFMultiGPU
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.policy import Policy
|
||||
from ray.rllib.utils import add_mixins
|
||||
|
@ -26,7 +26,21 @@ def default_execution_plan(workers: WorkerSet, config: TrainerConfigDict):
|
|||
ConcatBatches(
|
||||
min_batch_size=config["train_batch_size"],
|
||||
count_steps_by=config["multiagent"]["count_steps_by"],
|
||||
)).for_each(TrainOneStep(workers))
|
||||
))
|
||||
|
||||
if config.get("simple_optimizer") is True:
|
||||
train_op = train_op.for_each(TrainOneStep(workers))
|
||||
else:
|
||||
train_op = train_op.for_each(
|
||||
TrainTFMultiGPU(
|
||||
workers=workers,
|
||||
sgd_minibatch_size=config.get("sgd_minibatch_size",
|
||||
config["train_batch_size"]),
|
||||
num_sgd_iter=config.get("num_sgd_iter", 1),
|
||||
num_gpus=config["num_gpus"],
|
||||
shuffle_sequences=config.get("shuffle_sequences", False),
|
||||
_fake_gpus=config["_fake_gpus"],
|
||||
framework=config["framework"]))
|
||||
|
||||
# Add on the standard episode reward, etc. metrics reporting. This returns
|
||||
# a LocalIterator[metrics_dict] representing metrics for each train step.
|
||||
|
|
|
@ -63,13 +63,15 @@ if __name__ == "__main__":
|
|||
# Checkpoint with the lowest policy loss value:
|
||||
ckpt = results.get_best_checkpoint(
|
||||
best_trial,
|
||||
metric="info/learner/default_policy/policy_loss",
|
||||
metric="info/learner/default_policy/learner_stats/policy_loss",
|
||||
mode="min")
|
||||
print("Lowest pol-loss: {}".format(ckpt))
|
||||
|
||||
# Checkpoint with the highest value-function loss:
|
||||
ckpt = results.get_best_checkpoint(
|
||||
best_trial, metric="info/learner/default_policy/vf_loss", mode="max")
|
||||
best_trial,
|
||||
metric="info/learner/default_policy/learner_stats/vf_loss",
|
||||
mode="max")
|
||||
print("Highest vf-loss: {}".format(ckpt))
|
||||
|
||||
ray.shutdown()
|
||||
|
|
|
@ -11,6 +11,7 @@ from ray.rllib.models import ModelCatalog
|
|||
from ray.rllib.models.tf.misc import normc_initializer
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.models.tf.visionnet import VisionNetwork as MyVisionNetwork
|
||||
from ray.rllib.policy.policy import LEARNER_STATS_KEY
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
|
||||
|
@ -84,7 +85,7 @@ class MyKerasQModel(DistributionalQTFModel):
|
|||
kernel_initializer=normc_initializer(1.0))(layer_1)
|
||||
self.base_model = tf.keras.Model(self.inputs, layer_out)
|
||||
|
||||
# Implement the core forward method
|
||||
# Implement the core forward method.
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
model_out = self.base_model(input_dict["obs"])
|
||||
return model_out, state
|
||||
|
@ -95,7 +96,7 @@ class MyKerasQModel(DistributionalQTFModel):
|
|||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
ray.init(num_cpus=args.num_cpus or None)
|
||||
ray.init(num_cpus=args.num_cpus or None, local_mode=True)
|
||||
ModelCatalog.register_custom_model(
|
||||
"keras_model", MyVisionNetwork
|
||||
if args.use_vision_network else MyKerasModel)
|
||||
|
@ -107,7 +108,8 @@ if __name__ == "__main__":
|
|||
def check_has_custom_metric(result):
|
||||
r = result["result"]["info"]["learner"]
|
||||
if DEFAULT_POLICY_ID in r:
|
||||
r = r[DEFAULT_POLICY_ID]
|
||||
r = r[DEFAULT_POLICY_ID].get(LEARNER_STATS_KEY,
|
||||
r[DEFAULT_POLICY_ID])
|
||||
assert r["model"]["foo"] == 42, result
|
||||
|
||||
if args.run == "DQN":
|
||||
|
|
|
@ -33,7 +33,6 @@ parser.add_argument("--num-policies", type=int, default=2)
|
|||
parser.add_argument("--stop-iters", type=int, default=200)
|
||||
parser.add_argument("--stop-reward", type=float, default=150)
|
||||
parser.add_argument("--stop-timesteps", type=int, default=100000)
|
||||
parser.add_argument("--simple", action="store_true")
|
||||
parser.add_argument("--num-cpus", type=int, default=0)
|
||||
parser.add_argument("--as-test", action="store_true")
|
||||
parser.add_argument(
|
||||
|
@ -82,7 +81,6 @@ if __name__ == "__main__":
|
|||
"env_config": {
|
||||
"num_agents": args.num_agents,
|
||||
},
|
||||
"simple_optimizer": args.simple,
|
||||
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
||||
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
||||
"num_sgd_iter": 10,
|
||||
|
|
|
@ -73,7 +73,6 @@ if __name__ == "__main__":
|
|||
config["num_workers"] = args.num_workers
|
||||
config["rollout_fragment_length"] = 200
|
||||
config["sgd_minibatch_size"] = 256
|
||||
config["simple_optimizer"] = True
|
||||
config["train_batch_size"] = 4000
|
||||
|
||||
config["batch_mode"] = "complete_episodes"
|
||||
|
|
|
@ -251,8 +251,9 @@ class LocalSyncParallelOptimizer:
|
|||
feed_dict.update(tower.loss_graph.extra_compute_grad_feed_dict())
|
||||
|
||||
fetches = {"train": self._train_op}
|
||||
for tower in self._towers:
|
||||
fetches.update(tower.loss_graph._get_grad_and_stats_fetches())
|
||||
for tower_num, tower in enumerate(self._towers):
|
||||
tower_fetch = tower.loss_graph._get_grad_and_stats_fetches()
|
||||
fetches["tower_{}".format(tower_num)] = tower_fetch
|
||||
|
||||
return sess.run(fetches, feed_dict=feed_dict)
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from collections import defaultdict
|
||||
import logging
|
||||
import numpy as np
|
||||
import math
|
||||
import tree
|
||||
from typing import List, Tuple, Any
|
||||
|
||||
import ray
|
||||
|
@ -18,7 +18,7 @@ from ray.rllib.execution.multi_gpu_impl import LocalSyncParallelOptimizer
|
|||
from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
|
||||
MultiAgentBatch
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.sgd import do_minibatch_sgd, averaged
|
||||
from ray.rllib.utils.sgd import do_minibatch_sgd
|
||||
from ray.rllib.utils.typing import PolicyID, SampleBatchType, ModelGradients
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
@ -106,13 +106,11 @@ class TrainTFMultiGPU:
|
|||
"""
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
workers: WorkerSet,
|
||||
sgd_minibatch_size: int,
|
||||
num_sgd_iter: int,
|
||||
num_gpus: int,
|
||||
rollout_fragment_length: int,
|
||||
num_envs_per_worker: int,
|
||||
train_batch_size: int,
|
||||
shuffle_sequences: bool,
|
||||
policies: List[PolicyID] = frozenset([]),
|
||||
_fake_gpus: bool = False,
|
||||
|
@ -124,7 +122,7 @@ class TrainTFMultiGPU:
|
|||
self.shuffle_sequences = shuffle_sequences
|
||||
self.framework = framework
|
||||
|
||||
# Collect actual devices to use.
|
||||
# Collect actual GPU devices to use.
|
||||
if not num_gpus:
|
||||
_fake_gpus = True
|
||||
num_gpus = 1
|
||||
|
@ -133,10 +131,13 @@ class TrainTFMultiGPU:
|
|||
"/{}:{}".format(type_, i) for i in range(int(math.ceil(num_gpus)))
|
||||
]
|
||||
|
||||
# Total batch size (all towers). Make sure it is dividable by
|
||||
# num towers.
|
||||
self.batch_size = int(sgd_minibatch_size / len(self.devices)) * len(
|
||||
self.devices)
|
||||
assert self.batch_size % len(self.devices) == 0
|
||||
assert self.batch_size >= len(self.devices), "batch size too small"
|
||||
# Batch size per tower.
|
||||
self.per_device_batch_size = int(self.batch_size / len(self.devices))
|
||||
|
||||
# per-GPU graph copies created below must share vars with the policy
|
||||
|
@ -177,8 +178,8 @@ class TrainTFMultiGPU:
|
|||
metrics = _get_shared_metrics()
|
||||
load_timer = metrics.timers[LOAD_BATCH_TIMER]
|
||||
learn_timer = metrics.timers[LEARN_ON_BATCH_TIMER]
|
||||
# Load data into GPUs.
|
||||
with load_timer:
|
||||
# (1) Load data into GPUs.
|
||||
num_loaded_tuples = {}
|
||||
for policy_id, batch in samples.policy_batches.items():
|
||||
# Not a policy-to-train.
|
||||
|
@ -202,8 +203,8 @@ class TrainTFMultiGPU:
|
|||
self.sess, [tuples[k] for k in data_keys],
|
||||
[tuples[k] for k in state_keys]))
|
||||
|
||||
# Execute minibatch SGD on loaded data.
|
||||
with learn_timer:
|
||||
# (2) Execute minibatch SGD on loaded data.
|
||||
fetches = {}
|
||||
for policy_id, tuples_per_device in num_loaded_tuples.items():
|
||||
optimizer = self.optimizers[policy_id]
|
||||
|
@ -211,19 +212,25 @@ class TrainTFMultiGPU:
|
|||
1,
|
||||
int(tuples_per_device) // int(self.per_device_batch_size))
|
||||
logger.debug("== sgd epochs for {} ==".format(policy_id))
|
||||
for i in range(self.num_sgd_iter):
|
||||
iter_extra_fetches = defaultdict(list)
|
||||
for _ in range(self.num_sgd_iter):
|
||||
permutation = np.random.permutation(num_batches)
|
||||
batch_fetches_all_towers = []
|
||||
for batch_index in range(num_batches):
|
||||
batch_fetches = optimizer.optimize(
|
||||
self.sess, permutation[batch_index] *
|
||||
self.per_device_batch_size)
|
||||
for k, v in batch_fetches[LEARNER_STATS_KEY].items():
|
||||
iter_extra_fetches[k].append(v)
|
||||
if logger.getEffectiveLevel() <= logging.DEBUG:
|
||||
avg = averaged(iter_extra_fetches)
|
||||
logger.debug("{} {}".format(i, avg))
|
||||
fetches[policy_id] = averaged(iter_extra_fetches, axis=0)
|
||||
|
||||
batch_fetches_all_towers.append(
|
||||
tree.map_structure_with_path(
|
||||
lambda p, *s: self._all_tower_reduce(p, *s),
|
||||
*(batch_fetches["tower_{}".format(tower_num)]
|
||||
for tower_num in range(len(self.devices)))))
|
||||
|
||||
# Reduce mean across all minibatch SGD steps (axis=0 to keep
|
||||
# all shapes as-is).
|
||||
fetches[policy_id] = tree.map_structure(
|
||||
lambda *s: np.nanmean(s, axis=0),
|
||||
*batch_fetches_all_towers)
|
||||
|
||||
load_timer.push_units_processed(samples.count)
|
||||
learn_timer.push_units_processed(samples.count)
|
||||
|
@ -240,6 +247,16 @@ class TrainTFMultiGPU:
|
|||
self.workers.local_worker().set_global_vars(_get_global_vars())
|
||||
return samples, fetches
|
||||
|
||||
def _all_tower_reduce(self, path, *tower_data):
|
||||
"""Reduces stats across towers based on their stats-dict paths."""
|
||||
if len(path) == 1 and path[0] == "td_error":
|
||||
return np.concatenate(tower_data, axis=0)
|
||||
elif path[-1].startswith("min_"):
|
||||
return np.nanmin(tower_data)
|
||||
elif path[-1].startswith("max_"):
|
||||
return np.nanmax(tower_data)
|
||||
return np.nanmean(tower_data)
|
||||
|
||||
|
||||
class ComputeGradients:
|
||||
"""Callable that computes gradients with respect to the policy loss.
|
||||
|
|
|
@ -160,7 +160,11 @@ class DynamicTFPolicy(TFPolicy):
|
|||
|
||||
# Setup self.model.
|
||||
if existing_model:
|
||||
self.model = existing_model
|
||||
if isinstance(existing_model, list):
|
||||
self.model = existing_model[0]
|
||||
# TODO: (sven) hack, but works for `target_[q_]?model`.
|
||||
for i in range(1, len(existing_model)):
|
||||
setattr(self, existing_model[i][0], existing_model[i][1])
|
||||
elif make_model:
|
||||
self.model = make_model(self, obs_space, action_space, config)
|
||||
else:
|
||||
|
@ -389,7 +393,11 @@ class DynamicTFPolicy(TFPolicy):
|
|||
self.action_space,
|
||||
self.config,
|
||||
existing_inputs=input_dict,
|
||||
existing_model=self.model)
|
||||
existing_model=[
|
||||
self.model,
|
||||
("target_q_model", getattr(self, "target_q_model", None)),
|
||||
("target_model", getattr(self, "target_model", None)),
|
||||
])
|
||||
|
||||
instance._loss_input_dict = input_dict
|
||||
loss = instance._do_loss_init(input_dict)
|
||||
|
|
|
@ -292,6 +292,12 @@ def build_tf_policy(
|
|||
@override(TFPolicy)
|
||||
def extra_compute_grad_fetches(self):
|
||||
if extra_learn_fetches_fn:
|
||||
# TODO: (sven) in torch, extra_learn_fetches do not exist.
|
||||
# Hence, things like td_error are returned by the stats_fn
|
||||
# and end up under the LEARNER_STATS_KEY. We should
|
||||
# change tf to do this as well. However, this will confilct
|
||||
# the handling of LEARNER_STATS_KEY inside the multi-GPU
|
||||
# train op.
|
||||
# Auto-add empty learner stats dict if needed.
|
||||
return dict({
|
||||
LEARNER_STATS_KEY: {}
|
||||
|
|
|
@ -120,7 +120,6 @@ class TestRNNSequencing(unittest.TestCase):
|
|||
"rollout_fragment_length": 10,
|
||||
"train_batch_size": 10,
|
||||
"sgd_minibatch_size": 10,
|
||||
"simple_optimizer": True,
|
||||
"num_sgd_iter": 1,
|
||||
"model": {
|
||||
"custom_model": "rnn",
|
||||
|
@ -178,7 +177,6 @@ class TestRNNSequencing(unittest.TestCase):
|
|||
"rollout_fragment_length": 20,
|
||||
"train_batch_size": 20,
|
||||
"sgd_minibatch_size": 10,
|
||||
"simple_optimizer": False,
|
||||
"num_sgd_iter": 1,
|
||||
"model": {
|
||||
"custom_model": "rnn",
|
||||
|
@ -198,15 +196,17 @@ class TestRNNSequencing(unittest.TestCase):
|
|||
ray.experimental.internal_kv._internal_kv_get("rnn_spy_in_1"))
|
||||
if batch0["sequences"][0][0][0] > batch1["sequences"][0][0][0]:
|
||||
batch0, batch1 = batch1, batch0 # sort minibatches
|
||||
self.assertEqual(batch0["seq_lens"].tolist(), [4, 4])
|
||||
self.assertEqual(batch1["seq_lens"].tolist(), [4, 3])
|
||||
self.assertEqual(batch0["seq_lens"].tolist(), [4, 4, 2])
|
||||
self.assertEqual(batch1["seq_lens"].tolist(), [4, 3, 3])
|
||||
self.assertEqual(batch0["sequences"].tolist(), [
|
||||
[[0], [1], [2], [3]],
|
||||
[[4], [5], [6], [7]],
|
||||
[[8], [9], [0], [0]],
|
||||
])
|
||||
self.assertEqual(batch1["sequences"].tolist(), [
|
||||
[[8], [9], [10], [11]],
|
||||
[[12], [13], [14], [0]],
|
||||
[[0], [1], [2], [0]],
|
||||
])
|
||||
|
||||
# second epoch: 20 observations get split into 2 minibatches of 8
|
||||
|
@ -217,15 +217,17 @@ class TestRNNSequencing(unittest.TestCase):
|
|||
ray.experimental.internal_kv._internal_kv_get("rnn_spy_in_3"))
|
||||
if batch2["sequences"][0][0][0] > batch3["sequences"][0][0][0]:
|
||||
batch2, batch3 = batch3, batch2
|
||||
self.assertEqual(batch2["seq_lens"].tolist(), [4, 4])
|
||||
self.assertEqual(batch3["seq_lens"].tolist(), [2, 4])
|
||||
self.assertEqual(batch2["seq_lens"].tolist(), [4, 4, 2])
|
||||
self.assertEqual(batch3["seq_lens"].tolist(), [4, 4, 2])
|
||||
self.assertEqual(batch2["sequences"].tolist(), [
|
||||
[[5], [6], [7], [8]],
|
||||
[[9], [10], [11], [12]],
|
||||
[[0], [1], [2], [3]],
|
||||
[[4], [5], [6], [7]],
|
||||
[[8], [9], [0], [0]],
|
||||
])
|
||||
self.assertEqual(batch3["sequences"].tolist(), [
|
||||
[[5], [6], [7], [8]],
|
||||
[[9], [10], [11], [12]],
|
||||
[[13], [14], [0], [0]],
|
||||
[[0], [1], [2], [3]],
|
||||
])
|
||||
|
||||
|
||||
|
|
|
@ -96,7 +96,7 @@ def do_minibatch_sgd(samples, policies, local_worker, num_sgd_iter,
|
|||
if isinstance(samples, SampleBatch):
|
||||
samples = MultiAgentBatch({DEFAULT_POLICY_ID: samples}, samples.count)
|
||||
|
||||
fetches = {}
|
||||
fetches = defaultdict(dict)
|
||||
for policy_id in policies.keys():
|
||||
if policy_id not in samples.policy_batches:
|
||||
continue
|
||||
|
@ -106,14 +106,13 @@ def do_minibatch_sgd(samples, policies, local_worker, num_sgd_iter,
|
|||
batch[field] = standardized(batch[field])
|
||||
|
||||
for i in range(num_sgd_iter):
|
||||
iter_extra_fetches = defaultdict(list)
|
||||
learner_stats = defaultdict(list)
|
||||
for minibatch in minibatches(batch, sgd_minibatch_size):
|
||||
batch_fetches = (local_worker.learn_on_batch(
|
||||
MultiAgentBatch({
|
||||
policy_id: minibatch
|
||||
}, minibatch.count)))[policy_id]
|
||||
for k, v in batch_fetches.get(LEARNER_STATS_KEY, {}).items():
|
||||
iter_extra_fetches[k].append(v)
|
||||
logger.debug("{} {}".format(i, averaged(iter_extra_fetches)))
|
||||
fetches[policy_id] = averaged(iter_extra_fetches)
|
||||
learner_stats[k].append(v)
|
||||
fetches[policy_id][LEARNER_STATS_KEY] = averaged(learner_stats)
|
||||
return fetches
|
||||
|
|
Loading…
Add table
Reference in a new issue