mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[RLlib] Trajectory view API: enable by default for ES and ARS (#11826)
This commit is contained in:
parent
6e6c680f14
commit
0bd69edd71
5 changed files with 18 additions and 10 deletions
|
@ -47,6 +47,10 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"num_envs_per_worker": 1,
|
||||
"observation_filter": "NoFilter"
|
||||
},
|
||||
|
||||
# Use the new "trajectory view API" to collect samples and produce
|
||||
# model- and policy inputs.
|
||||
"_use_trajectory_view_api": True,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
@ -327,7 +331,7 @@ class ARSTrainer(Trainer):
|
|||
|
||||
@override(Trainer)
|
||||
def compute_action(self, observation, *args, **kwargs):
|
||||
action, _, _ = self.policy.compute_actions(observation, update=True)
|
||||
action, _, _ = self.policy.compute_actions([observation], update=True)
|
||||
if kwargs.get("full_fetch"):
|
||||
return action[0], [], {}
|
||||
return action[0]
|
||||
|
|
|
@ -68,9 +68,9 @@ class ARSTFPolicy(Policy):
|
|||
add_noise=False,
|
||||
update=True,
|
||||
**kwargs):
|
||||
# Batch is given as list of one.
|
||||
if isinstance(observation, list) and len(observation) == 1:
|
||||
observation = observation[0]
|
||||
# Squeeze batch dimension (we always calculate actions for only a
|
||||
# single obs).
|
||||
observation = observation[0]
|
||||
observation = self.preprocessor.transform(observation)
|
||||
observation = self.observation_filter(observation[None], update=update)
|
||||
|
||||
|
|
|
@ -45,6 +45,10 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"num_envs_per_worker": 1,
|
||||
"observation_filter": "NoFilter"
|
||||
},
|
||||
|
||||
# Use the new "trajectory view API" to collect samples and produce
|
||||
# model- and policy inputs.
|
||||
"_use_trajectory_view_api": True,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
@ -324,7 +328,7 @@ class ESTrainer(Trainer):
|
|||
|
||||
@override(Trainer)
|
||||
def compute_action(self, observation, *args, **kwargs):
|
||||
action, _, _ = self.policy.compute_actions(observation, update=False)
|
||||
action, _, _ = self.policy.compute_actions([observation], update=False)
|
||||
if kwargs.get("full_fetch"):
|
||||
return action[0], [], {}
|
||||
return action[0]
|
||||
|
|
|
@ -47,7 +47,7 @@ def rollout(policy, env, timestep_limit=None, add_noise=False, offset=0.0):
|
|||
observation = env.reset()
|
||||
for _ in range(timestep_limit or max_timestep_limit):
|
||||
ac, _, _ = policy.compute_actions(
|
||||
observation, add_noise=add_noise, update=True)
|
||||
[observation], add_noise=add_noise, update=True)
|
||||
ac = ac[0]
|
||||
observation, r, done, _ = env.step(ac)
|
||||
if offset != 0.0:
|
||||
|
@ -118,9 +118,9 @@ class ESTFPolicy(Policy):
|
|||
add_noise=False,
|
||||
update=True,
|
||||
**kwargs):
|
||||
# Batch is given as list of one.
|
||||
if isinstance(observation, list) and len(observation) == 1:
|
||||
observation = observation[0]
|
||||
# Squeeze batch dimension (we always calculate actions for only a
|
||||
# single obs).
|
||||
observation = observation[0]
|
||||
observation = self.preprocessor.transform(observation)
|
||||
observation = self.observation_filter(observation[None], update=update)
|
||||
# `actions` is a list of (component) batches.
|
||||
|
|
|
@ -18,7 +18,7 @@ class TestES(unittest.TestCase):
|
|||
config["num_workers"] = 1
|
||||
config["episodes_per_batch"] = 10
|
||||
config["train_batch_size"] = 100
|
||||
# Test eval workers ("normal" Trainer eval WorkerSet, unusual for ARS).
|
||||
# Test eval workers ("normal" Trainer eval WorkerSet).
|
||||
config["evaluation_interval"] = 1
|
||||
config["evaluation_num_workers"] = 2
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue