[RLlib] Issue 7401: In eval mode (if evaluation_episodes > 0), agent hangs if Env does not terminate. (#7448)

* Fix.

* Rollback.

* Fix issue 7421.

* Fix.
This commit is contained in:
Eric Liang 2020-03-04 12:58:34 -08:00 committed by GitHub
parent c38224d8e5
commit fddeb6809c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 17 additions and 22 deletions

View file

@ -29,26 +29,17 @@ DEFAULT_CONFIG = with_common_config({
"normalize_actions": True, "normalize_actions": True,
# === Learning === # === Learning ===
# Update the target by \tau * policy + (1-\tau) * target_policy # Update the target by \tau * policy + (1-\tau) * target_policy.
"tau": 5e-3, "tau": 5e-3,
# Target entropy lower bound. This is the inverse of reward scale, # Target entropy lower bound. This is the inverse of reward scale,
# and will be optimized automatically. # and will be optimized automatically.
"target_entropy": "auto", "target_entropy": "auto",
# Disable setting done=True at end of episode. # Disable setting done=True at end of episode.
"no_done_at_end": True, "no_done_at_end": True,
# N-step target updates # N-step target updates.
"n_step": 1, "n_step": 1,
# === Evaluation ===
# The evaluation stats will be reported under the "evaluation" metric key.
"evaluation_interval": 1,
# Number of episodes to run per evaluation period.
"evaluation_num_episodes": 1,
# Extra configuration that disables exploration.
"evaluation_config": {
"explore": False,
},
# Number of env steps to optimize for before returning # Number of env steps to optimize for before returning.
"timesteps_per_iteration": 100, "timesteps_per_iteration": 100,
# === Replay buffer === # === Replay buffer ===

View file

@ -95,9 +95,9 @@ class MultiAgentSampleBatchBuilder:
def total(self): def total(self):
"""Returns summed number of steps across all agent buffers.""" """Returns summed number of steps across all agent buffers."""
return sum(p.count for p in self.policy_builders.values()) return sum(a.count for a in self.agent_builders.values())
def has_pending_data(self): def has_pending_agent_data(self):
"""Returns whether there is pending unprocessed data.""" """Returns whether there is pending unprocessed data."""
return len(self.agent_builders) > 0 return len(self.agent_builders) > 0

View file

@ -266,7 +266,7 @@ def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
if not horizon: if not horizon:
horizon = (base_env.get_unwrapped()[0].spec.max_episode_steps) horizon = (base_env.get_unwrapped()[0].spec.max_episode_steps)
except Exception: except Exception:
logger.debug("no episode horizon specified, assuming inf") logger.debug("No episode horizon specified, assuming inf.")
if not horizon: if not horizon:
horizon = float("inf") horizon = float("inf")
@ -354,6 +354,8 @@ def _process_observations(base_env, policies, batch_builder_pool,
active_envs = set() active_envs = set()
to_eval = defaultdict(list) to_eval = defaultdict(list)
outputs = [] outputs = []
large_batch_threshold = max(1000, unroll_length * 10) if \
unroll_length != float("inf") else 5000
# For each environment # For each environment
for env_id, agent_obs in unfiltered_obs.items(): for env_id, agent_obs in unfiltered_obs.items():
@ -364,18 +366,21 @@ def _process_observations(base_env, policies, batch_builder_pool,
episode.batch_builder.count += 1 episode.batch_builder.count += 1
episode._add_agent_rewards(rewards[env_id]) episode._add_agent_rewards(rewards[env_id])
if (episode.batch_builder.total() > max(1000, unroll_length * 10) if (episode.batch_builder.total() > large_batch_threshold
and log_once("large_batch_warning")): and log_once("large_batch_warning")):
logger.warning( logger.warning(
"More than {} observations for {} env steps ".format( "More than {} observations for {} env steps ".format(
episode.batch_builder.total(), episode.batch_builder.total(),
episode.batch_builder.count) + "are buffered in " episode.batch_builder.count) + "are buffered in "
"the sampler. If this is more than you expected, check that " "the sampler. If this is more than you expected, check that "
"that you set a horizon on your environment correctly. Note " "that you set a horizon on your environment correctly and that"
"that in multi-agent environments, `sample_batch_size` sets " " it terminates at some point. "
"Note: In multi-agent environments, `sample_batch_size` sets "
"the batch size based on environment steps, not the steps of " "the batch size based on environment steps, not the steps of "
"individual agents, which can result in unexpectedly large " "individual agents, which can result in unexpectedly large "
"batches.") "batches. Also, you may be in evaluation waiting for your Env "
"to terminate (batch_mode=`complete_episodes`). Make sure it "
"does at some point.")
# Check episode termination conditions # Check episode termination conditions
if dones[env_id]["__all__"] or episode.length >= horizon: if dones[env_id]["__all__"] or episode.length >= horizon:
@ -398,7 +403,7 @@ def _process_observations(base_env, policies, batch_builder_pool,
all_done = False all_done = False
active_envs.add(env_id) active_envs.add(env_id)
# For each agent in the environment # For each agent in the environment.
for agent_id, raw_obs in agent_obs.items(): for agent_id, raw_obs in agent_obs.items():
policy_id = episode.policy_for(agent_id) policy_id = episode.policy_for(agent_id)
prep_obs = _get_or_raise(preprocessors, prep_obs = _get_or_raise(preprocessors,
@ -451,7 +456,7 @@ def _process_observations(base_env, policies, batch_builder_pool,
# Cut the batch if we're not packing multiple episodes into one, # Cut the batch if we're not packing multiple episodes into one,
# or if we've exceeded the requested batch size. # or if we've exceeded the requested batch size.
if episode.batch_builder.has_pending_data(): if episode.batch_builder.has_pending_agent_data():
if dones[env_id]["__all__"] and not no_done_at_end: if dones[env_id]["__all__"] and not no_done_at_end:
episode.batch_builder.check_missing_dones() episode.batch_builder.check_missing_dones()
if (all_done and not pack) or \ if (all_done and not pack) or \

View file

@ -5,7 +5,6 @@ pendulum-sac:
episode_reward_mean: -300 # note that evaluation perf is higher episode_reward_mean: -300 # note that evaluation perf is higher
timesteps_total: 10000 timesteps_total: 10000
config: config:
evaluation_interval: 1 # logged under evaluation/* metric keys
soft_horizon: True soft_horizon: True
clip_actions: False clip_actions: False
normalize_actions: True normalize_actions: True