[RLlib] Bug: If trainer config horizon is provided, should try to increase env steps to that value. (#7531)

This commit is contained in:
Sven Mika 2020-03-12 19:03:37 +01:00 committed by GitHub
parent 80d314ae5e
commit f165766813
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 54 additions and 9 deletions

View file

@ -262,13 +262,33 @@ def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
terminal condition, and other fields as dictated by `policy`.
"""
# Try to get Env's max_episode_steps prop. If it doesn't exist, catch
# error and continue.
max_episode_steps = None
try:
if not horizon:
horizon = (base_env.get_unwrapped()[0].spec.max_episode_steps)
max_episode_steps = base_env.get_unwrapped()[0].spec.max_episode_steps
except Exception:
logger.debug("No episode horizon specified, assuming inf.")
if not horizon:
pass
# Trainer has a given `horizon` setting.
if horizon:
# `horizon` is larger than env's limit -> Error and explain how
# to increase Env's own episode limit.
if max_episode_steps and horizon > max_episode_steps:
raise ValueError(
"Your `horizon` setting ({}) is larger than the Env's own "
"timestep limit ({})! Try to increase the Env's limit via "
"setting its `spec.max_episode_steps` property.".format(
horizon, max_episode_steps))
# Otherwise, set Trainer's horizon to env's max-steps.
elif max_episode_steps:
horizon = max_episode_steps
logger.debug(
"No episode horizon specified, setting it to Env's limit ({}).".
format(max_episode_steps))
else:
horizon = float("inf")
logger.debug("No episode horizon specified, assuming inf.")
# Pool of batch builders, which can be shared across episodes to pack
# trajectory data.

View file

@ -8,12 +8,13 @@ import unittest
import ray
from ray.rllib.agents.pg import PGTrainer
from ray.rllib.agents.a3c import A2CTrainer
from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.evaluation.metrics import collect_metrics
from ray.rllib.policy.tests.test_policy import TestPolicy
from ray.rllib.evaluation.postprocessing import compute_advantages
from ray.rllib.policy.tests.test_policy import TestPolicy
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.utils.test_utils import check
from ray.tune.registry import register_env
@ -261,18 +262,42 @@ class TestRolloutWorker(unittest.TestCase):
def test_hard_horizon(self):
ev = RolloutWorker(
env_creator=lambda _: MockEnv(episode_length=10),
env_creator=lambda _: MockEnv2(episode_length=10),
policy=MockPolicy,
batch_mode="complete_episodes",
batch_steps=10,
episode_horizon=4,
soft_horizon=False)
samples = ev.sample()
# three logical episodes
# Three logical episodes and correct episode resets (always after 4
# steps).
self.assertEqual(len(set(samples["eps_id"])), 3)
# 3 done values
for i in range(4):
self.assertEqual(np.argmax(samples["obs"][i]), i)
self.assertEqual(np.argmax(samples["obs"][4]), 0)
# 3 done values.
self.assertEqual(sum(samples["dones"]), 3)
# A gym env's max_episode_steps is smaller than Trainer's horizon.
ev = RolloutWorker(
env_creator=lambda _: gym.make("CartPole-v0"),
policy=MockPolicy,
batch_mode="complete_episodes",
batch_steps=10,
episode_horizon=6,
soft_horizon=False)
samples = ev.sample()
# 12 steps due to `complete_episodes` batch_mode.
self.assertEqual(len(samples["eps_id"]), 12)
# Two logical episodes and correct episode resets (always after 6(!)
# steps).
self.assertEqual(len(set(samples["eps_id"])), 2)
# 2 done values after 6 and 12 steps.
check(samples["dones"], [
False, False, False, False, False, True, False, False, False,
False, False, True
])
def test_soft_horizon(self):
ev = RolloutWorker(
env_creator=lambda _: MockEnv(episode_length=10),