[RLlib] Issue 7136: rollout not working for ES and ARS. (#7444)

* Fix.

* Fix issue #7136.

* ARS fix.
This commit is contained in:
Eric Liang 2020-03-04 23:57:44 -08:00 committed by GitHub
parent 476b5c6196
commit 1989eed3bf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 17 additions and 11 deletions

View file

@ -165,8 +165,7 @@ class ARSTrainer(Trainer):
# PyTorch check.
if config["use_pytorch"]:
raise ValueError(
"ARS does not support PyTorch yet! Use tf instead."
)
"ARS does not support PyTorch yet! Use tf instead.")
env = env_creator(config["env_config"])
from ray.rllib import models
@ -301,7 +300,7 @@ class ARSTrainer(Trainer):
w.__ray_terminate__.remote()
@override(Trainer)
def compute_action(self, observation):
def compute_action(self, observation, *args, **kwargs):
return self.policy.compute(observation, update=True)[0]
def _collect_results(self, theta_id, min_episodes):

View file

@ -171,8 +171,7 @@ class ESTrainer(Trainer):
# PyTorch check.
if config["use_pytorch"]:
raise ValueError(
"ES does not support PyTorch yet! Use tf instead."
)
"ES does not support PyTorch yet! Use tf instead.")
policy_params = {"action_noise_std": 0.01}
@ -292,7 +291,7 @@ class ESTrainer(Trainer):
return result
@override(Trainer)
def compute_action(self, observation):
def compute_action(self, observation, *args, **kwargs):
return self.policy.compute(observation, update=False)[0]
@override(Trainer)

View file

@ -15,6 +15,7 @@ from ray.rllib.agents.registry import get_agent_class
from ray.rllib.env import MultiAgentEnv
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
from ray.rllib.evaluation.episode import _flatten_action
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.deprecation import deprecation_warning
from ray.tune.utils import merge_dicts
@ -339,7 +340,7 @@ def rollout(agent,
if saver is None:
saver = RolloutSaver()
if hasattr(agent, "workers"):
if hasattr(agent, "workers") and isinstance(agent.workers, WorkerSet):
env = agent.workers.local_worker().env
multiagent = isinstance(env, MultiAgentEnv)
if agent.workers.local_worker().multiagent:
@ -349,15 +350,22 @@ def rollout(agent,
policy_map = agent.workers.local_worker().policy_map
state_init = {p: m.get_initial_state() for p, m in policy_map.items()}
use_lstm = {p: len(s) > 0 for p, s in state_init.items()}
action_init = {
p: _flatten_action(m.action_space.sample())
for p, m in policy_map.items()
}
else:
env = gym.make(env_name)
multiagent = False
try:
policy_map = {DEFAULT_POLICY_ID: agent.policy}
except AttributeError:
raise AttributeError(
"Agent ({}) does not have a `policy` property! This is needed "
"for performing (trained) agent rollouts.".format(agent))
use_lstm = {DEFAULT_POLICY_ID: False}
action_init = {
p: _flatten_action(m.action_space.sample())
for p, m in policy_map.items()
}
# If monitoring has been requested, manually wrap our environment with a
# gym monitor, which is set to record every episode.
if video_dir: