[RLlib] First attempt at cleaning up algo code in RLlib: PG. (#10115)

This commit is contained in:
Sven Mika 2020-08-20 17:05:57 +02:00 committed by GitHub
parent 538cb802d5
commit d14b501692
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 396 additions and 196 deletions

View file

@ -35,7 +35,7 @@ Algorithm Frameworks Discrete Actions Continuous Actions Multi-
.. _`+LSTM auto-wrapping`: rllib-models.html#built-in-models
.. _`+parametric`: rllib-models.html#variable-length-parametric-action-spaces
.. _`+RNN`: rllib-models.html#recurrent-models
.. _`+Transformer`: rllib-models.html#attention-networks
.. _`+Transformer`: rllib-models.html#attention-networks-transformers
.. _`A2C, A3C`: rllib-algorithms.html#a3c
.. _`APEX-DQN`: rllib-algorithms.html#apex
.. _`APEX-DDPG`: rllib-algorithms.html#apex
@ -304,16 +304,22 @@ SpaceInvaders 650 1001 1025
Policy Gradients
----------------
|pytorch| |tensorflow|
`[paper] <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/agents/pg/pg.py>`__ We include a vanilla policy gradients implementation as an example algorithm.
|pytorch| |tensorflow| An `implementation <https://github.com/ray-project/ray/blob/master/rllib/agents/pg/pg.py>`__ of a vanilla policy gradient algorithm for TensorFlow and PyTorch.
**Papers**:
`[1] - Policy Gradient Methods for Reinforcement Learning with Function Approximation. <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`__
and
`[2] - Simple Statistical Gradient-Following Algorithms for Connectionist Reinforcement Learning. <http://www-anw.cs.umass.edu/~barto/courses/cs687/williams92simple.pdf>`__
.. figure:: a2c-arch.svg
Policy gradients architecture (same as A2C)
Tuned examples: `CartPole-v0 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/pg/cartpole-pg.yaml>`__
**Tuned examples**: `CartPole-v0 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/pg/cartpole-pg.yaml>`__
**PG-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
**PG-specific configs**: The following updates will overwrite/be added to the
(base) Trainer config in `rllib/agents/trainer.py <rllib-training.html#common-parameters>`__ (*COMMON_CONFIG* dict):
.. literalinclude:: ../../rllib/agents/pg/pg.py
:language: python

View file

@ -10,7 +10,7 @@ from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.tf_policy import TFPolicy
from ray.rllib.policy.torch_policy import TorchPolicy
from ray.tune.registry import register_trainable
@ -60,6 +60,7 @@ _register_all()
__all__ = [
"Policy",
"TFPolicy",
"TorchPolicy",
"RolloutWorker",
"SampleBatch",
"BaseEnv",

View file

@ -0,0 +1,8 @@
Policy Gradient (PG)
====================
An implementation of a vanilla policy gradient algorithm for TensorFlow and PyTorch.
**[Detailed Documentation](https://docs.ray.io/en/latest/rllib-algorithms.html#pg)**
**[Implementation](https://github.com/ray-project/ray/blob/master/rllib/agents/pg/pg.py)**

View file

@ -1,29 +1,57 @@
"""
Policy Gradient (PG)
====================
This file defines the distributed Trainer class for policy gradients.
See `pg_[tf|torch]_policy.py` for the definition of the policy loss.
Detailed documentation: https://docs.ray.io/en/latest/rllib-algorithms.html#pg
"""
from typing import Optional, Type
from ray.rllib.agents.trainer import with_common_config
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.agents.pg.pg_tf_policy import PGTFPolicy
from ray.rllib.agents.pg.pg_torch_policy import PGTorchPolicy
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.typing import TrainerConfigDict
# yapf: disable
# __sphinx_doc_begin__
# Adds the following updates to the (base) `Trainer` config in
# rllib/agents/trainer.py (`COMMON_CONFIG` dict).
DEFAULT_CONFIG = with_common_config({
# No remote workers by default.
"num_workers": 0,
# Learning rate.
"lr": 0.0004,
})
# __sphinx_doc_end__
# yapf: enable
def get_policy_class(config):
def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
"""Policy class picker function. Class is chosen based on DL-framework.
Args:
config (TrainerConfigDict): The trainer's configuration dict.
Returns:
Optional[Type[Policy]]: The Policy class to use with PGTrainer.
If None, use `default_policy` provided in build_trainer().
"""
if config["framework"] == "torch":
from ray.rllib.agents.pg.pg_torch_policy import PGTorchPolicy
return PGTorchPolicy
else:
return PGTFPolicy
# Build a child class of `Trainer`, which uses the framework specific Policy
# determined in `get_policy_class()` above.
PGTrainer = build_trainer(
name="PG",
default_config=DEFAULT_CONFIG,
default_policy=PGTFPolicy,
get_policy_class=get_policy_class)
get_policy_class=get_policy_class,
)

View file

@ -1,35 +1,54 @@
"""
TensorFlow policy class used for PG.
"""
from typing import List, Type, Union
import ray
from ray.rllib.evaluation.postprocessing import Postprocessing, \
compute_advantages
from ray.rllib.agents.pg.utils import post_process_advantages
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.policy import Policy
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.typing import TensorType
tf1, tf, tfv = try_import_tf()
def post_process_advantages(policy,
sample_batch,
other_agent_batches=None,
episode=None):
"""This adds the "advantages" column to the sample train_batch."""
return compute_advantages(
sample_batch,
0.0,
policy.config["gamma"],
use_gae=False,
use_critic=False)
def pg_tf_loss(
policy: Policy, model: ModelV2, dist_class: Type[ActionDistribution],
train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]:
"""The basic policy gradients loss function.
Args:
policy (Policy): The Policy to calculate the loss for.
model (ModelV2): The Model to calculate the loss for.
dist_class (Type[ActionDistribution]: The action distr. class.
train_batch (SampleBatch): The training data.
def pg_tf_loss(policy, model, dist_class, train_batch):
"""The basic policy gradients loss."""
logits, _ = model.from_batch(train_batch)
action_dist = dist_class(logits, model)
Returns:
Union[TensorType, List[TensorType]]: A single loss tensor or a list
of loss tensors.
"""
# Pass the training data through our model to get distribution parameters.
dist_inputs, _ = model.from_batch(train_batch)
# Create an action distribution object.
action_dist = dist_class(dist_inputs, model)
# Calculate the vanilla PG loss based on:
# L = -E[ log(pi(a|s)) * A]
return -tf.reduce_mean(
action_dist.logp(train_batch[SampleBatch.ACTIONS]) * tf.cast(
train_batch[Postprocessing.ADVANTAGES], dtype=tf.float32))
# Build a child class of `TFPolicy`, given the extra options:
# - trajectory post-processing function (to calculate advantages)
# - PG loss function
PGTFPolicy = build_tf_policy(
name="PGTFPolicy",
get_default_config=lambda: ray.rllib.agents.pg.pg.DEFAULT_CONFIG,

View file

@ -1,31 +1,77 @@
"""
PyTorch policy class used for PG.
"""
from typing import Dict, List, Type, Union
import ray
from ray.rllib.agents.pg.pg_tf_policy import post_process_advantages
from ray.rllib.agents.pg.utils import post_process_advantages
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import TensorType
torch, _ = try_import_torch()
def pg_torch_loss(policy, model, dist_class, train_batch):
"""The basic policy gradients loss."""
logits, _ = model.from_batch(train_batch)
action_dist = dist_class(logits, model)
def pg_torch_loss(
policy: Policy, model: ModelV2,
dist_class: Type[TorchDistributionWrapper],
train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]:
"""The basic policy gradients loss function.
Args:
policy (Policy): The Policy to calculate the loss for.
model (ModelV2): The Model to calculate the loss for.
dist_class (Type[ActionDistribution]: The action distr. class.
train_batch (SampleBatch): The training data.
Returns:
Union[TensorType, List[TensorType]]: A single loss tensor or a list
of loss tensors.
"""
# Pass the training data through our model to get distribution parameters.
dist_inputs, _ = model.from_batch(train_batch)
# Create an action distribution object.
action_dist = dist_class(dist_inputs, model)
# Calculate the vanilla PG loss based on:
# L = -E[ log(pi(a|s)) * A]
log_probs = action_dist.logp(train_batch[SampleBatch.ACTIONS])
# Save the error in the policy object.
# policy.pi_err = -train_batch[Postprocessing.ADVANTAGES].dot(
# log_probs.reshape(-1)) / len(log_probs)
# Save the loss in the policy object for the stats_fn below.
policy.pi_err = -torch.mean(
log_probs * train_batch[Postprocessing.ADVANTAGES])
return policy.pi_err
def pg_loss_stats(policy, train_batch):
""" The error is recorded when computing the loss."""
return {"policy_loss": policy.pi_err.item()}
def pg_loss_stats(policy: Policy,
train_batch: SampleBatch) -> Dict[str, TensorType]:
"""Returns the calculated loss in a stats dict.
Args:
policy (Policy): The Policy object.
train_batch (SampleBatch): The data used for training.
Returns:
Dict[str, TensorType]: The stats dict.
"""
return {
# `pi_err` (the loss) is stored inside `pg_torch_loss()`.
"policy_loss": policy.pi_err.item(),
}
# Build a child class of `TFPolicy`, given the extra options:
# - trajectory post-processing function (to calculate advantages)
# - PG loss function
PGTorchPolicy = build_torch_policy(
name="PGTorchPolicy",
get_default_config=lambda: ray.rllib.agents.pg.pg.DEFAULT_CONFIG,

36
rllib/agents/pg/utils.py Normal file
View file

@ -0,0 +1,36 @@
from typing import List, Optional
from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.rllib.evaluation.postprocessing import compute_advantages
from ray.rllib.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
def post_process_advantages(
policy: Policy,
sample_batch: SampleBatch,
other_agent_batches: Optional[List[SampleBatch]] = None,
episode: Optional[MultiAgentEpisode] = None) -> SampleBatch:
"""Adds the "advantages" column to `sample_batch`.
Args:
policy (Policy): The Policy object to do post-processing for.
sample_batch (SampleBatch): The actual sample batch to post-process.
other_agent_batches (Optional[List[SampleBatch]]): Optional list of
other agents' SampleBatch objects.
episode (MultiAgentEpisode): The multi-agent episode object, from which
`sample_batch` was generated.
Returns:
SampleBatch: The SampleBatch enhanced by the added ADVANTAGES field.
"""
# Calculates advantage values based on the rewards in the sample batch.
# The value of the last observation is assumed to be 0.0 (no value function
# estimation at the end of the sampled chunk).
return compute_advantages(
rollout=sample_batch,
last_r=0.0,
gamma=policy.config["gamma"],
use_gae=False,
use_critic=False)

View file

@ -1,62 +1,65 @@
from ray.rllib.agents.impala.impala import validate_config
from ray.rllib.agents.ppo.appo_tf_policy import AsyncPPOTFPolicy
from ray.rllib.agents.ppo.ppo import UpdateKL
from ray.rllib.agents.trainer import with_base_config
from ray.rllib.agents import impala
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \
LAST_TARGET_UPDATE_TS, NUM_TARGET_UPDATES, _get_shared_metrics
# yapf: disable
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_base_config(impala.DEFAULT_CONFIG, {
# Whether to use V-trace weighted advantages. If false, PPO GAE advantages
# will be used instead.
"vtrace": False,
DEFAULT_CONFIG = impala.ImpalaTrainer.merge_trainer_configs(
impala.DEFAULT_CONFIG, # See keys in impala.py, which are also supported.
{
# Whether to use V-trace weighted advantages. If false, PPO GAE
# advantages will be used instead.
"vtrace": False,
# == These two options only apply if vtrace: False ==
# Should use a critic as a baseline (otherwise don't use value baseline;
# required for using GAE).
"use_critic": True,
# If true, use the Generalized Advantage Estimator (GAE)
# with a value function, see https://arxiv.org/pdf/1506.02438.pdf.
"use_gae": True,
# GAE(lambda) parameter
"lambda": 1.0,
# == These two options only apply if vtrace: False ==
# Should use a critic as a baseline (otherwise don't use value
# baseline; required for using GAE).
"use_critic": True,
# If true, use the Generalized Advantage Estimator (GAE)
# with a value function, see https://arxiv.org/pdf/1506.02438.pdf.
"use_gae": True,
# GAE(lambda) parameter
"lambda": 1.0,
# == PPO surrogate loss options ==
"clip_param": 0.4,
# == PPO surrogate loss options ==
"clip_param": 0.4,
# == PPO KL Loss options ==
"use_kl_loss": False,
"kl_coeff": 1.0,
"kl_target": 0.01,
# == PPO KL Loss options ==
"use_kl_loss": False,
"kl_coeff": 1.0,
"kl_target": 0.01,
# == IMPALA optimizer params (see documentation in impala.py) ==
"rollout_fragment_length": 50,
"train_batch_size": 500,
"min_iter_time_s": 10,
"num_workers": 2,
"num_gpus": 0,
"num_data_loader_buffers": 1,
"minibatch_buffer_size": 1,
"num_sgd_iter": 1,
"replay_proportion": 0.0,
"replay_buffer_num_slots": 100,
"learner_queue_size": 16,
"learner_queue_timeout": 300,
"max_sample_requests_in_flight_per_worker": 2,
"broadcast_interval": 1,
"grad_clip": 40.0,
"opt_type": "adam",
"lr": 0.0005,
"lr_schedule": None,
"decay": 0.99,
"momentum": 0.0,
"epsilon": 0.1,
"vf_loss_coeff": 0.5,
"entropy_coeff": 0.01,
"entropy_coeff_schedule": None,
})
# == IMPALA optimizer params (see documentation in impala.py) ==
"rollout_fragment_length": 50,
"train_batch_size": 500,
"min_iter_time_s": 10,
"num_workers": 2,
"num_gpus": 0,
"num_data_loader_buffers": 1,
"minibatch_buffer_size": 1,
"num_sgd_iter": 1,
"replay_proportion": 0.0,
"replay_buffer_num_slots": 100,
"learner_queue_size": 16,
"learner_queue_timeout": 300,
"max_sample_requests_in_flight_per_worker": 2,
"broadcast_interval": 1,
"grad_clip": 40.0,
"opt_type": "adam",
"lr": 0.0005,
"lr_schedule": None,
"decay": 0.99,
"momentum": 0.0,
"epsilon": 0.1,
"vf_loss_coeff": 0.5,
"entropy_coeff": 0.01,
"entropy_coeff_schedule": None,
},
_allow_unknown_configs=True,
)
# __sphinx_doc_end__
# yapf: enable

View file

@ -19,7 +19,6 @@ import time
import ray
from ray.rllib.agents.ppo import ppo
from ray.rllib.agents.trainer import with_base_config
from ray.rllib.execution.rollout_ops import ParallelRollouts
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \
@ -32,31 +31,42 @@ logger = logging.getLogger(__name__)
# yapf: disable
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_base_config(ppo.DEFAULT_CONFIG, {
# During the sampling phase, each rollout worker will collect a batch
# `rollout_fragment_length * num_envs_per_worker` steps in size.
"rollout_fragment_length": 100,
# Vectorize the env (should enable by default since each worker has a GPU).
"num_envs_per_worker": 5,
# During the SGD phase, workers iterate over minibatches of this size.
# The effective minibatch size will be `sgd_minibatch_size * num_workers`.
"sgd_minibatch_size": 50,
# Number of SGD epochs per optimization round.
"num_sgd_iter": 10,
# Download weights between each training step. This adds a bit of overhead
# but allows the user to access the weights from the trainer.
"keep_local_weights_in_sync": True,
DEFAULT_CONFIG = ppo.PPOTrainer.merge_trainer_configs(
ppo.DEFAULT_CONFIG,
{
# During the sampling phase, each rollout worker will collect a batch
# `rollout_fragment_length * num_envs_per_worker` steps in size.
"rollout_fragment_length": 100,
# Vectorize the env (should enable by default since each worker has
# a GPU).
"num_envs_per_worker": 5,
# During the SGD phase, workers iterate over minibatches of this size.
# The effective minibatch size will be:
# `sgd_minibatch_size * num_workers`.
"sgd_minibatch_size": 50,
# Number of SGD epochs per optimization round.
"num_sgd_iter": 10,
# Download weights between each training step. This adds a bit of
# overhead but allows the user to access the weights from the trainer.
"keep_local_weights_in_sync": True,
# *** WARNING: configs below are DDPPO overrides over PPO; you
# shouldn't need to adjust them. ***
"framework": "torch", # DDPPO requires PyTorch distributed.
"num_gpus": 0, # Learning is no longer done on the driver process, so
# giving GPUs to the driver does not make sense!
"num_gpus_per_worker": 1, # Each rollout worker gets a GPU.
"truncate_episodes": True, # Require evenly sized batches. Otherwise,
# collective allreduce could fail.
"train_batch_size": -1, # This is auto set based on sample batch size.
})
# *** WARNING: configs below are DDPPO overrides over PPO; you
# shouldn't need to adjust them. ***
# DDPPO requires PyTorch distributed.
"framework": "torch",
# Learning is no longer done on the driver process, so
# giving GPUs to the driver does not make sense!
"num_gpus": 0,
# Each rollout worker gets a GPU.
"num_gpus_per_worker": 1,
# Require evenly sized batches. Otherwise,
# collective allreduce could fail.
"truncate_episodes": True,
# This is auto set based on sample batch size.
"train_batch_size": -1,
},
_allow_unknown_configs=True,
)
# __sphinx_doc_end__
# yapf: enable

View file

@ -7,7 +7,7 @@ import os
import pickle
import time
import tempfile
from typing import Callable, List, Dict, Union
from typing import Callable, Dict, List, Optional, Type, Union
import ray
from ray.exceptions import RayError
@ -390,19 +390,18 @@ COMMON_CONFIG: TrainerConfigDict = {
@DeveloperAPI
def with_common_config(
extra_config: PartialTrainerConfigDict) -> TrainerConfigDict:
"""Returns the given config dict merged with common agent confs."""
"""Returns the given config dict merged with common agent confs.
return with_base_config(COMMON_CONFIG, extra_config)
Args:
extra_config (PartialTrainerConfigDict): A user defined partial config
which will get merged with COMMON_CONFIG and returned.
def with_base_config(
base_config: TrainerConfigDict,
extra_config: PartialTrainerConfigDict) -> TrainerConfigDict:
"""Returns the given config dict merged with a base agent conf."""
config = copy.deepcopy(base_config)
config.update(extra_config)
return config
Returns:
TrainerConfigDict: The merged config dict resulting of COMMON_CONFIG
plus `extra_config`.
"""
return Trainer.merge_trainer_configs(
COMMON_CONFIG, extra_config, _allow_unknown_configs=True)
@PublicAPI
@ -664,7 +663,7 @@ class Trainer(Trainable):
self.evaluation_workers = self._make_workers(
self.env_creator,
self._policy,
self._policy_class,
merge_dicts(self.config, extra_config),
num_workers=self.config["evaluation_num_workers"])
self.evaluation_metrics = {}
@ -691,7 +690,7 @@ class Trainer(Trainable):
@DeveloperAPI
def _make_workers(self, env_creator: Callable[[EnvContext], EnvType],
policy: type, config: TrainerConfigDict,
policy_class: Type[Policy], config: TrainerConfigDict,
num_workers: int) -> WorkerSet:
"""Default factory method for a WorkerSet running under this Trainer.
@ -701,9 +700,9 @@ class Trainer(Trainable):
Args:
env_creator (callable): A function that return and Env given an env
config.
policy (class): The Policy class to use for creating the policies
of the workers.
config (dict): The Trainer's config.
policy (Type[Policy]): The Policy class to use for creating the
policies of the workers.
config (TrainerConfigDict): The Trainer's config.
num_workers (int): Number of remote rollout workers to create.
0 for local only.
@ -711,9 +710,9 @@ class Trainer(Trainable):
WorkerSet: The created WorkerSet.
"""
return WorkerSet(
env_creator,
policy,
config,
env_creator=env_creator,
policy_class=policy_class,
trainer_config=config,
num_workers=num_workers,
logdir=self.logdir)
@ -1044,8 +1043,11 @@ class Trainer(Trainable):
"The config of this agent is: {}".format(config))
@classmethod
def merge_trainer_configs(cls, config1: TrainerConfigDict,
config2: PartialTrainerConfigDict) -> dict:
def merge_trainer_configs(cls,
config1: TrainerConfigDict,
config2: PartialTrainerConfigDict,
_allow_unknown_configs: Optional[bool] = None
) -> TrainerConfigDict:
config1 = copy.deepcopy(config1)
# Error if trainer default has deprecated value.
if config1["sample_batch_size"] != DEPRECATED_VALUE:
@ -1067,7 +1069,9 @@ class Trainer(Trainable):
legacy_callbacks_dict=legacy_callbacks_dict)
config2["callbacks"] = make_callbacks
return deep_update(config1, config2, cls._allow_unknown_configs,
if _allow_unknown_configs is None:
_allow_unknown_configs = cls._allow_unknown_configs
return deep_update(config1, config2, _allow_unknown_configs,
cls._allow_unknown_subkeys,
cls._override_all_subkeys_if_type_changes)

View file

@ -1,5 +1,5 @@
import logging
from typing import Callable, Optional, List, Iterable
from typing import Callable, Iterable, List, Optional, Type
from ray.rllib.agents.trainer import Trainer, COMMON_CONFIG
from ray.rllib.evaluation.worker_set import WorkerSet
@ -9,7 +9,8 @@ from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.policy import Policy
from ray.rllib.utils import add_mixins
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.typing import TrainerConfigDict, ResultDict
from ray.rllib.utils.typing import EnvConfigDict, EnvType, ResultDict, \
TrainerConfigDict
logger = logging.getLogger(__name__)
@ -33,17 +34,19 @@ def default_execution_plan(workers: WorkerSet, config: TrainerConfigDict):
@DeveloperAPI
def build_trainer(
name: str,
default_policy: Optional[Policy],
*,
default_config: TrainerConfigDict = None,
validate_config: Callable[[TrainerConfigDict], None] = None,
get_policy_class: Callable[[TrainerConfigDict], Policy] = None,
before_init: Callable[[Trainer], None] = None,
after_init: Callable[[Trainer], None] = None,
before_evaluate_fn: Callable[[Trainer], None] = None,
mixins: List[type] = None,
execution_plan: Callable[[WorkerSet, TrainerConfigDict], Iterable[
ResultDict]] = default_execution_plan):
default_policy: Optional[Type[Policy]] = None,
get_policy_class: Optional[Callable[[TrainerConfigDict], Optional[Type[
Policy]]]] = None,
before_init: Optional[Callable[[Trainer], None]] = None,
after_init: Optional[Callable[[Trainer], None]] = None,
before_evaluate_fn: Optional[Callable[[Trainer], None]] = None,
mixins: Optional[List[type]] = None,
execution_plan: Optional[Callable[[
WorkerSet, TrainerConfigDict
], Iterable[ResultDict]]] = default_execution_plan):
"""Helper function for defining a custom trainer.
Functions will be run in this order to initialize the trainer:
@ -51,22 +54,30 @@ def build_trainer(
2. Worker setup: before_init, execution_plan
3. Post setup: after_init
Arguments:
Args:
name (str): name of the trainer (e.g., "PPO")
default_policy (cls): the default Policy class to use
default_config (dict): The default config dict of the algorithm,
otherwise uses the Trainer default config.
default_config (TrainerConfigDict): The default config dict
of the algorithm, otherwise uses the Trainer default config.
validate_config (Optional[callable]): Optional callable that takes the
config to check for correctness. It may mutate the config as
needed.
get_policy_class (Optional[callable]): Optional callable that takes a
config and returns the policy class to override the default with.
before_init (Optional[callable]): Optional callable to run at the start
of trainer init that takes the trainer instance as argument.
after_init (Optional[callable]): Optional callable to run at the end of
trainer init that takes the trainer instance as argument.
before_evaluate_fn (Optional[callable]): callback to run before
evaluation. This takes the trainer instance as argument.
default_policy (Optional[Type[Policy]]): The default Policy class to
use.
get_policy_class (Optional[Callable[
TrainerConfigDict, Optional[Type[Policy]]]]): Optional callable
that takes a config and returns the policy class or None. If None
is returned, will use `default_policy` (which must be provided
then).
before_init (Optional[Callable[[Trainer], None]]): Optional callable to
run before anything is constructed inside Trainer (Workers with
Policies, execution plan, etc..). Takes the Trainer instance as
argument.
after_init (Optional[Callable[[Trainer], None]]): Optional callable to
run at the end of trainer init (after all Workers and the exec.
plan have been constructed). Takes the Trainer instance as
argument.
before_evaluate_fn (Optional[Callable[[Trainer], None]]): Callback to
run before evaluation. This takes the trainer instance as argument.
mixins (list): list of any class mixins for the returned trainer class.
These mixins will be applied in order and will have higher
precedence than the Trainer class.
@ -82,26 +93,37 @@ def build_trainer(
class trainer_cls(base):
_name = name
_default_config = default_config or COMMON_CONFIG
_policy = default_policy
_policy_class = default_policy
def __init__(self, config=None, env=None, logger_creator=None):
Trainer.__init__(self, config, env, logger_creator)
def _init(self, config, env_creator):
def _init(self, config: TrainerConfigDict,
env_creator: Callable[[EnvConfigDict], EnvType]):
# Validate config via custom validation function.
if validate_config:
validate_config(config)
if get_policy_class is None:
self._policy = default_policy
if not config["multiagent"]["policies"]:
assert default_policy is not None
self._policy_class = default_policy
else:
self._policy = get_policy_class(config)
self._policy_class = get_policy_class(config)
if self._policy_class is None:
assert default_policy is not None
self._policy_class = default_policy
if before_init:
before_init(self)
# Creating all workers (excluding evaluation workers).
self.workers = self._make_workers(
env_creator, self._policy, config, self.config["num_workers"])
self.workers = self._make_workers(env_creator, self._policy_class,
config,
self.config["num_workers"])
self.execution_plan = execution_plan
self.train_exec_impl = execution_plan(self.workers, config)
if after_init:
after_init(self)

View file

@ -12,7 +12,7 @@ from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.utils import force_list
from ray.rllib.utils.annotations import override
from ray.rllib.utils.debug import summarize
from ray.rllib.utils.types import AgentID, EnvID, EpisodeID, PolicyID, \
from ray.rllib.utils.typing import AgentID, EnvID, EpisodeID, PolicyID, \
TensorType
from ray.util.debug import log_once

View file

@ -6,7 +6,7 @@ from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.types import AgentID, EnvID, EpisodeID, TensorType
from ray.rllib.utils.typing import AgentID, EnvID, EpisodeID, TensorType
tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()

View file

@ -23,16 +23,16 @@ def compute_advantages(rollout: SampleBatch,
use_gae: bool = True,
use_critic: bool = True):
"""
Given a rollout, compute its value targets and the advantage.
Given a rollout, compute its value targets and the advantages.
Args:
rollout (SampleBatch): SampleBatch of a single trajectory
last_r (float): Value estimation for last observation
rollout (SampleBatch): SampleBatch of a single trajectory.
last_r (float): Value estimation for last observation.
gamma (float): Discount factor.
lambda_ (float): Parameter for GAE
use_gae (bool): Using Generalized Advantage Estimation
lambda_ (float): Parameter for GAE.
use_gae (bool): Using Generalized Advantage Estimation.
use_critic (bool): Whether to use critic (value estimates). Setting
this to False will use 0 as baseline.
this to False will use 0 as baseline.
Returns:
SampleBatch (SampleBatch): Object with experience from rollout and

View file

@ -1,6 +1,6 @@
import logging
from types import FunctionType
from typing import TypeVar, Callable, List, Union
from typing import Callable, List, Optional, Type, TypeVar, Union
import ray
from ray.rllib.utils.annotations import DeveloperAPI
@ -30,21 +30,23 @@ class WorkerSet:
"""
def __init__(self,
env_creator: Callable[[EnvContext], EnvType],
policy: type,
trainer_config: TrainerConfigDict = None,
*,
env_creator: Optional[Callable[[EnvContext], EnvType]] = None,
policy_class: Optional[Type[Policy]] = None,
trainer_config: Optional[TrainerConfigDict] = None,
num_workers: int = 0,
logdir: str = None,
logdir: Optional[str] = None,
_setup: bool = True):
"""Create a new WorkerSet and initialize its workers.
Arguments:
env_creator (func): Function that returns env given env config.
policy (cls): rllib.policy.Policy class.
trainer_config (dict): Optional dict that extends the common
config of the Trainer class.
env_creator (Optional[Callable[[EnvContext], EnvType]]): Function
that returns env given env config.
policy (Optional[Type[Policy]]): A rllib.policy.Policy class.
trainer_config (Optional[TrainerConfigDict]): Optional dict that
extends the common config of the Trainer class.
num_workers (int): Number of remote rollout workers to create.
logdir (str): Optional logging directory for workers.
logdir (Optional[str]): Optional logging directory for workers.
_setup (bool): Whether to setup workers. This is only for testing.
"""
@ -53,7 +55,7 @@ class WorkerSet:
trainer_config = COMMON_CONFIG
self._env_creator = env_creator
self._policy = policy
self._policy_class = policy_class
self._remote_config = trainer_config
self._logdir = logdir
@ -63,8 +65,9 @@ class WorkerSet:
{"tf_session_args": trainer_config["local_tf_session_args"]})
# Always create a local worker
self._local_worker = self._make_worker(
RolloutWorker, env_creator, policy, 0, self._local_config)
self._local_worker = self._make_worker(RolloutWorker, env_creator,
self._policy_class, 0,
self._local_config)
# Create a number of remote workers
self._remote_workers = []
@ -102,8 +105,9 @@ class WorkerSet:
}
cls = RolloutWorker.as_remote(**remote_args).remote
self._remote_workers.extend([
self._make_worker(cls, self._env_creator, self._policy, i + 1,
self._remote_config) for i in range(num_workers)
self._make_worker(cls, self._env_creator, self._policy_class,
i + 1, self._remote_config)
for i in range(num_workers)
])
def reset(self, new_remote_workers: List["ActorHandle"]) -> None:
@ -190,14 +194,18 @@ class WorkerSet:
@staticmethod
def _from_existing(local_worker: RolloutWorker,
remote_workers: List["ActorHandle"] = None):
workers = WorkerSet(None, None, {}, _setup=False)
workers = WorkerSet(
env_creator=None,
policy_class=None,
trainer_config={},
_setup=False)
workers._local_worker = local_worker
workers._remote_workers = remote_workers or []
return workers
def _make_worker(
self, cls: Callable, env_creator: Callable[[EnvContext], EnvType],
policy: Policy, worker_index: int,
policy: Type[Policy], worker_index: int,
config: TrainerConfigDict) -> Union[RolloutWorker, "ActorHandle"]:
def session_creator():
logger.debug("Creating TF session {}".format(

View file

@ -1,6 +1,7 @@
import gym
from typing import Callable, Dict, List, Optional, Tuple
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.policy.dynamic_tf_policy import DynamicTFPolicy
from ray.rllib.policy import eager_tf_policy
@ -17,12 +18,15 @@ from ray.rllib.utils.typing import ModelGradients, TensorType, \
def build_tf_policy(
name: str,
*,
loss_fn: Callable[[Policy, ModelV2, type, SampleBatch], TensorType],
loss_fn: Callable[[
Policy, ModelV2, Type[TFActionDistribution], SampleBatch
], Union[TensorType, List[TensorType]]],
get_default_config: Optional[Callable[[None],
TrainerConfigDict]] = None,
postprocess_fn: Optional[Callable[[
Policy, SampleBatch, List[SampleBatch], "MultiAgentEpisode"
], None]] = None,
Policy, SampleBatch, Optional[List[SampleBatch]], Optional[
"MultiAgentEpisode"]
], SampleBatch]] = None,
stats_fn: Optional[Callable[[Policy, SampleBatch], Dict[
str, TensorType]]] = None,
optimizer_fn: Optional[Callable[[
@ -81,8 +85,10 @@ def build_tf_policy(
Args:
name (str): Name of the policy (e.g., "PPOTFPolicy").
loss_fn (Callable[[Policy, ModelV2, type, SampleBatch], TensorType]):
Callable for calculating a loss tensor.
loss_fn (Callable[[
Policy, ModelV2, Type[TFActionDistribution], SampleBatch],
Union[TensorType, List[TensorType]]]): Callable for calculating a
loss tensor.
get_default_config (Optional[Callable[[None], TrainerConfigDict]]):
Optional callable that returns the default config to merge with any
overrides. If None, uses only(!) the user-provided

View file

@ -1,5 +1,5 @@
import gym
from typing import Callable, Dict, List, Optional, Tuple
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.modelv2 import ModelV2
@ -22,7 +22,9 @@ torch, _ = try_import_torch()
def build_torch_policy(
name: str,
*,
loss_fn: Callable[[Policy, ModelV2, type, SampleBatch], TensorType],
loss_fn: Callable[[
Policy, ModelV2, Type[TorchDistributionWrapper], SampleBatch
], Union[TensorType, List[TensorType]]],
get_default_config: Optional[Callable[[], TrainerConfigDict]] = None,
stats_fn: Optional[Callable[[Policy, SampleBatch], Dict[
str, TensorType]]] = None,
@ -80,8 +82,9 @@ def build_torch_policy(
super's `postprocess_trajectory` method).
stats_fn (Optional[Callable[[Policy, SampleBatch],
Dict[str, TensorType]]]): Optional callable that returns a dict of
values given the policy and batch input tensors. If None,
will use `TorchPolicy.extra_grad_info()` instead.
values given the policy and training batch. If None,
will use `TorchPolicy.extra_grad_info()` instead. The stats dict is
used for logging (e.g. in TensorBoard).
extra_action_out_fn (Optional[Callable[[Policy, Dict[str, TensorType,
List[TensorType], ModelV2, TorchDistributionWrapper]], Dict[str,
TensorType]]]): Optional callable that returns a dict of extra