mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Issue 7136: rollout not working for ES and ARS. (#7444)
* Fix. * Fix issue #7136. * ARS fix.
This commit is contained in:
parent
476b5c6196
commit
1989eed3bf
3 changed files with 17 additions and 11 deletions
|
@ -165,8 +165,7 @@ class ARSTrainer(Trainer):
|
|||
# PyTorch check.
|
||||
if config["use_pytorch"]:
|
||||
raise ValueError(
|
||||
"ARS does not support PyTorch yet! Use tf instead."
|
||||
)
|
||||
"ARS does not support PyTorch yet! Use tf instead.")
|
||||
|
||||
env = env_creator(config["env_config"])
|
||||
from ray.rllib import models
|
||||
|
@ -301,7 +300,7 @@ class ARSTrainer(Trainer):
|
|||
w.__ray_terminate__.remote()
|
||||
|
||||
@override(Trainer)
|
||||
def compute_action(self, observation):
|
||||
def compute_action(self, observation, *args, **kwargs):
|
||||
return self.policy.compute(observation, update=True)[0]
|
||||
|
||||
def _collect_results(self, theta_id, min_episodes):
|
||||
|
|
|
@ -171,8 +171,7 @@ class ESTrainer(Trainer):
|
|||
# PyTorch check.
|
||||
if config["use_pytorch"]:
|
||||
raise ValueError(
|
||||
"ES does not support PyTorch yet! Use tf instead."
|
||||
)
|
||||
"ES does not support PyTorch yet! Use tf instead.")
|
||||
|
||||
policy_params = {"action_noise_std": 0.01}
|
||||
|
||||
|
@ -292,7 +291,7 @@ class ESTrainer(Trainer):
|
|||
return result
|
||||
|
||||
@override(Trainer)
|
||||
def compute_action(self, observation):
|
||||
def compute_action(self, observation, *args, **kwargs):
|
||||
return self.policy.compute(observation, update=False)[0]
|
||||
|
||||
@override(Trainer)
|
||||
|
|
|
@ -15,6 +15,7 @@ from ray.rllib.agents.registry import get_agent_class
|
|||
from ray.rllib.env import MultiAgentEnv
|
||||
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
|
||||
from ray.rllib.evaluation.episode import _flatten_action
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
from ray.tune.utils import merge_dicts
|
||||
|
@ -339,7 +340,7 @@ def rollout(agent,
|
|||
if saver is None:
|
||||
saver = RolloutSaver()
|
||||
|
||||
if hasattr(agent, "workers"):
|
||||
if hasattr(agent, "workers") and isinstance(agent.workers, WorkerSet):
|
||||
env = agent.workers.local_worker().env
|
||||
multiagent = isinstance(env, MultiAgentEnv)
|
||||
if agent.workers.local_worker().multiagent:
|
||||
|
@ -349,15 +350,22 @@ def rollout(agent,
|
|||
policy_map = agent.workers.local_worker().policy_map
|
||||
state_init = {p: m.get_initial_state() for p, m in policy_map.items()}
|
||||
use_lstm = {p: len(s) > 0 for p, s in state_init.items()}
|
||||
action_init = {
|
||||
p: _flatten_action(m.action_space.sample())
|
||||
for p, m in policy_map.items()
|
||||
}
|
||||
else:
|
||||
env = gym.make(env_name)
|
||||
multiagent = False
|
||||
try:
|
||||
policy_map = {DEFAULT_POLICY_ID: agent.policy}
|
||||
except AttributeError:
|
||||
raise AttributeError(
|
||||
"Agent ({}) does not have a `policy` property! This is needed "
|
||||
"for performing (trained) agent rollouts.".format(agent))
|
||||
use_lstm = {DEFAULT_POLICY_ID: False}
|
||||
|
||||
action_init = {
|
||||
p: _flatten_action(m.action_space.sample())
|
||||
for p, m in policy_map.items()
|
||||
}
|
||||
|
||||
# If monitoring has been requested, manually wrap our environment with a
|
||||
# gym monitor, which is set to record every episode.
|
||||
if video_dir:
|
||||
|
|
Loading…
Add table
Reference in a new issue