[RLlib] Trajectory view API: enable by default for ES and ARS (#11826)

This commit is contained in:
Sven Mika 2020-11-12 19:33:10 +01:00 committed by GitHub
parent 6e6c680f14
commit 0bd69edd71
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 18 additions and 10 deletions

View file

@ -47,6 +47,10 @@ DEFAULT_CONFIG = with_common_config({
"num_envs_per_worker": 1,
"observation_filter": "NoFilter"
},
# Use the new "trajectory view API" to collect samples and produce
# model- and policy inputs.
"_use_trajectory_view_api": True,
})
# __sphinx_doc_end__
# yapf: enable
@ -327,7 +331,7 @@ class ARSTrainer(Trainer):
@override(Trainer)
def compute_action(self, observation, *args, **kwargs):
action, _, _ = self.policy.compute_actions(observation, update=True)
action, _, _ = self.policy.compute_actions([observation], update=True)
if kwargs.get("full_fetch"):
return action[0], [], {}
return action[0]

View file

@ -68,9 +68,9 @@ class ARSTFPolicy(Policy):
add_noise=False,
update=True,
**kwargs):
# Batch is given as list of one.
if isinstance(observation, list) and len(observation) == 1:
observation = observation[0]
# Squeeze batch dimension (we always calculate actions for only a
# single obs).
observation = observation[0]
observation = self.preprocessor.transform(observation)
observation = self.observation_filter(observation[None], update=update)

View file

@ -45,6 +45,10 @@ DEFAULT_CONFIG = with_common_config({
"num_envs_per_worker": 1,
"observation_filter": "NoFilter"
},
# Use the new "trajectory view API" to collect samples and produce
# model- and policy inputs.
"_use_trajectory_view_api": True,
})
# __sphinx_doc_end__
# yapf: enable
@ -324,7 +328,7 @@ class ESTrainer(Trainer):
@override(Trainer)
def compute_action(self, observation, *args, **kwargs):
action, _, _ = self.policy.compute_actions(observation, update=False)
action, _, _ = self.policy.compute_actions([observation], update=False)
if kwargs.get("full_fetch"):
return action[0], [], {}
return action[0]

View file

@ -47,7 +47,7 @@ def rollout(policy, env, timestep_limit=None, add_noise=False, offset=0.0):
observation = env.reset()
for _ in range(timestep_limit or max_timestep_limit):
ac, _, _ = policy.compute_actions(
observation, add_noise=add_noise, update=True)
[observation], add_noise=add_noise, update=True)
ac = ac[0]
observation, r, done, _ = env.step(ac)
if offset != 0.0:
@ -118,9 +118,9 @@ class ESTFPolicy(Policy):
add_noise=False,
update=True,
**kwargs):
# Batch is given as list of one.
if isinstance(observation, list) and len(observation) == 1:
observation = observation[0]
# Squeeze batch dimension (we always calculate actions for only a
# single obs).
observation = observation[0]
observation = self.preprocessor.transform(observation)
observation = self.observation_filter(observation[None], update=update)
# `actions` is a list of (component) batches.

View file

@ -18,7 +18,7 @@ class TestES(unittest.TestCase):
config["num_workers"] = 1
config["episodes_per_batch"] = 10
config["train_batch_size"] = 100
# Test eval workers ("normal" Trainer eval WorkerSet, unusual for ARS).
# Test eval workers ("normal" Trainer eval WorkerSet).
config["evaluation_interval"] = 1
config["evaluation_num_workers"] = 2