mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[RLlib] Bug: If trainer config horizon
is provided, should try to increase env steps to that value. (#7531)
This commit is contained in:
parent
80d314ae5e
commit
f165766813
2 changed files with 54 additions and 9 deletions
|
@ -262,13 +262,33 @@ def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
|
|||
terminal condition, and other fields as dictated by `policy`.
|
||||
"""
|
||||
|
||||
# Try to get Env's max_episode_steps prop. If it doesn't exist, catch
|
||||
# error and continue.
|
||||
max_episode_steps = None
|
||||
try:
|
||||
if not horizon:
|
||||
horizon = (base_env.get_unwrapped()[0].spec.max_episode_steps)
|
||||
max_episode_steps = base_env.get_unwrapped()[0].spec.max_episode_steps
|
||||
except Exception:
|
||||
logger.debug("No episode horizon specified, assuming inf.")
|
||||
if not horizon:
|
||||
pass
|
||||
|
||||
# Trainer has a given `horizon` setting.
|
||||
if horizon:
|
||||
# `horizon` is larger than env's limit -> Error and explain how
|
||||
# to increase Env's own episode limit.
|
||||
if max_episode_steps and horizon > max_episode_steps:
|
||||
raise ValueError(
|
||||
"Your `horizon` setting ({}) is larger than the Env's own "
|
||||
"timestep limit ({})! Try to increase the Env's limit via "
|
||||
"setting its `spec.max_episode_steps` property.".format(
|
||||
horizon, max_episode_steps))
|
||||
# Otherwise, set Trainer's horizon to env's max-steps.
|
||||
elif max_episode_steps:
|
||||
horizon = max_episode_steps
|
||||
logger.debug(
|
||||
"No episode horizon specified, setting it to Env's limit ({}).".
|
||||
format(max_episode_steps))
|
||||
else:
|
||||
horizon = float("inf")
|
||||
logger.debug("No episode horizon specified, assuming inf.")
|
||||
|
||||
# Pool of batch builders, which can be shared across episodes to pack
|
||||
# trajectory data.
|
||||
|
|
|
@ -8,12 +8,13 @@ import unittest
|
|||
import ray
|
||||
from ray.rllib.agents.pg import PGTrainer
|
||||
from ray.rllib.agents.a3c import A2CTrainer
|
||||
from ray.rllib.env.vector_env import VectorEnv
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
from ray.rllib.policy.tests.test_policy import TestPolicy
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages
|
||||
from ray.rllib.policy.tests.test_policy import TestPolicy
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
|
||||
from ray.rllib.env.vector_env import VectorEnv
|
||||
from ray.rllib.utils.test_utils import check
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
|
||||
|
@ -261,18 +262,42 @@ class TestRolloutWorker(unittest.TestCase):
|
|||
|
||||
def test_hard_horizon(self):
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: MockEnv(episode_length=10),
|
||||
env_creator=lambda _: MockEnv2(episode_length=10),
|
||||
policy=MockPolicy,
|
||||
batch_mode="complete_episodes",
|
||||
batch_steps=10,
|
||||
episode_horizon=4,
|
||||
soft_horizon=False)
|
||||
samples = ev.sample()
|
||||
# three logical episodes
|
||||
# Three logical episodes and correct episode resets (always after 4
|
||||
# steps).
|
||||
self.assertEqual(len(set(samples["eps_id"])), 3)
|
||||
# 3 done values
|
||||
for i in range(4):
|
||||
self.assertEqual(np.argmax(samples["obs"][i]), i)
|
||||
self.assertEqual(np.argmax(samples["obs"][4]), 0)
|
||||
# 3 done values.
|
||||
self.assertEqual(sum(samples["dones"]), 3)
|
||||
|
||||
# A gym env's max_episode_steps is smaller than Trainer's horizon.
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy=MockPolicy,
|
||||
batch_mode="complete_episodes",
|
||||
batch_steps=10,
|
||||
episode_horizon=6,
|
||||
soft_horizon=False)
|
||||
samples = ev.sample()
|
||||
# 12 steps due to `complete_episodes` batch_mode.
|
||||
self.assertEqual(len(samples["eps_id"]), 12)
|
||||
# Two logical episodes and correct episode resets (always after 6(!)
|
||||
# steps).
|
||||
self.assertEqual(len(set(samples["eps_id"])), 2)
|
||||
# 2 done values after 6 and 12 steps.
|
||||
check(samples["dones"], [
|
||||
False, False, False, False, False, True, False, False, False,
|
||||
False, False, True
|
||||
])
|
||||
|
||||
def test_soft_horizon(self):
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: MockEnv(episode_length=10),
|
||||
|
|
Loading…
Add table
Reference in a new issue