mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[RLlib] Small bug fix (#27003)
This commit is contained in:
parent
54df8bfe42
commit
a22457b548
6 changed files with 46 additions and 23 deletions
|
@ -1959,7 +1959,11 @@ def _determine_spaces_for_multi_agent_dict(
|
|||
# Multi-agent case AND different agents have different spaces:
|
||||
# Need to reverse map spaces (for the different agents) to certain
|
||||
# policy IDs.
|
||||
if isinstance(env, MultiAgentEnv) and env._spaces_in_preferred_format:
|
||||
if (
|
||||
isinstance(env, MultiAgentEnv)
|
||||
and hasattr(env, "_spaces_in_preferred_format")
|
||||
and env._spaces_in_preferred_format
|
||||
):
|
||||
obs_space = None
|
||||
mapping_fn = policy_config.get("multiagent", {}).get(
|
||||
"policy_mapping_fn", None
|
||||
|
@ -2004,7 +2008,11 @@ def _determine_spaces_for_multi_agent_dict(
|
|||
# Multi-agent case AND different agents have different spaces:
|
||||
# Need to reverse map spaces (for the different agents) to certain
|
||||
# policy IDs.
|
||||
if isinstance(env, MultiAgentEnv) and env._spaces_in_preferred_format:
|
||||
if (
|
||||
isinstance(env, MultiAgentEnv)
|
||||
and hasattr(env, "_spaces_in_preferred_format")
|
||||
and env._spaces_in_preferred_format
|
||||
):
|
||||
act_space = None
|
||||
mapping_fn = policy_config.get("multiagent", {}).get(
|
||||
"policy_mapping_fn", None
|
||||
|
|
|
@ -22,7 +22,7 @@ from ray.rllib.examples.env.mock_env import (
|
|||
)
|
||||
from ray.rllib.examples.env.multi_agent import BasicMultiAgent, MultiAgentCartPole
|
||||
from ray.rllib.examples.policy.random_policy import RandomPolicy
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.policy import Policy, PolicySpec
|
||||
from ray.rllib.policy.sample_batch import (
|
||||
DEFAULT_POLICY_ID,
|
||||
MultiAgentBatch,
|
||||
|
@ -814,6 +814,41 @@ class TestRolloutWorker(unittest.TestCase):
|
|||
self.assertEqual(seeds, [1, 2, 3])
|
||||
ev.stop()
|
||||
|
||||
def test_determine_spaces_for_multi_agent_dict(self):
|
||||
class MockMultiAgentEnv(MultiAgentEnv):
|
||||
"""A mock testing MultiAgentEnv that doesn't call super.__init__()."""
|
||||
|
||||
def __init__(self):
|
||||
# Intentinoally don't call super().__init__(),
|
||||
# so this env doesn't have _spaces_in_preferred_format
|
||||
# attribute.
|
||||
self.observation_space = gym.spaces.Discrete(2)
|
||||
self.action_space = gym.spaces.Discrete(2)
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def step(self, action_dict):
|
||||
obs = {1: [0, 0], 2: [1, 1]}
|
||||
rewards = {1: 0, 2: 0}
|
||||
dones = {1: False, 2: False, "__all__": False}
|
||||
infos = {1: {}, 2: {}}
|
||||
return obs, rewards, dones, infos
|
||||
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: MockMultiAgentEnv(),
|
||||
num_envs=3,
|
||||
policy_spec={
|
||||
"policy_1": PolicySpec(policy_class=MockPolicy),
|
||||
"policy_2": PolicySpec(policy_class=MockPolicy),
|
||||
},
|
||||
seed=1,
|
||||
)
|
||||
# The fact that this RolloutWorker can be created without throwing
|
||||
# exceptions means _determine_spaces_for_multi_agent_dict() is
|
||||
# handling multiagent user environments properly.
|
||||
self.assertIsNotNone(ev)
|
||||
|
||||
def test_wrap_multi_agent_env(self):
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: BasicMultiAgent(10),
|
||||
|
|
|
@ -24,11 +24,6 @@ cartpole-crashing-pg:
|
|||
# Switch on resiliency for failed sub environments (within a vectorized stack).
|
||||
restart_failed_sub_environments: true
|
||||
|
||||
# Disable env checking. Otherwise, RolloutWorkers will crash during
|
||||
# initialization, which is not covered by the
|
||||
# `restart_failed_sub_environments=True` failure tolerance mode.
|
||||
disable_env_checking: true
|
||||
|
||||
evaluation_num_workers: 2
|
||||
evaluation_interval: 1
|
||||
evaluation_duration: 20
|
||||
|
|
|
@ -22,10 +22,5 @@ cartpole-crashing-with-remote-envs-pg:
|
|||
# Use parallel remote envs.
|
||||
remote_worker_envs: true
|
||||
|
||||
# Disable env checking. Otherwise, RolloutWorkers will crash during
|
||||
# initialization, which is not covered by the
|
||||
# `restart_failed_sub_environments=True` failure tolerance mode.
|
||||
disable_env_checking: true
|
||||
|
||||
# Switch on resiliency for failed sub environments (within a vectorized stack).
|
||||
restart_failed_sub_environments: true
|
||||
|
|
|
@ -25,11 +25,6 @@ multi-agent-cartpole-crashing-pg:
|
|||
# Switch on resiliency for failed sub environments (within a vectorized stack).
|
||||
restart_failed_sub_environments: true
|
||||
|
||||
# Disable env checking. Otherwise, RolloutWorkers will crash during
|
||||
# initialization, which is not covered by the
|
||||
# `restart_failed_sub_environments=True` failure tolerance mode.
|
||||
disable_env_checking: true
|
||||
|
||||
evaluation_num_workers: 2
|
||||
evaluation_interval: 1
|
||||
evaluation_duration: 20
|
||||
|
|
|
@ -28,11 +28,6 @@ multi-agent-cartpole-crashing-pg:
|
|||
# Switch on resiliency for failed sub environments (within a vectorized stack).
|
||||
restart_failed_sub_environments: true
|
||||
|
||||
# Disable env checking. Otherwise, RolloutWorkers will crash during
|
||||
# initialization, which is not covered by the
|
||||
# `restart_failed_sub_environments=True` failure tolerance mode.
|
||||
disable_env_checking: true
|
||||
|
||||
evaluation_num_workers: 1
|
||||
evaluation_interval: 1
|
||||
evaluation_duration: 10
|
||||
|
|
Loading…
Add table
Reference in a new issue