From c5d20849aea8bfc842e6bb1f66d066141065756e Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Wed, 15 Sep 2021 08:45:17 +0200 Subject: [PATCH] [RLlib] Rename `rllib rollout` into `rllib evaluate` (backward compatible) to match Trainer API. (#18467) --- rllib/BUILD | 54 +- rllib/evaluate.py | 545 ++++++++++++++++++ rllib/rollout.py | 532 +---------------- rllib/scripts.py | 22 +- rllib/tests/test_nested_observation_spaces.py | 2 +- ...ut.py => test_rllib_train_and_evaluate.py} | 44 +- 6 files changed, 618 insertions(+), 581 deletions(-) create mode 100755 rllib/evaluate.py rename rllib/tests/{test_rollout.py => test_rllib_train_and_evaluate.py} (89%) diff --git a/rllib/BUILD b/rllib/BUILD index f7cd291a2..fd0df7b94 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1771,57 +1771,57 @@ py_test( srcs = ["tests/test_reproducibility.py"] ) -# Test train/rollout scripts (w/o confirming rollout performance). +# Test [train|evaluate].py scripts (w/o confirming evaluation performance). py_test( - name = "test_rollout_no_learning_1", - main = "tests/test_rollout.py", + name = "test_rllib_evaluate_1", + main = "tests/test_rllib_train_and_evaluate.py", tags = ["team:ml", "tests_dir", "tests_dir_R"], size = "large", - data = ["train.py", "rollout.py"], - srcs = ["tests/test_rollout.py"], - args = ["TestRolloutSimple1"] + data = ["train.py", "evaluate.py"], + srcs = ["tests/test_rllib_train_and_evaluate.py"], + args = ["TestEvaluate1"] ) py_test( - name = "test_rollout_no_learning_2", - main = "tests/test_rollout.py", + name = "test_rllib_evaluate_2", + main = "tests/test_rllib_train_and_evaluate.py", tags = ["team:ml", "tests_dir", "tests_dir_R"], size = "large", - data = ["train.py", "rollout.py"], - srcs = ["tests/test_rollout.py"], - args = ["TestRolloutSimple2"] + data = ["train.py", "evaluate.py"], + srcs = ["tests/test_rllib_train_and_evaluate.py"], + args = ["TestEvaluate2"] ) py_test( - name = "test_rollout_no_learning_3", - main = "tests/test_rollout.py", + name = "test_rllib_evaluate_3", + main = "tests/test_rllib_train_and_evaluate.py", tags = ["team:ml", "tests_dir", "tests_dir_R"], size = "large", - data = ["train.py", "rollout.py"], - srcs = ["tests/test_rollout.py"], - args = ["TestRolloutSimple3"] + data = ["train.py", "evaluate.py"], + srcs = ["tests/test_rllib_train_and_evaluate.py"], + args = ["TestEvaluate3"] ) py_test( - name = "test_rollout_no_learning_4", - main = "tests/test_rollout.py", + name = "test_rllib_evaluate_4", + main = "tests/test_rllib_train_and_evaluate.py", tags = ["team:ml", "tests_dir", "tests_dir_R"], size = "large", - data = ["train.py", "rollout.py"], - srcs = ["tests/test_rollout.py"], - args = ["TestRolloutSimple4"] + data = ["train.py", "evaluate.py"], + srcs = ["tests/test_rllib_train_and_evaluate.py"], + args = ["TestEvaluate4"] ) -# Test train/rollout scripts (and confirm `rllib rollout` performance is same +# Test [train|evaluate].py scripts (and confirm `rllib evaluate` performance is same # as the final one from the `rllib train` run). py_test( - name = "test_rollout_w_learning", - main = "tests/test_rollout.py", + name = "test_rllib_train_and_evaluate", + main = "tests/test_rllib_train_and_evaluate.py", tags = ["team:ml", "tests_dir", "tests_dir_R"], size = "large", - data = ["train.py", "rollout.py"], - srcs = ["tests/test_rollout.py"], - args = ["TestRolloutLearntPolicy"] + data = ["train.py", "evaluate.py"], + srcs = ["tests/test_rllib_train_and_evaluate.py"], + args = ["TestTrainAndEvaluate"] ) py_test( diff --git a/rllib/evaluate.py b/rllib/evaluate.py new file mode 100755 index 000000000..a2e94865b --- /dev/null +++ b/rllib/evaluate.py @@ -0,0 +1,545 @@ +#!/usr/bin/env python + +import argparse +import collections +import copy +import gym +from gym import wrappers as gym_wrappers +import json +import os +from pathlib import Path +import shelve + +import ray +import ray.cloudpickle as cloudpickle +from ray.rllib.agents.registry import get_trainer_class +from ray.rllib.env import MultiAgentEnv +from ray.rllib.env.base_env import _DUMMY_AGENT_ID +from ray.rllib.env.env_context import EnvContext +from ray.rllib.evaluation.worker_set import WorkerSet +from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID +from ray.rllib.utils.deprecation import deprecation_warning +from ray.rllib.utils.spaces.space_utils import flatten_to_single_ndarray +from ray.tune.utils import merge_dicts +from ray.tune.registry import get_trainable_cls, _global_registry, ENV_CREATOR + +EXAMPLE_USAGE = """ +Example usage via RLlib CLI: + rllib evaluate /tmp/ray/checkpoint_dir/checkpoint-0 --run DQN + --env CartPole-v0 --steps 1000000 --out rollouts.pkl + +Example usage via executable: + ./evaluate.py /tmp/ray/checkpoint_dir/checkpoint-0 --run DQN + --env CartPole-v0 --steps 1000000 --out rollouts.pkl + +Example usage w/o checkpoint (for testing purposes): + ./evaluate.py --run PPO --env CartPole-v0 --episodes 500 +""" + +# Note: if you use any custom models or envs, register them here first, e.g.: +# +# from ray.rllib.examples.env.parametric_actions_cartpole import \ +# ParametricActionsCartPole +# from ray.rllib.examples.model.parametric_actions_model import \ +# ParametricActionsModel +# ModelCatalog.register_custom_model("pa_model", ParametricActionsModel) +# register_env("pa_cartpole", lambda _: ParametricActionsCartPole(10)) + + +def create_parser(parser_creator=None): + parser_creator = parser_creator or argparse.ArgumentParser + parser = parser_creator( + formatter_class=argparse.RawDescriptionHelpFormatter, + description="Roll out a reinforcement learning agent " + "given a checkpoint.", + epilog=EXAMPLE_USAGE) + + parser.add_argument( + "checkpoint", + type=str, + nargs="?", + help="(Optional) checkpoint from which to roll out. " + "If none given, will use an initial (untrained) Trainer.") + + required_named = parser.add_argument_group("required named arguments") + required_named.add_argument( + "--run", + type=str, + required=True, + help="The algorithm or model to train. This may refer to the name " + "of a built-on algorithm (e.g. RLLib's `DQN` or `PPO`), or a " + "user-defined trainable function or class registered in the " + "tune registry.") + required_named.add_argument( + "--env", + type=str, + help="The environment specifier to use. This could be an openAI gym " + "specifier (e.g. `CartPole-v0`) or a full class-path (e.g. " + "`ray.rllib.examples.env.simple_corridor.SimpleCorridor`).") + parser.add_argument( + "--local-mode", + action="store_true", + help="Run ray in local mode for easier debugging.") + parser.add_argument( + "--render", + action="store_true", + help="Render the environment while evaluating.") + # Deprecated: Use --render, instead. + parser.add_argument( + "--no-render", + default=False, + action="store_const", + const=True, + help="Deprecated! Rendering is off by default now. " + "Use `--render` to enable.") + parser.add_argument( + "--video-dir", + type=str, + default=None, + help="Specifies the directory into which videos of all episode " + "rollouts will be stored.") + parser.add_argument( + "--steps", + default=10000, + help="Number of timesteps to roll out. Rollout will also stop if " + "`--episodes` limit is reached first. A value of 0 means no " + "limitation on the number of timesteps run.") + parser.add_argument( + "--episodes", + default=0, + help="Number of complete episodes to roll out. Rollout will also stop " + "if `--steps` (timesteps) limit is reached first. A value of 0 means " + "no limitation on the number of episodes run.") + parser.add_argument("--out", default=None, help="Output filename.") + parser.add_argument( + "--config", + default="{}", + type=json.loads, + help="Algorithm-specific configuration (e.g. env, hyperparams). " + "Gets merged with loaded configuration from checkpoint file and " + "`evaluation_config` settings therein.") + parser.add_argument( + "--save-info", + default=False, + action="store_true", + help="Save the info field generated by the step() method, " + "as well as the action, observations, rewards and done fields.") + parser.add_argument( + "--use-shelve", + default=False, + action="store_true", + help="Save rollouts into a python shelf file (will save each episode " + "as it is generated). An output filename must be set using --out.") + parser.add_argument( + "--track-progress", + default=False, + action="store_true", + help="Write progress to a temporary file (updated " + "after each episode). An output filename must be set using --out; " + "the progress file will live in the same folder.") + return parser + + +class RolloutSaver: + """Utility class for storing rollouts. + + Currently supports two behaviours: the original, which + simply dumps everything to a pickle file once complete, + and a mode which stores each rollout as an entry in a Python + shelf db file. The latter mode is more robust to memory problems + or crashes part-way through the rollout generation. Each rollout + is stored with a key based on the episode number (0-indexed), + and the number of episodes is stored with the key "num_episodes", + so to load the shelf file, use something like: + + with shelve.open('rollouts.pkl') as rollouts: + for episode_index in range(rollouts["num_episodes"]): + rollout = rollouts[str(episode_index)] + + If outfile is None, this class does nothing. + """ + + def __init__(self, + outfile=None, + use_shelve=False, + write_update_file=False, + target_steps=None, + target_episodes=None, + save_info=False): + self._outfile = outfile + self._update_file = None + self._use_shelve = use_shelve + self._write_update_file = write_update_file + self._shelf = None + self._num_episodes = 0 + self._rollouts = [] + self._current_rollout = [] + self._total_steps = 0 + self._target_episodes = target_episodes + self._target_steps = target_steps + self._save_info = save_info + + def _get_tmp_progress_filename(self): + outpath = Path(self._outfile) + return outpath.parent / ("__progress_" + outpath.name) + + @property + def outfile(self): + return self._outfile + + def __enter__(self): + if self._outfile: + if self._use_shelve: + # Open a shelf file to store each rollout as they come in + self._shelf = shelve.open(self._outfile) + else: + # Original behaviour - keep all rollouts in memory and save + # them all at the end. + # But check we can actually write to the outfile before going + # through the effort of generating the rollouts: + try: + with open(self._outfile, "wb") as _: + pass + except IOError as x: + print("Can not open {} for writing - cancelling rollouts.". + format(self._outfile)) + raise x + if self._write_update_file: + # Open a file to track rollout progress: + self._update_file = self._get_tmp_progress_filename().open( + mode="w") + return self + + def __exit__(self, type, value, traceback): + if self._shelf: + # Close the shelf file, and store the number of episodes for ease + self._shelf["num_episodes"] = self._num_episodes + self._shelf.close() + elif self._outfile and not self._use_shelve: + # Dump everything as one big pickle: + cloudpickle.dump(self._rollouts, open(self._outfile, "wb")) + if self._update_file: + # Remove the temp progress file: + self._get_tmp_progress_filename().unlink() + self._update_file = None + + def _get_progress(self): + if self._target_episodes: + return "{} / {} episodes completed".format(self._num_episodes, + self._target_episodes) + elif self._target_steps: + return "{} / {} steps completed".format(self._total_steps, + self._target_steps) + else: + return "{} episodes completed".format(self._num_episodes) + + def begin_rollout(self): + self._current_rollout = [] + + def end_rollout(self): + if self._outfile: + if self._use_shelve: + # Save this episode as a new entry in the shelf database, + # using the episode number as the key. + self._shelf[str(self._num_episodes)] = self._current_rollout + else: + # Append this rollout to our list, to save laer. + self._rollouts.append(self._current_rollout) + self._num_episodes += 1 + if self._update_file: + self._update_file.seek(0) + self._update_file.write(self._get_progress() + "\n") + self._update_file.flush() + + def append_step(self, obs, action, next_obs, reward, done, info): + """Add a step to the current rollout, if we are saving them""" + if self._outfile: + if self._save_info: + self._current_rollout.append( + [obs, action, next_obs, reward, done, info]) + else: + self._current_rollout.append( + [obs, action, next_obs, reward, done]) + self._total_steps += 1 + + +def run(args, parser): + # Load configuration from checkpoint file. + config_path = "" + if args.checkpoint: + config_dir = os.path.dirname(args.checkpoint) + config_path = os.path.join(config_dir, "params.pkl") + # Try parent directory. + if not os.path.exists(config_path): + config_path = os.path.join(config_dir, "../params.pkl") + + # Load the config from pickled. + if os.path.exists(config_path): + with open(config_path, "rb") as f: + config = cloudpickle.load(f) + # If no pkl file found, require command line `--config`. + else: + # If no config in given checkpoint -> Error. + if args.checkpoint: + raise ValueError( + "Could not find params.pkl in either the checkpoint dir or " + "its parent directory AND no `--config` given on command " + "line!") + + # Use default config for given agent. + _, config = get_trainer_class(args.run, return_config=True) + + # Make sure worker 0 has an Env. + config["create_env_on_driver"] = True + + # Merge with `evaluation_config` (first try from command line, then from + # pkl file). + evaluation_config = copy.deepcopy( + args.config.get("evaluation_config", config.get( + "evaluation_config", {}))) + config = merge_dicts(config, evaluation_config) + # Merge with command line `--config` settings (if not already the same + # anyways). + config = merge_dicts(config, args.config) + if not args.env: + if not config.get("env"): + parser.error("the following arguments are required: --env") + args.env = config.get("env") + + # Make sure we have evaluation workers. + if not config.get("evaluation_num_workers"): + config["evaluation_num_workers"] = config.get("num_workers", 0) + if not config.get("evaluation_num_episodes"): + config["evaluation_num_episodes"] = 1 + # Hard-override this as it raises a warning by Trainer otherwise. + # Makes no sense anyways, to have it set to None as we don't call + # `Trainer.train()` here. + config["evaluation_interval"] = 1 + + # Rendering and video recording settings. + if args.no_render: + deprecation_warning(old="--no-render", new="--render", error=False) + args.render = False + config["render_env"] = args.render + config["record_env"] = args.video_dir + + ray.init(local_mode=args.local_mode) + + # Create the Trainer from config. + cls = get_trainable_cls(args.run) + agent = cls(env=args.env, config=config) + + # Load state from checkpoint, if provided. + if args.checkpoint: + agent.restore(args.checkpoint) + + num_steps = int(args.steps) + num_episodes = int(args.episodes) + + # Determine the video output directory. + video_dir = None + # Allow user to specify a video output path. + if args.video_dir: + video_dir = os.path.expanduser(args.video_dir) + + # Do the actual rollout. + with RolloutSaver( + args.out, + args.use_shelve, + write_update_file=args.track_progress, + target_steps=num_steps, + target_episodes=num_episodes, + save_info=args.save_info) as saver: + rollout(agent, args.env, num_steps, num_episodes, saver, + args.no_render, video_dir) + agent.stop() + + +class DefaultMapping(collections.defaultdict): + """default_factory now takes as an argument the missing key.""" + + def __missing__(self, key): + self[key] = value = self.default_factory(key) + return value + + +def default_policy_agent_mapping(unused_agent_id): + return DEFAULT_POLICY_ID + + +def keep_going(steps, num_steps, episodes, num_episodes): + """Determine whether we've collected enough data""" + # If num_episodes is set, stop if limit reached. + if num_episodes and episodes >= num_episodes: + return False + # If num_steps is set, stop if limit reached. + elif num_steps and steps >= num_steps: + return False + # Otherwise, keep going. + return True + + +def rollout(agent, + env_name, + num_steps, + num_episodes=0, + saver=None, + no_render=True, + video_dir=None): + policy_agent_mapping = default_policy_agent_mapping + + if saver is None: + saver = RolloutSaver() + + # Normal case: Agent was setup correctly with an evaluation WorkerSet, + # which we will now use to rollout. + if hasattr(agent, "evaluation_workers") and isinstance( + agent.evaluation_workers, WorkerSet): + steps = 0 + episodes = 0 + while keep_going(steps, num_steps, episodes, num_episodes): + saver.begin_rollout() + eval_result = agent.evaluate()["evaluation"] + # Increase timestep and episode counters. + eps = agent.config["evaluation_num_episodes"] + episodes += eps + steps += eps * eval_result["episode_len_mean"] + # Print out results and continue. + print("Episode #{}: reward: {}".format( + episodes, eval_result["episode_reward_mean"])) + saver.end_rollout() + return + + # Agent has no evaluation workers, but RolloutWorkers. + elif hasattr(agent, "workers") and isinstance(agent.workers, WorkerSet): + env = agent.workers.local_worker().env + multiagent = isinstance(env, MultiAgentEnv) + if agent.workers.local_worker().multiagent: + policy_agent_mapping = agent.config["multiagent"][ + "policy_mapping_fn"] + policy_map = agent.workers.local_worker().policy_map + state_init = {p: m.get_initial_state() for p, m in policy_map.items()} + use_lstm = {p: len(s) > 0 for p, s in state_init.items()} + + # Agent has neither evaluation- nor rollout workers. + else: + from gym import envs + if envs.registry.env_specs.get(agent.config["env"]): + # if environment is gym environment, load from gym + env = gym.make(agent.config["env"]) + else: + # if environment registered ray environment, load from ray + env_creator = _global_registry.get(ENV_CREATOR, + agent.config["env"]) + env_context = EnvContext( + agent.config["env_config"] or {}, worker_index=0) + env = env_creator(env_context) + multiagent = False + try: + policy_map = {DEFAULT_POLICY_ID: agent.policy} + except AttributeError: + raise AttributeError( + "Agent ({}) does not have a `policy` property! This is needed " + "for performing (trained) agent rollouts.".format(agent)) + use_lstm = {DEFAULT_POLICY_ID: False} + + action_init = { + p: flatten_to_single_ndarray(m.action_space.sample()) + for p, m in policy_map.items() + } + + # If monitoring has been requested, manually wrap our environment with a + # gym monitor, which is set to record every episode. + if video_dir: + env = gym_wrappers.Monitor( + env=env, + directory=video_dir, + video_callable=lambda _: True, + force=True) + + steps = 0 + episodes = 0 + while keep_going(steps, num_steps, episodes, num_episodes): + mapping_cache = {} # in case policy_agent_mapping is stochastic + saver.begin_rollout() + obs = env.reset() + agent_states = DefaultMapping( + lambda agent_id: state_init[mapping_cache[agent_id]]) + prev_actions = DefaultMapping( + lambda agent_id: action_init[mapping_cache[agent_id]]) + prev_rewards = collections.defaultdict(lambda: 0.) + done = False + reward_total = 0.0 + while not done and keep_going(steps, num_steps, episodes, + num_episodes): + multi_obs = obs if multiagent else {_DUMMY_AGENT_ID: obs} + action_dict = {} + for agent_id, a_obs in multi_obs.items(): + if a_obs is not None: + policy_id = mapping_cache.setdefault( + agent_id, policy_agent_mapping(agent_id)) + p_use_lstm = use_lstm[policy_id] + if p_use_lstm: + a_action, p_state, _ = agent.compute_single_action( + a_obs, + state=agent_states[agent_id], + prev_action=prev_actions[agent_id], + prev_reward=prev_rewards[agent_id], + policy_id=policy_id) + agent_states[agent_id] = p_state + else: + a_action = agent.compute_single_action( + a_obs, + prev_action=prev_actions[agent_id], + prev_reward=prev_rewards[agent_id], + policy_id=policy_id) + a_action = flatten_to_single_ndarray(a_action) + action_dict[agent_id] = a_action + prev_actions[agent_id] = a_action + action = action_dict + + action = action if multiagent else action[_DUMMY_AGENT_ID] + next_obs, reward, done, info = env.step(action) + if multiagent: + for agent_id, r in reward.items(): + prev_rewards[agent_id] = r + else: + prev_rewards[_DUMMY_AGENT_ID] = reward + + if multiagent: + done = done["__all__"] + reward_total += sum( + r for r in reward.values() if r is not None) + else: + reward_total += reward + if not no_render: + env.render() + saver.append_step(obs, action, next_obs, reward, done, info) + steps += 1 + obs = next_obs + saver.end_rollout() + print("Episode #{}: reward: {}".format(episodes, reward_total)) + if done: + episodes += 1 + + +def main(): + parser = create_parser() + args = parser.parse_args() + + # --use_shelve w/o --out option. + if args.use_shelve and not args.out: + raise ValueError( + "If you set --use-shelve, you must provide an output file via " + "--out as well!") + # --track-progress w/o --out option. + if args.track_progress and not args.out: + raise ValueError( + "If you set --track-progress, you must provide an output file via " + "--out as well!") + + run(args, parser) + + +if __name__ == "__main__": + main() diff --git a/rllib/rollout.py b/rllib/rollout.py index ba2ee5fde..8b48e6808 100755 --- a/rllib/rollout.py +++ b/rllib/rollout.py @@ -1,529 +1,15 @@ #!/usr/bin/env python -import argparse -import collections -import copy -import gym -from gym import wrappers as gym_wrappers -import json -import os -from pathlib import Path -import shelve +from ray.rllib import evaluate +from ray.rllib.evaluate import rollout, RolloutSaver, run +from ray.rllib.utils.deprecation import deprecation_warning -import ray -import ray.cloudpickle as cloudpickle -from ray.rllib.agents.registry import get_trainer_class -from ray.rllib.env import MultiAgentEnv -from ray.rllib.env.base_env import _DUMMY_AGENT_ID -from ray.rllib.env.env_context import EnvContext -from ray.rllib.evaluation.worker_set import WorkerSet -from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID -from ray.rllib.utils.spaces.space_utils import flatten_to_single_ndarray -from ray.tune.utils import merge_dicts -from ray.tune.registry import get_trainable_cls, _global_registry, ENV_CREATOR - -EXAMPLE_USAGE = """ -Example usage via RLlib CLI: - rllib rollout /tmp/ray/checkpoint_dir/checkpoint-0 --run DQN - --env CartPole-v0 --steps 1000000 --out rollouts.pkl - -Example usage via executable: - ./rollout.py /tmp/ray/checkpoint_dir/checkpoint-0 --run DQN - --env CartPole-v0 --steps 1000000 --out rollouts.pkl - -Example usage w/o checkpoint (for testing purposes): - ./rollout.py --run PPO --env CartPole-v0 --episodes 500 -""" - -# Note: if you use any custom models or envs, register them here first, e.g.: -# -# from ray.rllib.examples.env.parametric_actions_cartpole import \ -# ParametricActionsCartPole -# from ray.rllib.examples.model.parametric_actions_model import \ -# ParametricActionsModel -# ModelCatalog.register_custom_model("pa_model", ParametricActionsModel) -# register_env("pa_cartpole", lambda _: ParametricActionsCartPole(10)) - - -def create_parser(parser_creator=None): - parser_creator = parser_creator or argparse.ArgumentParser - parser = parser_creator( - formatter_class=argparse.RawDescriptionHelpFormatter, - description="Roll out a reinforcement learning agent " - "given a checkpoint.", - epilog=EXAMPLE_USAGE) - - parser.add_argument( - "checkpoint", - type=str, - nargs="?", - help="(Optional) checkpoint from which to roll out. " - "If none given, will use an initial (untrained) Trainer.") - - required_named = parser.add_argument_group("required named arguments") - required_named.add_argument( - "--run", - type=str, - required=True, - help="The algorithm or model to train. This may refer to the name " - "of a built-on algorithm (e.g. RLLib's `DQN` or `PPO`), or a " - "user-defined trainable function or class registered in the " - "tune registry.") - required_named.add_argument( - "--env", - type=str, - help="The environment specifier to use. This could be an openAI gym " - "specifier (e.g. `CartPole-v0`) or a full class-path (e.g. " - "`ray.rllib.examples.env.simple_corridor.SimpleCorridor`).") - parser.add_argument( - "--local-mode", - action="store_true", - help="Run ray in local mode for easier debugging.") - parser.add_argument( - "--no-render", - default=False, - action="store_const", - const=True, - help="Suppress rendering of the environment.") - parser.add_argument( - "--video-dir", - type=str, - default=None, - help="Specifies the directory into which videos of all episode " - "rollouts will be stored.") - parser.add_argument( - "--steps", - default=10000, - help="Number of timesteps to roll out. Rollout will also stop if " - "`--episodes` limit is reached first. A value of 0 means no " - "limitation on the number of timesteps run.") - parser.add_argument( - "--episodes", - default=0, - help="Number of complete episodes to roll out. Rollout will also stop " - "if `--steps` (timesteps) limit is reached first. A value of 0 means " - "no limitation on the number of episodes run.") - parser.add_argument("--out", default=None, help="Output filename.") - parser.add_argument( - "--config", - default="{}", - type=json.loads, - help="Algorithm-specific configuration (e.g. env, hyperparams). " - "Gets merged with loaded configuration from checkpoint file and " - "`evaluation_config` settings therein.") - parser.add_argument( - "--save-info", - default=False, - action="store_true", - help="Save the info field generated by the step() method, " - "as well as the action, observations, rewards and done fields.") - parser.add_argument( - "--use-shelve", - default=False, - action="store_true", - help="Save rollouts into a python shelf file (will save each episode " - "as it is generated). An output filename must be set using --out.") - parser.add_argument( - "--track-progress", - default=False, - action="store_true", - help="Write progress to a temporary file (updated " - "after each episode). An output filename must be set using --out; " - "the progress file will live in the same folder.") - return parser - - -class RolloutSaver: - """Utility class for storing rollouts. - - Currently supports two behaviours: the original, which - simply dumps everything to a pickle file once complete, - and a mode which stores each rollout as an entry in a Python - shelf db file. The latter mode is more robust to memory problems - or crashes part-way through the rollout generation. Each rollout - is stored with a key based on the episode number (0-indexed), - and the number of episodes is stored with the key "num_episodes", - so to load the shelf file, use something like: - - with shelve.open('rollouts.pkl') as rollouts: - for episode_index in range(rollouts["num_episodes"]): - rollout = rollouts[str(episode_index)] - - If outfile is None, this class does nothing. - """ - - def __init__(self, - outfile=None, - use_shelve=False, - write_update_file=False, - target_steps=None, - target_episodes=None, - save_info=False): - self._outfile = outfile - self._update_file = None - self._use_shelve = use_shelve - self._write_update_file = write_update_file - self._shelf = None - self._num_episodes = 0 - self._rollouts = [] - self._current_rollout = [] - self._total_steps = 0 - self._target_episodes = target_episodes - self._target_steps = target_steps - self._save_info = save_info - - def _get_tmp_progress_filename(self): - outpath = Path(self._outfile) - return outpath.parent / ("__progress_" + outpath.name) - - @property - def outfile(self): - return self._outfile - - def __enter__(self): - if self._outfile: - if self._use_shelve: - # Open a shelf file to store each rollout as they come in - self._shelf = shelve.open(self._outfile) - else: - # Original behaviour - keep all rollouts in memory and save - # them all at the end. - # But check we can actually write to the outfile before going - # through the effort of generating the rollouts: - try: - with open(self._outfile, "wb") as _: - pass - except IOError as x: - print("Can not open {} for writing - cancelling rollouts.". - format(self._outfile)) - raise x - if self._write_update_file: - # Open a file to track rollout progress: - self._update_file = self._get_tmp_progress_filename().open( - mode="w") - return self - - def __exit__(self, type, value, traceback): - if self._shelf: - # Close the shelf file, and store the number of episodes for ease - self._shelf["num_episodes"] = self._num_episodes - self._shelf.close() - elif self._outfile and not self._use_shelve: - # Dump everything as one big pickle: - cloudpickle.dump(self._rollouts, open(self._outfile, "wb")) - if self._update_file: - # Remove the temp progress file: - self._get_tmp_progress_filename().unlink() - self._update_file = None - - def _get_progress(self): - if self._target_episodes: - return "{} / {} episodes completed".format(self._num_episodes, - self._target_episodes) - elif self._target_steps: - return "{} / {} steps completed".format(self._total_steps, - self._target_steps) - else: - return "{} episodes completed".format(self._num_episodes) - - def begin_rollout(self): - self._current_rollout = [] - - def end_rollout(self): - if self._outfile: - if self._use_shelve: - # Save this episode as a new entry in the shelf database, - # using the episode number as the key. - self._shelf[str(self._num_episodes)] = self._current_rollout - else: - # Append this rollout to our list, to save laer. - self._rollouts.append(self._current_rollout) - self._num_episodes += 1 - if self._update_file: - self._update_file.seek(0) - self._update_file.write(self._get_progress() + "\n") - self._update_file.flush() - - def append_step(self, obs, action, next_obs, reward, done, info): - """Add a step to the current rollout, if we are saving them""" - if self._outfile: - if self._save_info: - self._current_rollout.append( - [obs, action, next_obs, reward, done, info]) - else: - self._current_rollout.append( - [obs, action, next_obs, reward, done]) - self._total_steps += 1 - - -def run(args, parser): - # Load configuration from checkpoint file. - config_path = "" - if args.checkpoint: - config_dir = os.path.dirname(args.checkpoint) - config_path = os.path.join(config_dir, "params.pkl") - # Try parent directory. - if not os.path.exists(config_path): - config_path = os.path.join(config_dir, "../params.pkl") - - # Load the config from pickled. - if os.path.exists(config_path): - with open(config_path, "rb") as f: - config = cloudpickle.load(f) - # If no pkl file found, require command line `--config`. - else: - # If no config in given checkpoint -> Error. - if args.checkpoint: - raise ValueError( - "Could not find params.pkl in either the checkpoint dir or " - "its parent directory AND no `--config` given on command " - "line!") - - # Use default config for given agent. - _, config = get_trainer_class(args.run, return_config=True) - - # Make sure worker 0 has an Env. - config["create_env_on_driver"] = True - - # Merge with `evaluation_config` (first try from command line, then from - # pkl file). - evaluation_config = copy.deepcopy( - args.config.get("evaluation_config", config.get( - "evaluation_config", {}))) - config = merge_dicts(config, evaluation_config) - # Merge with command line `--config` settings (if not already the same - # anyways). - config = merge_dicts(config, args.config) - if not args.env: - if not config.get("env"): - parser.error("the following arguments are required: --env") - args.env = config.get("env") - - # Make sure we have evaluation workers. - if not config.get("evaluation_num_workers"): - config["evaluation_num_workers"] = config.get("num_workers", 0) - if not config.get("evaluation_num_episodes"): - config["evaluation_num_episodes"] = 1 - config["render_env"] = not args.no_render - config["record_env"] = args.video_dir - - ray.init(local_mode=args.local_mode) - - # Create the Trainer from config. - cls = get_trainable_cls(args.run) - agent = cls(env=args.env, config=config) - - # Load state from checkpoint, if provided. - if args.checkpoint: - agent.restore(args.checkpoint) - - num_steps = int(args.steps) - num_episodes = int(args.episodes) - - # Determine the video output directory. - video_dir = None - # Allow user to specify a video output path. - if args.video_dir: - video_dir = os.path.expanduser(args.video_dir) - - # Do the actual rollout. - with RolloutSaver( - args.out, - args.use_shelve, - write_update_file=args.track_progress, - target_steps=num_steps, - target_episodes=num_episodes, - save_info=args.save_info) as saver: - rollout(agent, args.env, num_steps, num_episodes, saver, - args.no_render, video_dir) - agent.stop() - - -class DefaultMapping(collections.defaultdict): - """default_factory now takes as an argument the missing key.""" - - def __missing__(self, key): - self[key] = value = self.default_factory(key) - return value - - -def default_policy_agent_mapping(unused_agent_id): - return DEFAULT_POLICY_ID - - -def keep_going(steps, num_steps, episodes, num_episodes): - """Determine whether we've collected enough data""" - # If num_episodes is set, stop if limit reached. - if num_episodes and episodes >= num_episodes: - return False - # If num_steps is set, stop if limit reached. - elif num_steps and steps >= num_steps: - return False - # Otherwise, keep going. - return True - - -def rollout(agent, - env_name, - num_steps, - num_episodes=0, - saver=None, - no_render=True, - video_dir=None): - policy_agent_mapping = default_policy_agent_mapping - - if saver is None: - saver = RolloutSaver() - - # Normal case: Agent was setup correctly with an evaluation WorkerSet, - # which we will now use to rollout. - if hasattr(agent, "evaluation_workers") and isinstance( - agent.evaluation_workers, WorkerSet): - steps = 0 - episodes = 0 - while keep_going(steps, num_steps, episodes, num_episodes): - saver.begin_rollout() - eval_result = agent.evaluate()["evaluation"] - # Increase timestep and episode counters. - eps = agent.config["evaluation_num_episodes"] - episodes += eps - steps += eps * eval_result["episode_len_mean"] - # Print out results and continue. - print("Episode #{}: reward: {}".format( - episodes, eval_result["episode_reward_mean"])) - saver.end_rollout() - return - - # Agent has no evaluation workers, but RolloutWorkers. - elif hasattr(agent, "workers") and isinstance(agent.workers, WorkerSet): - env = agent.workers.local_worker().env - multiagent = isinstance(env, MultiAgentEnv) - if agent.workers.local_worker().multiagent: - policy_agent_mapping = agent.config["multiagent"][ - "policy_mapping_fn"] - policy_map = agent.workers.local_worker().policy_map - state_init = {p: m.get_initial_state() for p, m in policy_map.items()} - use_lstm = {p: len(s) > 0 for p, s in state_init.items()} - - # Agent has neither evaluation- nor rollout workers. - else: - from gym import envs - if envs.registry.env_specs.get(agent.config["env"]): - # if environment is gym environment, load from gym - env = gym.make(agent.config["env"]) - else: - # if environment registered ray environment, load from ray - env_creator = _global_registry.get(ENV_CREATOR, - agent.config["env"]) - env_context = EnvContext( - agent.config["env_config"] or {}, worker_index=0) - env = env_creator(env_context) - multiagent = False - try: - policy_map = {DEFAULT_POLICY_ID: agent.policy} - except AttributeError: - raise AttributeError( - "Agent ({}) does not have a `policy` property! This is needed " - "for performing (trained) agent rollouts.".format(agent)) - use_lstm = {DEFAULT_POLICY_ID: False} - - action_init = { - p: flatten_to_single_ndarray(m.action_space.sample()) - for p, m in policy_map.items() - } - - # If monitoring has been requested, manually wrap our environment with a - # gym monitor, which is set to record every episode. - if video_dir: - env = gym_wrappers.Monitor( - env=env, - directory=video_dir, - video_callable=lambda _: True, - force=True) - - steps = 0 - episodes = 0 - while keep_going(steps, num_steps, episodes, num_episodes): - mapping_cache = {} # in case policy_agent_mapping is stochastic - saver.begin_rollout() - obs = env.reset() - agent_states = DefaultMapping( - lambda agent_id: state_init[mapping_cache[agent_id]]) - prev_actions = DefaultMapping( - lambda agent_id: action_init[mapping_cache[agent_id]]) - prev_rewards = collections.defaultdict(lambda: 0.) - done = False - reward_total = 0.0 - while not done and keep_going(steps, num_steps, episodes, - num_episodes): - multi_obs = obs if multiagent else {_DUMMY_AGENT_ID: obs} - action_dict = {} - for agent_id, a_obs in multi_obs.items(): - if a_obs is not None: - policy_id = mapping_cache.setdefault( - agent_id, policy_agent_mapping(agent_id)) - p_use_lstm = use_lstm[policy_id] - if p_use_lstm: - a_action, p_state, _ = agent.compute_single_action( - a_obs, - state=agent_states[agent_id], - prev_action=prev_actions[agent_id], - prev_reward=prev_rewards[agent_id], - policy_id=policy_id) - agent_states[agent_id] = p_state - else: - a_action = agent.compute_single_action( - a_obs, - prev_action=prev_actions[agent_id], - prev_reward=prev_rewards[agent_id], - policy_id=policy_id) - a_action = flatten_to_single_ndarray(a_action) - action_dict[agent_id] = a_action - prev_actions[agent_id] = a_action - action = action_dict - - action = action if multiagent else action[_DUMMY_AGENT_ID] - next_obs, reward, done, info = env.step(action) - if multiagent: - for agent_id, r in reward.items(): - prev_rewards[agent_id] = r - else: - prev_rewards[_DUMMY_AGENT_ID] = reward - - if multiagent: - done = done["__all__"] - reward_total += sum( - r for r in reward.values() if r is not None) - else: - reward_total += reward - if not no_render: - env.render() - saver.append_step(obs, action, next_obs, reward, done, info) - steps += 1 - obs = next_obs - saver.end_rollout() - print("Episode #{}: reward: {}".format(episodes, reward_total)) - if done: - episodes += 1 - - -def main(): - parser = create_parser() - args = parser.parse_args() - - # --use_shelve w/o --out option. - if args.use_shelve and not args.out: - raise ValueError( - "If you set --use-shelve, you must provide an output file via " - "--out as well!") - # --track-progress w/o --out option. - if args.track_progress and not args.out: - raise ValueError( - "If you set --track-progress, you must provide an output file via " - "--out as well!") - - run(args, parser) +deprecation_warning(old="rllib rollout", new="rllib evaluate", error=False) +# For backward compatibility +rollout = rollout +RolloutSaver = RolloutSaver +run = run if __name__ == "__main__": - main() + evaluate.main() diff --git a/rllib/scripts.py b/rllib/scripts.py index 5f7b5811e..80314ae16 100644 --- a/rllib/scripts.py +++ b/rllib/scripts.py @@ -2,37 +2,43 @@ import argparse -from ray.rllib import train -from ray.rllib import rollout +from ray.rllib import evaluate, train +from ray.rllib.utils.deprecation import deprecation_warning EXAMPLE_USAGE = """ Example usage for training: rllib train --run DQN --env CartPole-v0 -Example usage for rollout: - rllib rollout /trial_dir/checkpoint_000001/checkpoint-1 --run DQN +Example usage for evaluate (aka: "rollout"): + rllib evaluate /trial_dir/checkpoint_000001/checkpoint-1 --run DQN """ def cli(): parser = argparse.ArgumentParser( - description="Train or Run an RLlib Trainer.", + description="Train or evaluate an RLlib Trainer.", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=EXAMPLE_USAGE) subcommand_group = parser.add_subparsers( - help="Commands to train or run an RLlib agent.", dest="command") + help="Commands to train or evaluate an RLlib agent.", dest="command") # see _SubParsersAction.add_parser in # https://github.com/python/cpython/blob/master/Lib/argparse.py train_parser = train.create_parser( lambda **kwargs: subcommand_group.add_parser("train", **kwargs)) - rollout_parser = rollout.create_parser( + evaluate_parser = evaluate.create_parser( + lambda **kwargs: subcommand_group.add_parser("evaluate", **kwargs)) + rollout_parser = evaluate.create_parser( lambda **kwargs: subcommand_group.add_parser("rollout", **kwargs)) options = parser.parse_args() if options.command == "train": train.run(options, train_parser) + elif options.command == "evaluate": + evaluate.run(options, evaluate_parser) elif options.command == "rollout": - rollout.run(options, rollout_parser) + deprecation_warning( + old="rllib rollout", new="rllib evaluate", error=False) + evaluate.run(options, rollout_parser) else: parser.print_help() diff --git a/rllib/tests/test_nested_observation_spaces.py b/rllib/tests/test_nested_observation_spaces.py index abe03e885..cf27c6cc7 100644 --- a/rllib/tests/test_nested_observation_spaces.py +++ b/rllib/tests/test_nested_observation_spaces.py @@ -15,7 +15,7 @@ from ray.rllib.models import ModelCatalog from ray.rllib.models.tf.tf_modelv2 import TFModelV2 from ray.rllib.models.torch.fcnet import FullyConnectedNetwork from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 -from ray.rllib.rollout import rollout +from ray.rllib.evaluate import rollout from ray.rllib.tests.test_external_env import SimpleServing from ray.tune.registry import register_env from ray.rllib.utils.framework import try_import_tf, try_import_torch diff --git a/rllib/tests/test_rollout.py b/rllib/tests/test_rllib_train_and_evaluate.py similarity index 89% rename from rllib/tests/test_rollout.py rename to rllib/tests/test_rllib_train_and_evaluate.py index f0e7609bb..42eb11028 100644 --- a/rllib/tests/test_rollout.py +++ b/rllib/tests/test_rllib_train_and_evaluate.py @@ -10,7 +10,7 @@ from ray.rllib.examples.env.multi_agent import MultiAgentCartPole from ray.rllib.utils.test_utils import framework_iterator -def rollout_test(algo, env="CartPole-v0", test_episode_rollout=False): +def evaluate_test(algo, env="CartPole-v0", test_episode_rollout=False): extra_config = "" if algo == "ARS": extra_config = ",\"train_batch_size\": 10, \"noise_size\": 250000" @@ -46,27 +46,27 @@ def rollout_test(algo, env="CartPole-v0", test_episode_rollout=False): print("Checkpoint path {} (exists)".format(checkpoint_path)) # Test rolling out n steps. - os.popen("python {}/rollout.py --run={} \"{}\" --steps=10 " + os.popen("python {}/evaluate.py --run={} \"{}\" --steps=10 " "--out=\"{}/rollouts_10steps.pkl\" --no-render".format( rllib_dir, algo, checkpoint_path, tmp_dir)).read() if not os.path.exists(tmp_dir + "/rollouts_10steps.pkl"): sys.exit(1) - print("rollout output (10 steps) exists!") + print("evaluate output (10 steps) exists!") # Test rolling out 1 episode. if test_episode_rollout: - os.popen("python {}/rollout.py --run={} \"{}\" --episodes=1 " + os.popen("python {}/evaluate.py --run={} \"{}\" --episodes=1 " "--out=\"{}/rollouts_1episode.pkl\" --no-render".format( rllib_dir, algo, checkpoint_path, tmp_dir)).read() if not os.path.exists(tmp_dir + "/rollouts_1episode.pkl"): sys.exit(1) - print("rollout output (1 ep) exists!") + print("evaluate output (1 ep) exists!") # Cleanup. os.popen("rm -rf \"{}\"".format(tmp_dir)).read() -def learn_test_plus_rollout(algo, env="CartPole-v0"): +def learn_test_plus_evaluate(algo, env="CartPole-v0"): for fw in framework_iterator(frameworks=("tf", "torch")): fw_ = ", \\\"framework\\\": \\\"{}\\\"".format(fw) @@ -108,7 +108,7 @@ def learn_test_plus_rollout(algo, env="CartPole-v0"): # Test rolling out n steps. result = os.popen( - "python {}/rollout.py --run={} " + "python {}/evaluate.py --run={} " "--steps=400 " "--out=\"{}/rollouts_n_steps.pkl\" --no-render \"{}\"".format( rllib_dir, algo, tmp_dir, last_checkpoint)).read()[:-1] @@ -131,7 +131,7 @@ def learn_test_plus_rollout(algo, env="CartPole-v0"): os.popen("rm -rf \"{}\"".format(tmp_dir)).read() -def learn_test_multi_agent_plus_rollout(algo): +def learn_test_multi_agent_plus_evaluate(algo): for fw in framework_iterator(frameworks=("tf", "torch")): tmp_dir = os.popen("mktemp -d").read()[:-1] if not os.path.exists(tmp_dir): @@ -217,41 +217,41 @@ def learn_test_multi_agent_plus_rollout(algo): os.popen("rm -rf \"{}\"".format(tmp_dir)).read() -class TestRolloutSimple1(unittest.TestCase): +class TestEvaluate1(unittest.TestCase): def test_a3c(self): - rollout_test("A3C") + evaluate_test("A3C") def test_ddpg(self): - rollout_test("DDPG", env="Pendulum-v0") + evaluate_test("DDPG", env="Pendulum-v0") -class TestRolloutSimple2(unittest.TestCase): +class TestEvaluate2(unittest.TestCase): def test_dqn(self): - rollout_test("DQN") + evaluate_test("DQN") def test_es(self): - rollout_test("ES") + evaluate_test("ES") -class TestRolloutSimple3(unittest.TestCase): +class TestEvaluate3(unittest.TestCase): def test_impala(self): - rollout_test("IMPALA", env="CartPole-v0") + evaluate_test("IMPALA", env="CartPole-v0") def test_ppo(self): - rollout_test("PPO", env="CartPole-v0", test_episode_rollout=True) + evaluate_test("PPO", env="CartPole-v0", test_episode_rollout=True) -class TestRolloutSimple4(unittest.TestCase): +class TestEvaluate4(unittest.TestCase): def test_sac(self): - rollout_test("SAC", env="Pendulum-v0") + evaluate_test("SAC", env="Pendulum-v0") -class TestRolloutLearntPolicy(unittest.TestCase): +class TestTrainAndEvaluate(unittest.TestCase): def test_ppo_train_then_rollout(self): - learn_test_plus_rollout("PPO") + learn_test_plus_evaluate("PPO") def test_ppo_multi_agent_train_then_rollout(self): - learn_test_multi_agent_plus_rollout("PPO") + learn_test_multi_agent_plus_evaluate("PPO") if __name__ == "__main__":