From 60b2219d72817959f741ec1d7124ea17bee374cc Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Sat, 4 Dec 2021 13:26:33 +0100 Subject: [PATCH] [RLlib] Allow for evaluation to run by `timesteps` (alternative to `episodes`) and add auto-setting to make sure train doesn't ever have to wait for eval (e.g. long episodes) to finish. (#20757) --- doc/source/rllib-training.rst | 83 ++++- rllib/BUILD | 14 +- rllib/agents/cql/tests/test_cql.py | 2 +- rllib/agents/ddpg/ddpg.py | 2 +- rllib/agents/marwil/tests/test_bc.py | 2 +- rllib/agents/marwil/tests/test_marwil.py | 4 +- rllib/agents/qmix/qmix.py | 2 +- rllib/agents/tests/test_trainer.py | 8 +- rllib/agents/trainer.py | 315 ++++++++++++------ rllib/contrib/maddpg/maddpg.py | 2 +- rllib/evaluate.py | 6 +- rllib/evaluation/worker_set.py | 88 +++-- rllib/examples/custom_eval.py | 2 +- rllib/examples/custom_input_api.py | 2 +- rllib/examples/env_rendering_and_recording.py | 2 +- rllib/examples/offline_rl.py | 2 +- .../parallel_evaluation_and_training.py | 85 +++-- rllib/tests/test_rllib_train_and_evaluate.py | 2 +- .../ddpg/pendulum-apex-ddpg.yaml | 2 +- 19 files changed, 437 insertions(+), 188 deletions(-) diff --git a/doc/source/rllib-training.rst b/doc/source/rllib-training.rst index 789902565..fa6a6c2b0 100644 --- a/doc/source/rllib-training.rst +++ b/doc/source/rllib-training.rst @@ -729,19 +729,75 @@ Customized Evaluation During Training RLlib will report online training rewards, however in some cases you may want to compute rewards with different settings (e.g., with exploration turned off, or on a specific set -of environment configurations). You can evaluate policies during training by setting -the ``evaluation_interval`` config, and optionally also ``evaluation_num_episodes``, -``evaluation_config``, ``evaluation_num_workers``, and ``custom_eval_function`` -(see `trainer.py `__ for further documentation). +of environment configurations). You can activate evaluating policies during training (``Trainer.train()``) by setting +the ``evaluation_interval`` to an int value (> 0) indicating every how many ``Trainer.train()`` +calls an "evaluation step" is run: -By default, exploration is left as-is within ``evaluation_config``. -However, you can switch off any exploration behavior for the evaluation workers -via: +.. code-block:: python + + # Run one evaluation step on every 3rd `Trainer.train()` call. + { + "evaluation_interval": 3, + } + + +One such evaluation step runs over ``evaluation_duration`` episodes or timesteps, depending +on the ``evaluation_duration_unit`` setting, which can be either "episodes" (default) or "timesteps". + + +.. code-block:: python + + # Every time we do run an evaluation step, run it for exactly 10 episodes. + { + "evaluation_duration": 10, + "evaluation_duration_unit": "episodes", + } + # Every time we do run an evaluation step, run it for close to 200 timesteps. + { + "evaluation_duration": 200, + "evaluation_duration_unit": "timesteps", + } + + +Before each evaluation step, weights from the main model are synchronized to all evaluation workers. + +Normally, the evaluation step is run right after the respective train step. For example, for +``evaluation_interval=2``, the sequence of steps is: ``train, train, eval, train, train, eval, ...``. +For ``evaluation_interval=1``, the sequence is: ``train, eval, train, eval, ...``. + +However, it is possible to run evaluation in parallel to training via the ``evaluation_parallel_to_training=True`` +config setting. In this case, both steps (train and eval) are run at the same time via threading. +This can speed up the evaluation process significantly, but leads to a 1-iteration delay between reported +training results and evaluation results (the evaluation results are behind b/c they use slightly outdated +model weights). + +When running with the ``evaluation_parallel_to_training=True`` setting, a special "auto" value +is supported for ``evaluation_duration``. This can be used to make the evaluation step take +roughly as long as the train step: + +.. code-block:: python + + # Run eval and train at the same time via threading and make sure they roughly + # take the same time, such that the next `Trainer.train()` call can execute + # immediately and not have to wait for a still ongoing (e.g. very long episode) + # evaluation step: + { + "evaluation_interval": 1, + "evaluation_parallel_to_training": True, + "evaluation_duration": "auto", # automatically end evaluation when train step has finished + "evaluation_duration_unit": "timesteps", # <- more fine grained than "episodes" + } + + +The ``evaluation_config`` key allows you to override any config settings for +the evaluation workers. For example, to switch off exploration in the evaluation steps, +do: .. code-block:: python # Switching off exploration behavior for evaluation workers - # (see rllib/agents/trainer.py) + # (see rllib/agents/trainer.py). Use any keys in this sub-dict that are + # also supported in the main Trainer config. "evaluation_config": { "explore": False } @@ -752,6 +808,17 @@ via: policy, even if this is a stochastic one. Setting "explore=False" above will result in the evaluation workers not using this stochastic policy. +Parallelism for the evaluation step is determined via the ``evaluation_num_workers`` +setting. Set this to larger values if you want the desired evaluation episodes or timesteps to +run as much in parallel as possible. For example, if your ``evaluation_duration=10``, +``evaluation_duration_unit=episodes``, and ``evaluation_num_workers=10``, each eval worker +only has to run 1 episode in each eval step. + +In case you would like to entirely customize the evaluation step, set ``custom_eval_function`` in your +config to a callable taking the Trainer object and a WorkerSet object (the evaluation WorkerSet) +and returning a metrics dict. See `trainer.py `__ +for further documentation. + There is an end to end example of how to set up custom online evaluation in `custom_eval.py `__. Note that if you only want to eval your policy at the end of training, you can set ``evaluation_interval: N``, where ``N`` is the number of training iterations before stopping. Below are some examples of how the custom evaluation metrics are reported nested under the ``evaluation`` key of normal training results: diff --git a/rllib/BUILD b/rllib/BUILD index f602d8b85..1e8fa444c 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -2335,34 +2335,34 @@ py_test( tags = ["team:ml", "examples", "examples_P"], size = "medium", srcs = ["examples/parallel_evaluation_and_training.py"], - args = ["--as-test", "--stop-reward=50.0", "--num-cpus=6", "--evaluation-num-episodes=13"] + args = ["--as-test", "--stop-reward=50.0", "--num-cpus=6", "--evaluation-duration=13"] ) py_test( - name = "examples/parallel_evaluation_and_training_auto_num_episodes_tf", + name = "examples/parallel_evaluation_and_training_auto_episodes_tf", main = "examples/parallel_evaluation_and_training.py", tags = ["team:ml", "examples", "examples_P"], size = "medium", srcs = ["examples/parallel_evaluation_and_training.py"], - args = ["--as-test", "--stop-reward=50.0", "--num-cpus=6", "--evaluation-num-episodes=auto"] + args = ["--as-test", "--stop-reward=50.0", "--num-cpus=6", "--evaluation-duration=auto"] ) py_test( - name = "examples/parallel_evaluation_and_training_11_episodes_tf2", + name = "examples/parallel_evaluation_and_training_211_ts_tf2", main = "examples/parallel_evaluation_and_training.py", tags = ["team:ml", "examples", "examples_P"], size = "medium", srcs = ["examples/parallel_evaluation_and_training.py"], - args = ["--as-test", "--framework=tf2", "--stop-reward=30.0", "--num-cpus=6", "--evaluation-num-episodes=11"] + args = ["--as-test", "--framework=tf2", "--stop-reward=30.0", "--num-cpus=6", "--evaluation-num-workers=3", "--evaluation-duration=211", "--evaluation-duration-unit=timesteps"] ) py_test( - name = "examples/parallel_evaluation_and_training_14_episodes_torch", + name = "examples/parallel_evaluation_and_training_auto_ts_torch", main = "examples/parallel_evaluation_and_training.py", tags = ["team:ml", "examples", "examples_P"], size = "medium", srcs = ["examples/parallel_evaluation_and_training.py"], - args = ["--as-test", "--framework=torch", "--stop-reward=30.0", "--num-cpus=6", "--evaluation-num-episodes=14"] + args = ["--as-test", "--framework=torch", "--stop-reward=30.0", "--num-cpus=6", "--evaluation-num-workers=3", "--evaluation-duration=auto", "--evaluation-duration-unit=timesteps"] ) py_test( diff --git a/rllib/agents/cql/tests/test_cql.py b/rllib/agents/cql/tests/test_cql.py index 87c744ca7..0d8778415 100644 --- a/rllib/agents/cql/tests/test_cql.py +++ b/rllib/agents/cql/tests/test_cql.py @@ -58,7 +58,7 @@ class TestCQL(unittest.TestCase): config["input_evaluation"] = ["is"] config["evaluation_interval"] = 2 - config["evaluation_num_episodes"] = 10 + config["evaluation_duration"] = 10 config["evaluation_config"]["input"] = "sampler" config["evaluation_parallel_to_training"] = False config["evaluation_num_workers"] = 2 diff --git a/rllib/agents/ddpg/ddpg.py b/rllib/agents/ddpg/ddpg.py index e44dc5b26..3fb713ac1 100644 --- a/rllib/agents/ddpg/ddpg.py +++ b/rllib/agents/ddpg/ddpg.py @@ -39,7 +39,7 @@ DEFAULT_CONFIG = with_common_config({ # metrics are already only reported for the lowest epsilon workers. "evaluation_interval": None, # Number of episodes to run per evaluation period. - "evaluation_num_episodes": 10, + "evaluation_duration": 10, # === Model === # Apply a state preprocessor with spec given by the "model" config option diff --git a/rllib/agents/marwil/tests/test_bc.py b/rllib/agents/marwil/tests/test_bc.py index d6ac23489..acfd4282c 100644 --- a/rllib/agents/marwil/tests/test_bc.py +++ b/rllib/agents/marwil/tests/test_bc.py @@ -37,7 +37,7 @@ class TestBC(unittest.TestCase): config["evaluation_interval"] = 3 config["evaluation_num_workers"] = 1 - config["evaluation_num_episodes"] = 5 + config["evaluation_duration"] = 5 config["evaluation_parallel_to_training"] = True # Evaluate on actual environment. config["evaluation_config"] = {"input": "sampler"} diff --git a/rllib/agents/marwil/tests/test_marwil.py b/rllib/agents/marwil/tests/test_marwil.py index 63f7f5dc6..cdf57d424 100644 --- a/rllib/agents/marwil/tests/test_marwil.py +++ b/rllib/agents/marwil/tests/test_marwil.py @@ -43,7 +43,7 @@ class TestMARWIL(unittest.TestCase): config["num_workers"] = 2 config["evaluation_num_workers"] = 1 config["evaluation_interval"] = 3 - config["evaluation_num_episodes"] = 5 + config["evaluation_duration"] = 5 config["evaluation_parallel_to_training"] = True # Evaluate on actual environment. config["evaluation_config"] = {"input": "sampler"} @@ -100,7 +100,7 @@ class TestMARWIL(unittest.TestCase): config["num_workers"] = 1 config["evaluation_num_workers"] = 1 config["evaluation_interval"] = 3 - config["evaluation_num_episodes"] = 5 + config["evaluation_duration"] = 5 config["evaluation_parallel_to_training"] = True # Evaluate on actual environment. config["evaluation_config"] = {"input": "sampler"} diff --git a/rllib/agents/qmix/qmix.py b/rllib/agents/qmix/qmix.py index ef7b48772..065709329 100644 --- a/rllib/agents/qmix/qmix.py +++ b/rllib/agents/qmix/qmix.py @@ -49,7 +49,7 @@ DEFAULT_CONFIG = with_common_config({ # metrics are already only reported for the lowest epsilon workers. "evaluation_interval": None, # Number of episodes to run per evaluation period. - "evaluation_num_episodes": 10, + "evaluation_duration": 10, # Switch to greedy actions in evaluation workers. "evaluation_config": { "explore": False, diff --git a/rllib/agents/tests/test_trainer.py b/rllib/agents/tests/test_trainer.py index e15c837f4..479d7cae1 100644 --- a/rllib/agents/tests/test_trainer.py +++ b/rllib/agents/tests/test_trainer.py @@ -12,7 +12,7 @@ import ray.rllib.agents.pg as pg from ray.rllib.agents.trainer import COMMON_CONFIG from ray.rllib.examples.env.multi_agent import MultiAgentCartPole from ray.rllib.examples.parallel_evaluation_and_training import \ - AssertNumEvalEpisodesCallback + AssertEvalCallback from ray.rllib.utils.metrics.learner_info import LEARNER_INFO from ray.rllib.utils.test_utils import framework_iterator @@ -131,13 +131,13 @@ class TestTrainer(unittest.TestCase): config.update({ "env": "CartPole-v0", "evaluation_interval": 2, - "evaluation_num_episodes": 2, + "evaluation_duration": 2, "evaluation_config": { "gamma": 0.98, }, # Use a custom callback that asserts that we are running the # configured exact number of episodes per evaluation. - "callbacks": AssertNumEvalEpisodesCallback, + "callbacks": AssertEvalCallback, }) for _ in framework_iterator(config, frameworks=("tf", "torch")): @@ -169,7 +169,7 @@ class TestTrainer(unittest.TestCase): "evaluation_interval": None, # Use a custom callback that asserts that we are running the # configured exact number of episodes per evaluation. - "callbacks": AssertNumEvalEpisodesCallback, + "callbacks": AssertEvalCallback, }) for _ in framework_iterator(frameworks=("tf", "torch")): # Setup trainer w/o evaluation worker set and still call diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index 819ef66f1..849a36e0a 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -4,6 +4,7 @@ from datetime import datetime import functools import gym import logging +import math import numpy as np import os import pickle @@ -271,17 +272,23 @@ COMMON_CONFIG: TrainerConfigDict = { # === Evaluation Settings === # Evaluate with every `evaluation_interval` training iterations. # The evaluation stats will be reported under the "evaluation" metric key. - # Note that evaluation is currently not parallelized, and that for Ape-X - # metrics are already only reported for the lowest epsilon workers. + # Note that for Ape-X metrics are already only reported for the lowest + # epsilon workers (least random workers). + # Set to None (or 0) for no evaluation. "evaluation_interval": None, - # Number of episodes to run in total per evaluation period. + # Duration for which to run evaluation each `evaluation_interval`. + # The unit for the duration can be set via `evaluation_duration_unit` to + # either "episodes" (default) or "timesteps". # If using multiple evaluation workers (evaluation_num_workers > 1), - # episodes will be split amongst these. - # If "auto": - # - evaluation_parallel_to_training=True: Will run as many episodes as the - # training step takes. - # - evaluation_parallel_to_training=False: Error. - "evaluation_num_episodes": 10, + # the load to run will be split amongst these. + # If the value is "auto": + # - For `evaluation_parallel_to_training=True`: Will run as many + # episodes/timesteps that fit into the (parallel) training step. + # - For `evaluation_parallel_to_training=False`: Error. + "evaluation_duration": 10, + # The unit, with which to count the evaluation duration. Either "episodes" + # (default) or "timesteps". + "evaluation_duration_unit": "episodes", # Whether to run evaluation in parallel to a Trainer.train() call # using threading. Default=False. # E.g. evaluation_interval=2 -> For every other training iteration, @@ -310,9 +317,9 @@ COMMON_CONFIG: TrainerConfigDict = { "evaluation_num_workers": 0, # Customize the evaluation method. This must be a function of signature # (trainer: Trainer, eval_workers: WorkerSet) -> metrics: dict. See the - # Trainer.evaluate() method to see the default implementation. The - # trainer guarantees all eval workers have the latest policy state before - # this function is called. + # Trainer.evaluate() method to see the default implementation. + # The Trainer guarantees all eval workers have the latest policy state + # before this function is called. "custom_eval_function": None, # === Advanced Rollout Settings === @@ -520,6 +527,9 @@ COMMON_CONFIG: TrainerConfigDict = { # Whether to write episode stats and videos to the agent log dir. This is # typically located in ~/ray_results. "monitor": DEPRECATED_VALUE, + # Replaced by `evaluation_duration=10` and + # `evaluation_duration_unit=episodes`. + "evaluation_num_episodes": DEPRECATED_VALUE, } # __sphinx_doc_end__ # yapf: enable @@ -661,6 +671,10 @@ class Trainer(Trainable): logger_creator = default_logger_creator + # Evaluation WorkerSet and metrics last returned by `self.evaluate()`. + self.evaluation_workers = None + self.evaluation_metrics = {} + super().__init__(config, logger_creator, remote_checkpoint_dir, sync_function_tpl) @@ -769,29 +783,60 @@ class Trainer(Trainable): self.workers, self.config, **self._kwargs_for_execution_plan()) # Evaluation WorkerSet setup. - self.evaluation_workers = None - self.evaluation_metrics = {} # User would like to setup a separate evaluation worker set. + + # Update with evaluation settings: + user_eval_config = copy.deepcopy(self.config["evaluation_config"]) + + # Assert that user has not unset "in_evaluation". + assert "in_evaluation" not in user_eval_config or \ + user_eval_config["in_evaluation"] is True + + # Merge user-provided eval config with the base config. This makes sure + # the eval config is always complete, no matter whether we have eval + # workers or perform evaluation on the (non-eval) local worker. + eval_config = merge_dicts(self.config, user_eval_config) + self.config["evaluation_config"] = eval_config + if self.config.get("evaluation_num_workers", 0) > 0 or \ self.config.get("evaluation_interval"): - # Update env_config with evaluation settings: - extra_config = copy.deepcopy(self.config["evaluation_config"]) - # Assert that user has not unset "in_evaluation". - assert "in_evaluation" not in extra_config or \ - extra_config["in_evaluation"] is True - evaluation_config = merge_dicts(self.config, extra_config) + logger.debug(f"Using evaluation_config: {user_eval_config}.") + # Validate evaluation config. - self.validate_config(evaluation_config) - # Switch on complete_episode rollouts (evaluations are - # always done on n complete episodes) and set the - # `in_evaluation` flag. Also, make sure our rollout fragments - # are short so we don't have more than one episode in one rollout. - evaluation_config.update({ - "batch_mode": "complete_episodes", - "rollout_fragment_length": 1, - "in_evaluation": True, - }) - logger.debug("using evaluation_config: {}".format(extra_config)) + self.validate_config(eval_config) + + # Set the `in_evaluation` flag. + eval_config["in_evaluation"] = True + + # Evaluation duration unit: episodes. + # Switch on `complete_episode` rollouts. Also, make sure + # rollout fragments are short so we never have more than one + # episode in one rollout. + if eval_config["evaluation_duration_unit"] == "episodes": + eval_config.update({ + "batch_mode": "complete_episodes", + "rollout_fragment_length": 1, + }) + # Evaluation duration unit: timesteps. + # - Set `batch_mode=truncate_episodes` so we don't perform rollouts + # strictly along episode borders. + # Set `rollout_fragment_length` such that desired steps are divided + # equally amongst workers or - in "auto" duration mode - set it + # to a reasonably small number (10), such that a single `sample()` + # call doesn't take too much time so we can stop evaluation as soon + # as possible after the train step is completed. + else: + eval_config.update({ + "batch_mode": "truncate_episodes", + "rollout_fragment_length": 10 + if self.config["evaluation_duration"] == "auto" else int( + math.ceil( + self.config["evaluation_duration"] / + (self.config["evaluation_num_workers"] or 1))), + }) + + self.config["evaluation_config"] = eval_config + # Create a separate evaluation worker set for evaluation. # If evaluation_num_workers=0, use the evaluation set's local # worker for evaluation, otherwise, use its remote workers @@ -800,8 +845,11 @@ class Trainer(Trainable): env_creator=self.env_creator, validate_env=None, policy_class=self.get_default_policy_class(self.config), - config=evaluation_config, - num_workers=self.config["evaluation_num_workers"]) + config=eval_config, + num_workers=self.config["evaluation_num_workers"], + # Don't even create a local worker if num_workers > 0. + local_worker=False, + ) # TODO: Deprecated: In your sub-classes of Trainer, override `setup()` # directly and call super().setup() from within it if you would like the @@ -883,14 +931,33 @@ class Trainer(Trainable): """Attempts a single training step, including evaluation, if required. Override this method in your Trainer sub-classes if you would like to - keep the n attempts (catch worker failures) or override `step()` - directly if you would like to handle worker failures yourself. + keep the n step-attempts logic (catch worker failures) in place or + override `step()` directly if you would like to handle worker + failures yourself. Returns: The results dict with stats/infos on sampling, training, and - if required - evaluation. """ + def auto_duration_fn(unit, num_eval_workers, eval_cfg, num_units_done): + # Training is done and we already ran at least one + # evaluation -> Nothing left to run. + if num_units_done > 0 and \ + train_future.done(): + return 0 + # Count by episodes. -> Run n more + # (n=num eval workers). + elif unit == "episodes": + return num_eval_workers + # Count by timesteps. -> Run n*m*p more + # (n=num eval workers; m=rollout fragment length; + # p=num-envs-per-worker). + else: + return num_eval_workers * \ + eval_cfg["rollout_fragment_length"] * \ + eval_cfg["num_envs_per_worker"] + # self._iteration gets incremented after this function returns, # meaning that e. g. the first time this function is called, # self._iteration will be 0. @@ -914,19 +981,15 @@ class Trainer(Trainable): with concurrent.futures.ThreadPoolExecutor() as executor: train_future = executor.submit( lambda: next(self.train_exec_impl)) - if self.config["evaluation_num_episodes"] == "auto": - - # Run at least one `evaluate()` (num_episodes_done - # must be > 0), even if the training is very fast. - def episodes_left_fn(num_episodes_done): - if num_episodes_done > 0 and \ - train_future.done(): - return 0 - else: - return self.config["evaluation_num_workers"] + # Automatically determine duration of the evaluation. + if self.config["evaluation_duration"] == "auto": + unit = self.config["evaluation_duration_unit"] evaluation_metrics = self.evaluate( - episodes_left_fn=episodes_left_fn) + duration_fn=functools.partial( + auto_duration_fn, unit, self.config[ + "evaluation_num_workers"], self.config[ + "evaluation_config"])) else: evaluation_metrics = self.evaluate() # Collect the training results from the future. @@ -959,19 +1022,29 @@ class Trainer(Trainable): return step_results @PublicAPI - def evaluate(self, episodes_left_fn: Optional[Callable[[int], int]] = None - ) -> dict: + def evaluate( + self, + episodes_left_fn=None, # deprecated + duration_fn: Optional[Callable[[int], int]] = None, + ) -> dict: """Evaluates current policy under `evaluation_config` settings. Note that this default implementation does not do anything beyond merging evaluation_config with the normal trainer config. Args: - episodes_left_fn: An optional callable taking the already run + duration_fn: An optional callable taking the already run num episodes as only arg and returning the number of episodes left to run. It's used to find out whether evaluation should continue. """ + if episodes_left_fn is not None: + deprecation_warning( + old="Trainer.evaluate(episodes_left_fn)", + new="Trainer.evaluate(duration_fn)", + error=False) + duration_fn = episodes_left_fn + # In case we are evaluating (in a thread) parallel to training, # we may have to re-enable eager mode here (gets disabled in the # thread). @@ -996,80 +1069,102 @@ class Trainer(Trainable): raise ValueError("Custom eval function must return " "dict of metrics, got {}.".format(metrics)) else: - # How many episodes do we need to run? - # In "auto" mode (only for parallel eval + training): Run one - # episode per eval worker. - num_episodes = self.config["evaluation_num_episodes"] if \ - self.config["evaluation_num_episodes"] != "auto" else \ - (self.config["evaluation_num_workers"] or 1) + if self.evaluation_workers is None and \ + self.workers.local_worker().input_reader is None: + raise ValueError( + "Cannot evaluate w/o an evaluation worker set in " + "the Trainer or w/o an env on the local worker!\n" + "Try one of the following:\n1) Set " + "`evaluation_interval` >= 0 to force creating a " + "separate evaluation worker set.\n2) Set " + "`create_env_on_driver=True` to force the local " + "(non-eval) worker to have an environment to " + "evaluate on.") + + # How many episodes/timesteps do we need to run? + # In "auto" mode (only for parallel eval + training): Run as long + # as training lasts. + unit = self.config["evaluation_duration_unit"] + eval_cfg = self.config["evaluation_config"] + rollout = eval_cfg["rollout_fragment_length"] + num_envs = eval_cfg["num_envs_per_worker"] + duration = self.config["evaluation_duration"] if \ + self.config["evaluation_duration"] != "auto" else \ + (self.config["evaluation_num_workers"] or 1) * \ + (1 if unit == "episodes" else rollout) + num_ts_run = 0 # Default done-function returns True, whenever num episodes # have been completed. - if episodes_left_fn is None: + if duration_fn is None: - def episodes_left_fn(num_episodes_done): - return num_episodes - num_episodes_done + def duration_fn(num_units_done): + return duration - num_units_done - logger.info( - f"Evaluating current policy for {num_episodes} episodes.") + logger.info(f"Evaluating current policy for {duration} {unit}.") metrics = None # No evaluation worker set -> # Do evaluation using the local worker. Expect error due to the # local worker not having an env. if self.evaluation_workers is None: - try: - for _ in range(num_episodes): - self.workers.local_worker().sample() - metrics = collect_metrics(self.workers.local_worker()) - except ValueError as e: - if "RolloutWorker has no `input_reader` object" in \ - e.args[0]: - raise ValueError( - "Cannot evaluate w/o an evaluation worker set in " - "the Trainer or w/o an env on the local worker!\n" - "Try one of the following:\n1) Set " - "`evaluation_interval` >= 0 to force creating a " - "separate evaluation worker set.\n2) Set " - "`create_env_on_driver=True` to force the local " - "(non-eval) worker to have an environment to " - "evaluate on.") - else: - raise e + # If unit=episodes -> Run n times `sample()` (each sample + # produces exactly 1 episode). + # If unit=ts -> Run 1 `sample()` b/c the + # `rollout_fragment_length` is exactly the desired ts. + iters = duration if unit == "episodes" else 1 + for _ in range(iters): + num_ts_run += len(self.workers.local_worker().sample()) + metrics = collect_metrics(self.workers.local_worker()) # Evaluation worker set only has local worker. elif self.config["evaluation_num_workers"] == 0: - for _ in range(num_episodes): - self.evaluation_workers.local_worker().sample() + # If unit=episodes -> Run n times `sample()` (each sample + # produces exactly 1 episode). + # If unit=ts -> Run 1 `sample()` b/c the + # `rollout_fragment_length` is exactly the desired ts. + iters = duration if unit == "episodes" else 1 + for _ in range(iters): + num_ts_run += len( + self.evaluation_workers.local_worker().sample()) # Evaluation worker set has n remote workers. else: # How many episodes have we run (across all eval workers)? - num_episodes_done = 0 + num_units_done = 0 round_ = 0 while True: - episodes_left_to_do = episodes_left_fn(num_episodes_done) - if episodes_left_to_do <= 0: + units_left_to_do = duration_fn(num_units_done) + if units_left_to_do <= 0: break round_ += 1 batches = ray.get([ w.sample.remote() for i, w in enumerate( self.evaluation_workers.remote_workers()) - if i < episodes_left_to_do + if i * (1 if unit == "episodes" else rollout * + num_envs) < units_left_to_do ]) - # Per our config for the evaluation workers - # (`rollout_fragment_length=1` and - # `batch_mode=complete_episode`), we know that we'll have - # exactly one episode per returned batch. - num_episodes_done += len(batches) - logger.info( - f"Ran round {round_} of parallel evaluation " - f"({num_episodes_done}/{num_episodes} episodes done)") + # 1 episode per returned batch. + if unit == "episodes": + num_units_done += len(batches) + # n timesteps per returned batch. + else: + ts = sum(len(b) for b in batches) + num_ts_run += ts + num_units_done += ts + + logger.info(f"Ran round {round_} of parallel evaluation " + f"({num_units_done}/{duration} {unit} done)") + if metrics is None: metrics = collect_metrics( self.evaluation_workers.local_worker(), self.evaluation_workers.remote_workers()) + metrics["timesteps_this_iter"] = num_ts_run + + self.evaluation_metrics = metrics + return {"evaluation": metrics} @DeveloperAPI @@ -1684,6 +1779,7 @@ class Trainer(Trainable): policy_class: Type[Policy], config: TrainerConfigDict, num_workers: int, + local_worker: bool = True, ) -> WorkerSet: """Default factory method for a WorkerSet running under this Trainer. @@ -1703,6 +1799,9 @@ class Trainer(Trainable): config: The Trainer's config. num_workers: Number of remote rollout workers to create. 0 for local only. + local_worker: Whether to create a local (non @ray.remote) worker + in the returned set as well (default: True). If `num_workers` + is 0, always create a local worker. Returns: The created WorkerSet. @@ -1713,7 +1812,9 @@ class Trainer(Trainable): policy_class=policy_class, trainer_config=config, num_workers=num_workers, - logdir=self.logdir) + local_worker=local_worker, + logdir=self.logdir, + ) def _sync_filters_if_needed(self, workers: WorkerSet): if self.config.get("observation_filter", "NoFilter") != "NoFilter": @@ -1985,6 +2086,18 @@ class Trainer(Trainable): "Got {}".format(config["multiagent"]["count_steps_by"])) # Evaluation settings. + + # Deprecated setting: `evaluation_num_episodes`. + if config["evaluation_num_episodes"] != DEPRECATED_VALUE: + deprecation_warning( + old="evaluation_num_episodes", + new="`evaluation_duration` and `evaluation_duration_unit=" + "episodes`", + error=False) + config["evaluation_duration"] = config["evaluation_num_episodes"] + config["evaluation_duration_unit"] = "episodes" + config["evaluation_num_episodes"] = DEPRECATED_VALUE + # If `evaluation_num_workers` > 0, warn if `evaluation_interval` is # None (also set `evaluation_interval` to 1). if config["evaluation_num_workers"] > 0 and \ @@ -2008,18 +2121,18 @@ class Trainer(Trainable): "`evaluation_parallel_to_training` to False.") config["evaluation_parallel_to_training"] = False - # If `evaluation_num_episodes=auto`, error if + # If `evaluation_duration=auto`, error if # `evaluation_parallel_to_training=False`. - if config["evaluation_num_episodes"] == "auto": + if config["evaluation_duration"] == "auto": if not config["evaluation_parallel_to_training"]: raise ValueError( - "`evaluation_num_episodes=auto` not supported for " + "`evaluation_duration=auto` not supported for " "`evaluation_parallel_to_training=False`!") # Make sure, it's an int otherwise. - elif not isinstance(config["evaluation_num_episodes"], int): - raise ValueError( - "`evaluation_num_episodes` ({}) must be an int and " - ">0!".format(config["evaluation_num_episodes"])) + elif not isinstance(config["evaluation_duration"], int) or \ + config["evaluation_duration"] <= 0: + raise ValueError("`evaluation_duration` ({}) must be an int and " + ">0!".format(config["evaluation_duration"])) @ExperimentalAPI @staticmethod @@ -2284,7 +2397,7 @@ class Trainer(Trainable): def __repr__(self): return type(self).__name__ - @Deprecated(new="Trainer.evaluate()", error=False) + @Deprecated(new="Trainer.evaluate()", error=True) def _evaluate(self) -> dict: return self.evaluate() diff --git a/rllib/contrib/maddpg/maddpg.py b/rllib/contrib/maddpg/maddpg.py index 50b9936ef..b866fc1c4 100644 --- a/rllib/contrib/maddpg/maddpg.py +++ b/rllib/contrib/maddpg/maddpg.py @@ -37,7 +37,7 @@ DEFAULT_CONFIG = with_common_config({ # Evaluation interval "evaluation_interval": None, # Number of episodes to run per evaluation period. - "evaluation_num_episodes": 10, + "evaluation_duration": 10, # === Model === # Apply a state preprocessor with spec given by the "model" config option diff --git a/rllib/evaluate.py b/rllib/evaluate.py index 884f12648..f1d7c54c8 100755 --- a/rllib/evaluate.py +++ b/rllib/evaluate.py @@ -309,8 +309,8 @@ def run(args, parser): # Make sure we have evaluation workers. if not config.get("evaluation_num_workers"): config["evaluation_num_workers"] = config.get("num_workers", 0) - if not config.get("evaluation_num_episodes"): - config["evaluation_num_episodes"] = 1 + if not config.get("evaluation_duration"): + config["evaluation_duration"] = 1 # Hard-override this as it raises a warning by Trainer otherwise. # Makes no sense anyways, to have it set to None as we don't call # `Trainer.train()` here. @@ -401,7 +401,7 @@ def rollout(agent, saver.begin_rollout() eval_result = agent.evaluate()["evaluation"] # Increase timestep and episode counters. - eps = agent.config["evaluation_num_episodes"] + eps = agent.config["evaluation_duration"] episodes += eps steps += eps * eval_result["episode_len_mean"] # Print out results and continue. diff --git a/rllib/evaluation/worker_set.py b/rllib/evaluation/worker_set.py index b38c17509..4b6a74b6c 100644 --- a/rllib/evaluation/worker_set.py +++ b/rllib/evaluation/worker_set.py @@ -34,15 +34,18 @@ class WorkerSet: Where n may be 0. """ - def __init__(self, - *, - env_creator: Optional[Callable[[EnvContext], EnvType]] = None, - validate_env: Optional[Callable[[EnvType], None]] = None, - policy_class: Optional[Type[Policy]] = None, - trainer_config: Optional[TrainerConfigDict] = None, - num_workers: int = 0, - logdir: Optional[str] = None, - _setup: bool = True): + def __init__( + self, + *, + env_creator: Optional[Callable[[EnvContext], EnvType]] = None, + validate_env: Optional[Callable[[EnvType], None]] = None, + policy_class: Optional[Type[Policy]] = None, + trainer_config: Optional[TrainerConfigDict] = None, + num_workers: int = 0, + local_worker: bool = True, + logdir: Optional[str] = None, + _setup: bool = True, + ): """Initializes a WorkerSet instance. Args: @@ -55,6 +58,9 @@ class WorkerSet: trainer_config: Optional dict that extends the common config of the Trainer class. num_workers: Number of remote rollout workers to create. + local_worker: Whether to create a local (non @ray.remote) worker + in the returned set as well (default: True). If `num_workers` + is 0, always create a local worker. logdir: Optional logging directory for workers. _setup: Whether to setup workers. This is only for testing. """ @@ -69,18 +75,25 @@ class WorkerSet: self._logdir = logdir if _setup: + # Force a local worker if num_workers == 0 (no remote workers). + # Otherwise, this WorkerSet would be empty. + self._local_worker = None + if num_workers == 0: + local_worker = True + self._local_config = merge_dicts( trainer_config, {"tf_session_args": trainer_config["local_tf_session_args"]}) - # Create a number of remote workers. + # Create a number of @ray.remote workers. self._remote_workers = [] self.add_workers(num_workers) + # Create a local worker, if needed. # If num_workers > 0 and we don't have an env on the local worker, # get the observation- and action spaces for each policy from # the first remote worker (which does have an env). - if self._remote_workers and \ + if local_worker and self._remote_workers and \ not trainer_config.get("create_env_on_driver") and \ (not trainer_config.get("observation_space") or not trainer_config.get("action_space")): @@ -106,17 +119,17 @@ class WorkerSet: else: spaces = None - # Always create a local worker. - self._local_worker = self._make_worker( - cls=RolloutWorker, - env_creator=env_creator, - validate_env=validate_env, - policy_cls=self._policy_class, - worker_index=0, - num_workers=num_workers, - config=self._local_config, - spaces=spaces, - ) + if local_worker: + self._local_worker = self._make_worker( + cls=RolloutWorker, + env_creator=env_creator, + validate_env=validate_env, + policy_cls=self._policy_class, + worker_index=0, + num_workers=num_workers, + config=self._local_config, + spaces=spaces, + ) def local_worker(self) -> RolloutWorker: """Returns the local rollout worker.""" @@ -197,7 +210,9 @@ class WorkerSet: Returns: The list of return values of all calls to `func([worker])`. """ - local_result = [func(self.local_worker())] + local_result = [] + if self._local_worker: + local_result = [func(self.local_worker())] remote_results = ray.get( [w.apply.remote(func) for w in self.remote_workers()]) return local_result + remote_results @@ -219,8 +234,10 @@ class WorkerSet: The first entry in this list are the results of the local worker, followed by all remote workers' results. """ + local_result = [] # Local worker: Index=0. - local_result = [func(self.local_worker(), 0)] + if self._local_worker: + local_result = [func(self.local_worker(), 0)] # Remote workers: Index > 0. remote_results = ray.get([ w.apply.remote(func, i + 1) @@ -247,7 +264,9 @@ class WorkerSet: The local workers' results are first, followed by all remote workers' results """ - results = self.local_worker().foreach_policy(func) + results = [] + if self._local_worker: + results = self.local_worker().foreach_policy(func) ray_gets = [] for worker in self.remote_workers(): ray_gets.append( @@ -260,7 +279,10 @@ class WorkerSet: @DeveloperAPI def trainable_policies(self) -> List[PolicyID]: """Returns the list of trainable policy ids.""" - return self.local_worker().policies_to_train + if self._local_worker: + return self._local_worker.policies_to_train + else: + raise NotImplementedError @DeveloperAPI def foreach_trainable_policy( @@ -275,7 +297,9 @@ class WorkerSet: List[any]: The list of n return values of all `func([trainable policy], [ID])`-calls. """ - results = self.local_worker().foreach_trainable_policy(func) + results = [] + if self._local_worker: + results = self.local_worker().foreach_trainable_policy(func) ray_gets = [] for worker in self.remote_workers(): ray_gets.append( @@ -303,7 +327,9 @@ class WorkerSet: Returns: The list (workers) of lists (sub environments) of results. """ - local_results = [self.local_worker().foreach_env(func)] + local_results = [] + if self._local_worker: + local_results = [self.local_worker().foreach_env(func)] ray_gets = [] for worker in self.remote_workers(): ray_gets.append(worker.foreach_env.remote(func)) @@ -329,7 +355,11 @@ class WorkerSet: The list (1 item per workers) of lists (1 item per sub-environment) of results. """ - local_results = [self.local_worker().foreach_env_with_context(func)] + local_results = [] + if self._local_worker: + local_results = [ + self.local_worker().foreach_env_with_context(func) + ] ray_gets = [] for worker in self.remote_workers(): ray_gets.append(worker.foreach_env_with_context.remote(func)) diff --git a/rllib/examples/custom_eval.py b/rllib/examples/custom_eval.py index 12c191c34..db4e77732 100644 --- a/rllib/examples/custom_eval.py +++ b/rllib/examples/custom_eval.py @@ -181,7 +181,7 @@ if __name__ == "__main__": "evaluation_interval": 1, # Run 10 episodes each time evaluation runs. - "evaluation_num_episodes": 10, + "evaluation_duration": 10, # Override the env config for evaluation. "evaluation_config": { diff --git a/rllib/examples/custom_input_api.py b/rllib/examples/custom_input_api.py index 5ac560277..e7f71ee15 100644 --- a/rllib/examples/custom_input_api.py +++ b/rllib/examples/custom_input_api.py @@ -101,7 +101,7 @@ if __name__ == "__main__": "metrics_smoothing_episodes": 5, "evaluation_interval": 1, "evaluation_num_workers": 2, - "evaluation_num_episodes": 10, + "evaluation_duration": 10, "evaluation_parallel_to_training": True, "evaluation_config": { "input": "sampler", diff --git a/rllib/examples/env_rendering_and_recording.py b/rllib/examples/env_rendering_and_recording.py index 5ac1f08da..23bb9f961 100644 --- a/rllib/examples/env_rendering_and_recording.py +++ b/rllib/examples/env_rendering_and_recording.py @@ -110,7 +110,7 @@ if __name__ == "__main__": # Evaluate once per training iteration. "evaluation_interval": 1, # Run evaluation on (at least) two episodes - "evaluation_num_episodes": 2, + "evaluation_duration": 2, # ... using one evaluation worker (setting this to 0 will cause # evaluation to run on the local evaluation worker, blocking # training until evaluation is done). diff --git a/rllib/examples/offline_rl.py b/rllib/examples/offline_rl.py index 02d3a109a..5fc366d8d 100644 --- a/rllib/examples/offline_rl.py +++ b/rllib/examples/offline_rl.py @@ -69,7 +69,7 @@ if __name__ == "__main__": # Set up evaluation. config["evaluation_num_workers"] = 1 config["evaluation_interval"] = 1 - config["evaluation_num_episodes"] = 10 + config["evaluation_duration"] = 10 # This should be False b/c iterations are very long and this would # cause evaluation to lag one iter behind training. config["evaluation_parallel_to_training"] = False diff --git a/rllib/examples/parallel_evaluation_and_training.py b/rllib/examples/parallel_evaluation_and_training.py index 05e81ee3a..cdf75345a 100644 --- a/rllib/examples/parallel_evaluation_and_training.py +++ b/rllib/examples/parallel_evaluation_and_training.py @@ -7,11 +7,30 @@ from ray.rllib.utils.test_utils import check_learning_achieved parser = argparse.ArgumentParser() parser.add_argument( - "--evaluation-num-episodes", + "--evaluation-duration", type=lambda v: v if v == "auto" else int(v), default=13, - help="Number of evaluation episodes to run each iteration. " + help="Number of evaluation episodes/timesteps to run each iteration. " "If 'auto', will run as many as possible during train pass.") +parser.add_argument( + "--evaluation-duration-unit", + type=str, + default="episodes", + choices=["episodes", "timesteps"], + help="The unit in which to measure the duration (`episodes` or" + "`timesteps`).") +parser.add_argument( + "--evaluation-num-workers", + type=int, + default=2, + help="The number of evaluation workers to setup. " + "0 for a single local evaluation worker. Note that for values >0, no" + "local evaluation worker will be created (b/c not needed).") +parser.add_argument( + "--evaluation-interval", + type=int, + default=2, + help="Every how many train iterations should we run an evaluation loop?") parser.add_argument( "--run", @@ -50,26 +69,44 @@ parser.add_argument( help="Init Ray in local mode for easier debugging.") -class AssertNumEvalEpisodesCallback(DefaultCallbacks): +class AssertEvalCallback(DefaultCallbacks): def on_train_result(self, *, trainer, result, **kwargs): - # Make sure we always run exactly n evaluation episodes, + # Make sure we always run exactly the given evaluation duration, # no matter what the other settings are (such as # `evaluation_num_workers` or `evaluation_parallel_to_training`). if "evaluation" in result: hist_stats = result["evaluation"]["hist_stats"] - num_episodes_done = len(hist_stats["episode_lengths"]) - # Compare number of entries in episode_lengths (this is the - # number of episodes actually run) with desired number of - # episodes from the config. - if isinstance(trainer.config["evaluation_num_episodes"], int): - assert num_episodes_done == \ - trainer.config["evaluation_num_episodes"] + # We count in episodes. + if trainer.config["evaluation_duration_unit"] == "episodes": + num_episodes_done = len(hist_stats["episode_lengths"]) + # Compare number of entries in episode_lengths (this is the + # number of episodes actually run) with desired number of + # episodes from the config. + if isinstance(trainer.config["evaluation_duration"], int): + assert num_episodes_done == \ + trainer.config["evaluation_duration"] + # If auto-episodes: Expect at least as many episode as workers + # (each worker's `sample()` is at least called once). + else: + assert trainer.config["evaluation_duration"] == "auto" + assert num_episodes_done >= \ + trainer.config["evaluation_num_workers"] + print("Number of run evaluation episodes: " + f"{num_episodes_done} (ok)!") + # We count in timesteps. else: - assert trainer.config["evaluation_num_episodes"] == "auto" - assert num_episodes_done >= \ - trainer.config["evaluation_num_workers"] - print("Number of run evaluation episodes: " - f"{num_episodes_done} (ok)!") + num_timesteps_reported = result["evaluation"][ + "timesteps_this_iter"] + num_timesteps_wanted = trainer.config["evaluation_duration"] + if num_timesteps_wanted != "auto": + delta = num_timesteps_wanted - num_timesteps_reported + # Expect roughly the same (desired // num-eval-workers). + assert abs(delta) < 20, \ + (delta, num_timesteps_wanted, num_timesteps_reported) + print("Number of run evaluation timesteps: " + f"{num_timesteps_reported} (ok)!") + + print(f"R={result['evaluation']['episode_reward_mean']}") if __name__ == "__main__": @@ -93,21 +130,23 @@ if __name__ == "__main__": "evaluation_parallel_to_training": True, # Use two evaluation workers. Must be >0, otherwise, # evaluation will run on a local worker and block (no parallelism). - "evaluation_num_workers": 2, + "evaluation_num_workers": args.evaluation_num_workers, # Evaluate every other training iteration (together # with every other call to Trainer.train()). - "evaluation_interval": 2, - # Run for n episodes (properly distribute load amongst all eval - # workers). The longer it takes to evaluate, the more - # sense it makes to use `evaluation_parallel_to_training=True`. + "evaluation_interval": args.evaluation_interval, + # Run for n episodes/timesteps (properly distribute load amongst + # all eval workers). The longer it takes to evaluate, the more sense + # it makes to use `evaluation_parallel_to_training=True`. # Use "auto" to run evaluation for roughly as long as the training # step takes. - "evaluation_num_episodes": args.evaluation_num_episodes, + "evaluation_duration": args.evaluation_duration, + # "episodes" or "timesteps". + "evaluation_duration_unit": args.evaluation_duration_unit, # Use a custom callback that asserts that we are running the # configured exact number of episodes per evaluation OR - in auto # mode - run at least as many episodes as we have eval workers. - "callbacks": AssertNumEvalEpisodesCallback, + "callbacks": AssertEvalCallback, } stop = { diff --git a/rllib/tests/test_rllib_train_and_evaluate.py b/rllib/tests/test_rllib_train_and_evaluate.py index fb28257c1..2f1961714 100644 --- a/rllib/tests/test_rllib_train_and_evaluate.py +++ b/rllib/tests/test_rllib_train_and_evaluate.py @@ -194,7 +194,7 @@ def learn_test_multi_agent_plus_evaluate(algo): # Test rolling out n steps. result = os.popen( - "python {}/rollout.py --run={} " + "python {}/evaluate.py --run={} " "--steps=400 " "--out=\"{}/rollouts_n_steps.pkl\" --no-render \"{}\"".format( rllib_dir, algo, tmp_dir, last_checkpoint)).read()[:-1] diff --git a/rllib/tuned_examples/ddpg/pendulum-apex-ddpg.yaml b/rllib/tuned_examples/ddpg/pendulum-apex-ddpg.yaml index 1846ed4c7..8bad4a662 100644 --- a/rllib/tuned_examples/ddpg/pendulum-apex-ddpg.yaml +++ b/rllib/tuned_examples/ddpg/pendulum-apex-ddpg.yaml @@ -14,4 +14,4 @@ pendulum-apex-ddpg: target_network_update_freq: 50000 tau: 1.0 evaluation_interval: 5 - evaluation_num_episodes: 10 + evaluation_duration: 10