2021-04-13 09:53:35 +02:00
|
|
|
import concurrent.futures
|
2021-05-16 17:35:10 +02:00
|
|
|
from functools import partial
|
2020-03-02 15:16:37 -08:00
|
|
|
import logging
|
2020-08-20 17:05:57 +02:00
|
|
|
from typing import Callable, Iterable, List, Optional, Type
|
2019-06-03 06:49:24 +08:00
|
|
|
|
2019-05-27 14:17:32 -07:00
|
|
|
from ray.rllib.agents.trainer import Trainer, COMMON_CONFIG
|
2020-10-06 20:28:16 +02:00
|
|
|
from ray.rllib.env.env_context import EnvContext
|
2020-06-19 13:09:05 -07:00
|
|
|
from ray.rllib.evaluation.worker_set import WorkerSet
|
2020-05-21 10:16:18 -07:00
|
|
|
from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
|
2021-03-08 15:41:27 +01:00
|
|
|
from ray.rllib.execution.train_ops import TrainOneStep, TrainTFMultiGPU
|
2020-05-21 10:16:18 -07:00
|
|
|
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
2020-06-19 13:09:05 -07:00
|
|
|
from ray.rllib.policy import Policy
|
2019-06-07 16:45:36 -07:00
|
|
|
from ray.rllib.utils import add_mixins
|
2019-05-18 00:23:11 -07:00
|
|
|
from ray.rllib.utils.annotations import override, DeveloperAPI
|
2020-08-20 17:05:57 +02:00
|
|
|
from ray.rllib.utils.typing import EnvConfigDict, EnvType, ResultDict, \
|
|
|
|
TrainerConfigDict
|
2019-05-18 00:23:11 -07:00
|
|
|
|
2020-03-02 15:16:37 -08:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
2019-05-18 00:23:11 -07:00
|
|
|
|
2020-06-19 13:09:05 -07:00
|
|
|
def default_execution_plan(workers: WorkerSet, config: TrainerConfigDict):
|
2020-05-21 10:16:18 -07:00
|
|
|
# Collects experiences in parallel from multiple RolloutWorker actors.
|
|
|
|
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
|
|
|
|
|
|
|
# Combine experiences batches until we hit `train_batch_size` in size.
|
|
|
|
# Then, train the policy on those experiences and update the workers.
|
2020-12-09 01:41:45 +01:00
|
|
|
train_op = rollouts.combine(
|
|
|
|
ConcatBatches(
|
|
|
|
min_batch_size=config["train_batch_size"],
|
|
|
|
count_steps_by=config["multiagent"]["count_steps_by"],
|
2021-03-08 15:41:27 +01:00
|
|
|
))
|
|
|
|
|
|
|
|
if config.get("simple_optimizer") is True:
|
|
|
|
train_op = train_op.for_each(TrainOneStep(workers))
|
|
|
|
else:
|
|
|
|
train_op = train_op.for_each(
|
|
|
|
TrainTFMultiGPU(
|
|
|
|
workers=workers,
|
|
|
|
sgd_minibatch_size=config.get("sgd_minibatch_size",
|
|
|
|
config["train_batch_size"]),
|
|
|
|
num_sgd_iter=config.get("num_sgd_iter", 1),
|
|
|
|
num_gpus=config["num_gpus"],
|
|
|
|
shuffle_sequences=config.get("shuffle_sequences", False),
|
|
|
|
_fake_gpus=config["_fake_gpus"],
|
|
|
|
framework=config["framework"]))
|
2020-05-21 10:16:18 -07:00
|
|
|
|
|
|
|
# Add on the standard episode reward, etc. metrics reporting. This returns
|
|
|
|
# a LocalIterator[metrics_dict] representing metrics for each train step.
|
|
|
|
return StandardMetricsReporting(train_op, workers, config)
|
|
|
|
|
|
|
|
|
2019-05-18 00:23:11 -07:00
|
|
|
@DeveloperAPI
|
2020-05-21 10:16:18 -07:00
|
|
|
def build_trainer(
|
2020-07-01 11:00:00 -07:00
|
|
|
name: str,
|
|
|
|
*,
|
2020-09-02 14:03:01 +02:00
|
|
|
default_config: Optional[TrainerConfigDict] = None,
|
|
|
|
validate_config: Optional[Callable[[TrainerConfigDict], None]] = None,
|
2020-08-20 17:05:57 +02:00
|
|
|
default_policy: Optional[Type[Policy]] = None,
|
|
|
|
get_policy_class: Optional[Callable[[TrainerConfigDict], Optional[Type[
|
|
|
|
Policy]]]] = None,
|
2020-10-06 20:28:16 +02:00
|
|
|
validate_env: Optional[Callable[[EnvType, EnvContext], None]] = None,
|
2020-08-20 17:05:57 +02:00
|
|
|
before_init: Optional[Callable[[Trainer], None]] = None,
|
|
|
|
after_init: Optional[Callable[[Trainer], None]] = None,
|
|
|
|
before_evaluate_fn: Optional[Callable[[Trainer], None]] = None,
|
|
|
|
mixins: Optional[List[type]] = None,
|
|
|
|
execution_plan: Optional[Callable[[
|
|
|
|
WorkerSet, TrainerConfigDict
|
2020-09-02 14:03:01 +02:00
|
|
|
], Iterable[ResultDict]]] = default_execution_plan) -> Type[Trainer]:
|
2019-05-18 00:23:11 -07:00
|
|
|
"""Helper function for defining a custom trainer.
|
|
|
|
|
2019-06-07 16:45:36 -07:00
|
|
|
Functions will be run in this order to initialize the trainer:
|
2020-05-21 10:16:18 -07:00
|
|
|
1. Config setup: validate_config, get_policy
|
|
|
|
2. Worker setup: before_init, execution_plan
|
2019-06-07 16:45:36 -07:00
|
|
|
3. Post setup: after_init
|
|
|
|
|
2020-08-20 17:05:57 +02:00
|
|
|
Args:
|
2019-05-18 00:23:11 -07:00
|
|
|
name (str): name of the trainer (e.g., "PPO")
|
2020-09-02 14:03:01 +02:00
|
|
|
default_config (Optional[TrainerConfigDict]): The default config dict
|
2020-08-20 17:05:57 +02:00
|
|
|
of the algorithm, otherwise uses the Trainer default config.
|
2020-09-02 14:03:01 +02:00
|
|
|
validate_config (Optional[Callable[[TrainerConfigDict], None]]):
|
|
|
|
Optional callable that takes the config to check for correctness.
|
|
|
|
It may mutate the config as needed.
|
2020-08-20 17:05:57 +02:00
|
|
|
default_policy (Optional[Type[Policy]]): The default Policy class to
|
2020-12-09 20:49:21 +01:00
|
|
|
use if `get_policy_class` returns None.
|
2020-08-20 17:05:57 +02:00
|
|
|
get_policy_class (Optional[Callable[
|
|
|
|
TrainerConfigDict, Optional[Type[Policy]]]]): Optional callable
|
|
|
|
that takes a config and returns the policy class or None. If None
|
|
|
|
is returned, will use `default_policy` (which must be provided
|
|
|
|
then).
|
2020-10-06 20:28:16 +02:00
|
|
|
validate_env (Optional[Callable[[EnvType, EnvContext], None]]):
|
|
|
|
Optional callable to validate the generated environment (only
|
|
|
|
on worker=0).
|
2020-08-20 17:05:57 +02:00
|
|
|
before_init (Optional[Callable[[Trainer], None]]): Optional callable to
|
|
|
|
run before anything is constructed inside Trainer (Workers with
|
|
|
|
Policies, execution plan, etc..). Takes the Trainer instance as
|
|
|
|
argument.
|
|
|
|
after_init (Optional[Callable[[Trainer], None]]): Optional callable to
|
|
|
|
run at the end of trainer init (after all Workers and the exec.
|
|
|
|
plan have been constructed). Takes the Trainer instance as
|
|
|
|
argument.
|
|
|
|
before_evaluate_fn (Optional[Callable[[Trainer], None]]): Callback to
|
|
|
|
run before evaluation. This takes the trainer instance as argument.
|
2020-03-26 18:03:20 +01:00
|
|
|
mixins (list): list of any class mixins for the returned trainer class.
|
|
|
|
These mixins will be applied in order and will have higher
|
2020-06-25 19:01:32 +02:00
|
|
|
precedence than the Trainer class.
|
2020-09-02 14:03:01 +02:00
|
|
|
execution_plan (Optional[Callable[[WorkerSet, TrainerConfigDict],
|
|
|
|
Iterable[ResultDict]]]): Optional callable that sets up the
|
|
|
|
distributed execution workflow.
|
2019-05-18 00:23:11 -07:00
|
|
|
|
|
|
|
Returns:
|
2020-09-02 14:03:01 +02:00
|
|
|
Type[Trainer]: A Trainer sub-class configured by the specified args.
|
2019-05-18 00:23:11 -07:00
|
|
|
"""
|
|
|
|
|
2019-06-03 06:49:24 +08:00
|
|
|
original_kwargs = locals().copy()
|
2019-06-07 16:45:36 -07:00
|
|
|
base = add_mixins(Trainer, mixins)
|
2019-05-18 00:23:11 -07:00
|
|
|
|
2019-06-07 16:45:36 -07:00
|
|
|
class trainer_cls(base):
|
2019-05-18 00:23:11 -07:00
|
|
|
_name = name
|
2019-05-27 14:17:32 -07:00
|
|
|
_default_config = default_config or COMMON_CONFIG
|
2020-08-20 17:05:57 +02:00
|
|
|
_policy_class = default_policy
|
2019-05-18 00:23:11 -07:00
|
|
|
|
2019-06-07 16:45:36 -07:00
|
|
|
def __init__(self, config=None, env=None, logger_creator=None):
|
|
|
|
Trainer.__init__(self, config, env, logger_creator)
|
|
|
|
|
2020-08-20 17:05:57 +02:00
|
|
|
def _init(self, config: TrainerConfigDict,
|
|
|
|
env_creator: Callable[[EnvConfigDict], EnvType]):
|
|
|
|
# Validate config via custom validation function.
|
2019-05-18 00:23:11 -07:00
|
|
|
if validate_config:
|
|
|
|
validate_config(config)
|
2020-03-04 12:53:04 -08:00
|
|
|
|
2020-10-06 20:28:16 +02:00
|
|
|
# No `get_policy_class` function.
|
2020-03-26 18:03:20 +01:00
|
|
|
if get_policy_class is None:
|
2020-10-06 20:28:16 +02:00
|
|
|
# Default_policy must be provided (unless in multi-agent mode,
|
|
|
|
# where each policy can have its own default policy class.
|
2020-08-20 17:05:57 +02:00
|
|
|
if not config["multiagent"]["policies"]:
|
|
|
|
assert default_policy is not None
|
|
|
|
self._policy_class = default_policy
|
2020-10-06 20:28:16 +02:00
|
|
|
# Query the function for a class to use.
|
2020-03-26 18:03:20 +01:00
|
|
|
else:
|
2020-08-20 17:05:57 +02:00
|
|
|
self._policy_class = get_policy_class(config)
|
2020-10-06 20:28:16 +02:00
|
|
|
# If None returned, use default policy (must be provided).
|
2020-08-20 17:05:57 +02:00
|
|
|
if self._policy_class is None:
|
|
|
|
assert default_policy is not None
|
|
|
|
self._policy_class = default_policy
|
|
|
|
|
2019-06-07 16:45:36 -07:00
|
|
|
if before_init:
|
|
|
|
before_init(self)
|
2020-08-20 17:05:57 +02:00
|
|
|
|
2020-03-04 12:53:04 -08:00
|
|
|
# Creating all workers (excluding evaluation workers).
|
2020-10-06 20:28:16 +02:00
|
|
|
self.workers = self._make_workers(
|
|
|
|
env_creator=env_creator,
|
|
|
|
validate_env=validate_env,
|
|
|
|
policy_class=self._policy_class,
|
|
|
|
config=config,
|
|
|
|
num_workers=self.config["num_workers"])
|
2020-03-13 18:48:41 -07:00
|
|
|
self.execution_plan = execution_plan
|
2020-07-02 14:39:40 -07:00
|
|
|
self.train_exec_impl = execution_plan(self.workers, config)
|
2020-08-20 17:05:57 +02:00
|
|
|
|
2019-06-07 16:45:36 -07:00
|
|
|
if after_init:
|
|
|
|
after_init(self)
|
2019-05-18 00:23:11 -07:00
|
|
|
|
|
|
|
@override(Trainer)
|
2020-07-01 11:00:00 -07:00
|
|
|
def step(self):
|
2021-01-25 12:56:00 +01:00
|
|
|
# self._iteration gets incremented after this function returns,
|
|
|
|
# meaning that e. g. the first time this function is called,
|
2021-04-13 09:53:35 +02:00
|
|
|
# self._iteration will be 0.
|
2021-05-16 17:35:10 +02:00
|
|
|
evaluate_this_iter = \
|
|
|
|
self.config["evaluation_interval"] and \
|
|
|
|
(self._iteration + 1) % self.config["evaluation_interval"] == 0
|
2021-04-13 09:53:35 +02:00
|
|
|
|
|
|
|
# No evaluation necessary.
|
|
|
|
if not evaluate_this_iter:
|
|
|
|
res = next(self.train_exec_impl)
|
|
|
|
# We have to evaluate in this training iteration.
|
|
|
|
else:
|
|
|
|
# No parallelism.
|
|
|
|
if not self.config["evaluation_parallel_to_training"]:
|
|
|
|
res = next(self.train_exec_impl)
|
|
|
|
# Kick off evaluation-loop (and parallel train() call,
|
|
|
|
# if requested).
|
|
|
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
2021-05-12 12:16:00 +02:00
|
|
|
eval_future = executor.submit(self.evaluate)
|
2021-04-13 09:53:35 +02:00
|
|
|
# Parallelism.
|
|
|
|
if self.config["evaluation_parallel_to_training"]:
|
|
|
|
res = next(self.train_exec_impl)
|
|
|
|
evaluation_metrics = eval_future.result()
|
|
|
|
assert isinstance(evaluation_metrics, dict), \
|
|
|
|
"_evaluate() needs to return a dict."
|
|
|
|
res.update(evaluation_metrics)
|
2021-05-16 17:35:10 +02:00
|
|
|
|
|
|
|
# Check `env_task_fn` for possible update of the env's task.
|
|
|
|
if self.config["env_task_fn"] is not None:
|
|
|
|
if not callable(self.config["env_task_fn"]):
|
|
|
|
raise ValueError(
|
|
|
|
"`env_task_fn` must be None or a callable taking "
|
|
|
|
"[train_results, env, env_ctx] as args!")
|
|
|
|
|
|
|
|
def fn(env, env_context, task_fn):
|
|
|
|
new_task = task_fn(res, env, env_context)
|
|
|
|
cur_task = env.get_task()
|
|
|
|
if cur_task != new_task:
|
|
|
|
env.set_task(new_task)
|
|
|
|
|
|
|
|
fn = partial(fn, task_fn=self.config["env_task_fn"])
|
|
|
|
self.workers.foreach_env_with_context(fn)
|
|
|
|
|
2019-05-18 00:23:11 -07:00
|
|
|
return res
|
|
|
|
|
2019-06-07 16:45:36 -07:00
|
|
|
@override(Trainer)
|
|
|
|
def _before_evaluate(self):
|
|
|
|
if before_evaluate_fn:
|
|
|
|
before_evaluate_fn(self)
|
|
|
|
|
2020-09-09 17:33:21 +02:00
|
|
|
@override(Trainer)
|
2019-06-07 16:45:36 -07:00
|
|
|
def __getstate__(self):
|
|
|
|
state = Trainer.__getstate__(self)
|
2020-07-02 14:39:40 -07:00
|
|
|
state["train_exec_impl"] = (
|
|
|
|
self.train_exec_impl.shared_metrics.get().save())
|
2019-06-07 16:45:36 -07:00
|
|
|
return state
|
|
|
|
|
2020-09-09 17:33:21 +02:00
|
|
|
@override(Trainer)
|
2019-06-07 16:45:36 -07:00
|
|
|
def __setstate__(self, state):
|
|
|
|
Trainer.__setstate__(self, state)
|
2020-07-02 14:39:40 -07:00
|
|
|
self.train_exec_impl.shared_metrics.get().restore(
|
|
|
|
state["train_exec_impl"])
|
2019-06-07 16:45:36 -07:00
|
|
|
|
2020-09-09 17:33:21 +02:00
|
|
|
@staticmethod
|
|
|
|
@override(Trainer)
|
|
|
|
def with_updates(**overrides) -> Type[Trainer]:
|
2020-09-20 11:27:02 +02:00
|
|
|
"""Build a copy of this trainer class with the specified overrides.
|
2019-06-07 16:45:36 -07:00
|
|
|
|
2020-09-09 17:33:21 +02:00
|
|
|
Keyword Args:
|
|
|
|
overrides (dict): use this to override any of the arguments
|
|
|
|
originally passed to build_trainer() for this policy.
|
2020-09-20 11:27:02 +02:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
Type[Trainer]: A the Trainer sub-class using `original_kwargs`
|
|
|
|
and `overrides`.
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
>>> MyClass = SomeOtherClass.with_updates({"name": "Mine"})
|
|
|
|
>>> issubclass(MyClass, SomeOtherClass)
|
|
|
|
... False
|
|
|
|
>>> issubclass(MyClass, Trainer)
|
|
|
|
... True
|
2020-09-09 17:33:21 +02:00
|
|
|
"""
|
|
|
|
return build_trainer(**dict(original_kwargs, **overrides))
|
2019-06-03 06:49:24 +08:00
|
|
|
|
2021-05-12 12:16:00 +02:00
|
|
|
def __repr__(self):
|
|
|
|
return self._name
|
|
|
|
|
2019-05-27 14:17:32 -07:00
|
|
|
trainer_cls.__name__ = name
|
|
|
|
trainer_cls.__qualname__ = name
|
2019-05-18 00:23:11 -07:00
|
|
|
return trainer_cls
|