mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[rllib] Add option for RNN state and value estimates to span episodes (#4429)
* wip soft horizon * tests
This commit is contained in:
parent
c2c548bdfd
commit
55a2d39409
5 changed files with 80 additions and 13 deletions
|
@ -70,8 +70,13 @@ COMMON_CONFIG = {
|
|||
# === Environment ===
|
||||
# Discount factor of the MDP
|
||||
"gamma": 0.99,
|
||||
# Number of steps after which the episode is forced to terminate
|
||||
# Number of steps after which the episode is forced to terminate. Defaults
|
||||
# to `env.spec.max_episode_steps` (if present) for Gym envs.
|
||||
"horizon": None,
|
||||
# Calculate rewards but don't reset the environment when the horizon is
|
||||
# hit. This allows value estimation and RNN state to span across logical
|
||||
# episodes denoted by horizon. This only has an effect if horizon != inf.
|
||||
"soft_horizon": False,
|
||||
# Arguments to pass to the env creator
|
||||
"env_config": {},
|
||||
# Environment name can also be passed via config
|
||||
|
@ -746,6 +751,7 @@ class Agent(Trainable):
|
|||
output_creator=output_creator,
|
||||
remote_worker_envs=config["remote_worker_envs"],
|
||||
remote_env_batch_wait_ms=config["remote_env_batch_wait_ms"],
|
||||
soft_horizon=config["soft_horizon"],
|
||||
_fake_sampler=config.get("_fake_sampler", False))
|
||||
|
||||
@override(Trainable)
|
||||
|
|
|
@ -65,6 +65,19 @@ class MultiAgentEpisode(object):
|
|||
self._agent_to_prev_action = {}
|
||||
self._agent_reward_history = defaultdict(list)
|
||||
|
||||
@DeveloperAPI
|
||||
def soft_reset(self):
|
||||
"""Clears rewards and metrics, but retains RNN and other state.
|
||||
|
||||
This is used to carry state across multiple logical episodes in the
|
||||
same env (i.e., if `soft_horizon` is set).
|
||||
"""
|
||||
self.length = 0
|
||||
self.episode_id = random.randrange(2e9)
|
||||
self.total_reward = 0.0
|
||||
self.agent_rewards = defaultdict(float)
|
||||
self._agent_reward_history = defaultdict(list)
|
||||
|
||||
@DeveloperAPI
|
||||
def policy_for(self, agent_id=_DUMMY_AGENT_ID):
|
||||
"""Returns the policy graph for the specified agent.
|
||||
|
|
|
@ -125,6 +125,7 @@ class PolicyEvaluator(EvaluatorInterface):
|
|||
output_creator=lambda ioctx: NoopOutput(),
|
||||
remote_worker_envs=False,
|
||||
remote_env_batch_wait_ms=0,
|
||||
soft_horizon=False,
|
||||
_fake_sampler=False):
|
||||
"""Initialize a policy evaluator.
|
||||
|
||||
|
@ -208,6 +209,8 @@ class PolicyEvaluator(EvaluatorInterface):
|
|||
least one env is ready) is a reasonable default, but optimal
|
||||
value could be obtained by measuring your environment
|
||||
step / reset and model inference perf.
|
||||
soft_horizon (bool): Calculate rewards but don't reset the
|
||||
environment when the horizon is hit.
|
||||
_fake_sampler (bool): Use a fake (inf speed) sampler for testing.
|
||||
"""
|
||||
|
||||
|
@ -372,7 +375,8 @@ class PolicyEvaluator(EvaluatorInterface):
|
|||
pack=pack_episodes,
|
||||
tf_sess=self.tf_sess,
|
||||
clip_actions=clip_actions,
|
||||
blackhole_outputs="simulation" in input_evaluation)
|
||||
blackhole_outputs="simulation" in input_evaluation,
|
||||
soft_horizon=soft_horizon)
|
||||
self.sampler.start()
|
||||
else:
|
||||
self.sampler = SyncSampler(
|
||||
|
@ -387,7 +391,8 @@ class PolicyEvaluator(EvaluatorInterface):
|
|||
horizon=episode_horizon,
|
||||
pack=pack_episodes,
|
||||
tf_sess=self.tf_sess,
|
||||
clip_actions=clip_actions)
|
||||
clip_actions=clip_actions,
|
||||
soft_horizon=soft_horizon)
|
||||
|
||||
self.input_reader = input_creator(self.io_context)
|
||||
assert isinstance(self.input_reader, InputReader), self.input_reader
|
||||
|
|
|
@ -78,7 +78,8 @@ class SyncSampler(SamplerInput):
|
|||
horizon=None,
|
||||
pack=False,
|
||||
tf_sess=None,
|
||||
clip_actions=True):
|
||||
clip_actions=True,
|
||||
soft_horizon=False):
|
||||
self.base_env = BaseEnv.to_base_env(env)
|
||||
self.unroll_length = unroll_length
|
||||
self.horizon = horizon
|
||||
|
@ -92,7 +93,7 @@ class SyncSampler(SamplerInput):
|
|||
self.base_env, self.extra_batches.put, self.policies,
|
||||
self.policy_mapping_fn, self.unroll_length, self.horizon,
|
||||
self.preprocessors, self.obs_filters, clip_rewards, clip_actions,
|
||||
pack, callbacks, tf_sess, self.perf_stats)
|
||||
pack, callbacks, tf_sess, self.perf_stats, soft_horizon)
|
||||
self.metrics_queue = queue.Queue()
|
||||
|
||||
def get_data(self):
|
||||
|
@ -137,7 +138,8 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
|||
pack=False,
|
||||
tf_sess=None,
|
||||
clip_actions=True,
|
||||
blackhole_outputs=False):
|
||||
blackhole_outputs=False,
|
||||
soft_horizon=False):
|
||||
for _, f in obs_filters.items():
|
||||
assert getattr(f, "is_concurrent", False), \
|
||||
"Observation Filter must support concurrent updates."
|
||||
|
@ -159,6 +161,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
|||
self.callbacks = callbacks
|
||||
self.clip_actions = clip_actions
|
||||
self.blackhole_outputs = blackhole_outputs
|
||||
self.soft_horizon = soft_horizon
|
||||
self.perf_stats = PerfStats()
|
||||
self.shutdown = False
|
||||
|
||||
|
@ -182,7 +185,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
|||
self.policy_mapping_fn, self.unroll_length, self.horizon,
|
||||
self.preprocessors, self.obs_filters, self.clip_rewards,
|
||||
self.clip_actions, self.pack, self.callbacks, self.tf_sess,
|
||||
self.perf_stats)
|
||||
self.perf_stats, self.soft_horizon)
|
||||
while not self.shutdown:
|
||||
# The timeout variable exists because apparently, if one worker
|
||||
# dies, the other workers won't die with it, unless the timeout is
|
||||
|
@ -227,7 +230,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
|
|||
def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
|
||||
unroll_length, horizon, preprocessors, obs_filters,
|
||||
clip_rewards, clip_actions, pack, callbacks, tf_sess,
|
||||
perf_stats):
|
||||
perf_stats, soft_horizon):
|
||||
"""This implements the common experience collection logic.
|
||||
|
||||
Args:
|
||||
|
@ -252,6 +255,8 @@ def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
|
|||
tf_sess (Session|None): Optional tensorflow session to use for batching
|
||||
TF policy evaluations.
|
||||
perf_stats (PerfStats): Record perf stats into this object.
|
||||
soft_horizon (bool): Calculate rewards but don't reset the
|
||||
environment when the horizon is hit.
|
||||
|
||||
Yields:
|
||||
rollout (SampleBatch): Object containing state, action, reward,
|
||||
|
@ -307,7 +312,8 @@ def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
|
|||
active_envs, to_eval, outputs = _process_observations(
|
||||
base_env, policies, batch_builder_pool, active_episodes,
|
||||
unfiltered_obs, rewards, dones, infos, off_policy_actions, horizon,
|
||||
preprocessors, obs_filters, unroll_length, pack, callbacks)
|
||||
preprocessors, obs_filters, unroll_length, pack, callbacks,
|
||||
soft_horizon)
|
||||
perf_stats.processing_time += time.time() - t1
|
||||
for o in outputs:
|
||||
yield o
|
||||
|
@ -335,7 +341,8 @@ def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
|
|||
def _process_observations(base_env, policies, batch_builder_pool,
|
||||
active_episodes, unfiltered_obs, rewards, dones,
|
||||
infos, off_policy_actions, horizon, preprocessors,
|
||||
obs_filters, unroll_length, pack, callbacks):
|
||||
obs_filters, unroll_length, pack, callbacks,
|
||||
soft_horizon):
|
||||
"""Record new data from the environment and prepare for policy evaluation.
|
||||
|
||||
Returns:
|
||||
|
@ -372,6 +379,8 @@ def _process_observations(base_env, policies, batch_builder_pool,
|
|||
|
||||
# Check episode termination conditions
|
||||
if dones[env_id]["__all__"] or episode.length >= horizon:
|
||||
hit_horizon = (episode.length >= horizon
|
||||
and not dones[env_id]["__all__"])
|
||||
all_done = True
|
||||
atari_metrics = _fetch_atari_metrics(base_env)
|
||||
if atari_metrics is not None:
|
||||
|
@ -384,6 +393,7 @@ def _process_observations(base_env, policies, batch_builder_pool,
|
|||
dict(episode.agent_rewards),
|
||||
episode.custom_metrics, {}))
|
||||
else:
|
||||
hit_horizon = False
|
||||
all_done = False
|
||||
active_envs.add(env_id)
|
||||
|
||||
|
@ -427,7 +437,8 @@ def _process_observations(base_env, policies, batch_builder_pool,
|
|||
rewards=rewards[env_id][agent_id],
|
||||
prev_actions=episode.prev_action_for(agent_id),
|
||||
prev_rewards=episode.prev_reward_for(agent_id),
|
||||
dones=agent_done,
|
||||
dones=(False
|
||||
if (hit_horizon and soft_horizon) else agent_done),
|
||||
infos=infos[env_id].get(agent_id, {}),
|
||||
new_obs=filtered_obs,
|
||||
**episode.last_pi_info_for(agent_id))
|
||||
|
@ -457,8 +468,12 @@ def _process_observations(base_env, policies, batch_builder_pool,
|
|||
"policy": policies,
|
||||
"episode": episode
|
||||
})
|
||||
del active_episodes[env_id]
|
||||
resetted_obs = base_env.try_reset(env_id)
|
||||
if hit_horizon and soft_horizon:
|
||||
episode.soft_reset()
|
||||
resetted_obs = agent_obs
|
||||
else:
|
||||
del active_episodes[env_id]
|
||||
resetted_obs = base_env.try_reset(env_id)
|
||||
if resetted_obs is None:
|
||||
# Reset not supported, drop this env from the ready list
|
||||
if horizon != float("inf"):
|
||||
|
|
|
@ -249,6 +249,34 @@ class TestPolicyEvaluator(unittest.TestCase):
|
|||
result2 = collect_metrics(ev2, [])
|
||||
self.assertEqual(result2["episode_reward_mean"], 1000)
|
||||
|
||||
def testHardHorizon(self):
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: MockEnv(episode_length=10),
|
||||
policy_graph=MockPolicyGraph,
|
||||
batch_mode="complete_episodes",
|
||||
batch_steps=10,
|
||||
episode_horizon=4,
|
||||
soft_horizon=False)
|
||||
samples = ev.sample()
|
||||
# three logical episodes
|
||||
self.assertEqual(len(set(samples["eps_id"])), 3)
|
||||
# 3 done values
|
||||
self.assertEqual(sum(samples["dones"]), 3)
|
||||
|
||||
def testSoftHorizon(self):
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: MockEnv(episode_length=10),
|
||||
policy_graph=MockPolicyGraph,
|
||||
batch_mode="complete_episodes",
|
||||
batch_steps=10,
|
||||
episode_horizon=4,
|
||||
soft_horizon=True)
|
||||
samples = ev.sample()
|
||||
# three logical episodes
|
||||
self.assertEqual(len(set(samples["eps_id"])), 3)
|
||||
# only 1 hard done value
|
||||
self.assertEqual(sum(samples["dones"]), 1)
|
||||
|
||||
def testMetrics(self):
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: MockEnv(episode_length=10),
|
||||
|
|
Loading…
Add table
Reference in a new issue