ray/rllib/agents/trainer_template.py

306 lines
14 KiB
Python
Raw Normal View History

import concurrent.futures
from functools import partial
import logging
from typing import Callable, Iterable, List, Optional, Type, Union
from ray.rllib.agents.trainer import Trainer, COMMON_CONFIG
from ray.rllib.env.env_context import EnvContext
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
from ray.rllib.execution.train_ops import TrainOneStep, MultiGPUTrainOneStep
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.policy import Policy
from ray.rllib.utils import add_mixins
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.typing import EnvConfigDict, EnvType, \
PartialTrainerConfigDict, ResultDict, TrainerConfigDict
logger = logging.getLogger(__name__)
def default_execution_plan(workers: WorkerSet, config: TrainerConfigDict,
**kwargs):
assert len(kwargs) == 0, (
"Default execution_plan does NOT take any additional parameters")
# Collects experiences in parallel from multiple RolloutWorker actors.
rollouts = ParallelRollouts(workers, mode="bulk_sync")
# Combine experiences batches until we hit `train_batch_size` in size.
# Then, train the policy on those experiences and update the workers.
train_op = rollouts.combine(
ConcatBatches(
min_batch_size=config["train_batch_size"],
count_steps_by=config["multiagent"]["count_steps_by"],
))
if config.get("simple_optimizer") is True:
train_op = train_op.for_each(TrainOneStep(workers))
else:
train_op = train_op.for_each(
MultiGPUTrainOneStep(
workers=workers,
sgd_minibatch_size=config.get("sgd_minibatch_size",
config["train_batch_size"]),
num_sgd_iter=config.get("num_sgd_iter", 1),
num_gpus=config["num_gpus"],
shuffle_sequences=config.get("shuffle_sequences", False),
_fake_gpus=config["_fake_gpus"],
framework=config["framework"]))
# Add on the standard episode reward, etc. metrics reporting. This returns
# a LocalIterator[metrics_dict] representing metrics for each train step.
return StandardMetricsReporting(train_op, workers, config)
@DeveloperAPI
def build_trainer(
name: str,
*,
default_config: Optional[TrainerConfigDict] = None,
validate_config: Optional[Callable[[TrainerConfigDict], None]] = None,
default_policy: Optional[Type[Policy]] = None,
get_policy_class: Optional[Callable[[TrainerConfigDict], Optional[Type[
Policy]]]] = None,
validate_env: Optional[Callable[[EnvType, EnvContext], None]] = None,
before_init: Optional[Callable[[Trainer], None]] = None,
after_init: Optional[Callable[[Trainer], None]] = None,
before_evaluate_fn: Optional[Callable[[Trainer], None]] = None,
mixins: Optional[List[type]] = None,
execution_plan: Optional[Union[Callable[
[WorkerSet, TrainerConfigDict], Iterable[ResultDict]], Callable[[
Trainer, WorkerSet, TrainerConfigDict
], Iterable[ResultDict]]]] = default_execution_plan,
allow_unknown_configs: bool = False,
allow_unknown_subkeys: Optional[List[str]] = None,
override_all_subkeys_if_type_changes: Optional[List[str]] = None,
) -> Type[Trainer]:
"""Helper function for defining a custom Trainer class.
Functions will be run in this order to initialize the trainer:
1. Config setup: validate_config, get_policy.
2. Worker setup: before_init, execution_plan.
3. Post setup: after_init.
Args:
name: name of the trainer (e.g., "PPO")
default_config: The default config dict of the algorithm,
otherwise uses the Trainer default config.
validate_config: Optional callable that takes the config to check
for correctness. It may mutate the config as needed.
default_policy: The default Policy class to use if `get_policy_class`
returns None.
get_policy_class: Optional callable that takes a config and returns
the policy class or None. If None is returned, will use
`default_policy` (which must be provided then).
validate_env: Optional callable to validate the generated environment
(only on worker=0).
before_init: Optional callable to run before anything is constructed
inside Trainer (Workers with Policies, execution plan, etc..).
Takes the Trainer instance as argument.
after_init: Optional callable to run at the end of trainer init
(after all Workers and the exec. plan have been constructed).
Takes the Trainer instance as argument.
before_evaluate_fn: Callback to run before evaluation. This takes
the trainer instance as argument.
mixins: List of any class mixins for the returned trainer class.
These mixins will be applied in order and will have higher
precedence than the Trainer class.
execution_plan: Optional callable that sets up the
distributed execution workflow.
allow_unknown_configs: Whether to allow unknown top-level config keys.
allow_unknown_subkeys: List of top-level keys
with value=dict, for which new sub-keys are allowed to be added to
the value dict. Appends to Trainer class defaults.
override_all_subkeys_if_type_changes: List of top level keys with
value=dict, for which we always override the entire value (dict),
iff the "type" key in that value dict changes. Appends to Trainer
class defaults.
Returns:
A Trainer sub-class configured by the specified args.
"""
original_kwargs = locals().copy()
base = add_mixins(Trainer, mixins)
class trainer_cls(base):
_name = name
_default_config = default_config or COMMON_CONFIG
_policy_class = default_policy
def __init__(self, config=None, env=None, logger_creator=None):
Trainer.__init__(self, config, env, logger_creator)
@override(base)
def setup(self, config: PartialTrainerConfigDict):
if allow_unknown_subkeys is not None:
self._allow_unknown_subkeys += allow_unknown_subkeys
self._allow_unknown_configs = allow_unknown_configs
if override_all_subkeys_if_type_changes is not None:
self._override_all_subkeys_if_type_changes += \
override_all_subkeys_if_type_changes
super().setup(config)
def _init(self, config: TrainerConfigDict,
env_creator: Callable[[EnvConfigDict], EnvType]):
# No `get_policy_class` function.
if get_policy_class is None:
# Default_policy must be provided (unless in multi-agent mode,
# where each policy can have its own default policy class.
if not config["multiagent"]["policies"]:
assert default_policy is not None
self._policy_class = default_policy
# Query the function for a class to use.
else:
self._policy_class = get_policy_class(config)
# If None returned, use default policy (must be provided).
if self._policy_class is None:
assert default_policy is not None
self._policy_class = default_policy
if before_init:
before_init(self)
# Creating all workers (excluding evaluation workers).
self.workers = self._make_workers(
env_creator=env_creator,
validate_env=validate_env,
policy_class=self._policy_class,
config=config,
num_workers=self.config["num_workers"])
self.execution_plan = execution_plan
self.train_exec_impl = execution_plan(
self.workers, config, **self._kwargs_for_execution_plan())
if after_init:
after_init(self)
@override(Trainer)
def step(self):
# self._iteration gets incremented after this function returns,
# meaning that e. g. the first time this function is called,
# self._iteration will be 0.
evaluate_this_iter = \
self.config["evaluation_interval"] and \
(self._iteration + 1) % self.config["evaluation_interval"] == 0
# No evaluation necessary, just run the next training iteration.
if not evaluate_this_iter:
step_results = next(self.train_exec_impl)
# We have to evaluate in this training iteration.
else:
# No parallelism.
if not self.config["evaluation_parallel_to_training"]:
step_results = next(self.train_exec_impl)
# Kick off evaluation-loop (and parallel train() call,
# if requested).
# Parallel eval + training.
if self.config["evaluation_parallel_to_training"]:
with concurrent.futures.ThreadPoolExecutor() as executor:
train_future = executor.submit(
lambda: next(self.train_exec_impl))
if self.config["evaluation_num_episodes"] == "auto":
# Run at least one `evaluate()` (num_episodes_done
# must be > 0), even if the training is very fast.
def episodes_left_fn(num_episodes_done):
if num_episodes_done > 0 and \
train_future.done():
return 0
else:
return self.config[
"evaluation_num_workers"]
evaluation_metrics = self.evaluate(
episodes_left_fn=episodes_left_fn)
else:
evaluation_metrics = self.evaluate()
# Collect the training results from the future.
step_results = train_future.result()
# Sequential: train (already done above), then eval.
else:
evaluation_metrics = self.evaluate()
# Add evaluation results to train results.
assert isinstance(evaluation_metrics, dict), \
"Trainer.evaluate() needs to return a dict."
step_results.update(evaluation_metrics)
# Check `env_task_fn` for possible update of the env's task.
if self.config["env_task_fn"] is not None:
if not callable(self.config["env_task_fn"]):
raise ValueError(
"`env_task_fn` must be None or a callable taking "
"[train_results, env, env_ctx] as args!")
def fn(env, env_context, task_fn):
new_task = task_fn(step_results, env, env_context)
cur_task = env.get_task()
if cur_task != new_task:
env.set_task(new_task)
fn = partial(fn, task_fn=self.config["env_task_fn"])
self.workers.foreach_env_with_context(fn)
return step_results
@staticmethod
@override(Trainer)
def _validate_config(config: PartialTrainerConfigDict,
trainer_obj_or_none: Optional["Trainer"] = None):
# Call super (Trainer) validation method first.
Trainer._validate_config(config, trainer_obj_or_none)
# Then call user defined one, if any.
if validate_config is not None:
validate_config(config)
@override(Trainer)
def _before_evaluate(self):
if before_evaluate_fn:
before_evaluate_fn(self)
@override(Trainer)
def __getstate__(self):
state = Trainer.__getstate__(self)
state["train_exec_impl"] = (
self.train_exec_impl.shared_metrics.get().save())
return state
@override(Trainer)
def __setstate__(self, state):
Trainer.__setstate__(self, state)
self.train_exec_impl.shared_metrics.get().restore(
state["train_exec_impl"])
@staticmethod
@override(Trainer)
def with_updates(**overrides) -> Type[Trainer]:
2020-09-20 11:27:02 +02:00
"""Build a copy of this trainer class with the specified overrides.
Keyword Args:
overrides (dict): use this to override any of the arguments
originally passed to build_trainer() for this policy.
2020-09-20 11:27:02 +02:00
Returns:
Type[Trainer]: A the Trainer sub-class using `original_kwargs`
and `overrides`.
Examples:
>>> MyClass = SomeOtherClass.with_updates({"name": "Mine"})
>>> issubclass(MyClass, SomeOtherClass)
... False
>>> issubclass(MyClass, Trainer)
... True
"""
return build_trainer(**dict(original_kwargs, **overrides))
def __repr__(self):
return self._name
trainer_cls.__name__ = name
trainer_cls.__qualname__ = name
return trainer_cls