From 715ee8dfc95ab44909f800e8b917e7e44082a6ec Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Thu, 3 Sep 2020 17:27:05 +0200 Subject: [PATCH] [RLlib] Issue 10469: Callbacks should receive env idx ... (#10477) --- rllib/agents/callbacks.py | 26 ++++++++++------ rllib/evaluation/sampler.py | 31 +++++++++++++++---- rllib/evaluation/worker_set.py | 4 +-- .../examples/custom_metrics_and_callbacks.py | 27 +++++++++------- rllib/utils/exploration/curiosity.py | 2 -- src/ray/core_worker/lib/java/jni_utils.h | 2 +- 6 files changed, 59 insertions(+), 33 deletions(-) diff --git a/rllib/agents/callbacks.py b/rllib/agents/callbacks.py index 53921a0b1..4621adef2 100644 --- a/rllib/agents/callbacks.py +++ b/rllib/agents/callbacks.py @@ -30,9 +30,9 @@ class DefaultCallbacks: "a class extending rllib.agents.callbacks.DefaultCallbacks") self.legacy_callbacks = legacy_callbacks_dict or {} - def on_episode_start(self, worker: "RolloutWorker", base_env: BaseEnv, + def on_episode_start(self, *, worker: "RolloutWorker", base_env: BaseEnv, policies: Dict[PolicyID, Policy], - episode: MultiAgentEpisode, **kwargs): + episode: MultiAgentEpisode, env_index: int, **kwargs): """Callback run on the rollout worker before each episode starts. Args: @@ -45,6 +45,8 @@ class DefaultCallbacks: state. You can use the `episode.user_data` dict to store temporary data, and `episode.custom_metrics` to store custom metrics for the episode. + env_index (int): The index of the (vectorized) env, which the + episode belongs to. kwargs: Forward compatibility placeholder. """ @@ -55,8 +57,8 @@ class DefaultCallbacks: "episode": episode, }) - def on_episode_step(self, worker: "RolloutWorker", base_env: BaseEnv, - episode: MultiAgentEpisode, **kwargs): + def on_episode_step(self, *, worker: "RolloutWorker", base_env: BaseEnv, + episode: MultiAgentEpisode, env_index: int, **kwargs): """Runs on each episode step. Args: @@ -67,6 +69,8 @@ class DefaultCallbacks: state. You can use the `episode.user_data` dict to store temporary data, and `episode.custom_metrics` to store custom metrics for the episode. + env_index (int): The index of the (vectorized) env, which the + episode belongs to. kwargs: Forward compatibility placeholder. """ @@ -76,9 +80,9 @@ class DefaultCallbacks: "episode": episode }) - def on_episode_end(self, worker: "RolloutWorker", base_env: BaseEnv, + def on_episode_end(self, *, worker: "RolloutWorker", base_env: BaseEnv, policies: Dict[PolicyID, Policy], - episode: MultiAgentEpisode, **kwargs): + episode: MultiAgentEpisode, env_index: int, **kwargs): """Runs when an episode is done. Args: @@ -91,6 +95,8 @@ class DefaultCallbacks: state. You can use the `episode.user_data` dict to store temporary data, and `episode.custom_metrics` to store custom metrics for the episode. + env_index (int): The index of the (vectorized) env, which the + episode belongs to. kwargs: Forward compatibility placeholder. """ @@ -102,7 +108,7 @@ class DefaultCallbacks: }) def on_postprocess_trajectory( - self, worker: "RolloutWorker", episode: MultiAgentEpisode, + self, *, worker: "RolloutWorker", episode: MultiAgentEpisode, agent_id: AgentID, policy_id: PolicyID, policies: Dict[PolicyID, Policy], postprocessed_batch: SampleBatch, original_batches: Dict[AgentID, SampleBatch], **kwargs): @@ -136,9 +142,9 @@ class DefaultCallbacks: "all_pre_batches": original_batches, }) - def on_sample_end(self, worker: "RolloutWorker", samples: SampleBatch, + def on_sample_end(self, *, worker: "RolloutWorker", samples: SampleBatch, **kwargs): - """Called at the end RolloutWorker.sample(). + """Called at the end of RolloutWorker.sample(). Args: worker (RolloutWorker): Reference to the current rollout worker. @@ -153,7 +159,7 @@ class DefaultCallbacks: "samples": samples, }) - def on_train_result(self, trainer, result: dict, **kwargs): + def on_train_result(self, *, trainer, result: dict, **kwargs): """Called at the end of Trainable.train(). Args: diff --git a/rllib/evaluation/sampler.py b/rllib/evaluation/sampler.py index eba66e450..892aecfc4 100644 --- a/rllib/evaluation/sampler.py +++ b/rllib/evaluation/sampler.py @@ -49,6 +49,15 @@ PolicyEvalData = namedtuple("PolicyEvalData", [ StateBatch = List[List[Any]] +class NewEpisodeDefaultDict(defaultdict): + def __missing__(self, env_index): + if self.default_factory is None: + raise KeyError(env_index) + else: + ret = self[env_index] = self.default_factory(env_index) + return ret + + class _PerfStats: """Sampler perf stats that will be included in rollout metrics.""" @@ -505,7 +514,7 @@ def _env_runner( return MultiAgentSampleBatchBuilder(policies, clip_rewards, callbacks) - def new_episode(): + def new_episode(env_index): episode = MultiAgentEpisode(policies, policy_mapping_fn, get_batch_builder, extra_batch_callback) # Call each policy's Exploration.on_episode_start method. @@ -521,10 +530,13 @@ def _env_runner( worker=worker, base_env=base_env, policies=policies, - episode=episode) + episode=episode, + env_index=env_index, + ) return episode - active_episodes: Dict[str, MultiAgentEpisode] = defaultdict(new_episode) + active_episodes: Dict[str, MultiAgentEpisode] = \ + NewEpisodeDefaultDict(new_episode) eval_results = None while True: @@ -830,7 +842,10 @@ def _process_observations( # Invoke the step callback after the step is logged to the episode callbacks.on_episode_step( - worker=worker, base_env=base_env, episode=episode) + worker=worker, + base_env=base_env, + episode=episode, + env_index=env_id) # Cut the batch if ... # - all-agents-done and not packing multiple episodes into one @@ -869,7 +884,9 @@ def _process_observations( worker=worker, base_env=base_env, policies=policies, - episode=episode) + episode=episode, + env_index=env_id, + ) if hit_horizon and soft_horizon: episode.soft_reset() resetted_obs: Dict[AgentID, EnvObsType] = agent_obs @@ -1120,7 +1137,9 @@ def _process_observations_w_trajectory_view_api( worker=worker, base_env=base_env, policies=policies, - episode=episode) + episode=episode, + env_index=env_id, + ) if hit_horizon and soft_horizon: episode.soft_reset() resetted_obs: Dict[AgentID, EnvObsType] = agent_obs diff --git a/rllib/evaluation/worker_set.py b/rllib/evaluation/worker_set.py index ffa10a2e8..0f0f4f057 100644 --- a/rllib/evaluation/worker_set.py +++ b/rllib/evaluation/worker_set.py @@ -64,11 +64,11 @@ class WorkerSet: trainer_config, {"tf_session_args": trainer_config["local_tf_session_args"]}) - # Create a number of remote workers + # Create a number of remote workers. self._remote_workers = [] self.add_workers(num_workers) - # Always create a local worker + # Always create a local worker. self._local_worker = self._make_worker(RolloutWorker, env_creator, self._policy_class, 0, self._local_config) diff --git a/rllib/examples/custom_metrics_and_callbacks.py b/rllib/examples/custom_metrics_and_callbacks.py index 18f55e915..651f329e2 100644 --- a/rllib/examples/custom_metrics_and_callbacks.py +++ b/rllib/examples/custom_metrics_and_callbacks.py @@ -18,41 +18,43 @@ from ray.rllib.agents.callbacks import DefaultCallbacks class MyCallbacks(DefaultCallbacks): - def on_episode_start(self, worker: RolloutWorker, base_env: BaseEnv, + def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[str, Policy], - episode: MultiAgentEpisode, **kwargs): - print("episode {} started".format(episode.episode_id)) + episode: MultiAgentEpisode, env_index: int, **kwargs): + print("episode {} (env-idx={}) started.".format( + episode.episode_id, env_index)) episode.user_data["pole_angles"] = [] episode.hist_data["pole_angles"] = [] - def on_episode_step(self, worker: RolloutWorker, base_env: BaseEnv, - episode: MultiAgentEpisode, **kwargs): + def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, + episode: MultiAgentEpisode, env_index: int, **kwargs): pole_angle = abs(episode.last_observation_for()[2]) raw_angle = abs(episode.last_raw_obs_for()[2]) assert pole_angle == raw_angle episode.user_data["pole_angles"].append(pole_angle) - def on_episode_end(self, worker: RolloutWorker, base_env: BaseEnv, + def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[str, Policy], episode: MultiAgentEpisode, - **kwargs): + env_index: int, **kwargs): pole_angle = np.mean(episode.user_data["pole_angles"]) - print("episode {} ended with length {} and pole angles {}".format( - episode.episode_id, episode.length, pole_angle)) + print("episode {} (env-idx={}) ended with length {} and pole " + "angles {}".format(episode.episode_id, env_index, episode.length, + pole_angle)) episode.custom_metrics["pole_angle"] = pole_angle episode.hist_data["pole_angles"] = episode.user_data["pole_angles"] - def on_sample_end(self, worker: RolloutWorker, samples: SampleBatch, + def on_sample_end(self, *, worker: RolloutWorker, samples: SampleBatch, **kwargs): print("returned sample batch of size {}".format(samples.count)) - def on_train_result(self, trainer, result: dict, **kwargs): + def on_train_result(self, *, trainer, result: dict, **kwargs): print("trainer.train() result: {} -> {} episodes".format( trainer, result["episodes_this_iter"])) # you can mutate the result dict to add new fields to return result["callback_ok"] = True def on_postprocess_trajectory( - self, worker: RolloutWorker, episode: MultiAgentEpisode, + self, *, worker: RolloutWorker, episode: MultiAgentEpisode, agent_id: str, policy_id: str, policies: Dict[str, Policy], postprocessed_batch: SampleBatch, original_batches: Dict[str, SampleBatch], **kwargs): @@ -75,6 +77,7 @@ if __name__ == "__main__": }, config={ "env": "CartPole-v0", + "num_envs_per_worker": 2, "callbacks": MyCallbacks, "framework": "tf", }, diff --git a/rllib/utils/exploration/curiosity.py b/rllib/utils/exploration/curiosity.py index e0acf29ea..1b18b1c91 100644 --- a/rllib/utils/exploration/curiosity.py +++ b/rllib/utils/exploration/curiosity.py @@ -197,8 +197,6 @@ class Curiosity(Exploration): }) phi, next_phi = phis[:batch_size], phis[batch_size:] - # Detach phi from graph (should not backpropagate through feature net - # for forward-loss). predicted_next_phi = self.model._curiosity_forward_fcnet( torch.cat( [ diff --git a/src/ray/core_worker/lib/java/jni_utils.h b/src/ray/core_worker/lib/java/jni_utils.h index 2f4b73b61..37b7b0544 100644 --- a/src/ray/core_worker/lib/java/jni_utils.h +++ b/src/ray/core_worker/lib/java/jni_utils.h @@ -528,4 +528,4 @@ inline std::string GetActorFullName(bool global, std::string name) { return global ? name : ::ray::CoreWorkerProcess::GetCoreWorker().GetCurrentJobId().Hex() + "-" + name; -} \ No newline at end of file +}