[RLlib] Issue 10469: Callbacks should receive env idx ... (#10477)

This commit is contained in:
Sven Mika 2020-09-03 17:27:05 +02:00 committed by GitHub
parent d8ac4bc719
commit 715ee8dfc9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 59 additions and 33 deletions

View file

@ -30,9 +30,9 @@ class DefaultCallbacks:
"a class extending rllib.agents.callbacks.DefaultCallbacks")
self.legacy_callbacks = legacy_callbacks_dict or {}
def on_episode_start(self, worker: "RolloutWorker", base_env: BaseEnv,
def on_episode_start(self, *, worker: "RolloutWorker", base_env: BaseEnv,
policies: Dict[PolicyID, Policy],
episode: MultiAgentEpisode, **kwargs):
episode: MultiAgentEpisode, env_index: int, **kwargs):
"""Callback run on the rollout worker before each episode starts.
Args:
@ -45,6 +45,8 @@ class DefaultCallbacks:
state. You can use the `episode.user_data` dict to store
temporary data, and `episode.custom_metrics` to store custom
metrics for the episode.
env_index (int): The index of the (vectorized) env, which the
episode belongs to.
kwargs: Forward compatibility placeholder.
"""
@ -55,8 +57,8 @@ class DefaultCallbacks:
"episode": episode,
})
def on_episode_step(self, worker: "RolloutWorker", base_env: BaseEnv,
episode: MultiAgentEpisode, **kwargs):
def on_episode_step(self, *, worker: "RolloutWorker", base_env: BaseEnv,
episode: MultiAgentEpisode, env_index: int, **kwargs):
"""Runs on each episode step.
Args:
@ -67,6 +69,8 @@ class DefaultCallbacks:
state. You can use the `episode.user_data` dict to store
temporary data, and `episode.custom_metrics` to store custom
metrics for the episode.
env_index (int): The index of the (vectorized) env, which the
episode belongs to.
kwargs: Forward compatibility placeholder.
"""
@ -76,9 +80,9 @@ class DefaultCallbacks:
"episode": episode
})
def on_episode_end(self, worker: "RolloutWorker", base_env: BaseEnv,
def on_episode_end(self, *, worker: "RolloutWorker", base_env: BaseEnv,
policies: Dict[PolicyID, Policy],
episode: MultiAgentEpisode, **kwargs):
episode: MultiAgentEpisode, env_index: int, **kwargs):
"""Runs when an episode is done.
Args:
@ -91,6 +95,8 @@ class DefaultCallbacks:
state. You can use the `episode.user_data` dict to store
temporary data, and `episode.custom_metrics` to store custom
metrics for the episode.
env_index (int): The index of the (vectorized) env, which the
episode belongs to.
kwargs: Forward compatibility placeholder.
"""
@ -102,7 +108,7 @@ class DefaultCallbacks:
})
def on_postprocess_trajectory(
self, worker: "RolloutWorker", episode: MultiAgentEpisode,
self, *, worker: "RolloutWorker", episode: MultiAgentEpisode,
agent_id: AgentID, policy_id: PolicyID,
policies: Dict[PolicyID, Policy], postprocessed_batch: SampleBatch,
original_batches: Dict[AgentID, SampleBatch], **kwargs):
@ -136,9 +142,9 @@ class DefaultCallbacks:
"all_pre_batches": original_batches,
})
def on_sample_end(self, worker: "RolloutWorker", samples: SampleBatch,
def on_sample_end(self, *, worker: "RolloutWorker", samples: SampleBatch,
**kwargs):
"""Called at the end RolloutWorker.sample().
"""Called at the end of RolloutWorker.sample().
Args:
worker (RolloutWorker): Reference to the current rollout worker.
@ -153,7 +159,7 @@ class DefaultCallbacks:
"samples": samples,
})
def on_train_result(self, trainer, result: dict, **kwargs):
def on_train_result(self, *, trainer, result: dict, **kwargs):
"""Called at the end of Trainable.train().
Args:

View file

@ -49,6 +49,15 @@ PolicyEvalData = namedtuple("PolicyEvalData", [
StateBatch = List[List[Any]]
class NewEpisodeDefaultDict(defaultdict):
def __missing__(self, env_index):
if self.default_factory is None:
raise KeyError(env_index)
else:
ret = self[env_index] = self.default_factory(env_index)
return ret
class _PerfStats:
"""Sampler perf stats that will be included in rollout metrics."""
@ -505,7 +514,7 @@ def _env_runner(
return MultiAgentSampleBatchBuilder(policies, clip_rewards,
callbacks)
def new_episode():
def new_episode(env_index):
episode = MultiAgentEpisode(policies, policy_mapping_fn,
get_batch_builder, extra_batch_callback)
# Call each policy's Exploration.on_episode_start method.
@ -521,10 +530,13 @@ def _env_runner(
worker=worker,
base_env=base_env,
policies=policies,
episode=episode)
episode=episode,
env_index=env_index,
)
return episode
active_episodes: Dict[str, MultiAgentEpisode] = defaultdict(new_episode)
active_episodes: Dict[str, MultiAgentEpisode] = \
NewEpisodeDefaultDict(new_episode)
eval_results = None
while True:
@ -830,7 +842,10 @@ def _process_observations(
# Invoke the step callback after the step is logged to the episode
callbacks.on_episode_step(
worker=worker, base_env=base_env, episode=episode)
worker=worker,
base_env=base_env,
episode=episode,
env_index=env_id)
# Cut the batch if ...
# - all-agents-done and not packing multiple episodes into one
@ -869,7 +884,9 @@ def _process_observations(
worker=worker,
base_env=base_env,
policies=policies,
episode=episode)
episode=episode,
env_index=env_id,
)
if hit_horizon and soft_horizon:
episode.soft_reset()
resetted_obs: Dict[AgentID, EnvObsType] = agent_obs
@ -1120,7 +1137,9 @@ def _process_observations_w_trajectory_view_api(
worker=worker,
base_env=base_env,
policies=policies,
episode=episode)
episode=episode,
env_index=env_id,
)
if hit_horizon and soft_horizon:
episode.soft_reset()
resetted_obs: Dict[AgentID, EnvObsType] = agent_obs

View file

@ -64,11 +64,11 @@ class WorkerSet:
trainer_config,
{"tf_session_args": trainer_config["local_tf_session_args"]})
# Create a number of remote workers
# Create a number of remote workers.
self._remote_workers = []
self.add_workers(num_workers)
# Always create a local worker
# Always create a local worker.
self._local_worker = self._make_worker(RolloutWorker, env_creator,
self._policy_class, 0,
self._local_config)

View file

@ -18,41 +18,43 @@ from ray.rllib.agents.callbacks import DefaultCallbacks
class MyCallbacks(DefaultCallbacks):
def on_episode_start(self, worker: RolloutWorker, base_env: BaseEnv,
def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv,
policies: Dict[str, Policy],
episode: MultiAgentEpisode, **kwargs):
print("episode {} started".format(episode.episode_id))
episode: MultiAgentEpisode, env_index: int, **kwargs):
print("episode {} (env-idx={}) started.".format(
episode.episode_id, env_index))
episode.user_data["pole_angles"] = []
episode.hist_data["pole_angles"] = []
def on_episode_step(self, worker: RolloutWorker, base_env: BaseEnv,
episode: MultiAgentEpisode, **kwargs):
def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv,
episode: MultiAgentEpisode, env_index: int, **kwargs):
pole_angle = abs(episode.last_observation_for()[2])
raw_angle = abs(episode.last_raw_obs_for()[2])
assert pole_angle == raw_angle
episode.user_data["pole_angles"].append(pole_angle)
def on_episode_end(self, worker: RolloutWorker, base_env: BaseEnv,
def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv,
policies: Dict[str, Policy], episode: MultiAgentEpisode,
**kwargs):
env_index: int, **kwargs):
pole_angle = np.mean(episode.user_data["pole_angles"])
print("episode {} ended with length {} and pole angles {}".format(
episode.episode_id, episode.length, pole_angle))
print("episode {} (env-idx={}) ended with length {} and pole "
"angles {}".format(episode.episode_id, env_index, episode.length,
pole_angle))
episode.custom_metrics["pole_angle"] = pole_angle
episode.hist_data["pole_angles"] = episode.user_data["pole_angles"]
def on_sample_end(self, worker: RolloutWorker, samples: SampleBatch,
def on_sample_end(self, *, worker: RolloutWorker, samples: SampleBatch,
**kwargs):
print("returned sample batch of size {}".format(samples.count))
def on_train_result(self, trainer, result: dict, **kwargs):
def on_train_result(self, *, trainer, result: dict, **kwargs):
print("trainer.train() result: {} -> {} episodes".format(
trainer, result["episodes_this_iter"]))
# you can mutate the result dict to add new fields to return
result["callback_ok"] = True
def on_postprocess_trajectory(
self, worker: RolloutWorker, episode: MultiAgentEpisode,
self, *, worker: RolloutWorker, episode: MultiAgentEpisode,
agent_id: str, policy_id: str, policies: Dict[str, Policy],
postprocessed_batch: SampleBatch,
original_batches: Dict[str, SampleBatch], **kwargs):
@ -75,6 +77,7 @@ if __name__ == "__main__":
},
config={
"env": "CartPole-v0",
"num_envs_per_worker": 2,
"callbacks": MyCallbacks,
"framework": "tf",
},

View file

@ -197,8 +197,6 @@ class Curiosity(Exploration):
})
phi, next_phi = phis[:batch_size], phis[batch_size:]
# Detach phi from graph (should not backpropagate through feature net
# for forward-loss).
predicted_next_phi = self.model._curiosity_forward_fcnet(
torch.cat(
[

View file

@ -528,4 +528,4 @@ inline std::string GetActorFullName(bool global, std::string name) {
return global ? name
: ::ray::CoreWorkerProcess::GetCoreWorker().GetCurrentJobId().Hex() +
"-" + name;
}
}