[rllib] Add option for RNN state and value estimates to span episodes (#4429)

* wip soft horizon

* tests
This commit is contained in:
Eric Liang 2019-04-02 02:44:15 -07:00 committed by GitHub
parent c2c548bdfd
commit 55a2d39409
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 80 additions and 13 deletions

View file

@ -70,8 +70,13 @@ COMMON_CONFIG = {
# === Environment ===
# Discount factor of the MDP
"gamma": 0.99,
# Number of steps after which the episode is forced to terminate
# Number of steps after which the episode is forced to terminate. Defaults
# to `env.spec.max_episode_steps` (if present) for Gym envs.
"horizon": None,
# Calculate rewards but don't reset the environment when the horizon is
# hit. This allows value estimation and RNN state to span across logical
# episodes denoted by horizon. This only has an effect if horizon != inf.
"soft_horizon": False,
# Arguments to pass to the env creator
"env_config": {},
# Environment name can also be passed via config
@ -746,6 +751,7 @@ class Agent(Trainable):
output_creator=output_creator,
remote_worker_envs=config["remote_worker_envs"],
remote_env_batch_wait_ms=config["remote_env_batch_wait_ms"],
soft_horizon=config["soft_horizon"],
_fake_sampler=config.get("_fake_sampler", False))
@override(Trainable)

View file

@ -65,6 +65,19 @@ class MultiAgentEpisode(object):
self._agent_to_prev_action = {}
self._agent_reward_history = defaultdict(list)
@DeveloperAPI
def soft_reset(self):
"""Clears rewards and metrics, but retains RNN and other state.
This is used to carry state across multiple logical episodes in the
same env (i.e., if `soft_horizon` is set).
"""
self.length = 0
self.episode_id = random.randrange(2e9)
self.total_reward = 0.0
self.agent_rewards = defaultdict(float)
self._agent_reward_history = defaultdict(list)
@DeveloperAPI
def policy_for(self, agent_id=_DUMMY_AGENT_ID):
"""Returns the policy graph for the specified agent.

View file

@ -125,6 +125,7 @@ class PolicyEvaluator(EvaluatorInterface):
output_creator=lambda ioctx: NoopOutput(),
remote_worker_envs=False,
remote_env_batch_wait_ms=0,
soft_horizon=False,
_fake_sampler=False):
"""Initialize a policy evaluator.
@ -208,6 +209,8 @@ class PolicyEvaluator(EvaluatorInterface):
least one env is ready) is a reasonable default, but optimal
value could be obtained by measuring your environment
step / reset and model inference perf.
soft_horizon (bool): Calculate rewards but don't reset the
environment when the horizon is hit.
_fake_sampler (bool): Use a fake (inf speed) sampler for testing.
"""
@ -372,7 +375,8 @@ class PolicyEvaluator(EvaluatorInterface):
pack=pack_episodes,
tf_sess=self.tf_sess,
clip_actions=clip_actions,
blackhole_outputs="simulation" in input_evaluation)
blackhole_outputs="simulation" in input_evaluation,
soft_horizon=soft_horizon)
self.sampler.start()
else:
self.sampler = SyncSampler(
@ -387,7 +391,8 @@ class PolicyEvaluator(EvaluatorInterface):
horizon=episode_horizon,
pack=pack_episodes,
tf_sess=self.tf_sess,
clip_actions=clip_actions)
clip_actions=clip_actions,
soft_horizon=soft_horizon)
self.input_reader = input_creator(self.io_context)
assert isinstance(self.input_reader, InputReader), self.input_reader

View file

@ -78,7 +78,8 @@ class SyncSampler(SamplerInput):
horizon=None,
pack=False,
tf_sess=None,
clip_actions=True):
clip_actions=True,
soft_horizon=False):
self.base_env = BaseEnv.to_base_env(env)
self.unroll_length = unroll_length
self.horizon = horizon
@ -92,7 +93,7 @@ class SyncSampler(SamplerInput):
self.base_env, self.extra_batches.put, self.policies,
self.policy_mapping_fn, self.unroll_length, self.horizon,
self.preprocessors, self.obs_filters, clip_rewards, clip_actions,
pack, callbacks, tf_sess, self.perf_stats)
pack, callbacks, tf_sess, self.perf_stats, soft_horizon)
self.metrics_queue = queue.Queue()
def get_data(self):
@ -137,7 +138,8 @@ class AsyncSampler(threading.Thread, SamplerInput):
pack=False,
tf_sess=None,
clip_actions=True,
blackhole_outputs=False):
blackhole_outputs=False,
soft_horizon=False):
for _, f in obs_filters.items():
assert getattr(f, "is_concurrent", False), \
"Observation Filter must support concurrent updates."
@ -159,6 +161,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
self.callbacks = callbacks
self.clip_actions = clip_actions
self.blackhole_outputs = blackhole_outputs
self.soft_horizon = soft_horizon
self.perf_stats = PerfStats()
self.shutdown = False
@ -182,7 +185,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
self.policy_mapping_fn, self.unroll_length, self.horizon,
self.preprocessors, self.obs_filters, self.clip_rewards,
self.clip_actions, self.pack, self.callbacks, self.tf_sess,
self.perf_stats)
self.perf_stats, self.soft_horizon)
while not self.shutdown:
# The timeout variable exists because apparently, if one worker
# dies, the other workers won't die with it, unless the timeout is
@ -227,7 +230,7 @@ class AsyncSampler(threading.Thread, SamplerInput):
def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
unroll_length, horizon, preprocessors, obs_filters,
clip_rewards, clip_actions, pack, callbacks, tf_sess,
perf_stats):
perf_stats, soft_horizon):
"""This implements the common experience collection logic.
Args:
@ -252,6 +255,8 @@ def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
tf_sess (Session|None): Optional tensorflow session to use for batching
TF policy evaluations.
perf_stats (PerfStats): Record perf stats into this object.
soft_horizon (bool): Calculate rewards but don't reset the
environment when the horizon is hit.
Yields:
rollout (SampleBatch): Object containing state, action, reward,
@ -307,7 +312,8 @@ def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
active_envs, to_eval, outputs = _process_observations(
base_env, policies, batch_builder_pool, active_episodes,
unfiltered_obs, rewards, dones, infos, off_policy_actions, horizon,
preprocessors, obs_filters, unroll_length, pack, callbacks)
preprocessors, obs_filters, unroll_length, pack, callbacks,
soft_horizon)
perf_stats.processing_time += time.time() - t1
for o in outputs:
yield o
@ -335,7 +341,8 @@ def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
def _process_observations(base_env, policies, batch_builder_pool,
active_episodes, unfiltered_obs, rewards, dones,
infos, off_policy_actions, horizon, preprocessors,
obs_filters, unroll_length, pack, callbacks):
obs_filters, unroll_length, pack, callbacks,
soft_horizon):
"""Record new data from the environment and prepare for policy evaluation.
Returns:
@ -372,6 +379,8 @@ def _process_observations(base_env, policies, batch_builder_pool,
# Check episode termination conditions
if dones[env_id]["__all__"] or episode.length >= horizon:
hit_horizon = (episode.length >= horizon
and not dones[env_id]["__all__"])
all_done = True
atari_metrics = _fetch_atari_metrics(base_env)
if atari_metrics is not None:
@ -384,6 +393,7 @@ def _process_observations(base_env, policies, batch_builder_pool,
dict(episode.agent_rewards),
episode.custom_metrics, {}))
else:
hit_horizon = False
all_done = False
active_envs.add(env_id)
@ -427,7 +437,8 @@ def _process_observations(base_env, policies, batch_builder_pool,
rewards=rewards[env_id][agent_id],
prev_actions=episode.prev_action_for(agent_id),
prev_rewards=episode.prev_reward_for(agent_id),
dones=agent_done,
dones=(False
if (hit_horizon and soft_horizon) else agent_done),
infos=infos[env_id].get(agent_id, {}),
new_obs=filtered_obs,
**episode.last_pi_info_for(agent_id))
@ -457,8 +468,12 @@ def _process_observations(base_env, policies, batch_builder_pool,
"policy": policies,
"episode": episode
})
del active_episodes[env_id]
resetted_obs = base_env.try_reset(env_id)
if hit_horizon and soft_horizon:
episode.soft_reset()
resetted_obs = agent_obs
else:
del active_episodes[env_id]
resetted_obs = base_env.try_reset(env_id)
if resetted_obs is None:
# Reset not supported, drop this env from the ready list
if horizon != float("inf"):

View file

@ -249,6 +249,34 @@ class TestPolicyEvaluator(unittest.TestCase):
result2 = collect_metrics(ev2, [])
self.assertEqual(result2["episode_reward_mean"], 1000)
def testHardHorizon(self):
ev = PolicyEvaluator(
env_creator=lambda _: MockEnv(episode_length=10),
policy_graph=MockPolicyGraph,
batch_mode="complete_episodes",
batch_steps=10,
episode_horizon=4,
soft_horizon=False)
samples = ev.sample()
# three logical episodes
self.assertEqual(len(set(samples["eps_id"])), 3)
# 3 done values
self.assertEqual(sum(samples["dones"]), 3)
def testSoftHorizon(self):
ev = PolicyEvaluator(
env_creator=lambda _: MockEnv(episode_length=10),
policy_graph=MockPolicyGraph,
batch_mode="complete_episodes",
batch_steps=10,
episode_horizon=4,
soft_horizon=True)
samples = ev.sample()
# three logical episodes
self.assertEqual(len(set(samples["eps_id"])), 3)
# only 1 hard done value
self.assertEqual(sum(samples["dones"]), 1)
def testMetrics(self):
ev = PolicyEvaluator(
env_creator=lambda _: MockEnv(episode_length=10),