mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] trainer_template.py: hard deprecation (error when used). (#23488)
This commit is contained in:
parent
f78404da4a
commit
7cb86acce2
11 changed files with 155 additions and 357 deletions
|
@ -180,11 +180,12 @@ of a sequence of repeating steps, or *dataflow*, of:
|
||||||
2. ``ConcatBatches``: The experiences are concatenated into one batch for training.
|
2. ``ConcatBatches``: The experiences are concatenated into one batch for training.
|
||||||
3. ``TrainOneStep``: Take a gradient step with respect to the policy loss, and update the worker weights.
|
3. ``TrainOneStep``: Take a gradient step with respect to the policy loss, and update the worker weights.
|
||||||
|
|
||||||
In code, this dataflow can be expressed as the following execution plan, which is a simple function that can be passed to ``build_trainer`` to define a new algorithm.
|
In code, this dataflow can be expressed as the following execution plan, which is a static method that can be overridden in your custom Trainer sub-classes to define new algorithms.
|
||||||
It takes in a ``WorkerSet`` and config, and returns an iterator over training results:
|
It takes in a ``WorkerSet`` and config, and returns an iterator over training results:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def execution_plan(workers: WorkerSet, config: TrainerConfigDict):
|
def execution_plan(workers: WorkerSet, config: TrainerConfigDict):
|
||||||
# type: LocalIterator[SampleBatchType]
|
# type: LocalIterator[SampleBatchType]
|
||||||
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
||||||
|
|
|
@ -152,12 +152,11 @@ We can create a `Trainer <#trainers>`__ and try running this policy on a toy env
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray import tune
|
from ray import tune
|
||||||
from ray.rllib.agents.trainer_template import build_trainer
|
from ray.rllib.agents.trainer import Trainer
|
||||||
|
|
||||||
# <class 'ray.rllib.agents.trainer_template.MyCustomTrainer'>
|
class MyTrainer(Trainer):
|
||||||
MyTrainer = build_trainer(
|
def get_default_policy_class(self, config):
|
||||||
name="MyCustomTrainer",
|
return MyTFPolicy
|
||||||
default_policy=MyTFPolicy)
|
|
||||||
|
|
||||||
ray.init()
|
ray.init()
|
||||||
tune.run(MyTrainer, config={"env": "CartPole-v0", "num_workers": 2})
|
tune.run(MyTrainer, config={"env": "CartPole-v0", "num_workers": 2})
|
||||||
|
@ -209,20 +208,36 @@ You might be wondering how RLlib makes the advantages placeholder automatically
|
||||||
|
|
||||||
**Example 1: Proximal Policy Optimization**
|
**Example 1: Proximal Policy Optimization**
|
||||||
|
|
||||||
In the above section you saw how to compose a simple policy gradient algorithm with RLlib. In this example, we'll dive into how PPO was built with RLlib and how you can modify it. First, check out the `PPO trainer definition <https://github.com/ray-project/ray/blob/master/rllib/agents/ppo/ppo.py>`__:
|
In the above section you saw how to compose a simple policy gradient algorithm with RLlib.
|
||||||
|
In this example, we'll dive into how PPO is defined within RLlib and how you can modify it.
|
||||||
|
First, check out the `PPO trainer definition <https://github.com/ray-project/ray/blob/master/rllib/agents/ppo/ppo.py>`__:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
PPOTrainer = build_trainer(
|
class PPOTrainer(Trainer):
|
||||||
name="PPOTrainer",
|
@classmethod
|
||||||
default_config=DEFAULT_CONFIG,
|
@override(Trainer)
|
||||||
default_policy=PPOTFPolicy,
|
def get_default_config(cls) -> TrainerConfigDict:
|
||||||
validate_config=validate_config,
|
return DEFAULT_CONFIG
|
||||||
execution_plan=execution_plan)
|
|
||||||
|
|
||||||
Besides some boilerplate for defining the PPO configuration and some warnings, the most important argument to take note of is the ``execution_plan``.
|
@override(Trainer)
|
||||||
|
def validate_config(self, config: TrainerConfigDict) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
The trainer's `execution plan <#execution-plans>`__ defines the distributed training workflow. Depending on the ``simple_optimizer`` trainer config, PPO can switch between a simple synchronous plan, or a multi-GPU plan that implements minibatch SGD (the default):
|
@override(Trainer)
|
||||||
|
def get_default_policy_class(self, config):
|
||||||
|
return PPOTFPolicy
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@override(Trainer)
|
||||||
|
def execution_plan(workers, config, **kwargs):
|
||||||
|
...
|
||||||
|
|
||||||
|
Besides some boilerplate for defining the PPO configuration and some warnings, the most important method to take note of is the ``execution_plan``.
|
||||||
|
|
||||||
|
The trainer's `execution plan <#execution-plans>`__ defines the distributed training workflow.
|
||||||
|
Depending on the ``simple_optimizer`` trainer config,
|
||||||
|
PPO can switch between a simple synchronous plan, or a multi-GPU plan that implements minibatch SGD (the default):
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
|
|
@ -204,7 +204,7 @@ class DDPPOTrainer(PPOTrainer):
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
LocalIterator[dict]: The Policy class to use with PGTrainer.
|
LocalIterator[dict]: The Policy class to use with PGTrainer.
|
||||||
If None, use `default_policy` provided in build_trainer().
|
If None, use `get_default_policy_class()` provided by Trainer.
|
||||||
"""
|
"""
|
||||||
assert (
|
assert (
|
||||||
len(kwargs) == 0
|
len(kwargs) == 0
|
||||||
|
|
|
@ -831,7 +831,6 @@ class Trainer(Trainable):
|
||||||
config, logger_creator, remote_checkpoint_dir, sync_function_tpl
|
config, logger_creator, remote_checkpoint_dir, sync_function_tpl
|
||||||
)
|
)
|
||||||
|
|
||||||
@ExperimentalAPI
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_config(cls) -> TrainerConfigDict:
|
def get_default_config(cls) -> TrainerConfigDict:
|
||||||
return cls._default_config or COMMON_CONFIG
|
return cls._default_config or COMMON_CONFIG
|
||||||
|
@ -907,7 +906,7 @@ class Trainer(Trainable):
|
||||||
# - Run the execution plan to create the local iterator to `next()`
|
# - Run the execution plan to create the local iterator to `next()`
|
||||||
# in each training iteration.
|
# in each training iteration.
|
||||||
# This matches the behavior of using `build_trainer()`, which
|
# This matches the behavior of using `build_trainer()`, which
|
||||||
# should no longer be used.
|
# has been deprecated.
|
||||||
self.workers = WorkerSet(
|
self.workers = WorkerSet(
|
||||||
env_creator=self.env_creator,
|
env_creator=self.env_creator,
|
||||||
validate_env=self.validate_env,
|
validate_env=self.validate_env,
|
||||||
|
@ -1034,7 +1033,6 @@ class Trainer(Trainable):
|
||||||
def _init(self, config: TrainerConfigDict, env_creator: EnvCreator) -> None:
|
def _init(self, config: TrainerConfigDict, env_creator: EnvCreator) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ExperimentalAPI
|
|
||||||
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
|
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
|
||||||
"""Returns a default Policy class to use, given a config.
|
"""Returns a default Policy class to use, given a config.
|
||||||
|
|
||||||
|
@ -1107,7 +1105,6 @@ class Trainer(Trainable):
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@ExperimentalAPI
|
|
||||||
def step_attempt(self) -> ResultDict:
|
def step_attempt(self) -> ResultDict:
|
||||||
"""Attempts a single training step, including evaluation, if required.
|
"""Attempts a single training step, including evaluation, if required.
|
||||||
|
|
||||||
|
@ -1389,7 +1386,7 @@ class Trainer(Trainable):
|
||||||
# Also return the results here for convenience.
|
# Also return the results here for convenience.
|
||||||
return self.evaluation_metrics
|
return self.evaluation_metrics
|
||||||
|
|
||||||
@ExperimentalAPI
|
@DeveloperAPI
|
||||||
def training_iteration(self) -> ResultDict:
|
def training_iteration(self) -> ResultDict:
|
||||||
"""Default single iteration logic of an algorithm.
|
"""Default single iteration logic of an algorithm.
|
||||||
|
|
||||||
|
@ -2308,7 +2305,7 @@ class Trainer(Trainable):
|
||||||
check_if_correct_nn_framework_installed()
|
check_if_correct_nn_framework_installed()
|
||||||
resolve_tf_settings()
|
resolve_tf_settings()
|
||||||
|
|
||||||
@ExperimentalAPI
|
@DeveloperAPI
|
||||||
def validate_config(self, config: TrainerConfigDict) -> None:
|
def validate_config(self, config: TrainerConfigDict) -> None:
|
||||||
"""Validates a given config dict for this Trainer.
|
"""Validates a given config dict for this Trainer.
|
||||||
|
|
||||||
|
@ -2709,14 +2706,10 @@ class Trainer(Trainable):
|
||||||
if self.train_exec_impl is not None:
|
if self.train_exec_impl is not None:
|
||||||
self.train_exec_impl.shared_metrics.get().restore(state["train_exec_impl"])
|
self.train_exec_impl.shared_metrics.get().restore(state["train_exec_impl"])
|
||||||
|
|
||||||
# TODO: Deprecate this method (`build_trainer` should no longer be used).
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def with_updates(**overrides) -> Type["Trainer"]:
|
@Deprecated(error=True)
|
||||||
raise NotImplementedError(
|
def with_updates(*args, **kwargs):
|
||||||
"`with_updates` may only be called on Trainer sub-classes "
|
pass
|
||||||
"that were generated via the `ray.rllib.agents.trainer_template."
|
|
||||||
"build_trainer()` function (which has been deprecated)!"
|
|
||||||
)
|
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def _create_local_replay_buffer_if_necessary(
|
def _create_local_replay_buffer_if_necessary(
|
||||||
|
|
|
@ -1,226 +1,10 @@
|
||||||
import logging
|
|
||||||
from typing import Callable, Iterable, List, Optional, Type, Union
|
|
||||||
|
|
||||||
from ray.rllib.agents.trainer import Trainer, COMMON_CONFIG
|
|
||||||
from ray.rllib.env.env_context import EnvContext
|
|
||||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
|
||||||
from ray.rllib.policy import Policy
|
|
||||||
from ray.rllib.utils import add_mixins
|
|
||||||
from ray.rllib.utils.annotations import override
|
|
||||||
from ray.rllib.utils.deprecation import Deprecated
|
from ray.rllib.utils.deprecation import Deprecated
|
||||||
from ray.rllib.utils.typing import (
|
|
||||||
EnvCreator,
|
|
||||||
EnvType,
|
|
||||||
PartialTrainerConfigDict,
|
|
||||||
ResultDict,
|
|
||||||
TrainerConfigDict,
|
|
||||||
)
|
|
||||||
from ray.tune.logger import Logger
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@Deprecated(
|
@Deprecated(
|
||||||
new="Sub-class from Trainer (or another Trainer sub-class) directly! "
|
new="Sub-class from Trainer (or another Trainer sub-class) directly! "
|
||||||
"See e.g. ray.rllib.agents.dqn.dqn.py for an example.",
|
"See e.g. ray.rllib.agents.dqn.dqn.py for an example.",
|
||||||
error=False,
|
error=True,
|
||||||
)
|
)
|
||||||
def build_trainer(
|
def build_trainer(*args, **kwargs):
|
||||||
name: str,
|
pass # deprecated w/ error
|
||||||
*,
|
|
||||||
default_config: Optional[TrainerConfigDict] = None,
|
|
||||||
validate_config: Optional[Callable[[TrainerConfigDict], None]] = None,
|
|
||||||
default_policy: Optional[Type[Policy]] = None,
|
|
||||||
get_policy_class: Optional[
|
|
||||||
Callable[[TrainerConfigDict], Optional[Type[Policy]]]
|
|
||||||
] = None,
|
|
||||||
validate_env: Optional[Callable[[EnvType, EnvContext], None]] = None,
|
|
||||||
before_init: Optional[Callable[[Trainer], None]] = None,
|
|
||||||
after_init: Optional[Callable[[Trainer], None]] = None,
|
|
||||||
before_evaluate_fn: Optional[Callable[[Trainer], None]] = None,
|
|
||||||
mixins: Optional[List[type]] = None,
|
|
||||||
execution_plan: Optional[
|
|
||||||
Union[
|
|
||||||
Callable[[WorkerSet, TrainerConfigDict], Iterable[ResultDict]],
|
|
||||||
Callable[[Trainer, WorkerSet, TrainerConfigDict], Iterable[ResultDict]],
|
|
||||||
]
|
|
||||||
] = None,
|
|
||||||
allow_unknown_configs: bool = False,
|
|
||||||
allow_unknown_subkeys: Optional[List[str]] = None,
|
|
||||||
override_all_subkeys_if_type_changes: Optional[List[str]] = None,
|
|
||||||
) -> Type[Trainer]:
|
|
||||||
"""Helper function for defining a custom Trainer class.
|
|
||||||
|
|
||||||
Functions will be run in this order to initialize the trainer:
|
|
||||||
1. Config setup: validate_config, get_policy.
|
|
||||||
2. Worker setup: before_init, execution_plan.
|
|
||||||
3. Post setup: after_init.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: name of the trainer (e.g., "PPO")
|
|
||||||
default_config: The default config dict of the algorithm,
|
|
||||||
otherwise uses the Trainer default config.
|
|
||||||
validate_config: Optional callable that takes the config to check
|
|
||||||
for correctness. It may mutate the config as needed.
|
|
||||||
default_policy: The default Policy class to use if `get_policy_class`
|
|
||||||
returns None.
|
|
||||||
get_policy_class: Optional callable that takes a config and returns
|
|
||||||
the policy class or None. If None is returned, will use
|
|
||||||
`default_policy` (which must be provided then).
|
|
||||||
validate_env: Optional callable to validate the generated environment
|
|
||||||
(only on worker=0).
|
|
||||||
before_init: Optional callable to run before anything is constructed
|
|
||||||
inside Trainer (Workers with Policies, execution plan, etc..).
|
|
||||||
Takes the Trainer instance as argument.
|
|
||||||
after_init: Optional callable to run at the end of trainer init
|
|
||||||
(after all Workers and the exec. plan have been constructed).
|
|
||||||
Takes the Trainer instance as argument.
|
|
||||||
before_evaluate_fn: Callback to run before evaluation. This takes
|
|
||||||
the trainer instance as argument.
|
|
||||||
mixins: List of any class mixins for the returned trainer class.
|
|
||||||
These mixins will be applied in order and will have higher
|
|
||||||
precedence than the Trainer class.
|
|
||||||
execution_plan: Optional callable that sets up the
|
|
||||||
distributed execution workflow.
|
|
||||||
allow_unknown_configs: Whether to allow unknown top-level config keys.
|
|
||||||
allow_unknown_subkeys: List of top-level keys
|
|
||||||
with value=dict, for which new sub-keys are allowed to be added to
|
|
||||||
the value dict. Appends to Trainer class defaults.
|
|
||||||
override_all_subkeys_if_type_changes: List of top level keys with
|
|
||||||
value=dict, for which we always override the entire value (dict),
|
|
||||||
iff the "type" key in that value dict changes. Appends to Trainer
|
|
||||||
class defaults.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A Trainer sub-class configured by the specified args.
|
|
||||||
"""
|
|
||||||
|
|
||||||
original_kwargs = locals().copy()
|
|
||||||
base = add_mixins(Trainer, mixins)
|
|
||||||
|
|
||||||
class trainer_cls(base):
|
|
||||||
_name = name
|
|
||||||
_default_config = default_config or COMMON_CONFIG
|
|
||||||
_policy_class = default_policy
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: TrainerConfigDict = None,
|
|
||||||
env: Union[str, EnvType, None] = None,
|
|
||||||
logger_creator: Callable[[], Logger] = None,
|
|
||||||
remote_checkpoint_dir: Optional[str] = None,
|
|
||||||
sync_function_tpl: Optional[str] = None,
|
|
||||||
):
|
|
||||||
Trainer.__init__(
|
|
||||||
self,
|
|
||||||
config,
|
|
||||||
env,
|
|
||||||
logger_creator,
|
|
||||||
remote_checkpoint_dir,
|
|
||||||
sync_function_tpl,
|
|
||||||
)
|
|
||||||
|
|
||||||
@override(base)
|
|
||||||
def setup(self, config: PartialTrainerConfigDict):
|
|
||||||
if allow_unknown_subkeys is not None:
|
|
||||||
self._allow_unknown_subkeys += allow_unknown_subkeys
|
|
||||||
self._allow_unknown_configs = allow_unknown_configs
|
|
||||||
if override_all_subkeys_if_type_changes is not None:
|
|
||||||
self._override_all_subkeys_if_type_changes += (
|
|
||||||
override_all_subkeys_if_type_changes
|
|
||||||
)
|
|
||||||
Trainer.setup(self, config)
|
|
||||||
|
|
||||||
def _init(self, config: TrainerConfigDict, env_creator: EnvCreator):
|
|
||||||
|
|
||||||
# No `get_policy_class` function.
|
|
||||||
if get_policy_class is None:
|
|
||||||
# Default_policy must be provided (unless in multi-agent mode,
|
|
||||||
# where each policy can have its own default policy class).
|
|
||||||
if not config["multiagent"]["policies"]:
|
|
||||||
assert default_policy is not None
|
|
||||||
# Query the function for a class to use.
|
|
||||||
else:
|
|
||||||
self._policy_class = get_policy_class(config)
|
|
||||||
# If None returned, use default policy (must be provided).
|
|
||||||
if self._policy_class is None:
|
|
||||||
assert default_policy is not None
|
|
||||||
self._policy_class = default_policy
|
|
||||||
|
|
||||||
if before_init:
|
|
||||||
before_init(self)
|
|
||||||
|
|
||||||
# Creating all workers (excluding evaluation workers).
|
|
||||||
self.workers = WorkerSet(
|
|
||||||
env_creator=env_creator,
|
|
||||||
validate_env=validate_env,
|
|
||||||
policy_class=self._policy_class,
|
|
||||||
trainer_config=config,
|
|
||||||
num_workers=self.config["num_workers"],
|
|
||||||
)
|
|
||||||
|
|
||||||
self.train_exec_impl = self.execution_plan(
|
|
||||||
self.workers, config, **self._kwargs_for_execution_plan()
|
|
||||||
)
|
|
||||||
|
|
||||||
if after_init:
|
|
||||||
after_init(self)
|
|
||||||
|
|
||||||
@override(Trainer)
|
|
||||||
def validate_config(self, config: PartialTrainerConfigDict):
|
|
||||||
# Call super's validation method.
|
|
||||||
Trainer.validate_config(self, config)
|
|
||||||
# Then call user defined one, if any.
|
|
||||||
if validate_config is not None:
|
|
||||||
validate_config(config)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@override(Trainer)
|
|
||||||
def execution_plan(workers, config, **kwargs):
|
|
||||||
# `execution_plan` is provided, use it inside
|
|
||||||
# `self.execution_plan()`.
|
|
||||||
if execution_plan is not None:
|
|
||||||
return execution_plan(workers, config, **kwargs)
|
|
||||||
# If `execution_plan` is not provided (None), the Trainer will use
|
|
||||||
# it's already existing default `execution_plan()` static method
|
|
||||||
# instead.
|
|
||||||
else:
|
|
||||||
return Trainer.execution_plan(workers, config, **kwargs)
|
|
||||||
|
|
||||||
@override(Trainer)
|
|
||||||
def _before_evaluate(self):
|
|
||||||
if before_evaluate_fn:
|
|
||||||
before_evaluate_fn(self)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@override(Trainer)
|
|
||||||
def with_updates(**overrides) -> Type[Trainer]:
|
|
||||||
"""Build a copy of this trainer class with the specified overrides.
|
|
||||||
|
|
||||||
Keyword Args:
|
|
||||||
overrides (dict): use this to override any of the arguments
|
|
||||||
originally passed to build_trainer() for this policy.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Type[Trainer]: A the Trainer sub-class using `original_kwargs`
|
|
||||||
and `overrides`.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> from ray.rllib.agents.ppo import PPOTrainer
|
|
||||||
>>> MyPPOClass = PPOTrainer.with_updates({"name": "MyPPO"})
|
|
||||||
>>> issubclass(MyPPOClass, PPOTrainer)
|
|
||||||
False
|
|
||||||
>>> issubclass(MyPPOClass, Trainer)
|
|
||||||
True
|
|
||||||
>>> trainer = MyPPOClass()
|
|
||||||
>>> print(trainer)
|
|
||||||
MyPPO
|
|
||||||
"""
|
|
||||||
return build_trainer(**dict(original_kwargs, **overrides))
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return self._name
|
|
||||||
|
|
||||||
trainer_cls.__name__ = name
|
|
||||||
trainer_cls.__qualname__ = name
|
|
||||||
return trainer_cls
|
|
||||||
|
|
|
@ -2,7 +2,7 @@ import argparse
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.rllib.agents.trainer_template import build_trainer
|
from ray.rllib.agents.trainer import Trainer
|
||||||
from ray.rllib.examples.policy.bare_metal_policy_with_custom_view_reqs import (
|
from ray.rllib.examples.policy.bare_metal_policy_with_custom_view_reqs import (
|
||||||
BareMetalPolicyWithCustomViewReqs,
|
BareMetalPolicyWithCustomViewReqs,
|
||||||
)
|
)
|
||||||
|
@ -50,9 +50,9 @@ if __name__ == "__main__":
|
||||||
ray.init(num_cpus=args.num_cpus or None, local_mode=args.local_mode)
|
ray.init(num_cpus=args.num_cpus or None, local_mode=args.local_mode)
|
||||||
|
|
||||||
# Create q custom Trainer class using our custom Policy.
|
# Create q custom Trainer class using our custom Policy.
|
||||||
BareMetalPolicyTrainer = build_trainer(
|
class BareMetalPolicyTrainer(Trainer):
|
||||||
name="MyPolicy", default_policy=BareMetalPolicyWithCustomViewReqs
|
def get_default_policy_class(self, config):
|
||||||
)
|
return BareMetalPolicyWithCustomViewReqs
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
"env": "CartPole-v0",
|
"env": "CartPole-v0",
|
||||||
|
|
|
@ -3,7 +3,7 @@ import os
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray import tune
|
from ray import tune
|
||||||
from ray.rllib.agents.trainer_template import build_trainer
|
from ray.rllib.agents.trainer import Trainer
|
||||||
from ray.rllib.evaluation.postprocessing import discount_cumsum
|
from ray.rllib.evaluation.postprocessing import discount_cumsum
|
||||||
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
||||||
from ray.rllib.utils.framework import try_import_tf
|
from ray.rllib.utils.framework import try_import_tf
|
||||||
|
@ -35,11 +35,12 @@ MyTFPolicy = build_tf_policy(
|
||||||
postprocess_fn=calculate_advantages,
|
postprocess_fn=calculate_advantages,
|
||||||
)
|
)
|
||||||
|
|
||||||
# <class 'ray.rllib.agents.trainer_template.MyCustomTrainer'>
|
|
||||||
MyTrainer = build_trainer(
|
# Create a new Trainer using the Policy defined above.
|
||||||
name="MyCustomTrainer",
|
class MyTrainer(Trainer):
|
||||||
default_policy=MyTFPolicy,
|
def get_default_policy_class(self, config):
|
||||||
)
|
return MyTFPolicy
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
@ -3,7 +3,7 @@ import os
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray import tune
|
from ray import tune
|
||||||
from ray.rllib.agents.trainer_template import build_trainer
|
from ray.rllib.agents.trainer import Trainer
|
||||||
from ray.rllib.policy.policy_template import build_policy_class
|
from ray.rllib.policy.policy_template import build_policy_class
|
||||||
from ray.rllib.policy.sample_batch import SampleBatch
|
from ray.rllib.policy.sample_batch import SampleBatch
|
||||||
|
|
||||||
|
@ -24,11 +24,12 @@ MyTorchPolicy = build_policy_class(
|
||||||
name="MyTorchPolicy", framework="torch", loss_fn=policy_gradient_loss
|
name="MyTorchPolicy", framework="torch", loss_fn=policy_gradient_loss
|
||||||
)
|
)
|
||||||
|
|
||||||
# <class 'ray.rllib.agents.trainer_template.MyCustomTrainer'>
|
|
||||||
MyTrainer = build_trainer(
|
# Create a new Trainer using the Policy defined above.
|
||||||
name="MyCustomTrainer",
|
class MyTrainer(Trainer):
|
||||||
default_policy=MyTorchPolicy,
|
def get_default_policy_class(self, config):
|
||||||
)
|
return MyTorchPolicy
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
@ -3,7 +3,7 @@ import os
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray.rllib.agents.trainer_template import build_trainer
|
from ray.rllib.agents.trainer import Trainer
|
||||||
from ray.rllib.examples.models.eager_model import EagerModel
|
from ray.rllib.examples.models.eager_model import EagerModel
|
||||||
from ray.rllib.models import ModelCatalog
|
from ray.rllib.models import ModelCatalog
|
||||||
from ray.rllib.policy.sample_batch import SampleBatch
|
from ray.rllib.policy.sample_batch import SampleBatch
|
||||||
|
@ -91,11 +91,12 @@ MyTFPolicy = build_tf_policy(
|
||||||
loss_fn=policy_gradient_loss,
|
loss_fn=policy_gradient_loss,
|
||||||
)
|
)
|
||||||
|
|
||||||
# <class 'ray.rllib.agents.trainer_template.MyCustomTrainer'>
|
|
||||||
MyTrainer = build_trainer(
|
# Create a new Trainer using the Policy defined above.
|
||||||
name="MyCustomTrainer",
|
class MyTrainer(Trainer):
|
||||||
default_policy=MyTFPolicy,
|
def get_default_policy_class(self, config):
|
||||||
)
|
return MyTFPolicy
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
ray.init()
|
ray.init()
|
||||||
|
|
|
@ -6,7 +6,7 @@ import numpy as np
|
||||||
|
|
||||||
from ray.rllib import Policy
|
from ray.rllib import Policy
|
||||||
from ray.rllib.agents import with_common_config
|
from ray.rllib.agents import with_common_config
|
||||||
from ray.rllib.agents.trainer_template import build_trainer
|
from ray.rllib.agents.trainer import Trainer
|
||||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||||
from ray.rllib.execution.rollout_ops import ParallelRollouts, SelectExperiences
|
from ray.rllib.execution.rollout_ops import ParallelRollouts, SelectExperiences
|
||||||
|
@ -64,24 +64,29 @@ class RandomParametriclPolicy(Policy, ABC):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def execution_plan(
|
# Create a new Trainer using the Policy and config defined above and a new
|
||||||
workers: WorkerSet, config: TrainerConfigDict, **kwargs
|
# execution plan.
|
||||||
) -> LocalIterator[dict]:
|
class RandomParametricTrainer(Trainer):
|
||||||
rollouts = ParallelRollouts(workers, mode="async")
|
@classmethod
|
||||||
|
def get_default_config(cls):
|
||||||
|
return DEFAULT_CONFIG
|
||||||
|
|
||||||
# Collect batches for the trainable policies.
|
def get_default_policy_class(self, config):
|
||||||
rollouts = rollouts.for_each(SelectExperiences(local_worker=workers.local_worker()))
|
return RandomParametriclPolicy
|
||||||
|
|
||||||
# Return training metrics.
|
@staticmethod
|
||||||
return StandardMetricsReporting(rollouts, workers, config)
|
def execution_plan(
|
||||||
|
workers: WorkerSet, config: TrainerConfigDict, **kwargs
|
||||||
|
) -> LocalIterator[dict]:
|
||||||
|
rollouts = ParallelRollouts(workers, mode="async")
|
||||||
|
|
||||||
|
# Collect batches for the trainable policies.
|
||||||
|
rollouts = rollouts.for_each(
|
||||||
|
SelectExperiences(local_worker=workers.local_worker())
|
||||||
|
)
|
||||||
|
|
||||||
RandomParametricTrainer = build_trainer(
|
# Return training metrics.
|
||||||
name="RandomParametric",
|
return StandardMetricsReporting(rollouts, workers, config)
|
||||||
default_config=DEFAULT_CONFIG,
|
|
||||||
default_policy=RandomParametriclPolicy,
|
|
||||||
execution_plan=execution_plan,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
|
@ -10,7 +10,7 @@ import os
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray import tune
|
from ray import tune
|
||||||
from ray.rllib.agents.trainer_template import build_trainer
|
from ray.rllib.agents.trainer import Trainer
|
||||||
from ray.rllib.agents.dqn.dqn import DEFAULT_CONFIG as DQN_CONFIG
|
from ray.rllib.agents.dqn.dqn import DEFAULT_CONFIG as DQN_CONFIG
|
||||||
from ray.rllib.agents.dqn.dqn_tf_policy import DQNTFPolicy
|
from ray.rllib.agents.dqn.dqn_tf_policy import DQNTFPolicy
|
||||||
from ray.rllib.agents.dqn.dqn_torch_policy import DQNTorchPolicy
|
from ray.rllib.agents.dqn.dqn_torch_policy import DQNTorchPolicy
|
||||||
|
@ -54,81 +54,84 @@ parser.add_argument(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def custom_training_workflow(workers: WorkerSet, config: dict):
|
# Define new Trainer with custom execution_plan/workflow.
|
||||||
local_replay_buffer = MultiAgentReplayBuffer(
|
class MyTrainer(Trainer):
|
||||||
num_shards=1, learning_starts=1000, capacity=50000, replay_batch_size=64
|
@staticmethod
|
||||||
)
|
def execution_plan(workers: WorkerSet, config: dict, **kwargs):
|
||||||
|
local_replay_buffer = MultiAgentReplayBuffer(
|
||||||
def add_ppo_metrics(batch):
|
num_shards=1, learning_starts=1000, capacity=50000, replay_batch_size=64
|
||||||
print(
|
|
||||||
"PPO policy learning on samples from",
|
|
||||||
batch.policy_batches.keys(),
|
|
||||||
"env steps",
|
|
||||||
batch.env_steps(),
|
|
||||||
"agent steps",
|
|
||||||
batch.env_steps(),
|
|
||||||
)
|
)
|
||||||
metrics = _get_shared_metrics()
|
|
||||||
metrics.counters["agent_steps_trained_PPO"] += batch.env_steps()
|
|
||||||
return batch
|
|
||||||
|
|
||||||
def add_dqn_metrics(batch):
|
def add_ppo_metrics(batch):
|
||||||
print(
|
print(
|
||||||
"DQN policy learning on samples from",
|
"PPO policy learning on samples from",
|
||||||
batch.policy_batches.keys(),
|
batch.policy_batches.keys(),
|
||||||
"env steps",
|
"env steps",
|
||||||
batch.env_steps(),
|
batch.env_steps(),
|
||||||
"agent steps",
|
"agent steps",
|
||||||
batch.env_steps(),
|
batch.env_steps(),
|
||||||
|
)
|
||||||
|
metrics = _get_shared_metrics()
|
||||||
|
metrics.counters["agent_steps_trained_PPO"] += batch.env_steps()
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def add_dqn_metrics(batch):
|
||||||
|
print(
|
||||||
|
"DQN policy learning on samples from",
|
||||||
|
batch.policy_batches.keys(),
|
||||||
|
"env steps",
|
||||||
|
batch.env_steps(),
|
||||||
|
"agent steps",
|
||||||
|
batch.env_steps(),
|
||||||
|
)
|
||||||
|
metrics = _get_shared_metrics()
|
||||||
|
metrics.counters["agent_steps_trained_DQN"] += batch.env_steps()
|
||||||
|
return batch
|
||||||
|
|
||||||
|
# Generate common experiences.
|
||||||
|
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
||||||
|
r1, r2 = rollouts.duplicate(n=2)
|
||||||
|
|
||||||
|
# DQN sub-flow.
|
||||||
|
dqn_store_op = r1.for_each(SelectExperiences(["dqn_policy"])).for_each(
|
||||||
|
StoreToReplayBuffer(local_buffer=local_replay_buffer)
|
||||||
)
|
)
|
||||||
metrics = _get_shared_metrics()
|
dqn_replay_op = (
|
||||||
metrics.counters["agent_steps_trained_DQN"] += batch.env_steps()
|
Replay(local_buffer=local_replay_buffer)
|
||||||
return batch
|
.for_each(add_dqn_metrics)
|
||||||
|
.for_each(TrainOneStep(workers, policies=["dqn_policy"]))
|
||||||
# Generate common experiences.
|
.for_each(
|
||||||
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
UpdateTargetNetwork(
|
||||||
r1, r2 = rollouts.duplicate(n=2)
|
workers, target_update_freq=500, policies=["dqn_policy"]
|
||||||
|
)
|
||||||
# DQN sub-flow.
|
|
||||||
dqn_store_op = r1.for_each(SelectExperiences(["dqn_policy"])).for_each(
|
|
||||||
StoreToReplayBuffer(local_buffer=local_replay_buffer)
|
|
||||||
)
|
|
||||||
dqn_replay_op = (
|
|
||||||
Replay(local_buffer=local_replay_buffer)
|
|
||||||
.for_each(add_dqn_metrics)
|
|
||||||
.for_each(TrainOneStep(workers, policies=["dqn_policy"]))
|
|
||||||
.for_each(
|
|
||||||
UpdateTargetNetwork(
|
|
||||||
workers, target_update_freq=500, policies=["dqn_policy"]
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
dqn_train_op = Concurrently(
|
||||||
dqn_train_op = Concurrently(
|
[dqn_store_op, dqn_replay_op], mode="round_robin", output_indexes=[1]
|
||||||
[dqn_store_op, dqn_replay_op], mode="round_robin", output_indexes=[1]
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# PPO sub-flow.
|
# PPO sub-flow.
|
||||||
ppo_train_op = (
|
ppo_train_op = (
|
||||||
r2.for_each(SelectExperiences(["ppo_policy"]))
|
r2.for_each(SelectExperiences(["ppo_policy"]))
|
||||||
.combine(ConcatBatches(min_batch_size=200, count_steps_by="env_steps"))
|
.combine(ConcatBatches(min_batch_size=200, count_steps_by="env_steps"))
|
||||||
.for_each(add_ppo_metrics)
|
.for_each(add_ppo_metrics)
|
||||||
.for_each(StandardizeFields(["advantages"]))
|
.for_each(StandardizeFields(["advantages"]))
|
||||||
.for_each(
|
.for_each(
|
||||||
TrainOneStep(
|
TrainOneStep(
|
||||||
workers,
|
workers,
|
||||||
policies=["ppo_policy"],
|
policies=["ppo_policy"],
|
||||||
num_sgd_iter=10,
|
num_sgd_iter=10,
|
||||||
sgd_minibatch_size=128,
|
sgd_minibatch_size=128,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# Combined training flow
|
# Combined training flow
|
||||||
train_op = Concurrently(
|
train_op = Concurrently(
|
||||||
[ppo_train_op, dqn_train_op], mode="async", output_indexes=[1]
|
[ppo_train_op, dqn_train_op], mode="async", output_indexes=[1]
|
||||||
)
|
)
|
||||||
|
|
||||||
return StandardMetricsReporting(train_op, workers, config)
|
return StandardMetricsReporting(train_op, workers, config)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -167,12 +170,6 @@ if __name__ == "__main__":
|
||||||
else:
|
else:
|
||||||
return "dqn_policy"
|
return "dqn_policy"
|
||||||
|
|
||||||
MyTrainer = build_trainer(
|
|
||||||
name="PPO_DQN_MultiAgent",
|
|
||||||
default_policy=None,
|
|
||||||
execution_plan=custom_training_workflow,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
"rollout_fragment_length": 50,
|
"rollout_fragment_length": 50,
|
||||||
"num_workers": 0,
|
"num_workers": 0,
|
||||||
|
|
Loading…
Add table
Reference in a new issue