mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[RLlib] Issue 7401: In eval mode (if evaluation_episodes > 0), agent hangs if Env does not terminate. (#7448)
* Fix. * Rollback. * Fix issue 7421. * Fix.
This commit is contained in:
parent
c38224d8e5
commit
fddeb6809c
4 changed files with 17 additions and 22 deletions
|
@ -29,26 +29,17 @@ DEFAULT_CONFIG = with_common_config({
|
||||||
"normalize_actions": True,
|
"normalize_actions": True,
|
||||||
|
|
||||||
# === Learning ===
|
# === Learning ===
|
||||||
# Update the target by \tau * policy + (1-\tau) * target_policy
|
# Update the target by \tau * policy + (1-\tau) * target_policy.
|
||||||
"tau": 5e-3,
|
"tau": 5e-3,
|
||||||
# Target entropy lower bound. This is the inverse of reward scale,
|
# Target entropy lower bound. This is the inverse of reward scale,
|
||||||
# and will be optimized automatically.
|
# and will be optimized automatically.
|
||||||
"target_entropy": "auto",
|
"target_entropy": "auto",
|
||||||
# Disable setting done=True at end of episode.
|
# Disable setting done=True at end of episode.
|
||||||
"no_done_at_end": True,
|
"no_done_at_end": True,
|
||||||
# N-step target updates
|
# N-step target updates.
|
||||||
"n_step": 1,
|
"n_step": 1,
|
||||||
# === Evaluation ===
|
|
||||||
# The evaluation stats will be reported under the "evaluation" metric key.
|
|
||||||
"evaluation_interval": 1,
|
|
||||||
# Number of episodes to run per evaluation period.
|
|
||||||
"evaluation_num_episodes": 1,
|
|
||||||
# Extra configuration that disables exploration.
|
|
||||||
"evaluation_config": {
|
|
||||||
"explore": False,
|
|
||||||
},
|
|
||||||
|
|
||||||
# Number of env steps to optimize for before returning
|
# Number of env steps to optimize for before returning.
|
||||||
"timesteps_per_iteration": 100,
|
"timesteps_per_iteration": 100,
|
||||||
|
|
||||||
# === Replay buffer ===
|
# === Replay buffer ===
|
||||||
|
|
|
@ -95,9 +95,9 @@ class MultiAgentSampleBatchBuilder:
|
||||||
def total(self):
|
def total(self):
|
||||||
"""Returns summed number of steps across all agent buffers."""
|
"""Returns summed number of steps across all agent buffers."""
|
||||||
|
|
||||||
return sum(p.count for p in self.policy_builders.values())
|
return sum(a.count for a in self.agent_builders.values())
|
||||||
|
|
||||||
def has_pending_data(self):
|
def has_pending_agent_data(self):
|
||||||
"""Returns whether there is pending unprocessed data."""
|
"""Returns whether there is pending unprocessed data."""
|
||||||
|
|
||||||
return len(self.agent_builders) > 0
|
return len(self.agent_builders) > 0
|
||||||
|
|
|
@ -266,7 +266,7 @@ def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
|
||||||
if not horizon:
|
if not horizon:
|
||||||
horizon = (base_env.get_unwrapped()[0].spec.max_episode_steps)
|
horizon = (base_env.get_unwrapped()[0].spec.max_episode_steps)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.debug("no episode horizon specified, assuming inf")
|
logger.debug("No episode horizon specified, assuming inf.")
|
||||||
if not horizon:
|
if not horizon:
|
||||||
horizon = float("inf")
|
horizon = float("inf")
|
||||||
|
|
||||||
|
@ -354,6 +354,8 @@ def _process_observations(base_env, policies, batch_builder_pool,
|
||||||
active_envs = set()
|
active_envs = set()
|
||||||
to_eval = defaultdict(list)
|
to_eval = defaultdict(list)
|
||||||
outputs = []
|
outputs = []
|
||||||
|
large_batch_threshold = max(1000, unroll_length * 10) if \
|
||||||
|
unroll_length != float("inf") else 5000
|
||||||
|
|
||||||
# For each environment
|
# For each environment
|
||||||
for env_id, agent_obs in unfiltered_obs.items():
|
for env_id, agent_obs in unfiltered_obs.items():
|
||||||
|
@ -364,18 +366,21 @@ def _process_observations(base_env, policies, batch_builder_pool,
|
||||||
episode.batch_builder.count += 1
|
episode.batch_builder.count += 1
|
||||||
episode._add_agent_rewards(rewards[env_id])
|
episode._add_agent_rewards(rewards[env_id])
|
||||||
|
|
||||||
if (episode.batch_builder.total() > max(1000, unroll_length * 10)
|
if (episode.batch_builder.total() > large_batch_threshold
|
||||||
and log_once("large_batch_warning")):
|
and log_once("large_batch_warning")):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"More than {} observations for {} env steps ".format(
|
"More than {} observations for {} env steps ".format(
|
||||||
episode.batch_builder.total(),
|
episode.batch_builder.total(),
|
||||||
episode.batch_builder.count) + "are buffered in "
|
episode.batch_builder.count) + "are buffered in "
|
||||||
"the sampler. If this is more than you expected, check that "
|
"the sampler. If this is more than you expected, check that "
|
||||||
"that you set a horizon on your environment correctly. Note "
|
"that you set a horizon on your environment correctly and that"
|
||||||
"that in multi-agent environments, `sample_batch_size` sets "
|
" it terminates at some point. "
|
||||||
|
"Note: In multi-agent environments, `sample_batch_size` sets "
|
||||||
"the batch size based on environment steps, not the steps of "
|
"the batch size based on environment steps, not the steps of "
|
||||||
"individual agents, which can result in unexpectedly large "
|
"individual agents, which can result in unexpectedly large "
|
||||||
"batches.")
|
"batches. Also, you may be in evaluation waiting for your Env "
|
||||||
|
"to terminate (batch_mode=`complete_episodes`). Make sure it "
|
||||||
|
"does at some point.")
|
||||||
|
|
||||||
# Check episode termination conditions
|
# Check episode termination conditions
|
||||||
if dones[env_id]["__all__"] or episode.length >= horizon:
|
if dones[env_id]["__all__"] or episode.length >= horizon:
|
||||||
|
@ -398,7 +403,7 @@ def _process_observations(base_env, policies, batch_builder_pool,
|
||||||
all_done = False
|
all_done = False
|
||||||
active_envs.add(env_id)
|
active_envs.add(env_id)
|
||||||
|
|
||||||
# For each agent in the environment
|
# For each agent in the environment.
|
||||||
for agent_id, raw_obs in agent_obs.items():
|
for agent_id, raw_obs in agent_obs.items():
|
||||||
policy_id = episode.policy_for(agent_id)
|
policy_id = episode.policy_for(agent_id)
|
||||||
prep_obs = _get_or_raise(preprocessors,
|
prep_obs = _get_or_raise(preprocessors,
|
||||||
|
@ -451,7 +456,7 @@ def _process_observations(base_env, policies, batch_builder_pool,
|
||||||
|
|
||||||
# Cut the batch if we're not packing multiple episodes into one,
|
# Cut the batch if we're not packing multiple episodes into one,
|
||||||
# or if we've exceeded the requested batch size.
|
# or if we've exceeded the requested batch size.
|
||||||
if episode.batch_builder.has_pending_data():
|
if episode.batch_builder.has_pending_agent_data():
|
||||||
if dones[env_id]["__all__"] and not no_done_at_end:
|
if dones[env_id]["__all__"] and not no_done_at_end:
|
||||||
episode.batch_builder.check_missing_dones()
|
episode.batch_builder.check_missing_dones()
|
||||||
if (all_done and not pack) or \
|
if (all_done and not pack) or \
|
||||||
|
|
|
@ -5,7 +5,6 @@ pendulum-sac:
|
||||||
episode_reward_mean: -300 # note that evaluation perf is higher
|
episode_reward_mean: -300 # note that evaluation perf is higher
|
||||||
timesteps_total: 10000
|
timesteps_total: 10000
|
||||||
config:
|
config:
|
||||||
evaluation_interval: 1 # logged under evaluation/* metric keys
|
|
||||||
soft_horizon: True
|
soft_horizon: True
|
||||||
clip_actions: False
|
clip_actions: False
|
||||||
normalize_actions: True
|
normalize_actions: True
|
||||||
|
|
Loading…
Add table
Reference in a new issue