diff --git a/rllib/agents/ars/ars.py b/rllib/agents/ars/ars.py index ef74f1a3e..ef2195213 100644 --- a/rllib/agents/ars/ars.py +++ b/rllib/agents/ars/ars.py @@ -47,6 +47,10 @@ DEFAULT_CONFIG = with_common_config({ "num_envs_per_worker": 1, "observation_filter": "NoFilter" }, + + # Use the new "trajectory view API" to collect samples and produce + # model- and policy inputs. + "_use_trajectory_view_api": True, }) # __sphinx_doc_end__ # yapf: enable @@ -327,7 +331,7 @@ class ARSTrainer(Trainer): @override(Trainer) def compute_action(self, observation, *args, **kwargs): - action, _, _ = self.policy.compute_actions(observation, update=True) + action, _, _ = self.policy.compute_actions([observation], update=True) if kwargs.get("full_fetch"): return action[0], [], {} return action[0] diff --git a/rllib/agents/ars/ars_tf_policy.py b/rllib/agents/ars/ars_tf_policy.py index b482449d8..9adcc1561 100644 --- a/rllib/agents/ars/ars_tf_policy.py +++ b/rllib/agents/ars/ars_tf_policy.py @@ -68,9 +68,9 @@ class ARSTFPolicy(Policy): add_noise=False, update=True, **kwargs): - # Batch is given as list of one. - if isinstance(observation, list) and len(observation) == 1: - observation = observation[0] + # Squeeze batch dimension (we always calculate actions for only a + # single obs). + observation = observation[0] observation = self.preprocessor.transform(observation) observation = self.observation_filter(observation[None], update=update) diff --git a/rllib/agents/es/es.py b/rllib/agents/es/es.py index d669bd5d0..065f141f5 100644 --- a/rllib/agents/es/es.py +++ b/rllib/agents/es/es.py @@ -45,6 +45,10 @@ DEFAULT_CONFIG = with_common_config({ "num_envs_per_worker": 1, "observation_filter": "NoFilter" }, + + # Use the new "trajectory view API" to collect samples and produce + # model- and policy inputs. + "_use_trajectory_view_api": True, }) # __sphinx_doc_end__ # yapf: enable @@ -324,7 +328,7 @@ class ESTrainer(Trainer): @override(Trainer) def compute_action(self, observation, *args, **kwargs): - action, _, _ = self.policy.compute_actions(observation, update=False) + action, _, _ = self.policy.compute_actions([observation], update=False) if kwargs.get("full_fetch"): return action[0], [], {} return action[0] diff --git a/rllib/agents/es/es_tf_policy.py b/rllib/agents/es/es_tf_policy.py index ad19b7ba5..6c2bf206b 100644 --- a/rllib/agents/es/es_tf_policy.py +++ b/rllib/agents/es/es_tf_policy.py @@ -47,7 +47,7 @@ def rollout(policy, env, timestep_limit=None, add_noise=False, offset=0.0): observation = env.reset() for _ in range(timestep_limit or max_timestep_limit): ac, _, _ = policy.compute_actions( - observation, add_noise=add_noise, update=True) + [observation], add_noise=add_noise, update=True) ac = ac[0] observation, r, done, _ = env.step(ac) if offset != 0.0: @@ -118,9 +118,9 @@ class ESTFPolicy(Policy): add_noise=False, update=True, **kwargs): - # Batch is given as list of one. - if isinstance(observation, list) and len(observation) == 1: - observation = observation[0] + # Squeeze batch dimension (we always calculate actions for only a + # single obs). + observation = observation[0] observation = self.preprocessor.transform(observation) observation = self.observation_filter(observation[None], update=update) # `actions` is a list of (component) batches. diff --git a/rllib/agents/es/tests/test_es.py b/rllib/agents/es/tests/test_es.py index c29761193..94d8b0f74 100644 --- a/rllib/agents/es/tests/test_es.py +++ b/rllib/agents/es/tests/test_es.py @@ -18,7 +18,7 @@ class TestES(unittest.TestCase): config["num_workers"] = 1 config["episodes_per_batch"] = 10 config["train_batch_size"] = 100 - # Test eval workers ("normal" Trainer eval WorkerSet, unusual for ARS). + # Test eval workers ("normal" Trainer eval WorkerSet). config["evaluation_interval"] = 1 config["evaluation_num_workers"] = 2