mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Allow for evaluation to run by timesteps
(alternative to episodes
) and add auto-setting to make sure train doesn't ever have to wait for eval (e.g. long episodes) to finish. (#20757)
This commit is contained in:
parent
d4413299c0
commit
60b2219d72
19 changed files with 437 additions and 188 deletions
|
@ -729,19 +729,75 @@ Customized Evaluation During Training
|
|||
|
||||
RLlib will report online training rewards, however in some cases you may want to compute
|
||||
rewards with different settings (e.g., with exploration turned off, or on a specific set
|
||||
of environment configurations). You can evaluate policies during training by setting
|
||||
the ``evaluation_interval`` config, and optionally also ``evaluation_num_episodes``,
|
||||
``evaluation_config``, ``evaluation_num_workers``, and ``custom_eval_function``
|
||||
(see `trainer.py <https://github.com/ray-project/ray/blob/master/rllib/agents/trainer.py>`__ for further documentation).
|
||||
of environment configurations). You can activate evaluating policies during training (``Trainer.train()``) by setting
|
||||
the ``evaluation_interval`` to an int value (> 0) indicating every how many ``Trainer.train()``
|
||||
calls an "evaluation step" is run:
|
||||
|
||||
By default, exploration is left as-is within ``evaluation_config``.
|
||||
However, you can switch off any exploration behavior for the evaluation workers
|
||||
via:
|
||||
.. code-block:: python
|
||||
|
||||
# Run one evaluation step on every 3rd `Trainer.train()` call.
|
||||
{
|
||||
"evaluation_interval": 3,
|
||||
}
|
||||
|
||||
|
||||
One such evaluation step runs over ``evaluation_duration`` episodes or timesteps, depending
|
||||
on the ``evaluation_duration_unit`` setting, which can be either "episodes" (default) or "timesteps".
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Every time we do run an evaluation step, run it for exactly 10 episodes.
|
||||
{
|
||||
"evaluation_duration": 10,
|
||||
"evaluation_duration_unit": "episodes",
|
||||
}
|
||||
# Every time we do run an evaluation step, run it for close to 200 timesteps.
|
||||
{
|
||||
"evaluation_duration": 200,
|
||||
"evaluation_duration_unit": "timesteps",
|
||||
}
|
||||
|
||||
|
||||
Before each evaluation step, weights from the main model are synchronized to all evaluation workers.
|
||||
|
||||
Normally, the evaluation step is run right after the respective train step. For example, for
|
||||
``evaluation_interval=2``, the sequence of steps is: ``train, train, eval, train, train, eval, ...``.
|
||||
For ``evaluation_interval=1``, the sequence is: ``train, eval, train, eval, ...``.
|
||||
|
||||
However, it is possible to run evaluation in parallel to training via the ``evaluation_parallel_to_training=True``
|
||||
config setting. In this case, both steps (train and eval) are run at the same time via threading.
|
||||
This can speed up the evaluation process significantly, but leads to a 1-iteration delay between reported
|
||||
training results and evaluation results (the evaluation results are behind b/c they use slightly outdated
|
||||
model weights).
|
||||
|
||||
When running with the ``evaluation_parallel_to_training=True`` setting, a special "auto" value
|
||||
is supported for ``evaluation_duration``. This can be used to make the evaluation step take
|
||||
roughly as long as the train step:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Run eval and train at the same time via threading and make sure they roughly
|
||||
# take the same time, such that the next `Trainer.train()` call can execute
|
||||
# immediately and not have to wait for a still ongoing (e.g. very long episode)
|
||||
# evaluation step:
|
||||
{
|
||||
"evaluation_interval": 1,
|
||||
"evaluation_parallel_to_training": True,
|
||||
"evaluation_duration": "auto", # automatically end evaluation when train step has finished
|
||||
"evaluation_duration_unit": "timesteps", # <- more fine grained than "episodes"
|
||||
}
|
||||
|
||||
|
||||
The ``evaluation_config`` key allows you to override any config settings for
|
||||
the evaluation workers. For example, to switch off exploration in the evaluation steps,
|
||||
do:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Switching off exploration behavior for evaluation workers
|
||||
# (see rllib/agents/trainer.py)
|
||||
# (see rllib/agents/trainer.py). Use any keys in this sub-dict that are
|
||||
# also supported in the main Trainer config.
|
||||
"evaluation_config": {
|
||||
"explore": False
|
||||
}
|
||||
|
@ -752,6 +808,17 @@ via:
|
|||
policy, even if this is a stochastic one. Setting "explore=False" above
|
||||
will result in the evaluation workers not using this stochastic policy.
|
||||
|
||||
Parallelism for the evaluation step is determined via the ``evaluation_num_workers``
|
||||
setting. Set this to larger values if you want the desired evaluation episodes or timesteps to
|
||||
run as much in parallel as possible. For example, if your ``evaluation_duration=10``,
|
||||
``evaluation_duration_unit=episodes``, and ``evaluation_num_workers=10``, each eval worker
|
||||
only has to run 1 episode in each eval step.
|
||||
|
||||
In case you would like to entirely customize the evaluation step, set ``custom_eval_function`` in your
|
||||
config to a callable taking the Trainer object and a WorkerSet object (the evaluation WorkerSet)
|
||||
and returning a metrics dict. See `trainer.py <https://github.com/ray-project/ray/blob/master/rllib/agents/trainer.py>`__
|
||||
for further documentation.
|
||||
|
||||
There is an end to end example of how to set up custom online evaluation in `custom_eval.py <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_eval.py>`__. Note that if you only want to eval your policy at the end of training, you can set ``evaluation_interval: N``, where ``N`` is the number of training iterations before stopping.
|
||||
|
||||
Below are some examples of how the custom evaluation metrics are reported nested under the ``evaluation`` key of normal training results:
|
||||
|
|
14
rllib/BUILD
14
rllib/BUILD
|
@ -2335,34 +2335,34 @@ py_test(
|
|||
tags = ["team:ml", "examples", "examples_P"],
|
||||
size = "medium",
|
||||
srcs = ["examples/parallel_evaluation_and_training.py"],
|
||||
args = ["--as-test", "--stop-reward=50.0", "--num-cpus=6", "--evaluation-num-episodes=13"]
|
||||
args = ["--as-test", "--stop-reward=50.0", "--num-cpus=6", "--evaluation-duration=13"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "examples/parallel_evaluation_and_training_auto_num_episodes_tf",
|
||||
name = "examples/parallel_evaluation_and_training_auto_episodes_tf",
|
||||
main = "examples/parallel_evaluation_and_training.py",
|
||||
tags = ["team:ml", "examples", "examples_P"],
|
||||
size = "medium",
|
||||
srcs = ["examples/parallel_evaluation_and_training.py"],
|
||||
args = ["--as-test", "--stop-reward=50.0", "--num-cpus=6", "--evaluation-num-episodes=auto"]
|
||||
args = ["--as-test", "--stop-reward=50.0", "--num-cpus=6", "--evaluation-duration=auto"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "examples/parallel_evaluation_and_training_11_episodes_tf2",
|
||||
name = "examples/parallel_evaluation_and_training_211_ts_tf2",
|
||||
main = "examples/parallel_evaluation_and_training.py",
|
||||
tags = ["team:ml", "examples", "examples_P"],
|
||||
size = "medium",
|
||||
srcs = ["examples/parallel_evaluation_and_training.py"],
|
||||
args = ["--as-test", "--framework=tf2", "--stop-reward=30.0", "--num-cpus=6", "--evaluation-num-episodes=11"]
|
||||
args = ["--as-test", "--framework=tf2", "--stop-reward=30.0", "--num-cpus=6", "--evaluation-num-workers=3", "--evaluation-duration=211", "--evaluation-duration-unit=timesteps"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "examples/parallel_evaluation_and_training_14_episodes_torch",
|
||||
name = "examples/parallel_evaluation_and_training_auto_ts_torch",
|
||||
main = "examples/parallel_evaluation_and_training.py",
|
||||
tags = ["team:ml", "examples", "examples_P"],
|
||||
size = "medium",
|
||||
srcs = ["examples/parallel_evaluation_and_training.py"],
|
||||
args = ["--as-test", "--framework=torch", "--stop-reward=30.0", "--num-cpus=6", "--evaluation-num-episodes=14"]
|
||||
args = ["--as-test", "--framework=torch", "--stop-reward=30.0", "--num-cpus=6", "--evaluation-num-workers=3", "--evaluation-duration=auto", "--evaluation-duration-unit=timesteps"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
|
|
|
@ -58,7 +58,7 @@ class TestCQL(unittest.TestCase):
|
|||
config["input_evaluation"] = ["is"]
|
||||
|
||||
config["evaluation_interval"] = 2
|
||||
config["evaluation_num_episodes"] = 10
|
||||
config["evaluation_duration"] = 10
|
||||
config["evaluation_config"]["input"] = "sampler"
|
||||
config["evaluation_parallel_to_training"] = False
|
||||
config["evaluation_num_workers"] = 2
|
||||
|
|
|
@ -39,7 +39,7 @@ DEFAULT_CONFIG = with_common_config({
|
|||
# metrics are already only reported for the lowest epsilon workers.
|
||||
"evaluation_interval": None,
|
||||
# Number of episodes to run per evaluation period.
|
||||
"evaluation_num_episodes": 10,
|
||||
"evaluation_duration": 10,
|
||||
|
||||
# === Model ===
|
||||
# Apply a state preprocessor with spec given by the "model" config option
|
||||
|
|
|
@ -37,7 +37,7 @@ class TestBC(unittest.TestCase):
|
|||
|
||||
config["evaluation_interval"] = 3
|
||||
config["evaluation_num_workers"] = 1
|
||||
config["evaluation_num_episodes"] = 5
|
||||
config["evaluation_duration"] = 5
|
||||
config["evaluation_parallel_to_training"] = True
|
||||
# Evaluate on actual environment.
|
||||
config["evaluation_config"] = {"input": "sampler"}
|
||||
|
|
|
@ -43,7 +43,7 @@ class TestMARWIL(unittest.TestCase):
|
|||
config["num_workers"] = 2
|
||||
config["evaluation_num_workers"] = 1
|
||||
config["evaluation_interval"] = 3
|
||||
config["evaluation_num_episodes"] = 5
|
||||
config["evaluation_duration"] = 5
|
||||
config["evaluation_parallel_to_training"] = True
|
||||
# Evaluate on actual environment.
|
||||
config["evaluation_config"] = {"input": "sampler"}
|
||||
|
@ -100,7 +100,7 @@ class TestMARWIL(unittest.TestCase):
|
|||
config["num_workers"] = 1
|
||||
config["evaluation_num_workers"] = 1
|
||||
config["evaluation_interval"] = 3
|
||||
config["evaluation_num_episodes"] = 5
|
||||
config["evaluation_duration"] = 5
|
||||
config["evaluation_parallel_to_training"] = True
|
||||
# Evaluate on actual environment.
|
||||
config["evaluation_config"] = {"input": "sampler"}
|
||||
|
|
|
@ -49,7 +49,7 @@ DEFAULT_CONFIG = with_common_config({
|
|||
# metrics are already only reported for the lowest epsilon workers.
|
||||
"evaluation_interval": None,
|
||||
# Number of episodes to run per evaluation period.
|
||||
"evaluation_num_episodes": 10,
|
||||
"evaluation_duration": 10,
|
||||
# Switch to greedy actions in evaluation workers.
|
||||
"evaluation_config": {
|
||||
"explore": False,
|
||||
|
|
|
@ -12,7 +12,7 @@ import ray.rllib.agents.pg as pg
|
|||
from ray.rllib.agents.trainer import COMMON_CONFIG
|
||||
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
|
||||
from ray.rllib.examples.parallel_evaluation_and_training import \
|
||||
AssertNumEvalEpisodesCallback
|
||||
AssertEvalCallback
|
||||
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
|
||||
from ray.rllib.utils.test_utils import framework_iterator
|
||||
|
||||
|
@ -131,13 +131,13 @@ class TestTrainer(unittest.TestCase):
|
|||
config.update({
|
||||
"env": "CartPole-v0",
|
||||
"evaluation_interval": 2,
|
||||
"evaluation_num_episodes": 2,
|
||||
"evaluation_duration": 2,
|
||||
"evaluation_config": {
|
||||
"gamma": 0.98,
|
||||
},
|
||||
# Use a custom callback that asserts that we are running the
|
||||
# configured exact number of episodes per evaluation.
|
||||
"callbacks": AssertNumEvalEpisodesCallback,
|
||||
"callbacks": AssertEvalCallback,
|
||||
})
|
||||
|
||||
for _ in framework_iterator(config, frameworks=("tf", "torch")):
|
||||
|
@ -169,7 +169,7 @@ class TestTrainer(unittest.TestCase):
|
|||
"evaluation_interval": None,
|
||||
# Use a custom callback that asserts that we are running the
|
||||
# configured exact number of episodes per evaluation.
|
||||
"callbacks": AssertNumEvalEpisodesCallback,
|
||||
"callbacks": AssertEvalCallback,
|
||||
})
|
||||
for _ in framework_iterator(frameworks=("tf", "torch")):
|
||||
# Setup trainer w/o evaluation worker set and still call
|
||||
|
|
|
@ -4,6 +4,7 @@ from datetime import datetime
|
|||
import functools
|
||||
import gym
|
||||
import logging
|
||||
import math
|
||||
import numpy as np
|
||||
import os
|
||||
import pickle
|
||||
|
@ -271,17 +272,23 @@ COMMON_CONFIG: TrainerConfigDict = {
|
|||
# === Evaluation Settings ===
|
||||
# Evaluate with every `evaluation_interval` training iterations.
|
||||
# The evaluation stats will be reported under the "evaluation" metric key.
|
||||
# Note that evaluation is currently not parallelized, and that for Ape-X
|
||||
# metrics are already only reported for the lowest epsilon workers.
|
||||
# Note that for Ape-X metrics are already only reported for the lowest
|
||||
# epsilon workers (least random workers).
|
||||
# Set to None (or 0) for no evaluation.
|
||||
"evaluation_interval": None,
|
||||
# Number of episodes to run in total per evaluation period.
|
||||
# Duration for which to run evaluation each `evaluation_interval`.
|
||||
# The unit for the duration can be set via `evaluation_duration_unit` to
|
||||
# either "episodes" (default) or "timesteps".
|
||||
# If using multiple evaluation workers (evaluation_num_workers > 1),
|
||||
# episodes will be split amongst these.
|
||||
# If "auto":
|
||||
# - evaluation_parallel_to_training=True: Will run as many episodes as the
|
||||
# training step takes.
|
||||
# - evaluation_parallel_to_training=False: Error.
|
||||
"evaluation_num_episodes": 10,
|
||||
# the load to run will be split amongst these.
|
||||
# If the value is "auto":
|
||||
# - For `evaluation_parallel_to_training=True`: Will run as many
|
||||
# episodes/timesteps that fit into the (parallel) training step.
|
||||
# - For `evaluation_parallel_to_training=False`: Error.
|
||||
"evaluation_duration": 10,
|
||||
# The unit, with which to count the evaluation duration. Either "episodes"
|
||||
# (default) or "timesteps".
|
||||
"evaluation_duration_unit": "episodes",
|
||||
# Whether to run evaluation in parallel to a Trainer.train() call
|
||||
# using threading. Default=False.
|
||||
# E.g. evaluation_interval=2 -> For every other training iteration,
|
||||
|
@ -310,9 +317,9 @@ COMMON_CONFIG: TrainerConfigDict = {
|
|||
"evaluation_num_workers": 0,
|
||||
# Customize the evaluation method. This must be a function of signature
|
||||
# (trainer: Trainer, eval_workers: WorkerSet) -> metrics: dict. See the
|
||||
# Trainer.evaluate() method to see the default implementation. The
|
||||
# trainer guarantees all eval workers have the latest policy state before
|
||||
# this function is called.
|
||||
# Trainer.evaluate() method to see the default implementation.
|
||||
# The Trainer guarantees all eval workers have the latest policy state
|
||||
# before this function is called.
|
||||
"custom_eval_function": None,
|
||||
|
||||
# === Advanced Rollout Settings ===
|
||||
|
@ -520,6 +527,9 @@ COMMON_CONFIG: TrainerConfigDict = {
|
|||
# Whether to write episode stats and videos to the agent log dir. This is
|
||||
# typically located in ~/ray_results.
|
||||
"monitor": DEPRECATED_VALUE,
|
||||
# Replaced by `evaluation_duration=10` and
|
||||
# `evaluation_duration_unit=episodes`.
|
||||
"evaluation_num_episodes": DEPRECATED_VALUE,
|
||||
}
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
@ -661,6 +671,10 @@ class Trainer(Trainable):
|
|||
|
||||
logger_creator = default_logger_creator
|
||||
|
||||
# Evaluation WorkerSet and metrics last returned by `self.evaluate()`.
|
||||
self.evaluation_workers = None
|
||||
self.evaluation_metrics = {}
|
||||
|
||||
super().__init__(config, logger_creator, remote_checkpoint_dir,
|
||||
sync_function_tpl)
|
||||
|
||||
|
@ -769,29 +783,60 @@ class Trainer(Trainable):
|
|||
self.workers, self.config, **self._kwargs_for_execution_plan())
|
||||
|
||||
# Evaluation WorkerSet setup.
|
||||
self.evaluation_workers = None
|
||||
self.evaluation_metrics = {}
|
||||
# User would like to setup a separate evaluation worker set.
|
||||
|
||||
# Update with evaluation settings:
|
||||
user_eval_config = copy.deepcopy(self.config["evaluation_config"])
|
||||
|
||||
# Assert that user has not unset "in_evaluation".
|
||||
assert "in_evaluation" not in user_eval_config or \
|
||||
user_eval_config["in_evaluation"] is True
|
||||
|
||||
# Merge user-provided eval config with the base config. This makes sure
|
||||
# the eval config is always complete, no matter whether we have eval
|
||||
# workers or perform evaluation on the (non-eval) local worker.
|
||||
eval_config = merge_dicts(self.config, user_eval_config)
|
||||
self.config["evaluation_config"] = eval_config
|
||||
|
||||
if self.config.get("evaluation_num_workers", 0) > 0 or \
|
||||
self.config.get("evaluation_interval"):
|
||||
# Update env_config with evaluation settings:
|
||||
extra_config = copy.deepcopy(self.config["evaluation_config"])
|
||||
# Assert that user has not unset "in_evaluation".
|
||||
assert "in_evaluation" not in extra_config or \
|
||||
extra_config["in_evaluation"] is True
|
||||
evaluation_config = merge_dicts(self.config, extra_config)
|
||||
logger.debug(f"Using evaluation_config: {user_eval_config}.")
|
||||
|
||||
# Validate evaluation config.
|
||||
self.validate_config(evaluation_config)
|
||||
# Switch on complete_episode rollouts (evaluations are
|
||||
# always done on n complete episodes) and set the
|
||||
# `in_evaluation` flag. Also, make sure our rollout fragments
|
||||
# are short so we don't have more than one episode in one rollout.
|
||||
evaluation_config.update({
|
||||
"batch_mode": "complete_episodes",
|
||||
"rollout_fragment_length": 1,
|
||||
"in_evaluation": True,
|
||||
})
|
||||
logger.debug("using evaluation_config: {}".format(extra_config))
|
||||
self.validate_config(eval_config)
|
||||
|
||||
# Set the `in_evaluation` flag.
|
||||
eval_config["in_evaluation"] = True
|
||||
|
||||
# Evaluation duration unit: episodes.
|
||||
# Switch on `complete_episode` rollouts. Also, make sure
|
||||
# rollout fragments are short so we never have more than one
|
||||
# episode in one rollout.
|
||||
if eval_config["evaluation_duration_unit"] == "episodes":
|
||||
eval_config.update({
|
||||
"batch_mode": "complete_episodes",
|
||||
"rollout_fragment_length": 1,
|
||||
})
|
||||
# Evaluation duration unit: timesteps.
|
||||
# - Set `batch_mode=truncate_episodes` so we don't perform rollouts
|
||||
# strictly along episode borders.
|
||||
# Set `rollout_fragment_length` such that desired steps are divided
|
||||
# equally amongst workers or - in "auto" duration mode - set it
|
||||
# to a reasonably small number (10), such that a single `sample()`
|
||||
# call doesn't take too much time so we can stop evaluation as soon
|
||||
# as possible after the train step is completed.
|
||||
else:
|
||||
eval_config.update({
|
||||
"batch_mode": "truncate_episodes",
|
||||
"rollout_fragment_length": 10
|
||||
if self.config["evaluation_duration"] == "auto" else int(
|
||||
math.ceil(
|
||||
self.config["evaluation_duration"] /
|
||||
(self.config["evaluation_num_workers"] or 1))),
|
||||
})
|
||||
|
||||
self.config["evaluation_config"] = eval_config
|
||||
|
||||
# Create a separate evaluation worker set for evaluation.
|
||||
# If evaluation_num_workers=0, use the evaluation set's local
|
||||
# worker for evaluation, otherwise, use its remote workers
|
||||
|
@ -800,8 +845,11 @@ class Trainer(Trainable):
|
|||
env_creator=self.env_creator,
|
||||
validate_env=None,
|
||||
policy_class=self.get_default_policy_class(self.config),
|
||||
config=evaluation_config,
|
||||
num_workers=self.config["evaluation_num_workers"])
|
||||
config=eval_config,
|
||||
num_workers=self.config["evaluation_num_workers"],
|
||||
# Don't even create a local worker if num_workers > 0.
|
||||
local_worker=False,
|
||||
)
|
||||
|
||||
# TODO: Deprecated: In your sub-classes of Trainer, override `setup()`
|
||||
# directly and call super().setup() from within it if you would like the
|
||||
|
@ -883,14 +931,33 @@ class Trainer(Trainable):
|
|||
"""Attempts a single training step, including evaluation, if required.
|
||||
|
||||
Override this method in your Trainer sub-classes if you would like to
|
||||
keep the n attempts (catch worker failures) or override `step()`
|
||||
directly if you would like to handle worker failures yourself.
|
||||
keep the n step-attempts logic (catch worker failures) in place or
|
||||
override `step()` directly if you would like to handle worker
|
||||
failures yourself.
|
||||
|
||||
Returns:
|
||||
The results dict with stats/infos on sampling, training,
|
||||
and - if required - evaluation.
|
||||
"""
|
||||
|
||||
def auto_duration_fn(unit, num_eval_workers, eval_cfg, num_units_done):
|
||||
# Training is done and we already ran at least one
|
||||
# evaluation -> Nothing left to run.
|
||||
if num_units_done > 0 and \
|
||||
train_future.done():
|
||||
return 0
|
||||
# Count by episodes. -> Run n more
|
||||
# (n=num eval workers).
|
||||
elif unit == "episodes":
|
||||
return num_eval_workers
|
||||
# Count by timesteps. -> Run n*m*p more
|
||||
# (n=num eval workers; m=rollout fragment length;
|
||||
# p=num-envs-per-worker).
|
||||
else:
|
||||
return num_eval_workers * \
|
||||
eval_cfg["rollout_fragment_length"] * \
|
||||
eval_cfg["num_envs_per_worker"]
|
||||
|
||||
# self._iteration gets incremented after this function returns,
|
||||
# meaning that e. g. the first time this function is called,
|
||||
# self._iteration will be 0.
|
||||
|
@ -914,19 +981,15 @@ class Trainer(Trainable):
|
|||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
train_future = executor.submit(
|
||||
lambda: next(self.train_exec_impl))
|
||||
if self.config["evaluation_num_episodes"] == "auto":
|
||||
|
||||
# Run at least one `evaluate()` (num_episodes_done
|
||||
# must be > 0), even if the training is very fast.
|
||||
def episodes_left_fn(num_episodes_done):
|
||||
if num_episodes_done > 0 and \
|
||||
train_future.done():
|
||||
return 0
|
||||
else:
|
||||
return self.config["evaluation_num_workers"]
|
||||
# Automatically determine duration of the evaluation.
|
||||
if self.config["evaluation_duration"] == "auto":
|
||||
unit = self.config["evaluation_duration_unit"]
|
||||
|
||||
evaluation_metrics = self.evaluate(
|
||||
episodes_left_fn=episodes_left_fn)
|
||||
duration_fn=functools.partial(
|
||||
auto_duration_fn, unit, self.config[
|
||||
"evaluation_num_workers"], self.config[
|
||||
"evaluation_config"]))
|
||||
else:
|
||||
evaluation_metrics = self.evaluate()
|
||||
# Collect the training results from the future.
|
||||
|
@ -959,19 +1022,29 @@ class Trainer(Trainable):
|
|||
return step_results
|
||||
|
||||
@PublicAPI
|
||||
def evaluate(self, episodes_left_fn: Optional[Callable[[int], int]] = None
|
||||
) -> dict:
|
||||
def evaluate(
|
||||
self,
|
||||
episodes_left_fn=None, # deprecated
|
||||
duration_fn: Optional[Callable[[int], int]] = None,
|
||||
) -> dict:
|
||||
"""Evaluates current policy under `evaluation_config` settings.
|
||||
|
||||
Note that this default implementation does not do anything beyond
|
||||
merging evaluation_config with the normal trainer config.
|
||||
|
||||
Args:
|
||||
episodes_left_fn: An optional callable taking the already run
|
||||
duration_fn: An optional callable taking the already run
|
||||
num episodes as only arg and returning the number of
|
||||
episodes left to run. It's used to find out whether
|
||||
evaluation should continue.
|
||||
"""
|
||||
if episodes_left_fn is not None:
|
||||
deprecation_warning(
|
||||
old="Trainer.evaluate(episodes_left_fn)",
|
||||
new="Trainer.evaluate(duration_fn)",
|
||||
error=False)
|
||||
duration_fn = episodes_left_fn
|
||||
|
||||
# In case we are evaluating (in a thread) parallel to training,
|
||||
# we may have to re-enable eager mode here (gets disabled in the
|
||||
# thread).
|
||||
|
@ -996,80 +1069,102 @@ class Trainer(Trainable):
|
|||
raise ValueError("Custom eval function must return "
|
||||
"dict of metrics, got {}.".format(metrics))
|
||||
else:
|
||||
# How many episodes do we need to run?
|
||||
# In "auto" mode (only for parallel eval + training): Run one
|
||||
# episode per eval worker.
|
||||
num_episodes = self.config["evaluation_num_episodes"] if \
|
||||
self.config["evaluation_num_episodes"] != "auto" else \
|
||||
(self.config["evaluation_num_workers"] or 1)
|
||||
if self.evaluation_workers is None and \
|
||||
self.workers.local_worker().input_reader is None:
|
||||
raise ValueError(
|
||||
"Cannot evaluate w/o an evaluation worker set in "
|
||||
"the Trainer or w/o an env on the local worker!\n"
|
||||
"Try one of the following:\n1) Set "
|
||||
"`evaluation_interval` >= 0 to force creating a "
|
||||
"separate evaluation worker set.\n2) Set "
|
||||
"`create_env_on_driver=True` to force the local "
|
||||
"(non-eval) worker to have an environment to "
|
||||
"evaluate on.")
|
||||
|
||||
# How many episodes/timesteps do we need to run?
|
||||
# In "auto" mode (only for parallel eval + training): Run as long
|
||||
# as training lasts.
|
||||
unit = self.config["evaluation_duration_unit"]
|
||||
eval_cfg = self.config["evaluation_config"]
|
||||
rollout = eval_cfg["rollout_fragment_length"]
|
||||
num_envs = eval_cfg["num_envs_per_worker"]
|
||||
duration = self.config["evaluation_duration"] if \
|
||||
self.config["evaluation_duration"] != "auto" else \
|
||||
(self.config["evaluation_num_workers"] or 1) * \
|
||||
(1 if unit == "episodes" else rollout)
|
||||
num_ts_run = 0
|
||||
|
||||
# Default done-function returns True, whenever num episodes
|
||||
# have been completed.
|
||||
if episodes_left_fn is None:
|
||||
if duration_fn is None:
|
||||
|
||||
def episodes_left_fn(num_episodes_done):
|
||||
return num_episodes - num_episodes_done
|
||||
def duration_fn(num_units_done):
|
||||
return duration - num_units_done
|
||||
|
||||
logger.info(
|
||||
f"Evaluating current policy for {num_episodes} episodes.")
|
||||
logger.info(f"Evaluating current policy for {duration} {unit}.")
|
||||
|
||||
metrics = None
|
||||
# No evaluation worker set ->
|
||||
# Do evaluation using the local worker. Expect error due to the
|
||||
# local worker not having an env.
|
||||
if self.evaluation_workers is None:
|
||||
try:
|
||||
for _ in range(num_episodes):
|
||||
self.workers.local_worker().sample()
|
||||
metrics = collect_metrics(self.workers.local_worker())
|
||||
except ValueError as e:
|
||||
if "RolloutWorker has no `input_reader` object" in \
|
||||
e.args[0]:
|
||||
raise ValueError(
|
||||
"Cannot evaluate w/o an evaluation worker set in "
|
||||
"the Trainer or w/o an env on the local worker!\n"
|
||||
"Try one of the following:\n1) Set "
|
||||
"`evaluation_interval` >= 0 to force creating a "
|
||||
"separate evaluation worker set.\n2) Set "
|
||||
"`create_env_on_driver=True` to force the local "
|
||||
"(non-eval) worker to have an environment to "
|
||||
"evaluate on.")
|
||||
else:
|
||||
raise e
|
||||
# If unit=episodes -> Run n times `sample()` (each sample
|
||||
# produces exactly 1 episode).
|
||||
# If unit=ts -> Run 1 `sample()` b/c the
|
||||
# `rollout_fragment_length` is exactly the desired ts.
|
||||
iters = duration if unit == "episodes" else 1
|
||||
for _ in range(iters):
|
||||
num_ts_run += len(self.workers.local_worker().sample())
|
||||
metrics = collect_metrics(self.workers.local_worker())
|
||||
|
||||
# Evaluation worker set only has local worker.
|
||||
elif self.config["evaluation_num_workers"] == 0:
|
||||
for _ in range(num_episodes):
|
||||
self.evaluation_workers.local_worker().sample()
|
||||
# If unit=episodes -> Run n times `sample()` (each sample
|
||||
# produces exactly 1 episode).
|
||||
# If unit=ts -> Run 1 `sample()` b/c the
|
||||
# `rollout_fragment_length` is exactly the desired ts.
|
||||
iters = duration if unit == "episodes" else 1
|
||||
for _ in range(iters):
|
||||
num_ts_run += len(
|
||||
self.evaluation_workers.local_worker().sample())
|
||||
|
||||
# Evaluation worker set has n remote workers.
|
||||
else:
|
||||
# How many episodes have we run (across all eval workers)?
|
||||
num_episodes_done = 0
|
||||
num_units_done = 0
|
||||
round_ = 0
|
||||
while True:
|
||||
episodes_left_to_do = episodes_left_fn(num_episodes_done)
|
||||
if episodes_left_to_do <= 0:
|
||||
units_left_to_do = duration_fn(num_units_done)
|
||||
if units_left_to_do <= 0:
|
||||
break
|
||||
|
||||
round_ += 1
|
||||
batches = ray.get([
|
||||
w.sample.remote() for i, w in enumerate(
|
||||
self.evaluation_workers.remote_workers())
|
||||
if i < episodes_left_to_do
|
||||
if i * (1 if unit == "episodes" else rollout *
|
||||
num_envs) < units_left_to_do
|
||||
])
|
||||
# Per our config for the evaluation workers
|
||||
# (`rollout_fragment_length=1` and
|
||||
# `batch_mode=complete_episode`), we know that we'll have
|
||||
# exactly one episode per returned batch.
|
||||
num_episodes_done += len(batches)
|
||||
logger.info(
|
||||
f"Ran round {round_} of parallel evaluation "
|
||||
f"({num_episodes_done}/{num_episodes} episodes done)")
|
||||
# 1 episode per returned batch.
|
||||
if unit == "episodes":
|
||||
num_units_done += len(batches)
|
||||
# n timesteps per returned batch.
|
||||
else:
|
||||
ts = sum(len(b) for b in batches)
|
||||
num_ts_run += ts
|
||||
num_units_done += ts
|
||||
|
||||
logger.info(f"Ran round {round_} of parallel evaluation "
|
||||
f"({num_units_done}/{duration} {unit} done)")
|
||||
|
||||
if metrics is None:
|
||||
metrics = collect_metrics(
|
||||
self.evaluation_workers.local_worker(),
|
||||
self.evaluation_workers.remote_workers())
|
||||
metrics["timesteps_this_iter"] = num_ts_run
|
||||
|
||||
self.evaluation_metrics = metrics
|
||||
|
||||
return {"evaluation": metrics}
|
||||
|
||||
@DeveloperAPI
|
||||
|
@ -1684,6 +1779,7 @@ class Trainer(Trainable):
|
|||
policy_class: Type[Policy],
|
||||
config: TrainerConfigDict,
|
||||
num_workers: int,
|
||||
local_worker: bool = True,
|
||||
) -> WorkerSet:
|
||||
"""Default factory method for a WorkerSet running under this Trainer.
|
||||
|
||||
|
@ -1703,6 +1799,9 @@ class Trainer(Trainable):
|
|||
config: The Trainer's config.
|
||||
num_workers: Number of remote rollout workers to create.
|
||||
0 for local only.
|
||||
local_worker: Whether to create a local (non @ray.remote) worker
|
||||
in the returned set as well (default: True). If `num_workers`
|
||||
is 0, always create a local worker.
|
||||
|
||||
Returns:
|
||||
The created WorkerSet.
|
||||
|
@ -1713,7 +1812,9 @@ class Trainer(Trainable):
|
|||
policy_class=policy_class,
|
||||
trainer_config=config,
|
||||
num_workers=num_workers,
|
||||
logdir=self.logdir)
|
||||
local_worker=local_worker,
|
||||
logdir=self.logdir,
|
||||
)
|
||||
|
||||
def _sync_filters_if_needed(self, workers: WorkerSet):
|
||||
if self.config.get("observation_filter", "NoFilter") != "NoFilter":
|
||||
|
@ -1985,6 +2086,18 @@ class Trainer(Trainable):
|
|||
"Got {}".format(config["multiagent"]["count_steps_by"]))
|
||||
|
||||
# Evaluation settings.
|
||||
|
||||
# Deprecated setting: `evaluation_num_episodes`.
|
||||
if config["evaluation_num_episodes"] != DEPRECATED_VALUE:
|
||||
deprecation_warning(
|
||||
old="evaluation_num_episodes",
|
||||
new="`evaluation_duration` and `evaluation_duration_unit="
|
||||
"episodes`",
|
||||
error=False)
|
||||
config["evaluation_duration"] = config["evaluation_num_episodes"]
|
||||
config["evaluation_duration_unit"] = "episodes"
|
||||
config["evaluation_num_episodes"] = DEPRECATED_VALUE
|
||||
|
||||
# If `evaluation_num_workers` > 0, warn if `evaluation_interval` is
|
||||
# None (also set `evaluation_interval` to 1).
|
||||
if config["evaluation_num_workers"] > 0 and \
|
||||
|
@ -2008,18 +2121,18 @@ class Trainer(Trainable):
|
|||
"`evaluation_parallel_to_training` to False.")
|
||||
config["evaluation_parallel_to_training"] = False
|
||||
|
||||
# If `evaluation_num_episodes=auto`, error if
|
||||
# If `evaluation_duration=auto`, error if
|
||||
# `evaluation_parallel_to_training=False`.
|
||||
if config["evaluation_num_episodes"] == "auto":
|
||||
if config["evaluation_duration"] == "auto":
|
||||
if not config["evaluation_parallel_to_training"]:
|
||||
raise ValueError(
|
||||
"`evaluation_num_episodes=auto` not supported for "
|
||||
"`evaluation_duration=auto` not supported for "
|
||||
"`evaluation_parallel_to_training=False`!")
|
||||
# Make sure, it's an int otherwise.
|
||||
elif not isinstance(config["evaluation_num_episodes"], int):
|
||||
raise ValueError(
|
||||
"`evaluation_num_episodes` ({}) must be an int and "
|
||||
">0!".format(config["evaluation_num_episodes"]))
|
||||
elif not isinstance(config["evaluation_duration"], int) or \
|
||||
config["evaluation_duration"] <= 0:
|
||||
raise ValueError("`evaluation_duration` ({}) must be an int and "
|
||||
">0!".format(config["evaluation_duration"]))
|
||||
|
||||
@ExperimentalAPI
|
||||
@staticmethod
|
||||
|
@ -2284,7 +2397,7 @@ class Trainer(Trainable):
|
|||
def __repr__(self):
|
||||
return type(self).__name__
|
||||
|
||||
@Deprecated(new="Trainer.evaluate()", error=False)
|
||||
@Deprecated(new="Trainer.evaluate()", error=True)
|
||||
def _evaluate(self) -> dict:
|
||||
return self.evaluate()
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ DEFAULT_CONFIG = with_common_config({
|
|||
# Evaluation interval
|
||||
"evaluation_interval": None,
|
||||
# Number of episodes to run per evaluation period.
|
||||
"evaluation_num_episodes": 10,
|
||||
"evaluation_duration": 10,
|
||||
|
||||
# === Model ===
|
||||
# Apply a state preprocessor with spec given by the "model" config option
|
||||
|
|
|
@ -309,8 +309,8 @@ def run(args, parser):
|
|||
# Make sure we have evaluation workers.
|
||||
if not config.get("evaluation_num_workers"):
|
||||
config["evaluation_num_workers"] = config.get("num_workers", 0)
|
||||
if not config.get("evaluation_num_episodes"):
|
||||
config["evaluation_num_episodes"] = 1
|
||||
if not config.get("evaluation_duration"):
|
||||
config["evaluation_duration"] = 1
|
||||
# Hard-override this as it raises a warning by Trainer otherwise.
|
||||
# Makes no sense anyways, to have it set to None as we don't call
|
||||
# `Trainer.train()` here.
|
||||
|
@ -401,7 +401,7 @@ def rollout(agent,
|
|||
saver.begin_rollout()
|
||||
eval_result = agent.evaluate()["evaluation"]
|
||||
# Increase timestep and episode counters.
|
||||
eps = agent.config["evaluation_num_episodes"]
|
||||
eps = agent.config["evaluation_duration"]
|
||||
episodes += eps
|
||||
steps += eps * eval_result["episode_len_mean"]
|
||||
# Print out results and continue.
|
||||
|
|
|
@ -34,15 +34,18 @@ class WorkerSet:
|
|||
Where n may be 0.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
env_creator: Optional[Callable[[EnvContext], EnvType]] = None,
|
||||
validate_env: Optional[Callable[[EnvType], None]] = None,
|
||||
policy_class: Optional[Type[Policy]] = None,
|
||||
trainer_config: Optional[TrainerConfigDict] = None,
|
||||
num_workers: int = 0,
|
||||
logdir: Optional[str] = None,
|
||||
_setup: bool = True):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
env_creator: Optional[Callable[[EnvContext], EnvType]] = None,
|
||||
validate_env: Optional[Callable[[EnvType], None]] = None,
|
||||
policy_class: Optional[Type[Policy]] = None,
|
||||
trainer_config: Optional[TrainerConfigDict] = None,
|
||||
num_workers: int = 0,
|
||||
local_worker: bool = True,
|
||||
logdir: Optional[str] = None,
|
||||
_setup: bool = True,
|
||||
):
|
||||
"""Initializes a WorkerSet instance.
|
||||
|
||||
Args:
|
||||
|
@ -55,6 +58,9 @@ class WorkerSet:
|
|||
trainer_config: Optional dict that extends the common config of
|
||||
the Trainer class.
|
||||
num_workers: Number of remote rollout workers to create.
|
||||
local_worker: Whether to create a local (non @ray.remote) worker
|
||||
in the returned set as well (default: True). If `num_workers`
|
||||
is 0, always create a local worker.
|
||||
logdir: Optional logging directory for workers.
|
||||
_setup: Whether to setup workers. This is only for testing.
|
||||
"""
|
||||
|
@ -69,18 +75,25 @@ class WorkerSet:
|
|||
self._logdir = logdir
|
||||
|
||||
if _setup:
|
||||
# Force a local worker if num_workers == 0 (no remote workers).
|
||||
# Otherwise, this WorkerSet would be empty.
|
||||
self._local_worker = None
|
||||
if num_workers == 0:
|
||||
local_worker = True
|
||||
|
||||
self._local_config = merge_dicts(
|
||||
trainer_config,
|
||||
{"tf_session_args": trainer_config["local_tf_session_args"]})
|
||||
|
||||
# Create a number of remote workers.
|
||||
# Create a number of @ray.remote workers.
|
||||
self._remote_workers = []
|
||||
self.add_workers(num_workers)
|
||||
|
||||
# Create a local worker, if needed.
|
||||
# If num_workers > 0 and we don't have an env on the local worker,
|
||||
# get the observation- and action spaces for each policy from
|
||||
# the first remote worker (which does have an env).
|
||||
if self._remote_workers and \
|
||||
if local_worker and self._remote_workers and \
|
||||
not trainer_config.get("create_env_on_driver") and \
|
||||
(not trainer_config.get("observation_space") or
|
||||
not trainer_config.get("action_space")):
|
||||
|
@ -106,17 +119,17 @@ class WorkerSet:
|
|||
else:
|
||||
spaces = None
|
||||
|
||||
# Always create a local worker.
|
||||
self._local_worker = self._make_worker(
|
||||
cls=RolloutWorker,
|
||||
env_creator=env_creator,
|
||||
validate_env=validate_env,
|
||||
policy_cls=self._policy_class,
|
||||
worker_index=0,
|
||||
num_workers=num_workers,
|
||||
config=self._local_config,
|
||||
spaces=spaces,
|
||||
)
|
||||
if local_worker:
|
||||
self._local_worker = self._make_worker(
|
||||
cls=RolloutWorker,
|
||||
env_creator=env_creator,
|
||||
validate_env=validate_env,
|
||||
policy_cls=self._policy_class,
|
||||
worker_index=0,
|
||||
num_workers=num_workers,
|
||||
config=self._local_config,
|
||||
spaces=spaces,
|
||||
)
|
||||
|
||||
def local_worker(self) -> RolloutWorker:
|
||||
"""Returns the local rollout worker."""
|
||||
|
@ -197,7 +210,9 @@ class WorkerSet:
|
|||
Returns:
|
||||
The list of return values of all calls to `func([worker])`.
|
||||
"""
|
||||
local_result = [func(self.local_worker())]
|
||||
local_result = []
|
||||
if self._local_worker:
|
||||
local_result = [func(self.local_worker())]
|
||||
remote_results = ray.get(
|
||||
[w.apply.remote(func) for w in self.remote_workers()])
|
||||
return local_result + remote_results
|
||||
|
@ -219,8 +234,10 @@ class WorkerSet:
|
|||
The first entry in this list are the results of the local
|
||||
worker, followed by all remote workers' results.
|
||||
"""
|
||||
local_result = []
|
||||
# Local worker: Index=0.
|
||||
local_result = [func(self.local_worker(), 0)]
|
||||
if self._local_worker:
|
||||
local_result = [func(self.local_worker(), 0)]
|
||||
# Remote workers: Index > 0.
|
||||
remote_results = ray.get([
|
||||
w.apply.remote(func, i + 1)
|
||||
|
@ -247,7 +264,9 @@ class WorkerSet:
|
|||
The local workers' results are first, followed by all remote
|
||||
workers' results
|
||||
"""
|
||||
results = self.local_worker().foreach_policy(func)
|
||||
results = []
|
||||
if self._local_worker:
|
||||
results = self.local_worker().foreach_policy(func)
|
||||
ray_gets = []
|
||||
for worker in self.remote_workers():
|
||||
ray_gets.append(
|
||||
|
@ -260,7 +279,10 @@ class WorkerSet:
|
|||
@DeveloperAPI
|
||||
def trainable_policies(self) -> List[PolicyID]:
|
||||
"""Returns the list of trainable policy ids."""
|
||||
return self.local_worker().policies_to_train
|
||||
if self._local_worker:
|
||||
return self._local_worker.policies_to_train
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def foreach_trainable_policy(
|
||||
|
@ -275,7 +297,9 @@ class WorkerSet:
|
|||
List[any]: The list of n return values of all
|
||||
`func([trainable policy], [ID])`-calls.
|
||||
"""
|
||||
results = self.local_worker().foreach_trainable_policy(func)
|
||||
results = []
|
||||
if self._local_worker:
|
||||
results = self.local_worker().foreach_trainable_policy(func)
|
||||
ray_gets = []
|
||||
for worker in self.remote_workers():
|
||||
ray_gets.append(
|
||||
|
@ -303,7 +327,9 @@ class WorkerSet:
|
|||
Returns:
|
||||
The list (workers) of lists (sub environments) of results.
|
||||
"""
|
||||
local_results = [self.local_worker().foreach_env(func)]
|
||||
local_results = []
|
||||
if self._local_worker:
|
||||
local_results = [self.local_worker().foreach_env(func)]
|
||||
ray_gets = []
|
||||
for worker in self.remote_workers():
|
||||
ray_gets.append(worker.foreach_env.remote(func))
|
||||
|
@ -329,7 +355,11 @@ class WorkerSet:
|
|||
The list (1 item per workers) of lists (1 item per sub-environment)
|
||||
of results.
|
||||
"""
|
||||
local_results = [self.local_worker().foreach_env_with_context(func)]
|
||||
local_results = []
|
||||
if self._local_worker:
|
||||
local_results = [
|
||||
self.local_worker().foreach_env_with_context(func)
|
||||
]
|
||||
ray_gets = []
|
||||
for worker in self.remote_workers():
|
||||
ray_gets.append(worker.foreach_env_with_context.remote(func))
|
||||
|
|
|
@ -181,7 +181,7 @@ if __name__ == "__main__":
|
|||
"evaluation_interval": 1,
|
||||
|
||||
# Run 10 episodes each time evaluation runs.
|
||||
"evaluation_num_episodes": 10,
|
||||
"evaluation_duration": 10,
|
||||
|
||||
# Override the env config for evaluation.
|
||||
"evaluation_config": {
|
||||
|
|
|
@ -101,7 +101,7 @@ if __name__ == "__main__":
|
|||
"metrics_smoothing_episodes": 5,
|
||||
"evaluation_interval": 1,
|
||||
"evaluation_num_workers": 2,
|
||||
"evaluation_num_episodes": 10,
|
||||
"evaluation_duration": 10,
|
||||
"evaluation_parallel_to_training": True,
|
||||
"evaluation_config": {
|
||||
"input": "sampler",
|
||||
|
|
|
@ -110,7 +110,7 @@ if __name__ == "__main__":
|
|||
# Evaluate once per training iteration.
|
||||
"evaluation_interval": 1,
|
||||
# Run evaluation on (at least) two episodes
|
||||
"evaluation_num_episodes": 2,
|
||||
"evaluation_duration": 2,
|
||||
# ... using one evaluation worker (setting this to 0 will cause
|
||||
# evaluation to run on the local evaluation worker, blocking
|
||||
# training until evaluation is done).
|
||||
|
|
|
@ -69,7 +69,7 @@ if __name__ == "__main__":
|
|||
# Set up evaluation.
|
||||
config["evaluation_num_workers"] = 1
|
||||
config["evaluation_interval"] = 1
|
||||
config["evaluation_num_episodes"] = 10
|
||||
config["evaluation_duration"] = 10
|
||||
# This should be False b/c iterations are very long and this would
|
||||
# cause evaluation to lag one iter behind training.
|
||||
config["evaluation_parallel_to_training"] = False
|
||||
|
|
|
@ -7,11 +7,30 @@ from ray.rllib.utils.test_utils import check_learning_achieved
|
|||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--evaluation-num-episodes",
|
||||
"--evaluation-duration",
|
||||
type=lambda v: v if v == "auto" else int(v),
|
||||
default=13,
|
||||
help="Number of evaluation episodes to run each iteration. "
|
||||
help="Number of evaluation episodes/timesteps to run each iteration. "
|
||||
"If 'auto', will run as many as possible during train pass.")
|
||||
parser.add_argument(
|
||||
"--evaluation-duration-unit",
|
||||
type=str,
|
||||
default="episodes",
|
||||
choices=["episodes", "timesteps"],
|
||||
help="The unit in which to measure the duration (`episodes` or"
|
||||
"`timesteps`).")
|
||||
parser.add_argument(
|
||||
"--evaluation-num-workers",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The number of evaluation workers to setup. "
|
||||
"0 for a single local evaluation worker. Note that for values >0, no"
|
||||
"local evaluation worker will be created (b/c not needed).")
|
||||
parser.add_argument(
|
||||
"--evaluation-interval",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Every how many train iterations should we run an evaluation loop?")
|
||||
|
||||
parser.add_argument(
|
||||
"--run",
|
||||
|
@ -50,26 +69,44 @@ parser.add_argument(
|
|||
help="Init Ray in local mode for easier debugging.")
|
||||
|
||||
|
||||
class AssertNumEvalEpisodesCallback(DefaultCallbacks):
|
||||
class AssertEvalCallback(DefaultCallbacks):
|
||||
def on_train_result(self, *, trainer, result, **kwargs):
|
||||
# Make sure we always run exactly n evaluation episodes,
|
||||
# Make sure we always run exactly the given evaluation duration,
|
||||
# no matter what the other settings are (such as
|
||||
# `evaluation_num_workers` or `evaluation_parallel_to_training`).
|
||||
if "evaluation" in result:
|
||||
hist_stats = result["evaluation"]["hist_stats"]
|
||||
num_episodes_done = len(hist_stats["episode_lengths"])
|
||||
# Compare number of entries in episode_lengths (this is the
|
||||
# number of episodes actually run) with desired number of
|
||||
# episodes from the config.
|
||||
if isinstance(trainer.config["evaluation_num_episodes"], int):
|
||||
assert num_episodes_done == \
|
||||
trainer.config["evaluation_num_episodes"]
|
||||
# We count in episodes.
|
||||
if trainer.config["evaluation_duration_unit"] == "episodes":
|
||||
num_episodes_done = len(hist_stats["episode_lengths"])
|
||||
# Compare number of entries in episode_lengths (this is the
|
||||
# number of episodes actually run) with desired number of
|
||||
# episodes from the config.
|
||||
if isinstance(trainer.config["evaluation_duration"], int):
|
||||
assert num_episodes_done == \
|
||||
trainer.config["evaluation_duration"]
|
||||
# If auto-episodes: Expect at least as many episode as workers
|
||||
# (each worker's `sample()` is at least called once).
|
||||
else:
|
||||
assert trainer.config["evaluation_duration"] == "auto"
|
||||
assert num_episodes_done >= \
|
||||
trainer.config["evaluation_num_workers"]
|
||||
print("Number of run evaluation episodes: "
|
||||
f"{num_episodes_done} (ok)!")
|
||||
# We count in timesteps.
|
||||
else:
|
||||
assert trainer.config["evaluation_num_episodes"] == "auto"
|
||||
assert num_episodes_done >= \
|
||||
trainer.config["evaluation_num_workers"]
|
||||
print("Number of run evaluation episodes: "
|
||||
f"{num_episodes_done} (ok)!")
|
||||
num_timesteps_reported = result["evaluation"][
|
||||
"timesteps_this_iter"]
|
||||
num_timesteps_wanted = trainer.config["evaluation_duration"]
|
||||
if num_timesteps_wanted != "auto":
|
||||
delta = num_timesteps_wanted - num_timesteps_reported
|
||||
# Expect roughly the same (desired // num-eval-workers).
|
||||
assert abs(delta) < 20, \
|
||||
(delta, num_timesteps_wanted, num_timesteps_reported)
|
||||
print("Number of run evaluation timesteps: "
|
||||
f"{num_timesteps_reported} (ok)!")
|
||||
|
||||
print(f"R={result['evaluation']['episode_reward_mean']}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -93,21 +130,23 @@ if __name__ == "__main__":
|
|||
"evaluation_parallel_to_training": True,
|
||||
# Use two evaluation workers. Must be >0, otherwise,
|
||||
# evaluation will run on a local worker and block (no parallelism).
|
||||
"evaluation_num_workers": 2,
|
||||
"evaluation_num_workers": args.evaluation_num_workers,
|
||||
# Evaluate every other training iteration (together
|
||||
# with every other call to Trainer.train()).
|
||||
"evaluation_interval": 2,
|
||||
# Run for n episodes (properly distribute load amongst all eval
|
||||
# workers). The longer it takes to evaluate, the more
|
||||
# sense it makes to use `evaluation_parallel_to_training=True`.
|
||||
"evaluation_interval": args.evaluation_interval,
|
||||
# Run for n episodes/timesteps (properly distribute load amongst
|
||||
# all eval workers). The longer it takes to evaluate, the more sense
|
||||
# it makes to use `evaluation_parallel_to_training=True`.
|
||||
# Use "auto" to run evaluation for roughly as long as the training
|
||||
# step takes.
|
||||
"evaluation_num_episodes": args.evaluation_num_episodes,
|
||||
"evaluation_duration": args.evaluation_duration,
|
||||
# "episodes" or "timesteps".
|
||||
"evaluation_duration_unit": args.evaluation_duration_unit,
|
||||
|
||||
# Use a custom callback that asserts that we are running the
|
||||
# configured exact number of episodes per evaluation OR - in auto
|
||||
# mode - run at least as many episodes as we have eval workers.
|
||||
"callbacks": AssertNumEvalEpisodesCallback,
|
||||
"callbacks": AssertEvalCallback,
|
||||
}
|
||||
|
||||
stop = {
|
||||
|
|
|
@ -194,7 +194,7 @@ def learn_test_multi_agent_plus_evaluate(algo):
|
|||
|
||||
# Test rolling out n steps.
|
||||
result = os.popen(
|
||||
"python {}/rollout.py --run={} "
|
||||
"python {}/evaluate.py --run={} "
|
||||
"--steps=400 "
|
||||
"--out=\"{}/rollouts_n_steps.pkl\" --no-render \"{}\"".format(
|
||||
rllib_dir, algo, tmp_dir, last_checkpoint)).read()[:-1]
|
||||
|
|
|
@ -14,4 +14,4 @@ pendulum-apex-ddpg:
|
|||
target_network_update_freq: 50000
|
||||
tau: 1.0
|
||||
evaluation_interval: 5
|
||||
evaluation_num_episodes: 10
|
||||
evaluation_duration: 10
|
||||
|
|
Loading…
Add table
Reference in a new issue