mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Issue 10469: Callbacks should receive env idx ... (#10477)
This commit is contained in:
parent
d8ac4bc719
commit
715ee8dfc9
6 changed files with 59 additions and 33 deletions
|
@ -30,9 +30,9 @@ class DefaultCallbacks:
|
|||
"a class extending rllib.agents.callbacks.DefaultCallbacks")
|
||||
self.legacy_callbacks = legacy_callbacks_dict or {}
|
||||
|
||||
def on_episode_start(self, worker: "RolloutWorker", base_env: BaseEnv,
|
||||
def on_episode_start(self, *, worker: "RolloutWorker", base_env: BaseEnv,
|
||||
policies: Dict[PolicyID, Policy],
|
||||
episode: MultiAgentEpisode, **kwargs):
|
||||
episode: MultiAgentEpisode, env_index: int, **kwargs):
|
||||
"""Callback run on the rollout worker before each episode starts.
|
||||
|
||||
Args:
|
||||
|
@ -45,6 +45,8 @@ class DefaultCallbacks:
|
|||
state. You can use the `episode.user_data` dict to store
|
||||
temporary data, and `episode.custom_metrics` to store custom
|
||||
metrics for the episode.
|
||||
env_index (int): The index of the (vectorized) env, which the
|
||||
episode belongs to.
|
||||
kwargs: Forward compatibility placeholder.
|
||||
"""
|
||||
|
||||
|
@ -55,8 +57,8 @@ class DefaultCallbacks:
|
|||
"episode": episode,
|
||||
})
|
||||
|
||||
def on_episode_step(self, worker: "RolloutWorker", base_env: BaseEnv,
|
||||
episode: MultiAgentEpisode, **kwargs):
|
||||
def on_episode_step(self, *, worker: "RolloutWorker", base_env: BaseEnv,
|
||||
episode: MultiAgentEpisode, env_index: int, **kwargs):
|
||||
"""Runs on each episode step.
|
||||
|
||||
Args:
|
||||
|
@ -67,6 +69,8 @@ class DefaultCallbacks:
|
|||
state. You can use the `episode.user_data` dict to store
|
||||
temporary data, and `episode.custom_metrics` to store custom
|
||||
metrics for the episode.
|
||||
env_index (int): The index of the (vectorized) env, which the
|
||||
episode belongs to.
|
||||
kwargs: Forward compatibility placeholder.
|
||||
"""
|
||||
|
||||
|
@ -76,9 +80,9 @@ class DefaultCallbacks:
|
|||
"episode": episode
|
||||
})
|
||||
|
||||
def on_episode_end(self, worker: "RolloutWorker", base_env: BaseEnv,
|
||||
def on_episode_end(self, *, worker: "RolloutWorker", base_env: BaseEnv,
|
||||
policies: Dict[PolicyID, Policy],
|
||||
episode: MultiAgentEpisode, **kwargs):
|
||||
episode: MultiAgentEpisode, env_index: int, **kwargs):
|
||||
"""Runs when an episode is done.
|
||||
|
||||
Args:
|
||||
|
@ -91,6 +95,8 @@ class DefaultCallbacks:
|
|||
state. You can use the `episode.user_data` dict to store
|
||||
temporary data, and `episode.custom_metrics` to store custom
|
||||
metrics for the episode.
|
||||
env_index (int): The index of the (vectorized) env, which the
|
||||
episode belongs to.
|
||||
kwargs: Forward compatibility placeholder.
|
||||
"""
|
||||
|
||||
|
@ -102,7 +108,7 @@ class DefaultCallbacks:
|
|||
})
|
||||
|
||||
def on_postprocess_trajectory(
|
||||
self, worker: "RolloutWorker", episode: MultiAgentEpisode,
|
||||
self, *, worker: "RolloutWorker", episode: MultiAgentEpisode,
|
||||
agent_id: AgentID, policy_id: PolicyID,
|
||||
policies: Dict[PolicyID, Policy], postprocessed_batch: SampleBatch,
|
||||
original_batches: Dict[AgentID, SampleBatch], **kwargs):
|
||||
|
@ -136,9 +142,9 @@ class DefaultCallbacks:
|
|||
"all_pre_batches": original_batches,
|
||||
})
|
||||
|
||||
def on_sample_end(self, worker: "RolloutWorker", samples: SampleBatch,
|
||||
def on_sample_end(self, *, worker: "RolloutWorker", samples: SampleBatch,
|
||||
**kwargs):
|
||||
"""Called at the end RolloutWorker.sample().
|
||||
"""Called at the end of RolloutWorker.sample().
|
||||
|
||||
Args:
|
||||
worker (RolloutWorker): Reference to the current rollout worker.
|
||||
|
@ -153,7 +159,7 @@ class DefaultCallbacks:
|
|||
"samples": samples,
|
||||
})
|
||||
|
||||
def on_train_result(self, trainer, result: dict, **kwargs):
|
||||
def on_train_result(self, *, trainer, result: dict, **kwargs):
|
||||
"""Called at the end of Trainable.train().
|
||||
|
||||
Args:
|
||||
|
|
|
@ -49,6 +49,15 @@ PolicyEvalData = namedtuple("PolicyEvalData", [
|
|||
StateBatch = List[List[Any]]
|
||||
|
||||
|
||||
class NewEpisodeDefaultDict(defaultdict):
|
||||
def __missing__(self, env_index):
|
||||
if self.default_factory is None:
|
||||
raise KeyError(env_index)
|
||||
else:
|
||||
ret = self[env_index] = self.default_factory(env_index)
|
||||
return ret
|
||||
|
||||
|
||||
class _PerfStats:
|
||||
"""Sampler perf stats that will be included in rollout metrics."""
|
||||
|
||||
|
@ -505,7 +514,7 @@ def _env_runner(
|
|||
return MultiAgentSampleBatchBuilder(policies, clip_rewards,
|
||||
callbacks)
|
||||
|
||||
def new_episode():
|
||||
def new_episode(env_index):
|
||||
episode = MultiAgentEpisode(policies, policy_mapping_fn,
|
||||
get_batch_builder, extra_batch_callback)
|
||||
# Call each policy's Exploration.on_episode_start method.
|
||||
|
@ -521,10 +530,13 @@ def _env_runner(
|
|||
worker=worker,
|
||||
base_env=base_env,
|
||||
policies=policies,
|
||||
episode=episode)
|
||||
episode=episode,
|
||||
env_index=env_index,
|
||||
)
|
||||
return episode
|
||||
|
||||
active_episodes: Dict[str, MultiAgentEpisode] = defaultdict(new_episode)
|
||||
active_episodes: Dict[str, MultiAgentEpisode] = \
|
||||
NewEpisodeDefaultDict(new_episode)
|
||||
eval_results = None
|
||||
|
||||
while True:
|
||||
|
@ -830,7 +842,10 @@ def _process_observations(
|
|||
|
||||
# Invoke the step callback after the step is logged to the episode
|
||||
callbacks.on_episode_step(
|
||||
worker=worker, base_env=base_env, episode=episode)
|
||||
worker=worker,
|
||||
base_env=base_env,
|
||||
episode=episode,
|
||||
env_index=env_id)
|
||||
|
||||
# Cut the batch if ...
|
||||
# - all-agents-done and not packing multiple episodes into one
|
||||
|
@ -869,7 +884,9 @@ def _process_observations(
|
|||
worker=worker,
|
||||
base_env=base_env,
|
||||
policies=policies,
|
||||
episode=episode)
|
||||
episode=episode,
|
||||
env_index=env_id,
|
||||
)
|
||||
if hit_horizon and soft_horizon:
|
||||
episode.soft_reset()
|
||||
resetted_obs: Dict[AgentID, EnvObsType] = agent_obs
|
||||
|
@ -1120,7 +1137,9 @@ def _process_observations_w_trajectory_view_api(
|
|||
worker=worker,
|
||||
base_env=base_env,
|
||||
policies=policies,
|
||||
episode=episode)
|
||||
episode=episode,
|
||||
env_index=env_id,
|
||||
)
|
||||
if hit_horizon and soft_horizon:
|
||||
episode.soft_reset()
|
||||
resetted_obs: Dict[AgentID, EnvObsType] = agent_obs
|
||||
|
|
|
@ -64,11 +64,11 @@ class WorkerSet:
|
|||
trainer_config,
|
||||
{"tf_session_args": trainer_config["local_tf_session_args"]})
|
||||
|
||||
# Create a number of remote workers
|
||||
# Create a number of remote workers.
|
||||
self._remote_workers = []
|
||||
self.add_workers(num_workers)
|
||||
|
||||
# Always create a local worker
|
||||
# Always create a local worker.
|
||||
self._local_worker = self._make_worker(RolloutWorker, env_creator,
|
||||
self._policy_class, 0,
|
||||
self._local_config)
|
||||
|
|
|
@ -18,41 +18,43 @@ from ray.rllib.agents.callbacks import DefaultCallbacks
|
|||
|
||||
|
||||
class MyCallbacks(DefaultCallbacks):
|
||||
def on_episode_start(self, worker: RolloutWorker, base_env: BaseEnv,
|
||||
def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv,
|
||||
policies: Dict[str, Policy],
|
||||
episode: MultiAgentEpisode, **kwargs):
|
||||
print("episode {} started".format(episode.episode_id))
|
||||
episode: MultiAgentEpisode, env_index: int, **kwargs):
|
||||
print("episode {} (env-idx={}) started.".format(
|
||||
episode.episode_id, env_index))
|
||||
episode.user_data["pole_angles"] = []
|
||||
episode.hist_data["pole_angles"] = []
|
||||
|
||||
def on_episode_step(self, worker: RolloutWorker, base_env: BaseEnv,
|
||||
episode: MultiAgentEpisode, **kwargs):
|
||||
def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv,
|
||||
episode: MultiAgentEpisode, env_index: int, **kwargs):
|
||||
pole_angle = abs(episode.last_observation_for()[2])
|
||||
raw_angle = abs(episode.last_raw_obs_for()[2])
|
||||
assert pole_angle == raw_angle
|
||||
episode.user_data["pole_angles"].append(pole_angle)
|
||||
|
||||
def on_episode_end(self, worker: RolloutWorker, base_env: BaseEnv,
|
||||
def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv,
|
||||
policies: Dict[str, Policy], episode: MultiAgentEpisode,
|
||||
**kwargs):
|
||||
env_index: int, **kwargs):
|
||||
pole_angle = np.mean(episode.user_data["pole_angles"])
|
||||
print("episode {} ended with length {} and pole angles {}".format(
|
||||
episode.episode_id, episode.length, pole_angle))
|
||||
print("episode {} (env-idx={}) ended with length {} and pole "
|
||||
"angles {}".format(episode.episode_id, env_index, episode.length,
|
||||
pole_angle))
|
||||
episode.custom_metrics["pole_angle"] = pole_angle
|
||||
episode.hist_data["pole_angles"] = episode.user_data["pole_angles"]
|
||||
|
||||
def on_sample_end(self, worker: RolloutWorker, samples: SampleBatch,
|
||||
def on_sample_end(self, *, worker: RolloutWorker, samples: SampleBatch,
|
||||
**kwargs):
|
||||
print("returned sample batch of size {}".format(samples.count))
|
||||
|
||||
def on_train_result(self, trainer, result: dict, **kwargs):
|
||||
def on_train_result(self, *, trainer, result: dict, **kwargs):
|
||||
print("trainer.train() result: {} -> {} episodes".format(
|
||||
trainer, result["episodes_this_iter"]))
|
||||
# you can mutate the result dict to add new fields to return
|
||||
result["callback_ok"] = True
|
||||
|
||||
def on_postprocess_trajectory(
|
||||
self, worker: RolloutWorker, episode: MultiAgentEpisode,
|
||||
self, *, worker: RolloutWorker, episode: MultiAgentEpisode,
|
||||
agent_id: str, policy_id: str, policies: Dict[str, Policy],
|
||||
postprocessed_batch: SampleBatch,
|
||||
original_batches: Dict[str, SampleBatch], **kwargs):
|
||||
|
@ -75,6 +77,7 @@ if __name__ == "__main__":
|
|||
},
|
||||
config={
|
||||
"env": "CartPole-v0",
|
||||
"num_envs_per_worker": 2,
|
||||
"callbacks": MyCallbacks,
|
||||
"framework": "tf",
|
||||
},
|
||||
|
|
|
@ -197,8 +197,6 @@ class Curiosity(Exploration):
|
|||
})
|
||||
phi, next_phi = phis[:batch_size], phis[batch_size:]
|
||||
|
||||
# Detach phi from graph (should not backpropagate through feature net
|
||||
# for forward-loss).
|
||||
predicted_next_phi = self.model._curiosity_forward_fcnet(
|
||||
torch.cat(
|
||||
[
|
||||
|
|
|
@ -528,4 +528,4 @@ inline std::string GetActorFullName(bool global, std::string name) {
|
|||
return global ? name
|
||||
: ::ray::CoreWorkerProcess::GetCoreWorker().GetCurrentJobId().Hex() +
|
||||
"-" + name;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue