2020-03-02 15:16:37 -08:00
|
|
|
import logging
|
2021-08-31 12:21:49 +02:00
|
|
|
from typing import Callable, Iterable, List, Optional, Type, Union
|
2019-06-03 06:49:24 +08:00
|
|
|
|
2019-05-27 14:17:32 -07:00
|
|
|
from ray.rllib.agents.trainer import Trainer, COMMON_CONFIG
|
2020-10-06 20:28:16 +02:00
|
|
|
from ray.rllib.env.env_context import EnvContext
|
2020-06-19 13:09:05 -07:00
|
|
|
from ray.rllib.evaluation.worker_set import WorkerSet
|
|
|
|
from ray.rllib.policy import Policy
|
2019-06-07 16:45:36 -07:00
|
|
|
from ray.rllib.utils import add_mixins
|
2019-05-18 00:23:11 -07:00
|
|
|
from ray.rllib.utils.annotations import override, DeveloperAPI
|
2021-06-19 22:42:00 +02:00
|
|
|
from ray.rllib.utils.typing import EnvConfigDict, EnvType, \
|
|
|
|
PartialTrainerConfigDict, ResultDict, TrainerConfigDict
|
2021-11-16 20:52:42 +00:00
|
|
|
from ray.tune.logger import Logger
|
2019-05-18 00:23:11 -07:00
|
|
|
|
2020-03-02 15:16:37 -08:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
2019-05-18 00:23:11 -07:00
|
|
|
|
2021-11-16 11:26:47 +00:00
|
|
|
# TODO: Deprecate Trainer template generated by this utility function.
|
|
|
|
# Instead, users should sub-class Trainer directly and override some of its
|
|
|
|
# methods, e.g. `Trainer.setup()`.
|
2019-05-18 00:23:11 -07:00
|
|
|
@DeveloperAPI
|
2020-05-21 10:16:18 -07:00
|
|
|
def build_trainer(
|
2020-07-01 11:00:00 -07:00
|
|
|
name: str,
|
|
|
|
*,
|
2020-09-02 14:03:01 +02:00
|
|
|
default_config: Optional[TrainerConfigDict] = None,
|
|
|
|
validate_config: Optional[Callable[[TrainerConfigDict], None]] = None,
|
2020-08-20 17:05:57 +02:00
|
|
|
default_policy: Optional[Type[Policy]] = None,
|
|
|
|
get_policy_class: Optional[Callable[[TrainerConfigDict], Optional[Type[
|
|
|
|
Policy]]]] = None,
|
2020-10-06 20:28:16 +02:00
|
|
|
validate_env: Optional[Callable[[EnvType, EnvContext], None]] = None,
|
2020-08-20 17:05:57 +02:00
|
|
|
before_init: Optional[Callable[[Trainer], None]] = None,
|
|
|
|
after_init: Optional[Callable[[Trainer], None]] = None,
|
|
|
|
before_evaluate_fn: Optional[Callable[[Trainer], None]] = None,
|
|
|
|
mixins: Optional[List[type]] = None,
|
2021-08-31 12:21:49 +02:00
|
|
|
execution_plan: Optional[Union[Callable[
|
|
|
|
[WorkerSet, TrainerConfigDict], Iterable[ResultDict]], Callable[[
|
|
|
|
Trainer, WorkerSet, TrainerConfigDict
|
2021-11-16 11:26:47 +00:00
|
|
|
], Iterable[ResultDict]]]] = None,
|
2021-07-21 18:43:06 -04:00
|
|
|
allow_unknown_configs: bool = False,
|
|
|
|
allow_unknown_subkeys: Optional[List[str]] = None,
|
|
|
|
override_all_subkeys_if_type_changes: Optional[List[str]] = None,
|
|
|
|
) -> Type[Trainer]:
|
2021-11-01 21:45:11 +01:00
|
|
|
"""Helper function for defining a custom Trainer class.
|
2019-05-18 00:23:11 -07:00
|
|
|
|
2019-06-07 16:45:36 -07:00
|
|
|
Functions will be run in this order to initialize the trainer:
|
2021-11-01 21:45:11 +01:00
|
|
|
1. Config setup: validate_config, get_policy.
|
|
|
|
2. Worker setup: before_init, execution_plan.
|
|
|
|
3. Post setup: after_init.
|
2019-06-07 16:45:36 -07:00
|
|
|
|
2020-08-20 17:05:57 +02:00
|
|
|
Args:
|
2021-11-01 21:45:11 +01:00
|
|
|
name: name of the trainer (e.g., "PPO")
|
|
|
|
default_config: The default config dict of the algorithm,
|
|
|
|
otherwise uses the Trainer default config.
|
|
|
|
validate_config: Optional callable that takes the config to check
|
|
|
|
for correctness. It may mutate the config as needed.
|
|
|
|
default_policy: The default Policy class to use if `get_policy_class`
|
|
|
|
returns None.
|
|
|
|
get_policy_class: Optional callable that takes a config and returns
|
|
|
|
the policy class or None. If None is returned, will use
|
|
|
|
`default_policy` (which must be provided then).
|
|
|
|
validate_env: Optional callable to validate the generated environment
|
|
|
|
(only on worker=0).
|
|
|
|
before_init: Optional callable to run before anything is constructed
|
|
|
|
inside Trainer (Workers with Policies, execution plan, etc..).
|
|
|
|
Takes the Trainer instance as argument.
|
|
|
|
after_init: Optional callable to run at the end of trainer init
|
|
|
|
(after all Workers and the exec. plan have been constructed).
|
|
|
|
Takes the Trainer instance as argument.
|
|
|
|
before_evaluate_fn: Callback to run before evaluation. This takes
|
|
|
|
the trainer instance as argument.
|
|
|
|
mixins: List of any class mixins for the returned trainer class.
|
2020-03-26 18:03:20 +01:00
|
|
|
These mixins will be applied in order and will have higher
|
2020-06-25 19:01:32 +02:00
|
|
|
precedence than the Trainer class.
|
2021-11-01 21:45:11 +01:00
|
|
|
execution_plan: Optional callable that sets up the
|
2020-09-02 14:03:01 +02:00
|
|
|
distributed execution workflow.
|
2021-11-01 21:45:11 +01:00
|
|
|
allow_unknown_configs: Whether to allow unknown top-level config keys.
|
|
|
|
allow_unknown_subkeys: List of top-level keys
|
2021-07-21 18:43:06 -04:00
|
|
|
with value=dict, for which new sub-keys are allowed to be added to
|
|
|
|
the value dict. Appends to Trainer class defaults.
|
2021-11-01 21:45:11 +01:00
|
|
|
override_all_subkeys_if_type_changes: List of top level keys with
|
|
|
|
value=dict, for which we always override the entire value (dict),
|
|
|
|
iff the "type" key in that value dict changes. Appends to Trainer
|
|
|
|
class defaults.
|
2019-05-18 00:23:11 -07:00
|
|
|
|
|
|
|
Returns:
|
2021-11-01 21:45:11 +01:00
|
|
|
A Trainer sub-class configured by the specified args.
|
2019-05-18 00:23:11 -07:00
|
|
|
"""
|
|
|
|
|
2019-06-03 06:49:24 +08:00
|
|
|
original_kwargs = locals().copy()
|
2019-06-07 16:45:36 -07:00
|
|
|
base = add_mixins(Trainer, mixins)
|
2019-05-18 00:23:11 -07:00
|
|
|
|
2019-06-07 16:45:36 -07:00
|
|
|
class trainer_cls(base):
|
2019-05-18 00:23:11 -07:00
|
|
|
_name = name
|
2019-05-27 14:17:32 -07:00
|
|
|
_default_config = default_config or COMMON_CONFIG
|
2020-08-20 17:05:57 +02:00
|
|
|
_policy_class = default_policy
|
2019-05-18 00:23:11 -07:00
|
|
|
|
2021-11-16 20:52:42 +00:00
|
|
|
def __init__(self,
|
|
|
|
config: TrainerConfigDict = None,
|
|
|
|
env: Union[str, EnvType, None] = None,
|
|
|
|
logger_creator: Callable[[], Logger] = None,
|
|
|
|
remote_checkpoint_dir: Optional[str] = None,
|
|
|
|
sync_function_tpl: Optional[str] = None):
|
|
|
|
Trainer.__init__(self, config, env, logger_creator,
|
|
|
|
remote_checkpoint_dir, sync_function_tpl)
|
2019-06-07 16:45:36 -07:00
|
|
|
|
2021-07-21 18:43:06 -04:00
|
|
|
@override(base)
|
|
|
|
def setup(self, config: PartialTrainerConfigDict):
|
|
|
|
if allow_unknown_subkeys is not None:
|
|
|
|
self._allow_unknown_subkeys += allow_unknown_subkeys
|
|
|
|
self._allow_unknown_configs = allow_unknown_configs
|
|
|
|
if override_all_subkeys_if_type_changes is not None:
|
|
|
|
self._override_all_subkeys_if_type_changes += \
|
|
|
|
override_all_subkeys_if_type_changes
|
|
|
|
super().setup(config)
|
|
|
|
|
2020-08-20 17:05:57 +02:00
|
|
|
def _init(self, config: TrainerConfigDict,
|
|
|
|
env_creator: Callable[[EnvConfigDict], EnvType]):
|
2020-03-04 12:53:04 -08:00
|
|
|
|
2020-10-06 20:28:16 +02:00
|
|
|
# No `get_policy_class` function.
|
2020-03-26 18:03:20 +01:00
|
|
|
if get_policy_class is None:
|
2020-10-06 20:28:16 +02:00
|
|
|
# Default_policy must be provided (unless in multi-agent mode,
|
2021-11-16 11:26:47 +00:00
|
|
|
# where each policy can have its own default policy class).
|
2020-08-20 17:05:57 +02:00
|
|
|
if not config["multiagent"]["policies"]:
|
|
|
|
assert default_policy is not None
|
2020-10-06 20:28:16 +02:00
|
|
|
# Query the function for a class to use.
|
2020-03-26 18:03:20 +01:00
|
|
|
else:
|
2020-08-20 17:05:57 +02:00
|
|
|
self._policy_class = get_policy_class(config)
|
2020-10-06 20:28:16 +02:00
|
|
|
# If None returned, use default policy (must be provided).
|
2020-08-20 17:05:57 +02:00
|
|
|
if self._policy_class is None:
|
|
|
|
assert default_policy is not None
|
|
|
|
self._policy_class = default_policy
|
|
|
|
|
2019-06-07 16:45:36 -07:00
|
|
|
if before_init:
|
|
|
|
before_init(self)
|
2020-08-20 17:05:57 +02:00
|
|
|
|
2020-03-04 12:53:04 -08:00
|
|
|
# Creating all workers (excluding evaluation workers).
|
2020-10-06 20:28:16 +02:00
|
|
|
self.workers = self._make_workers(
|
|
|
|
env_creator=env_creator,
|
|
|
|
validate_env=validate_env,
|
|
|
|
policy_class=self._policy_class,
|
|
|
|
config=config,
|
|
|
|
num_workers=self.config["num_workers"])
|
2021-11-16 11:26:47 +00:00
|
|
|
# If execution plan is not provided (None), the Trainer will use
|
|
|
|
# it's already existing default `execution_plan()` static method
|
|
|
|
# instead.
|
|
|
|
if execution_plan is not None:
|
|
|
|
self.execution_plan = execution_plan
|
|
|
|
self.train_exec_impl = self.execution_plan(
|
2021-10-26 11:56:02 -07:00
|
|
|
self.workers, config, **self._kwargs_for_execution_plan())
|
2020-08-20 17:05:57 +02:00
|
|
|
|
2019-06-07 16:45:36 -07:00
|
|
|
if after_init:
|
|
|
|
after_init(self)
|
2019-05-18 00:23:11 -07:00
|
|
|
|
2021-06-19 22:42:00 +02:00
|
|
|
@staticmethod
|
|
|
|
@override(Trainer)
|
|
|
|
def _validate_config(config: PartialTrainerConfigDict,
|
|
|
|
trainer_obj_or_none: Optional["Trainer"] = None):
|
|
|
|
# Call super (Trainer) validation method first.
|
|
|
|
Trainer._validate_config(config, trainer_obj_or_none)
|
|
|
|
# Then call user defined one, if any.
|
|
|
|
if validate_config is not None:
|
|
|
|
validate_config(config)
|
|
|
|
|
2019-06-07 16:45:36 -07:00
|
|
|
@override(Trainer)
|
|
|
|
def _before_evaluate(self):
|
|
|
|
if before_evaluate_fn:
|
|
|
|
before_evaluate_fn(self)
|
|
|
|
|
2020-09-09 17:33:21 +02:00
|
|
|
@staticmethod
|
|
|
|
@override(Trainer)
|
|
|
|
def with_updates(**overrides) -> Type[Trainer]:
|
2020-09-20 11:27:02 +02:00
|
|
|
"""Build a copy of this trainer class with the specified overrides.
|
2019-06-07 16:45:36 -07:00
|
|
|
|
2020-09-09 17:33:21 +02:00
|
|
|
Keyword Args:
|
|
|
|
overrides (dict): use this to override any of the arguments
|
|
|
|
originally passed to build_trainer() for this policy.
|
2020-09-20 11:27:02 +02:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
Type[Trainer]: A the Trainer sub-class using `original_kwargs`
|
|
|
|
and `overrides`.
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
>>> MyClass = SomeOtherClass.with_updates({"name": "Mine"})
|
|
|
|
>>> issubclass(MyClass, SomeOtherClass)
|
|
|
|
... False
|
|
|
|
>>> issubclass(MyClass, Trainer)
|
|
|
|
... True
|
2020-09-09 17:33:21 +02:00
|
|
|
"""
|
|
|
|
return build_trainer(**dict(original_kwargs, **overrides))
|
2019-06-03 06:49:24 +08:00
|
|
|
|
2021-05-12 12:16:00 +02:00
|
|
|
def __repr__(self):
|
|
|
|
return self._name
|
|
|
|
|
2019-05-27 14:17:32 -07:00
|
|
|
trainer_cls.__name__ = name
|
|
|
|
trainer_cls.__qualname__ = name
|
2019-05-18 00:23:11 -07:00
|
|
|
return trainer_cls
|