[RLlib] trainer_template.py: hard deprecation (error when used). (#23488)

This commit is contained in:
Sven Mika 2022-03-25 18:25:51 +01:00 committed by GitHub
parent f78404da4a
commit 7cb86acce2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 155 additions and 357 deletions

View file

@ -180,11 +180,12 @@ of a sequence of repeating steps, or *dataflow*, of:
2. ``ConcatBatches``: The experiences are concatenated into one batch for training.
3. ``TrainOneStep``: Take a gradient step with respect to the policy loss, and update the worker weights.
In code, this dataflow can be expressed as the following execution plan, which is a simple function that can be passed to ``build_trainer`` to define a new algorithm.
In code, this dataflow can be expressed as the following execution plan, which is a static method that can be overridden in your custom Trainer sub-classes to define new algorithms.
It takes in a ``WorkerSet`` and config, and returns an iterator over training results:
.. code-block:: python
@staticmethod
def execution_plan(workers: WorkerSet, config: TrainerConfigDict):
# type: LocalIterator[SampleBatchType]
rollouts = ParallelRollouts(workers, mode="bulk_sync")

View file

@ -152,12 +152,11 @@ We can create a `Trainer <#trainers>`__ and try running this policy on a toy env
import ray
from ray import tune
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.agents.trainer import Trainer
# <class 'ray.rllib.agents.trainer_template.MyCustomTrainer'>
MyTrainer = build_trainer(
name="MyCustomTrainer",
default_policy=MyTFPolicy)
class MyTrainer(Trainer):
def get_default_policy_class(self, config):
return MyTFPolicy
ray.init()
tune.run(MyTrainer, config={"env": "CartPole-v0", "num_workers": 2})
@ -209,20 +208,36 @@ You might be wondering how RLlib makes the advantages placeholder automatically
**Example 1: Proximal Policy Optimization**
In the above section you saw how to compose a simple policy gradient algorithm with RLlib. In this example, we'll dive into how PPO was built with RLlib and how you can modify it. First, check out the `PPO trainer definition <https://github.com/ray-project/ray/blob/master/rllib/agents/ppo/ppo.py>`__:
In the above section you saw how to compose a simple policy gradient algorithm with RLlib.
In this example, we'll dive into how PPO is defined within RLlib and how you can modify it.
First, check out the `PPO trainer definition <https://github.com/ray-project/ray/blob/master/rllib/agents/ppo/ppo.py>`__:
.. code-block:: python
PPOTrainer = build_trainer(
name="PPOTrainer",
default_config=DEFAULT_CONFIG,
default_policy=PPOTFPolicy,
validate_config=validate_config,
execution_plan=execution_plan)
class PPOTrainer(Trainer):
@classmethod
@override(Trainer)
def get_default_config(cls) -> TrainerConfigDict:
return DEFAULT_CONFIG
Besides some boilerplate for defining the PPO configuration and some warnings, the most important argument to take note of is the ``execution_plan``.
@override(Trainer)
def validate_config(self, config: TrainerConfigDict) -> None:
...
The trainer's `execution plan <#execution-plans>`__ defines the distributed training workflow. Depending on the ``simple_optimizer`` trainer config, PPO can switch between a simple synchronous plan, or a multi-GPU plan that implements minibatch SGD (the default):
@override(Trainer)
def get_default_policy_class(self, config):
return PPOTFPolicy
@staticmethod
@override(Trainer)
def execution_plan(workers, config, **kwargs):
...
Besides some boilerplate for defining the PPO configuration and some warnings, the most important method to take note of is the ``execution_plan``.
The trainer's `execution plan <#execution-plans>`__ defines the distributed training workflow.
Depending on the ``simple_optimizer`` trainer config,
PPO can switch between a simple synchronous plan, or a multi-GPU plan that implements minibatch SGD (the default):
.. code-block:: python

View file

@ -204,7 +204,7 @@ class DDPPOTrainer(PPOTrainer):
Returns:
LocalIterator[dict]: The Policy class to use with PGTrainer.
If None, use `default_policy` provided in build_trainer().
If None, use `get_default_policy_class()` provided by Trainer.
"""
assert (
len(kwargs) == 0

View file

@ -831,7 +831,6 @@ class Trainer(Trainable):
config, logger_creator, remote_checkpoint_dir, sync_function_tpl
)
@ExperimentalAPI
@classmethod
def get_default_config(cls) -> TrainerConfigDict:
return cls._default_config or COMMON_CONFIG
@ -907,7 +906,7 @@ class Trainer(Trainable):
# - Run the execution plan to create the local iterator to `next()`
# in each training iteration.
# This matches the behavior of using `build_trainer()`, which
# should no longer be used.
# has been deprecated.
self.workers = WorkerSet(
env_creator=self.env_creator,
validate_env=self.validate_env,
@ -1034,7 +1033,6 @@ class Trainer(Trainable):
def _init(self, config: TrainerConfigDict, env_creator: EnvCreator) -> None:
raise NotImplementedError
@ExperimentalAPI
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
"""Returns a default Policy class to use, given a config.
@ -1107,7 +1105,6 @@ class Trainer(Trainable):
return result
@ExperimentalAPI
def step_attempt(self) -> ResultDict:
"""Attempts a single training step, including evaluation, if required.
@ -1389,7 +1386,7 @@ class Trainer(Trainable):
# Also return the results here for convenience.
return self.evaluation_metrics
@ExperimentalAPI
@DeveloperAPI
def training_iteration(self) -> ResultDict:
"""Default single iteration logic of an algorithm.
@ -2308,7 +2305,7 @@ class Trainer(Trainable):
check_if_correct_nn_framework_installed()
resolve_tf_settings()
@ExperimentalAPI
@DeveloperAPI
def validate_config(self, config: TrainerConfigDict) -> None:
"""Validates a given config dict for this Trainer.
@ -2709,14 +2706,10 @@ class Trainer(Trainable):
if self.train_exec_impl is not None:
self.train_exec_impl.shared_metrics.get().restore(state["train_exec_impl"])
# TODO: Deprecate this method (`build_trainer` should no longer be used).
@staticmethod
def with_updates(**overrides) -> Type["Trainer"]:
raise NotImplementedError(
"`with_updates` may only be called on Trainer sub-classes "
"that were generated via the `ray.rllib.agents.trainer_template."
"build_trainer()` function (which has been deprecated)!"
)
@Deprecated(error=True)
def with_updates(*args, **kwargs):
pass
@DeveloperAPI
def _create_local_replay_buffer_if_necessary(

View file

@ -1,226 +1,10 @@
import logging
from typing import Callable, Iterable, List, Optional, Type, Union
from ray.rllib.agents.trainer import Trainer, COMMON_CONFIG
from ray.rllib.env.env_context import EnvContext
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.policy import Policy
from ray.rllib.utils import add_mixins
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.typing import (
EnvCreator,
EnvType,
PartialTrainerConfigDict,
ResultDict,
TrainerConfigDict,
)
from ray.tune.logger import Logger
logger = logging.getLogger(__name__)
@Deprecated(
new="Sub-class from Trainer (or another Trainer sub-class) directly! "
"See e.g. ray.rllib.agents.dqn.dqn.py for an example.",
error=False,
error=True,
)
def build_trainer(
name: str,
*,
default_config: Optional[TrainerConfigDict] = None,
validate_config: Optional[Callable[[TrainerConfigDict], None]] = None,
default_policy: Optional[Type[Policy]] = None,
get_policy_class: Optional[
Callable[[TrainerConfigDict], Optional[Type[Policy]]]
] = None,
validate_env: Optional[Callable[[EnvType, EnvContext], None]] = None,
before_init: Optional[Callable[[Trainer], None]] = None,
after_init: Optional[Callable[[Trainer], None]] = None,
before_evaluate_fn: Optional[Callable[[Trainer], None]] = None,
mixins: Optional[List[type]] = None,
execution_plan: Optional[
Union[
Callable[[WorkerSet, TrainerConfigDict], Iterable[ResultDict]],
Callable[[Trainer, WorkerSet, TrainerConfigDict], Iterable[ResultDict]],
]
] = None,
allow_unknown_configs: bool = False,
allow_unknown_subkeys: Optional[List[str]] = None,
override_all_subkeys_if_type_changes: Optional[List[str]] = None,
) -> Type[Trainer]:
"""Helper function for defining a custom Trainer class.
Functions will be run in this order to initialize the trainer:
1. Config setup: validate_config, get_policy.
2. Worker setup: before_init, execution_plan.
3. Post setup: after_init.
Args:
name: name of the trainer (e.g., "PPO")
default_config: The default config dict of the algorithm,
otherwise uses the Trainer default config.
validate_config: Optional callable that takes the config to check
for correctness. It may mutate the config as needed.
default_policy: The default Policy class to use if `get_policy_class`
returns None.
get_policy_class: Optional callable that takes a config and returns
the policy class or None. If None is returned, will use
`default_policy` (which must be provided then).
validate_env: Optional callable to validate the generated environment
(only on worker=0).
before_init: Optional callable to run before anything is constructed
inside Trainer (Workers with Policies, execution plan, etc..).
Takes the Trainer instance as argument.
after_init: Optional callable to run at the end of trainer init
(after all Workers and the exec. plan have been constructed).
Takes the Trainer instance as argument.
before_evaluate_fn: Callback to run before evaluation. This takes
the trainer instance as argument.
mixins: List of any class mixins for the returned trainer class.
These mixins will be applied in order and will have higher
precedence than the Trainer class.
execution_plan: Optional callable that sets up the
distributed execution workflow.
allow_unknown_configs: Whether to allow unknown top-level config keys.
allow_unknown_subkeys: List of top-level keys
with value=dict, for which new sub-keys are allowed to be added to
the value dict. Appends to Trainer class defaults.
override_all_subkeys_if_type_changes: List of top level keys with
value=dict, for which we always override the entire value (dict),
iff the "type" key in that value dict changes. Appends to Trainer
class defaults.
Returns:
A Trainer sub-class configured by the specified args.
"""
original_kwargs = locals().copy()
base = add_mixins(Trainer, mixins)
class trainer_cls(base):
_name = name
_default_config = default_config or COMMON_CONFIG
_policy_class = default_policy
def __init__(
self,
config: TrainerConfigDict = None,
env: Union[str, EnvType, None] = None,
logger_creator: Callable[[], Logger] = None,
remote_checkpoint_dir: Optional[str] = None,
sync_function_tpl: Optional[str] = None,
):
Trainer.__init__(
self,
config,
env,
logger_creator,
remote_checkpoint_dir,
sync_function_tpl,
)
@override(base)
def setup(self, config: PartialTrainerConfigDict):
if allow_unknown_subkeys is not None:
self._allow_unknown_subkeys += allow_unknown_subkeys
self._allow_unknown_configs = allow_unknown_configs
if override_all_subkeys_if_type_changes is not None:
self._override_all_subkeys_if_type_changes += (
override_all_subkeys_if_type_changes
)
Trainer.setup(self, config)
def _init(self, config: TrainerConfigDict, env_creator: EnvCreator):
# No `get_policy_class` function.
if get_policy_class is None:
# Default_policy must be provided (unless in multi-agent mode,
# where each policy can have its own default policy class).
if not config["multiagent"]["policies"]:
assert default_policy is not None
# Query the function for a class to use.
else:
self._policy_class = get_policy_class(config)
# If None returned, use default policy (must be provided).
if self._policy_class is None:
assert default_policy is not None
self._policy_class = default_policy
if before_init:
before_init(self)
# Creating all workers (excluding evaluation workers).
self.workers = WorkerSet(
env_creator=env_creator,
validate_env=validate_env,
policy_class=self._policy_class,
trainer_config=config,
num_workers=self.config["num_workers"],
)
self.train_exec_impl = self.execution_plan(
self.workers, config, **self._kwargs_for_execution_plan()
)
if after_init:
after_init(self)
@override(Trainer)
def validate_config(self, config: PartialTrainerConfigDict):
# Call super's validation method.
Trainer.validate_config(self, config)
# Then call user defined one, if any.
if validate_config is not None:
validate_config(config)
@staticmethod
@override(Trainer)
def execution_plan(workers, config, **kwargs):
# `execution_plan` is provided, use it inside
# `self.execution_plan()`.
if execution_plan is not None:
return execution_plan(workers, config, **kwargs)
# If `execution_plan` is not provided (None), the Trainer will use
# it's already existing default `execution_plan()` static method
# instead.
else:
return Trainer.execution_plan(workers, config, **kwargs)
@override(Trainer)
def _before_evaluate(self):
if before_evaluate_fn:
before_evaluate_fn(self)
@staticmethod
@override(Trainer)
def with_updates(**overrides) -> Type[Trainer]:
"""Build a copy of this trainer class with the specified overrides.
Keyword Args:
overrides (dict): use this to override any of the arguments
originally passed to build_trainer() for this policy.
Returns:
Type[Trainer]: A the Trainer sub-class using `original_kwargs`
and `overrides`.
Examples:
>>> from ray.rllib.agents.ppo import PPOTrainer
>>> MyPPOClass = PPOTrainer.with_updates({"name": "MyPPO"})
>>> issubclass(MyPPOClass, PPOTrainer)
False
>>> issubclass(MyPPOClass, Trainer)
True
>>> trainer = MyPPOClass()
>>> print(trainer)
MyPPO
"""
return build_trainer(**dict(original_kwargs, **overrides))
def __repr__(self):
return self._name
trainer_cls.__name__ = name
trainer_cls.__qualname__ = name
return trainer_cls
def build_trainer(*args, **kwargs):
pass # deprecated w/ error

View file

@ -2,7 +2,7 @@ import argparse
import os
import ray
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.agents.trainer import Trainer
from ray.rllib.examples.policy.bare_metal_policy_with_custom_view_reqs import (
BareMetalPolicyWithCustomViewReqs,
)
@ -50,9 +50,9 @@ if __name__ == "__main__":
ray.init(num_cpus=args.num_cpus or None, local_mode=args.local_mode)
# Create q custom Trainer class using our custom Policy.
BareMetalPolicyTrainer = build_trainer(
name="MyPolicy", default_policy=BareMetalPolicyWithCustomViewReqs
)
class BareMetalPolicyTrainer(Trainer):
def get_default_policy_class(self, config):
return BareMetalPolicyWithCustomViewReqs
config = {
"env": "CartPole-v0",

View file

@ -3,7 +3,7 @@ import os
import ray
from ray import tune
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.agents.trainer import Trainer
from ray.rllib.evaluation.postprocessing import discount_cumsum
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.utils.framework import try_import_tf
@ -35,11 +35,12 @@ MyTFPolicy = build_tf_policy(
postprocess_fn=calculate_advantages,
)
# <class 'ray.rllib.agents.trainer_template.MyCustomTrainer'>
MyTrainer = build_trainer(
name="MyCustomTrainer",
default_policy=MyTFPolicy,
)
# Create a new Trainer using the Policy defined above.
class MyTrainer(Trainer):
def get_default_policy_class(self, config):
return MyTFPolicy
if __name__ == "__main__":
args = parser.parse_args()

View file

@ -3,7 +3,7 @@ import os
import ray
from ray import tune
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.agents.trainer import Trainer
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
@ -24,11 +24,12 @@ MyTorchPolicy = build_policy_class(
name="MyTorchPolicy", framework="torch", loss_fn=policy_gradient_loss
)
# <class 'ray.rllib.agents.trainer_template.MyCustomTrainer'>
MyTrainer = build_trainer(
name="MyCustomTrainer",
default_policy=MyTorchPolicy,
)
# Create a new Trainer using the Policy defined above.
class MyTrainer(Trainer):
def get_default_policy_class(self, config):
return MyTorchPolicy
if __name__ == "__main__":
args = parser.parse_args()

View file

@ -3,7 +3,7 @@ import os
import random
import ray
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.agents.trainer import Trainer
from ray.rllib.examples.models.eager_model import EagerModel
from ray.rllib.models import ModelCatalog
from ray.rllib.policy.sample_batch import SampleBatch
@ -91,11 +91,12 @@ MyTFPolicy = build_tf_policy(
loss_fn=policy_gradient_loss,
)
# <class 'ray.rllib.agents.trainer_template.MyCustomTrainer'>
MyTrainer = build_trainer(
name="MyCustomTrainer",
default_policy=MyTFPolicy,
)
# Create a new Trainer using the Policy defined above.
class MyTrainer(Trainer):
def get_default_policy_class(self, config):
return MyTFPolicy
if __name__ == "__main__":
ray.init()

View file

@ -6,7 +6,7 @@ import numpy as np
from ray.rllib import Policy
from ray.rllib.agents import with_common_config
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.agents.trainer import Trainer
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.rollout_ops import ParallelRollouts, SelectExperiences
@ -64,26 +64,31 @@ class RandomParametriclPolicy(Policy, ABC):
pass
# Create a new Trainer using the Policy and config defined above and a new
# execution plan.
class RandomParametricTrainer(Trainer):
@classmethod
def get_default_config(cls):
return DEFAULT_CONFIG
def get_default_policy_class(self, config):
return RandomParametriclPolicy
@staticmethod
def execution_plan(
workers: WorkerSet, config: TrainerConfigDict, **kwargs
) -> LocalIterator[dict]:
rollouts = ParallelRollouts(workers, mode="async")
# Collect batches for the trainable policies.
rollouts = rollouts.for_each(SelectExperiences(local_worker=workers.local_worker()))
rollouts = rollouts.for_each(
SelectExperiences(local_worker=workers.local_worker())
)
# Return training metrics.
return StandardMetricsReporting(rollouts, workers, config)
RandomParametricTrainer = build_trainer(
name="RandomParametric",
default_config=DEFAULT_CONFIG,
default_policy=RandomParametriclPolicy,
execution_plan=execution_plan,
)
def main():
register_env("pa_cartpole", lambda _: ParametricActionsCartPole(10))
trainer = RandomParametricTrainer(env="pa_cartpole")

View file

@ -10,7 +10,7 @@ import os
import ray
from ray import tune
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.agents.trainer import Trainer
from ray.rllib.agents.dqn.dqn import DEFAULT_CONFIG as DQN_CONFIG
from ray.rllib.agents.dqn.dqn_tf_policy import DQNTFPolicy
from ray.rllib.agents.dqn.dqn_torch_policy import DQNTorchPolicy
@ -54,7 +54,10 @@ parser.add_argument(
)
def custom_training_workflow(workers: WorkerSet, config: dict):
# Define new Trainer with custom execution_plan/workflow.
class MyTrainer(Trainer):
@staticmethod
def execution_plan(workers: WorkerSet, config: dict, **kwargs):
local_replay_buffer = MultiAgentReplayBuffer(
num_shards=1, learning_starts=1000, capacity=50000, replay_batch_size=64
)
@ -167,12 +170,6 @@ if __name__ == "__main__":
else:
return "dqn_policy"
MyTrainer = build_trainer(
name="PPO_DQN_MultiAgent",
default_policy=None,
execution_plan=custom_training_workflow,
)
config = {
"rollout_fragment_length": 50,
"num_workers": 0,