mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Add QMIX support for complex obs spaces (Issue 8523). (#8533)
This commit is contained in:
parent
9823e15311
commit
8870270164
2 changed files with 33 additions and 16 deletions
|
@ -13,11 +13,13 @@ from ray.rllib.policy.sample_batch import SampleBatch
|
|||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.models.modelv2 import _unpack_obs
|
||||
from ray.rllib.env.constants import GROUP_REWARDS
|
||||
from ray.rllib.utils import try_import_tree
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
||||
# Torch must be installed.
|
||||
torch, nn = try_import_torch(error=True)
|
||||
tree = try_import_tree()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -463,25 +465,28 @@ class QMixTorchPolicy(Policy):
|
|||
state (np.ndarray or None): state tensor of shape [B, state_size]
|
||||
or None if it is not in the batch
|
||||
"""
|
||||
|
||||
unpacked = _unpack_obs(
|
||||
np.array(obs_batch, dtype=np.float32),
|
||||
self.observation_space.original_space,
|
||||
tensorlib=np)
|
||||
|
||||
if isinstance(unpacked[0], dict):
|
||||
unpacked_obs = [
|
||||
np.concatenate(tree.flatten(u["obs"]), 1) for u in unpacked
|
||||
]
|
||||
else:
|
||||
unpacked_obs = unpacked
|
||||
|
||||
obs = np.concatenate(
|
||||
unpacked_obs,
|
||||
axis=1).reshape([len(obs_batch), self.n_agents, self.obs_size])
|
||||
|
||||
if self.has_action_mask:
|
||||
obs = np.concatenate(
|
||||
[o["obs"] for o in unpacked],
|
||||
axis=1).reshape([len(obs_batch), self.n_agents, self.obs_size])
|
||||
action_mask = np.concatenate(
|
||||
[o["action_mask"] for o in unpacked], axis=1).reshape(
|
||||
[len(obs_batch), self.n_agents, self.n_actions])
|
||||
else:
|
||||
if isinstance(unpacked[0], dict):
|
||||
unpacked_obs = [u["obs"] for u in unpacked]
|
||||
else:
|
||||
unpacked_obs = unpacked
|
||||
obs = np.concatenate(
|
||||
unpacked_obs,
|
||||
axis=1).reshape([len(obs_batch), self.n_agents, self.obs_size])
|
||||
action_mask = np.ones(
|
||||
[len(obs_batch), self.n_agents, self.n_actions],
|
||||
dtype=np.float32)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from gym.spaces import Tuple, Discrete, Dict, Box
|
||||
from gym.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple
|
||||
import numpy as np
|
||||
import unittest
|
||||
|
||||
|
@ -9,10 +9,17 @@ from ray.rllib.agents.qmix import QMixTrainer
|
|||
|
||||
|
||||
class AvailActionsTestEnv(MultiAgentEnv):
|
||||
action_space = Discrete(10)
|
||||
num_actions = 10
|
||||
action_space = Discrete(num_actions)
|
||||
observation_space = Dict({
|
||||
"obs": Discrete(3),
|
||||
"action_mask": Box(0, 1, (10, )),
|
||||
"obs": Dict({
|
||||
"test": Dict({
|
||||
"a": Discrete(2),
|
||||
"b": MultiDiscrete([2, 3, 4])
|
||||
}),
|
||||
"state": MultiDiscrete([2, 2, 2])
|
||||
}),
|
||||
"action_mask": Box(0, 1, (num_actions, )),
|
||||
})
|
||||
|
||||
def __init__(self, env_config):
|
||||
|
@ -25,7 +32,7 @@ class AvailActionsTestEnv(MultiAgentEnv):
|
|||
self.state = 0
|
||||
return {
|
||||
"agent_1": {
|
||||
"obs": self.state,
|
||||
"obs": self.observation_space["obs"].sample(),
|
||||
"action_mask": self.action_mask
|
||||
}
|
||||
}
|
||||
|
@ -36,7 +43,12 @@ class AvailActionsTestEnv(MultiAgentEnv):
|
|||
"Failed to obey available actions mask!"
|
||||
self.state += 1
|
||||
rewards = {"agent_1": 1}
|
||||
obs = {"agent_1": {"obs": 0, "action_mask": self.action_mask}}
|
||||
obs = {
|
||||
"agent_1": {
|
||||
"obs": self.observation_space["obs"].sample(),
|
||||
"action_mask": self.action_mask
|
||||
}
|
||||
}
|
||||
dones = {"__all__": self.state > 20}
|
||||
return obs, rewards, dones, {}
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue