[RLlib] Small bug fix (#27003)

This commit is contained in:
Jun Gong 2022-07-27 00:02:18 -07:00 committed by GitHub
parent 54df8bfe42
commit a22457b548
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 46 additions and 23 deletions

View file

@ -1959,7 +1959,11 @@ def _determine_spaces_for_multi_agent_dict(
# Multi-agent case AND different agents have different spaces:
# Need to reverse map spaces (for the different agents) to certain
# policy IDs.
if isinstance(env, MultiAgentEnv) and env._spaces_in_preferred_format:
if (
isinstance(env, MultiAgentEnv)
and hasattr(env, "_spaces_in_preferred_format")
and env._spaces_in_preferred_format
):
obs_space = None
mapping_fn = policy_config.get("multiagent", {}).get(
"policy_mapping_fn", None
@ -2004,7 +2008,11 @@ def _determine_spaces_for_multi_agent_dict(
# Multi-agent case AND different agents have different spaces:
# Need to reverse map spaces (for the different agents) to certain
# policy IDs.
if isinstance(env, MultiAgentEnv) and env._spaces_in_preferred_format:
if (
isinstance(env, MultiAgentEnv)
and hasattr(env, "_spaces_in_preferred_format")
and env._spaces_in_preferred_format
):
act_space = None
mapping_fn = policy_config.get("multiagent", {}).get(
"policy_mapping_fn", None

View file

@ -22,7 +22,7 @@ from ray.rllib.examples.env.mock_env import (
)
from ray.rllib.examples.env.multi_agent import BasicMultiAgent, MultiAgentCartPole
from ray.rllib.examples.policy.random_policy import RandomPolicy
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy import Policy, PolicySpec
from ray.rllib.policy.sample_batch import (
DEFAULT_POLICY_ID,
MultiAgentBatch,
@ -814,6 +814,41 @@ class TestRolloutWorker(unittest.TestCase):
self.assertEqual(seeds, [1, 2, 3])
ev.stop()
def test_determine_spaces_for_multi_agent_dict(self):
class MockMultiAgentEnv(MultiAgentEnv):
"""A mock testing MultiAgentEnv that doesn't call super.__init__()."""
def __init__(self):
# Intentinoally don't call super().__init__(),
# so this env doesn't have _spaces_in_preferred_format
# attribute.
self.observation_space = gym.spaces.Discrete(2)
self.action_space = gym.spaces.Discrete(2)
def reset(self):
pass
def step(self, action_dict):
obs = {1: [0, 0], 2: [1, 1]}
rewards = {1: 0, 2: 0}
dones = {1: False, 2: False, "__all__": False}
infos = {1: {}, 2: {}}
return obs, rewards, dones, infos
ev = RolloutWorker(
env_creator=lambda _: MockMultiAgentEnv(),
num_envs=3,
policy_spec={
"policy_1": PolicySpec(policy_class=MockPolicy),
"policy_2": PolicySpec(policy_class=MockPolicy),
},
seed=1,
)
# The fact that this RolloutWorker can be created without throwing
# exceptions means _determine_spaces_for_multi_agent_dict() is
# handling multiagent user environments properly.
self.assertIsNotNone(ev)
def test_wrap_multi_agent_env(self):
ev = RolloutWorker(
env_creator=lambda _: BasicMultiAgent(10),

View file

@ -24,11 +24,6 @@ cartpole-crashing-pg:
# Switch on resiliency for failed sub environments (within a vectorized stack).
restart_failed_sub_environments: true
# Disable env checking. Otherwise, RolloutWorkers will crash during
# initialization, which is not covered by the
# `restart_failed_sub_environments=True` failure tolerance mode.
disable_env_checking: true
evaluation_num_workers: 2
evaluation_interval: 1
evaluation_duration: 20

View file

@ -22,10 +22,5 @@ cartpole-crashing-with-remote-envs-pg:
# Use parallel remote envs.
remote_worker_envs: true
# Disable env checking. Otherwise, RolloutWorkers will crash during
# initialization, which is not covered by the
# `restart_failed_sub_environments=True` failure tolerance mode.
disable_env_checking: true
# Switch on resiliency for failed sub environments (within a vectorized stack).
restart_failed_sub_environments: true

View file

@ -25,11 +25,6 @@ multi-agent-cartpole-crashing-pg:
# Switch on resiliency for failed sub environments (within a vectorized stack).
restart_failed_sub_environments: true
# Disable env checking. Otherwise, RolloutWorkers will crash during
# initialization, which is not covered by the
# `restart_failed_sub_environments=True` failure tolerance mode.
disable_env_checking: true
evaluation_num_workers: 2
evaluation_interval: 1
evaluation_duration: 20

View file

@ -28,11 +28,6 @@ multi-agent-cartpole-crashing-pg:
# Switch on resiliency for failed sub environments (within a vectorized stack).
restart_failed_sub_environments: true
# Disable env checking. Otherwise, RolloutWorkers will crash during
# initialization, which is not covered by the
# `restart_failed_sub_environments=True` failure tolerance mode.
disable_env_checking: true
evaluation_num_workers: 1
evaluation_interval: 1
evaluation_duration: 10