mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] First attempt at cleaning up algo code in RLlib: PG. (#10115)
This commit is contained in:
parent
538cb802d5
commit
d14b501692
17 changed files with 396 additions and 196 deletions
|
@ -35,7 +35,7 @@ Algorithm Frameworks Discrete Actions Continuous Actions Multi-
|
|||
.. _`+LSTM auto-wrapping`: rllib-models.html#built-in-models
|
||||
.. _`+parametric`: rllib-models.html#variable-length-parametric-action-spaces
|
||||
.. _`+RNN`: rllib-models.html#recurrent-models
|
||||
.. _`+Transformer`: rllib-models.html#attention-networks
|
||||
.. _`+Transformer`: rllib-models.html#attention-networks-transformers
|
||||
.. _`A2C, A3C`: rllib-algorithms.html#a3c
|
||||
.. _`APEX-DQN`: rllib-algorithms.html#apex
|
||||
.. _`APEX-DDPG`: rllib-algorithms.html#apex
|
||||
|
@ -304,16 +304,22 @@ SpaceInvaders 650 1001 1025
|
|||
|
||||
Policy Gradients
|
||||
----------------
|
||||
|pytorch| |tensorflow|
|
||||
`[paper] <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/agents/pg/pg.py>`__ We include a vanilla policy gradients implementation as an example algorithm.
|
||||
|pytorch| |tensorflow| An `implementation <https://github.com/ray-project/ray/blob/master/rllib/agents/pg/pg.py>`__ of a vanilla policy gradient algorithm for TensorFlow and PyTorch.
|
||||
|
||||
**Papers**:
|
||||
`[1] - Policy Gradient Methods for Reinforcement Learning with Function Approximation. <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`__
|
||||
and
|
||||
`[2] - Simple Statistical Gradient-Following Algorithms for Connectionist Reinforcement Learning. <http://www-anw.cs.umass.edu/~barto/courses/cs687/williams92simple.pdf>`__
|
||||
|
||||
|
||||
.. figure:: a2c-arch.svg
|
||||
|
||||
Policy gradients architecture (same as A2C)
|
||||
|
||||
Tuned examples: `CartPole-v0 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/pg/cartpole-pg.yaml>`__
|
||||
**Tuned examples**: `CartPole-v0 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/pg/cartpole-pg.yaml>`__
|
||||
|
||||
**PG-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
|
||||
**PG-specific configs**: The following updates will overwrite/be added to the
|
||||
(base) Trainer config in `rllib/agents/trainer.py <rllib-training.html#common-parameters>`__ (*COMMON_CONFIG* dict):
|
||||
|
||||
.. literalinclude:: ../../rllib/agents/pg/pg.py
|
||||
:language: python
|
||||
|
|
|
@ -10,7 +10,7 @@ from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
|||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.tf_policy import TFPolicy
|
||||
|
||||
from ray.rllib.policy.torch_policy import TorchPolicy
|
||||
from ray.tune.registry import register_trainable
|
||||
|
||||
|
||||
|
@ -60,6 +60,7 @@ _register_all()
|
|||
__all__ = [
|
||||
"Policy",
|
||||
"TFPolicy",
|
||||
"TorchPolicy",
|
||||
"RolloutWorker",
|
||||
"SampleBatch",
|
||||
"BaseEnv",
|
||||
|
|
8
rllib/agents/pg/README.md
Normal file
8
rllib/agents/pg/README.md
Normal file
|
@ -0,0 +1,8 @@
|
|||
Policy Gradient (PG)
|
||||
====================
|
||||
|
||||
An implementation of a vanilla policy gradient algorithm for TensorFlow and PyTorch.
|
||||
|
||||
**[Detailed Documentation](https://docs.ray.io/en/latest/rllib-algorithms.html#pg)**
|
||||
|
||||
**[Implementation](https://github.com/ray-project/ray/blob/master/rllib/agents/pg/pg.py)**
|
|
@ -1,29 +1,57 @@
|
|||
"""
|
||||
Policy Gradient (PG)
|
||||
====================
|
||||
|
||||
This file defines the distributed Trainer class for policy gradients.
|
||||
See `pg_[tf|torch]_policy.py` for the definition of the policy loss.
|
||||
|
||||
Detailed documentation: https://docs.ray.io/en/latest/rllib-algorithms.html#pg
|
||||
"""
|
||||
|
||||
from typing import Optional, Type
|
||||
|
||||
from ray.rllib.agents.trainer import with_common_config
|
||||
from ray.rllib.agents.trainer_template import build_trainer
|
||||
from ray.rllib.agents.pg.pg_tf_policy import PGTFPolicy
|
||||
from ray.rllib.agents.pg.pg_torch_policy import PGTorchPolicy
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
|
||||
# Adds the following updates to the (base) `Trainer` config in
|
||||
# rllib/agents/trainer.py (`COMMON_CONFIG` dict).
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
# No remote workers by default.
|
||||
"num_workers": 0,
|
||||
# Learning rate.
|
||||
"lr": 0.0004,
|
||||
})
|
||||
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
||||
|
||||
def get_policy_class(config):
|
||||
def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
|
||||
"""Policy class picker function. Class is chosen based on DL-framework.
|
||||
|
||||
Args:
|
||||
config (TrainerConfigDict): The trainer's configuration dict.
|
||||
|
||||
Returns:
|
||||
Optional[Type[Policy]]: The Policy class to use with PGTrainer.
|
||||
If None, use `default_policy` provided in build_trainer().
|
||||
"""
|
||||
if config["framework"] == "torch":
|
||||
from ray.rllib.agents.pg.pg_torch_policy import PGTorchPolicy
|
||||
return PGTorchPolicy
|
||||
else:
|
||||
return PGTFPolicy
|
||||
|
||||
|
||||
# Build a child class of `Trainer`, which uses the framework specific Policy
|
||||
# determined in `get_policy_class()` above.
|
||||
PGTrainer = build_trainer(
|
||||
name="PG",
|
||||
default_config=DEFAULT_CONFIG,
|
||||
default_policy=PGTFPolicy,
|
||||
get_policy_class=get_policy_class)
|
||||
get_policy_class=get_policy_class,
|
||||
)
|
||||
|
|
|
@ -1,35 +1,54 @@
|
|||
"""
|
||||
TensorFlow policy class used for PG.
|
||||
"""
|
||||
|
||||
from typing import List, Type, Union
|
||||
|
||||
import ray
|
||||
from ray.rllib.evaluation.postprocessing import Postprocessing, \
|
||||
compute_advantages
|
||||
from ray.rllib.agents.pg.utils import post_process_advantages
|
||||
from ray.rllib.evaluation.postprocessing import Postprocessing
|
||||
from ray.rllib.models.action_dist import ActionDistribution
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.policy import Policy
|
||||
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.typing import TensorType
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
|
||||
def post_process_advantages(policy,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
"""This adds the "advantages" column to the sample train_batch."""
|
||||
return compute_advantages(
|
||||
sample_batch,
|
||||
0.0,
|
||||
policy.config["gamma"],
|
||||
use_gae=False,
|
||||
use_critic=False)
|
||||
def pg_tf_loss(
|
||||
policy: Policy, model: ModelV2, dist_class: Type[ActionDistribution],
|
||||
train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]:
|
||||
"""The basic policy gradients loss function.
|
||||
|
||||
Args:
|
||||
policy (Policy): The Policy to calculate the loss for.
|
||||
model (ModelV2): The Model to calculate the loss for.
|
||||
dist_class (Type[ActionDistribution]: The action distr. class.
|
||||
train_batch (SampleBatch): The training data.
|
||||
|
||||
def pg_tf_loss(policy, model, dist_class, train_batch):
|
||||
"""The basic policy gradients loss."""
|
||||
logits, _ = model.from_batch(train_batch)
|
||||
action_dist = dist_class(logits, model)
|
||||
Returns:
|
||||
Union[TensorType, List[TensorType]]: A single loss tensor or a list
|
||||
of loss tensors.
|
||||
"""
|
||||
# Pass the training data through our model to get distribution parameters.
|
||||
dist_inputs, _ = model.from_batch(train_batch)
|
||||
|
||||
# Create an action distribution object.
|
||||
action_dist = dist_class(dist_inputs, model)
|
||||
|
||||
# Calculate the vanilla PG loss based on:
|
||||
# L = -E[ log(pi(a|s)) * A]
|
||||
return -tf.reduce_mean(
|
||||
action_dist.logp(train_batch[SampleBatch.ACTIONS]) * tf.cast(
|
||||
train_batch[Postprocessing.ADVANTAGES], dtype=tf.float32))
|
||||
|
||||
|
||||
# Build a child class of `TFPolicy`, given the extra options:
|
||||
# - trajectory post-processing function (to calculate advantages)
|
||||
# - PG loss function
|
||||
PGTFPolicy = build_tf_policy(
|
||||
name="PGTFPolicy",
|
||||
get_default_config=lambda: ray.rllib.agents.pg.pg.DEFAULT_CONFIG,
|
||||
|
|
|
@ -1,31 +1,77 @@
|
|||
"""
|
||||
PyTorch policy class used for PG.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Type, Union
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.pg.pg_tf_policy import post_process_advantages
|
||||
from ray.rllib.agents.pg.utils import post_process_advantages
|
||||
from ray.rllib.evaluation.postprocessing import Postprocessing
|
||||
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.torch_policy_template import build_torch_policy
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.typing import TensorType
|
||||
|
||||
torch, _ = try_import_torch()
|
||||
|
||||
|
||||
def pg_torch_loss(policy, model, dist_class, train_batch):
|
||||
"""The basic policy gradients loss."""
|
||||
logits, _ = model.from_batch(train_batch)
|
||||
action_dist = dist_class(logits, model)
|
||||
def pg_torch_loss(
|
||||
policy: Policy, model: ModelV2,
|
||||
dist_class: Type[TorchDistributionWrapper],
|
||||
train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]:
|
||||
"""The basic policy gradients loss function.
|
||||
|
||||
Args:
|
||||
policy (Policy): The Policy to calculate the loss for.
|
||||
model (ModelV2): The Model to calculate the loss for.
|
||||
dist_class (Type[ActionDistribution]: The action distr. class.
|
||||
train_batch (SampleBatch): The training data.
|
||||
|
||||
Returns:
|
||||
Union[TensorType, List[TensorType]]: A single loss tensor or a list
|
||||
of loss tensors.
|
||||
"""
|
||||
# Pass the training data through our model to get distribution parameters.
|
||||
dist_inputs, _ = model.from_batch(train_batch)
|
||||
|
||||
# Create an action distribution object.
|
||||
action_dist = dist_class(dist_inputs, model)
|
||||
|
||||
# Calculate the vanilla PG loss based on:
|
||||
# L = -E[ log(pi(a|s)) * A]
|
||||
log_probs = action_dist.logp(train_batch[SampleBatch.ACTIONS])
|
||||
# Save the error in the policy object.
|
||||
# policy.pi_err = -train_batch[Postprocessing.ADVANTAGES].dot(
|
||||
# log_probs.reshape(-1)) / len(log_probs)
|
||||
|
||||
# Save the loss in the policy object for the stats_fn below.
|
||||
policy.pi_err = -torch.mean(
|
||||
log_probs * train_batch[Postprocessing.ADVANTAGES])
|
||||
|
||||
return policy.pi_err
|
||||
|
||||
|
||||
def pg_loss_stats(policy, train_batch):
|
||||
""" The error is recorded when computing the loss."""
|
||||
return {"policy_loss": policy.pi_err.item()}
|
||||
def pg_loss_stats(policy: Policy,
|
||||
train_batch: SampleBatch) -> Dict[str, TensorType]:
|
||||
"""Returns the calculated loss in a stats dict.
|
||||
|
||||
Args:
|
||||
policy (Policy): The Policy object.
|
||||
train_batch (SampleBatch): The data used for training.
|
||||
|
||||
Returns:
|
||||
Dict[str, TensorType]: The stats dict.
|
||||
"""
|
||||
|
||||
return {
|
||||
# `pi_err` (the loss) is stored inside `pg_torch_loss()`.
|
||||
"policy_loss": policy.pi_err.item(),
|
||||
}
|
||||
|
||||
|
||||
# Build a child class of `TFPolicy`, given the extra options:
|
||||
# - trajectory post-processing function (to calculate advantages)
|
||||
# - PG loss function
|
||||
PGTorchPolicy = build_torch_policy(
|
||||
name="PGTorchPolicy",
|
||||
get_default_config=lambda: ray.rllib.agents.pg.pg.DEFAULT_CONFIG,
|
||||
|
|
36
rllib/agents/pg/utils.py
Normal file
36
rllib/agents/pg/utils.py
Normal file
|
@ -0,0 +1,36 @@
|
|||
from typing import List, Optional
|
||||
|
||||
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages
|
||||
from ray.rllib.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
|
||||
|
||||
def post_process_advantages(
|
||||
policy: Policy,
|
||||
sample_batch: SampleBatch,
|
||||
other_agent_batches: Optional[List[SampleBatch]] = None,
|
||||
episode: Optional[MultiAgentEpisode] = None) -> SampleBatch:
|
||||
"""Adds the "advantages" column to `sample_batch`.
|
||||
|
||||
Args:
|
||||
policy (Policy): The Policy object to do post-processing for.
|
||||
sample_batch (SampleBatch): The actual sample batch to post-process.
|
||||
other_agent_batches (Optional[List[SampleBatch]]): Optional list of
|
||||
other agents' SampleBatch objects.
|
||||
episode (MultiAgentEpisode): The multi-agent episode object, from which
|
||||
`sample_batch` was generated.
|
||||
|
||||
Returns:
|
||||
SampleBatch: The SampleBatch enhanced by the added ADVANTAGES field.
|
||||
"""
|
||||
|
||||
# Calculates advantage values based on the rewards in the sample batch.
|
||||
# The value of the last observation is assumed to be 0.0 (no value function
|
||||
# estimation at the end of the sampled chunk).
|
||||
return compute_advantages(
|
||||
rollout=sample_batch,
|
||||
last_r=0.0,
|
||||
gamma=policy.config["gamma"],
|
||||
use_gae=False,
|
||||
use_critic=False)
|
|
@ -1,62 +1,65 @@
|
|||
from ray.rllib.agents.impala.impala import validate_config
|
||||
from ray.rllib.agents.ppo.appo_tf_policy import AsyncPPOTFPolicy
|
||||
from ray.rllib.agents.ppo.ppo import UpdateKL
|
||||
from ray.rllib.agents.trainer import with_base_config
|
||||
from ray.rllib.agents import impala
|
||||
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \
|
||||
LAST_TARGET_UPDATE_TS, NUM_TARGET_UPDATES, _get_shared_metrics
|
||||
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
DEFAULT_CONFIG = with_base_config(impala.DEFAULT_CONFIG, {
|
||||
# Whether to use V-trace weighted advantages. If false, PPO GAE advantages
|
||||
# will be used instead.
|
||||
"vtrace": False,
|
||||
DEFAULT_CONFIG = impala.ImpalaTrainer.merge_trainer_configs(
|
||||
impala.DEFAULT_CONFIG, # See keys in impala.py, which are also supported.
|
||||
{
|
||||
# Whether to use V-trace weighted advantages. If false, PPO GAE
|
||||
# advantages will be used instead.
|
||||
"vtrace": False,
|
||||
|
||||
# == These two options only apply if vtrace: False ==
|
||||
# Should use a critic as a baseline (otherwise don't use value baseline;
|
||||
# required for using GAE).
|
||||
"use_critic": True,
|
||||
# If true, use the Generalized Advantage Estimator (GAE)
|
||||
# with a value function, see https://arxiv.org/pdf/1506.02438.pdf.
|
||||
"use_gae": True,
|
||||
# GAE(lambda) parameter
|
||||
"lambda": 1.0,
|
||||
# == These two options only apply if vtrace: False ==
|
||||
# Should use a critic as a baseline (otherwise don't use value
|
||||
# baseline; required for using GAE).
|
||||
"use_critic": True,
|
||||
# If true, use the Generalized Advantage Estimator (GAE)
|
||||
# with a value function, see https://arxiv.org/pdf/1506.02438.pdf.
|
||||
"use_gae": True,
|
||||
# GAE(lambda) parameter
|
||||
"lambda": 1.0,
|
||||
|
||||
# == PPO surrogate loss options ==
|
||||
"clip_param": 0.4,
|
||||
# == PPO surrogate loss options ==
|
||||
"clip_param": 0.4,
|
||||
|
||||
# == PPO KL Loss options ==
|
||||
"use_kl_loss": False,
|
||||
"kl_coeff": 1.0,
|
||||
"kl_target": 0.01,
|
||||
# == PPO KL Loss options ==
|
||||
"use_kl_loss": False,
|
||||
"kl_coeff": 1.0,
|
||||
"kl_target": 0.01,
|
||||
|
||||
# == IMPALA optimizer params (see documentation in impala.py) ==
|
||||
"rollout_fragment_length": 50,
|
||||
"train_batch_size": 500,
|
||||
"min_iter_time_s": 10,
|
||||
"num_workers": 2,
|
||||
"num_gpus": 0,
|
||||
"num_data_loader_buffers": 1,
|
||||
"minibatch_buffer_size": 1,
|
||||
"num_sgd_iter": 1,
|
||||
"replay_proportion": 0.0,
|
||||
"replay_buffer_num_slots": 100,
|
||||
"learner_queue_size": 16,
|
||||
"learner_queue_timeout": 300,
|
||||
"max_sample_requests_in_flight_per_worker": 2,
|
||||
"broadcast_interval": 1,
|
||||
"grad_clip": 40.0,
|
||||
"opt_type": "adam",
|
||||
"lr": 0.0005,
|
||||
"lr_schedule": None,
|
||||
"decay": 0.99,
|
||||
"momentum": 0.0,
|
||||
"epsilon": 0.1,
|
||||
"vf_loss_coeff": 0.5,
|
||||
"entropy_coeff": 0.01,
|
||||
"entropy_coeff_schedule": None,
|
||||
})
|
||||
# == IMPALA optimizer params (see documentation in impala.py) ==
|
||||
"rollout_fragment_length": 50,
|
||||
"train_batch_size": 500,
|
||||
"min_iter_time_s": 10,
|
||||
"num_workers": 2,
|
||||
"num_gpus": 0,
|
||||
"num_data_loader_buffers": 1,
|
||||
"minibatch_buffer_size": 1,
|
||||
"num_sgd_iter": 1,
|
||||
"replay_proportion": 0.0,
|
||||
"replay_buffer_num_slots": 100,
|
||||
"learner_queue_size": 16,
|
||||
"learner_queue_timeout": 300,
|
||||
"max_sample_requests_in_flight_per_worker": 2,
|
||||
"broadcast_interval": 1,
|
||||
"grad_clip": 40.0,
|
||||
"opt_type": "adam",
|
||||
"lr": 0.0005,
|
||||
"lr_schedule": None,
|
||||
"decay": 0.99,
|
||||
"momentum": 0.0,
|
||||
"epsilon": 0.1,
|
||||
"vf_loss_coeff": 0.5,
|
||||
"entropy_coeff": 0.01,
|
||||
"entropy_coeff_schedule": None,
|
||||
},
|
||||
_allow_unknown_configs=True,
|
||||
)
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
||||
|
|
|
@ -19,7 +19,6 @@ import time
|
|||
|
||||
import ray
|
||||
from ray.rllib.agents.ppo import ppo
|
||||
from ray.rllib.agents.trainer import with_base_config
|
||||
from ray.rllib.execution.rollout_ops import ParallelRollouts
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \
|
||||
|
@ -32,31 +31,42 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
DEFAULT_CONFIG = with_base_config(ppo.DEFAULT_CONFIG, {
|
||||
# During the sampling phase, each rollout worker will collect a batch
|
||||
# `rollout_fragment_length * num_envs_per_worker` steps in size.
|
||||
"rollout_fragment_length": 100,
|
||||
# Vectorize the env (should enable by default since each worker has a GPU).
|
||||
"num_envs_per_worker": 5,
|
||||
# During the SGD phase, workers iterate over minibatches of this size.
|
||||
# The effective minibatch size will be `sgd_minibatch_size * num_workers`.
|
||||
"sgd_minibatch_size": 50,
|
||||
# Number of SGD epochs per optimization round.
|
||||
"num_sgd_iter": 10,
|
||||
# Download weights between each training step. This adds a bit of overhead
|
||||
# but allows the user to access the weights from the trainer.
|
||||
"keep_local_weights_in_sync": True,
|
||||
DEFAULT_CONFIG = ppo.PPOTrainer.merge_trainer_configs(
|
||||
ppo.DEFAULT_CONFIG,
|
||||
{
|
||||
# During the sampling phase, each rollout worker will collect a batch
|
||||
# `rollout_fragment_length * num_envs_per_worker` steps in size.
|
||||
"rollout_fragment_length": 100,
|
||||
# Vectorize the env (should enable by default since each worker has
|
||||
# a GPU).
|
||||
"num_envs_per_worker": 5,
|
||||
# During the SGD phase, workers iterate over minibatches of this size.
|
||||
# The effective minibatch size will be:
|
||||
# `sgd_minibatch_size * num_workers`.
|
||||
"sgd_minibatch_size": 50,
|
||||
# Number of SGD epochs per optimization round.
|
||||
"num_sgd_iter": 10,
|
||||
# Download weights between each training step. This adds a bit of
|
||||
# overhead but allows the user to access the weights from the trainer.
|
||||
"keep_local_weights_in_sync": True,
|
||||
|
||||
# *** WARNING: configs below are DDPPO overrides over PPO; you
|
||||
# shouldn't need to adjust them. ***
|
||||
"framework": "torch", # DDPPO requires PyTorch distributed.
|
||||
"num_gpus": 0, # Learning is no longer done on the driver process, so
|
||||
# giving GPUs to the driver does not make sense!
|
||||
"num_gpus_per_worker": 1, # Each rollout worker gets a GPU.
|
||||
"truncate_episodes": True, # Require evenly sized batches. Otherwise,
|
||||
# collective allreduce could fail.
|
||||
"train_batch_size": -1, # This is auto set based on sample batch size.
|
||||
})
|
||||
# *** WARNING: configs below are DDPPO overrides over PPO; you
|
||||
# shouldn't need to adjust them. ***
|
||||
# DDPPO requires PyTorch distributed.
|
||||
"framework": "torch",
|
||||
# Learning is no longer done on the driver process, so
|
||||
# giving GPUs to the driver does not make sense!
|
||||
"num_gpus": 0,
|
||||
# Each rollout worker gets a GPU.
|
||||
"num_gpus_per_worker": 1,
|
||||
# Require evenly sized batches. Otherwise,
|
||||
# collective allreduce could fail.
|
||||
"truncate_episodes": True,
|
||||
# This is auto set based on sample batch size.
|
||||
"train_batch_size": -1,
|
||||
},
|
||||
_allow_unknown_configs=True,
|
||||
)
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ import os
|
|||
import pickle
|
||||
import time
|
||||
import tempfile
|
||||
from typing import Callable, List, Dict, Union
|
||||
from typing import Callable, Dict, List, Optional, Type, Union
|
||||
|
||||
import ray
|
||||
from ray.exceptions import RayError
|
||||
|
@ -390,19 +390,18 @@ COMMON_CONFIG: TrainerConfigDict = {
|
|||
@DeveloperAPI
|
||||
def with_common_config(
|
||||
extra_config: PartialTrainerConfigDict) -> TrainerConfigDict:
|
||||
"""Returns the given config dict merged with common agent confs."""
|
||||
"""Returns the given config dict merged with common agent confs.
|
||||
|
||||
return with_base_config(COMMON_CONFIG, extra_config)
|
||||
Args:
|
||||
extra_config (PartialTrainerConfigDict): A user defined partial config
|
||||
which will get merged with COMMON_CONFIG and returned.
|
||||
|
||||
|
||||
def with_base_config(
|
||||
base_config: TrainerConfigDict,
|
||||
extra_config: PartialTrainerConfigDict) -> TrainerConfigDict:
|
||||
"""Returns the given config dict merged with a base agent conf."""
|
||||
|
||||
config = copy.deepcopy(base_config)
|
||||
config.update(extra_config)
|
||||
return config
|
||||
Returns:
|
||||
TrainerConfigDict: The merged config dict resulting of COMMON_CONFIG
|
||||
plus `extra_config`.
|
||||
"""
|
||||
return Trainer.merge_trainer_configs(
|
||||
COMMON_CONFIG, extra_config, _allow_unknown_configs=True)
|
||||
|
||||
|
||||
@PublicAPI
|
||||
|
@ -664,7 +663,7 @@ class Trainer(Trainable):
|
|||
|
||||
self.evaluation_workers = self._make_workers(
|
||||
self.env_creator,
|
||||
self._policy,
|
||||
self._policy_class,
|
||||
merge_dicts(self.config, extra_config),
|
||||
num_workers=self.config["evaluation_num_workers"])
|
||||
self.evaluation_metrics = {}
|
||||
|
@ -691,7 +690,7 @@ class Trainer(Trainable):
|
|||
|
||||
@DeveloperAPI
|
||||
def _make_workers(self, env_creator: Callable[[EnvContext], EnvType],
|
||||
policy: type, config: TrainerConfigDict,
|
||||
policy_class: Type[Policy], config: TrainerConfigDict,
|
||||
num_workers: int) -> WorkerSet:
|
||||
"""Default factory method for a WorkerSet running under this Trainer.
|
||||
|
||||
|
@ -701,9 +700,9 @@ class Trainer(Trainable):
|
|||
Args:
|
||||
env_creator (callable): A function that return and Env given an env
|
||||
config.
|
||||
policy (class): The Policy class to use for creating the policies
|
||||
of the workers.
|
||||
config (dict): The Trainer's config.
|
||||
policy (Type[Policy]): The Policy class to use for creating the
|
||||
policies of the workers.
|
||||
config (TrainerConfigDict): The Trainer's config.
|
||||
num_workers (int): Number of remote rollout workers to create.
|
||||
0 for local only.
|
||||
|
||||
|
@ -711,9 +710,9 @@ class Trainer(Trainable):
|
|||
WorkerSet: The created WorkerSet.
|
||||
"""
|
||||
return WorkerSet(
|
||||
env_creator,
|
||||
policy,
|
||||
config,
|
||||
env_creator=env_creator,
|
||||
policy_class=policy_class,
|
||||
trainer_config=config,
|
||||
num_workers=num_workers,
|
||||
logdir=self.logdir)
|
||||
|
||||
|
@ -1044,8 +1043,11 @@ class Trainer(Trainable):
|
|||
"The config of this agent is: {}".format(config))
|
||||
|
||||
@classmethod
|
||||
def merge_trainer_configs(cls, config1: TrainerConfigDict,
|
||||
config2: PartialTrainerConfigDict) -> dict:
|
||||
def merge_trainer_configs(cls,
|
||||
config1: TrainerConfigDict,
|
||||
config2: PartialTrainerConfigDict,
|
||||
_allow_unknown_configs: Optional[bool] = None
|
||||
) -> TrainerConfigDict:
|
||||
config1 = copy.deepcopy(config1)
|
||||
# Error if trainer default has deprecated value.
|
||||
if config1["sample_batch_size"] != DEPRECATED_VALUE:
|
||||
|
@ -1067,7 +1069,9 @@ class Trainer(Trainable):
|
|||
legacy_callbacks_dict=legacy_callbacks_dict)
|
||||
|
||||
config2["callbacks"] = make_callbacks
|
||||
return deep_update(config1, config2, cls._allow_unknown_configs,
|
||||
if _allow_unknown_configs is None:
|
||||
_allow_unknown_configs = cls._allow_unknown_configs
|
||||
return deep_update(config1, config2, _allow_unknown_configs,
|
||||
cls._allow_unknown_subkeys,
|
||||
cls._override_all_subkeys_if_type_changes)
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import logging
|
||||
from typing import Callable, Optional, List, Iterable
|
||||
from typing import Callable, Iterable, List, Optional, Type
|
||||
|
||||
from ray.rllib.agents.trainer import Trainer, COMMON_CONFIG
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
|
@ -9,7 +9,8 @@ from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
|||
from ray.rllib.policy import Policy
|
||||
from ray.rllib.utils import add_mixins
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.typing import TrainerConfigDict, ResultDict
|
||||
from ray.rllib.utils.typing import EnvConfigDict, EnvType, ResultDict, \
|
||||
TrainerConfigDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -33,17 +34,19 @@ def default_execution_plan(workers: WorkerSet, config: TrainerConfigDict):
|
|||
@DeveloperAPI
|
||||
def build_trainer(
|
||||
name: str,
|
||||
default_policy: Optional[Policy],
|
||||
*,
|
||||
default_config: TrainerConfigDict = None,
|
||||
validate_config: Callable[[TrainerConfigDict], None] = None,
|
||||
get_policy_class: Callable[[TrainerConfigDict], Policy] = None,
|
||||
before_init: Callable[[Trainer], None] = None,
|
||||
after_init: Callable[[Trainer], None] = None,
|
||||
before_evaluate_fn: Callable[[Trainer], None] = None,
|
||||
mixins: List[type] = None,
|
||||
execution_plan: Callable[[WorkerSet, TrainerConfigDict], Iterable[
|
||||
ResultDict]] = default_execution_plan):
|
||||
default_policy: Optional[Type[Policy]] = None,
|
||||
get_policy_class: Optional[Callable[[TrainerConfigDict], Optional[Type[
|
||||
Policy]]]] = None,
|
||||
before_init: Optional[Callable[[Trainer], None]] = None,
|
||||
after_init: Optional[Callable[[Trainer], None]] = None,
|
||||
before_evaluate_fn: Optional[Callable[[Trainer], None]] = None,
|
||||
mixins: Optional[List[type]] = None,
|
||||
execution_plan: Optional[Callable[[
|
||||
WorkerSet, TrainerConfigDict
|
||||
], Iterable[ResultDict]]] = default_execution_plan):
|
||||
"""Helper function for defining a custom trainer.
|
||||
|
||||
Functions will be run in this order to initialize the trainer:
|
||||
|
@ -51,22 +54,30 @@ def build_trainer(
|
|||
2. Worker setup: before_init, execution_plan
|
||||
3. Post setup: after_init
|
||||
|
||||
Arguments:
|
||||
Args:
|
||||
name (str): name of the trainer (e.g., "PPO")
|
||||
default_policy (cls): the default Policy class to use
|
||||
default_config (dict): The default config dict of the algorithm,
|
||||
otherwise uses the Trainer default config.
|
||||
default_config (TrainerConfigDict): The default config dict
|
||||
of the algorithm, otherwise uses the Trainer default config.
|
||||
validate_config (Optional[callable]): Optional callable that takes the
|
||||
config to check for correctness. It may mutate the config as
|
||||
needed.
|
||||
get_policy_class (Optional[callable]): Optional callable that takes a
|
||||
config and returns the policy class to override the default with.
|
||||
before_init (Optional[callable]): Optional callable to run at the start
|
||||
of trainer init that takes the trainer instance as argument.
|
||||
after_init (Optional[callable]): Optional callable to run at the end of
|
||||
trainer init that takes the trainer instance as argument.
|
||||
before_evaluate_fn (Optional[callable]): callback to run before
|
||||
evaluation. This takes the trainer instance as argument.
|
||||
default_policy (Optional[Type[Policy]]): The default Policy class to
|
||||
use.
|
||||
get_policy_class (Optional[Callable[
|
||||
TrainerConfigDict, Optional[Type[Policy]]]]): Optional callable
|
||||
that takes a config and returns the policy class or None. If None
|
||||
is returned, will use `default_policy` (which must be provided
|
||||
then).
|
||||
before_init (Optional[Callable[[Trainer], None]]): Optional callable to
|
||||
run before anything is constructed inside Trainer (Workers with
|
||||
Policies, execution plan, etc..). Takes the Trainer instance as
|
||||
argument.
|
||||
after_init (Optional[Callable[[Trainer], None]]): Optional callable to
|
||||
run at the end of trainer init (after all Workers and the exec.
|
||||
plan have been constructed). Takes the Trainer instance as
|
||||
argument.
|
||||
before_evaluate_fn (Optional[Callable[[Trainer], None]]): Callback to
|
||||
run before evaluation. This takes the trainer instance as argument.
|
||||
mixins (list): list of any class mixins for the returned trainer class.
|
||||
These mixins will be applied in order and will have higher
|
||||
precedence than the Trainer class.
|
||||
|
@ -82,26 +93,37 @@ def build_trainer(
|
|||
class trainer_cls(base):
|
||||
_name = name
|
||||
_default_config = default_config or COMMON_CONFIG
|
||||
_policy = default_policy
|
||||
_policy_class = default_policy
|
||||
|
||||
def __init__(self, config=None, env=None, logger_creator=None):
|
||||
Trainer.__init__(self, config, env, logger_creator)
|
||||
|
||||
def _init(self, config, env_creator):
|
||||
def _init(self, config: TrainerConfigDict,
|
||||
env_creator: Callable[[EnvConfigDict], EnvType]):
|
||||
# Validate config via custom validation function.
|
||||
if validate_config:
|
||||
validate_config(config)
|
||||
|
||||
if get_policy_class is None:
|
||||
self._policy = default_policy
|
||||
if not config["multiagent"]["policies"]:
|
||||
assert default_policy is not None
|
||||
self._policy_class = default_policy
|
||||
else:
|
||||
self._policy = get_policy_class(config)
|
||||
self._policy_class = get_policy_class(config)
|
||||
if self._policy_class is None:
|
||||
assert default_policy is not None
|
||||
self._policy_class = default_policy
|
||||
|
||||
if before_init:
|
||||
before_init(self)
|
||||
|
||||
# Creating all workers (excluding evaluation workers).
|
||||
self.workers = self._make_workers(
|
||||
env_creator, self._policy, config, self.config["num_workers"])
|
||||
self.workers = self._make_workers(env_creator, self._policy_class,
|
||||
config,
|
||||
self.config["num_workers"])
|
||||
self.execution_plan = execution_plan
|
||||
self.train_exec_impl = execution_plan(self.workers, config)
|
||||
|
||||
if after_init:
|
||||
after_init(self)
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ from ray.rllib.policy.sample_batch import MultiAgentBatch
|
|||
from ray.rllib.utils import force_list
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.debug import summarize
|
||||
from ray.rllib.utils.types import AgentID, EnvID, EpisodeID, PolicyID, \
|
||||
from ray.rllib.utils.typing import AgentID, EnvID, EpisodeID, PolicyID, \
|
||||
TensorType
|
||||
from ray.util.debug import log_once
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ from ray.rllib.evaluation.episode import MultiAgentEpisode
|
|||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.view_requirement import ViewRequirement
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.types import AgentID, EnvID, EpisodeID, TensorType
|
||||
from ray.rllib.utils.typing import AgentID, EnvID, EpisodeID, TensorType
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
torch, _ = try_import_torch()
|
||||
|
|
|
@ -23,16 +23,16 @@ def compute_advantages(rollout: SampleBatch,
|
|||
use_gae: bool = True,
|
||||
use_critic: bool = True):
|
||||
"""
|
||||
Given a rollout, compute its value targets and the advantage.
|
||||
Given a rollout, compute its value targets and the advantages.
|
||||
|
||||
Args:
|
||||
rollout (SampleBatch): SampleBatch of a single trajectory
|
||||
last_r (float): Value estimation for last observation
|
||||
rollout (SampleBatch): SampleBatch of a single trajectory.
|
||||
last_r (float): Value estimation for last observation.
|
||||
gamma (float): Discount factor.
|
||||
lambda_ (float): Parameter for GAE
|
||||
use_gae (bool): Using Generalized Advantage Estimation
|
||||
lambda_ (float): Parameter for GAE.
|
||||
use_gae (bool): Using Generalized Advantage Estimation.
|
||||
use_critic (bool): Whether to use critic (value estimates). Setting
|
||||
this to False will use 0 as baseline.
|
||||
this to False will use 0 as baseline.
|
||||
|
||||
Returns:
|
||||
SampleBatch (SampleBatch): Object with experience from rollout and
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import logging
|
||||
from types import FunctionType
|
||||
from typing import TypeVar, Callable, List, Union
|
||||
from typing import Callable, List, Optional, Type, TypeVar, Union
|
||||
|
||||
import ray
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
|
@ -30,21 +30,23 @@ class WorkerSet:
|
|||
"""
|
||||
|
||||
def __init__(self,
|
||||
env_creator: Callable[[EnvContext], EnvType],
|
||||
policy: type,
|
||||
trainer_config: TrainerConfigDict = None,
|
||||
*,
|
||||
env_creator: Optional[Callable[[EnvContext], EnvType]] = None,
|
||||
policy_class: Optional[Type[Policy]] = None,
|
||||
trainer_config: Optional[TrainerConfigDict] = None,
|
||||
num_workers: int = 0,
|
||||
logdir: str = None,
|
||||
logdir: Optional[str] = None,
|
||||
_setup: bool = True):
|
||||
"""Create a new WorkerSet and initialize its workers.
|
||||
|
||||
Arguments:
|
||||
env_creator (func): Function that returns env given env config.
|
||||
policy (cls): rllib.policy.Policy class.
|
||||
trainer_config (dict): Optional dict that extends the common
|
||||
config of the Trainer class.
|
||||
env_creator (Optional[Callable[[EnvContext], EnvType]]): Function
|
||||
that returns env given env config.
|
||||
policy (Optional[Type[Policy]]): A rllib.policy.Policy class.
|
||||
trainer_config (Optional[TrainerConfigDict]): Optional dict that
|
||||
extends the common config of the Trainer class.
|
||||
num_workers (int): Number of remote rollout workers to create.
|
||||
logdir (str): Optional logging directory for workers.
|
||||
logdir (Optional[str]): Optional logging directory for workers.
|
||||
_setup (bool): Whether to setup workers. This is only for testing.
|
||||
"""
|
||||
|
||||
|
@ -53,7 +55,7 @@ class WorkerSet:
|
|||
trainer_config = COMMON_CONFIG
|
||||
|
||||
self._env_creator = env_creator
|
||||
self._policy = policy
|
||||
self._policy_class = policy_class
|
||||
self._remote_config = trainer_config
|
||||
self._logdir = logdir
|
||||
|
||||
|
@ -63,8 +65,9 @@ class WorkerSet:
|
|||
{"tf_session_args": trainer_config["local_tf_session_args"]})
|
||||
|
||||
# Always create a local worker
|
||||
self._local_worker = self._make_worker(
|
||||
RolloutWorker, env_creator, policy, 0, self._local_config)
|
||||
self._local_worker = self._make_worker(RolloutWorker, env_creator,
|
||||
self._policy_class, 0,
|
||||
self._local_config)
|
||||
|
||||
# Create a number of remote workers
|
||||
self._remote_workers = []
|
||||
|
@ -102,8 +105,9 @@ class WorkerSet:
|
|||
}
|
||||
cls = RolloutWorker.as_remote(**remote_args).remote
|
||||
self._remote_workers.extend([
|
||||
self._make_worker(cls, self._env_creator, self._policy, i + 1,
|
||||
self._remote_config) for i in range(num_workers)
|
||||
self._make_worker(cls, self._env_creator, self._policy_class,
|
||||
i + 1, self._remote_config)
|
||||
for i in range(num_workers)
|
||||
])
|
||||
|
||||
def reset(self, new_remote_workers: List["ActorHandle"]) -> None:
|
||||
|
@ -190,14 +194,18 @@ class WorkerSet:
|
|||
@staticmethod
|
||||
def _from_existing(local_worker: RolloutWorker,
|
||||
remote_workers: List["ActorHandle"] = None):
|
||||
workers = WorkerSet(None, None, {}, _setup=False)
|
||||
workers = WorkerSet(
|
||||
env_creator=None,
|
||||
policy_class=None,
|
||||
trainer_config={},
|
||||
_setup=False)
|
||||
workers._local_worker = local_worker
|
||||
workers._remote_workers = remote_workers or []
|
||||
return workers
|
||||
|
||||
def _make_worker(
|
||||
self, cls: Callable, env_creator: Callable[[EnvContext], EnvType],
|
||||
policy: Policy, worker_index: int,
|
||||
policy: Type[Policy], worker_index: int,
|
||||
config: TrainerConfigDict) -> Union[RolloutWorker, "ActorHandle"]:
|
||||
def session_creator():
|
||||
logger.debug("Creating TF session {}".format(
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import gym
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
from ray.rllib.policy.dynamic_tf_policy import DynamicTFPolicy
|
||||
from ray.rllib.policy import eager_tf_policy
|
||||
|
@ -17,12 +18,15 @@ from ray.rllib.utils.typing import ModelGradients, TensorType, \
|
|||
def build_tf_policy(
|
||||
name: str,
|
||||
*,
|
||||
loss_fn: Callable[[Policy, ModelV2, type, SampleBatch], TensorType],
|
||||
loss_fn: Callable[[
|
||||
Policy, ModelV2, Type[TFActionDistribution], SampleBatch
|
||||
], Union[TensorType, List[TensorType]]],
|
||||
get_default_config: Optional[Callable[[None],
|
||||
TrainerConfigDict]] = None,
|
||||
postprocess_fn: Optional[Callable[[
|
||||
Policy, SampleBatch, List[SampleBatch], "MultiAgentEpisode"
|
||||
], None]] = None,
|
||||
Policy, SampleBatch, Optional[List[SampleBatch]], Optional[
|
||||
"MultiAgentEpisode"]
|
||||
], SampleBatch]] = None,
|
||||
stats_fn: Optional[Callable[[Policy, SampleBatch], Dict[
|
||||
str, TensorType]]] = None,
|
||||
optimizer_fn: Optional[Callable[[
|
||||
|
@ -81,8 +85,10 @@ def build_tf_policy(
|
|||
|
||||
Args:
|
||||
name (str): Name of the policy (e.g., "PPOTFPolicy").
|
||||
loss_fn (Callable[[Policy, ModelV2, type, SampleBatch], TensorType]):
|
||||
Callable for calculating a loss tensor.
|
||||
loss_fn (Callable[[
|
||||
Policy, ModelV2, Type[TFActionDistribution], SampleBatch],
|
||||
Union[TensorType, List[TensorType]]]): Callable for calculating a
|
||||
loss tensor.
|
||||
get_default_config (Optional[Callable[[None], TrainerConfigDict]]):
|
||||
Optional callable that returns the default config to merge with any
|
||||
overrides. If None, uses only(!) the user-provided
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import gym
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
|
@ -22,7 +22,9 @@ torch, _ = try_import_torch()
|
|||
def build_torch_policy(
|
||||
name: str,
|
||||
*,
|
||||
loss_fn: Callable[[Policy, ModelV2, type, SampleBatch], TensorType],
|
||||
loss_fn: Callable[[
|
||||
Policy, ModelV2, Type[TorchDistributionWrapper], SampleBatch
|
||||
], Union[TensorType, List[TensorType]]],
|
||||
get_default_config: Optional[Callable[[], TrainerConfigDict]] = None,
|
||||
stats_fn: Optional[Callable[[Policy, SampleBatch], Dict[
|
||||
str, TensorType]]] = None,
|
||||
|
@ -80,8 +82,9 @@ def build_torch_policy(
|
|||
super's `postprocess_trajectory` method).
|
||||
stats_fn (Optional[Callable[[Policy, SampleBatch],
|
||||
Dict[str, TensorType]]]): Optional callable that returns a dict of
|
||||
values given the policy and batch input tensors. If None,
|
||||
will use `TorchPolicy.extra_grad_info()` instead.
|
||||
values given the policy and training batch. If None,
|
||||
will use `TorchPolicy.extra_grad_info()` instead. The stats dict is
|
||||
used for logging (e.g. in TensorBoard).
|
||||
extra_action_out_fn (Optional[Callable[[Policy, Dict[str, TensorType,
|
||||
List[TensorType], ModelV2, TorchDistributionWrapper]], Dict[str,
|
||||
TensorType]]]): Optional callable that returns a dict of extra
|
||||
|
|
Loading…
Add table
Reference in a new issue