[RLlib] Add QMIX support for complex obs spaces (Issue 8523). (#8533)

This commit is contained in:
Sven Mika 2020-05-22 10:17:51 +02:00 committed by GitHub
parent 9823e15311
commit 8870270164
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 16 deletions

View file

@ -13,11 +13,13 @@ from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.modelv2 import _unpack_obs
from ray.rllib.env.constants import GROUP_REWARDS
from ray.rllib.utils import try_import_tree
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.annotations import override
# Torch must be installed.
torch, nn = try_import_torch(error=True)
tree = try_import_tree()
logger = logging.getLogger(__name__)
@ -463,25 +465,28 @@ class QMixTorchPolicy(Policy):
state (np.ndarray or None): state tensor of shape [B, state_size]
or None if it is not in the batch
"""
unpacked = _unpack_obs(
np.array(obs_batch, dtype=np.float32),
self.observation_space.original_space,
tensorlib=np)
if isinstance(unpacked[0], dict):
unpacked_obs = [
np.concatenate(tree.flatten(u["obs"]), 1) for u in unpacked
]
else:
unpacked_obs = unpacked
obs = np.concatenate(
unpacked_obs,
axis=1).reshape([len(obs_batch), self.n_agents, self.obs_size])
if self.has_action_mask:
obs = np.concatenate(
[o["obs"] for o in unpacked],
axis=1).reshape([len(obs_batch), self.n_agents, self.obs_size])
action_mask = np.concatenate(
[o["action_mask"] for o in unpacked], axis=1).reshape(
[len(obs_batch), self.n_agents, self.n_actions])
else:
if isinstance(unpacked[0], dict):
unpacked_obs = [u["obs"] for u in unpacked]
else:
unpacked_obs = unpacked
obs = np.concatenate(
unpacked_obs,
axis=1).reshape([len(obs_batch), self.n_agents, self.obs_size])
action_mask = np.ones(
[len(obs_batch), self.n_agents, self.n_actions],
dtype=np.float32)

View file

@ -1,4 +1,4 @@
from gym.spaces import Tuple, Discrete, Dict, Box
from gym.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple
import numpy as np
import unittest
@ -9,10 +9,17 @@ from ray.rllib.agents.qmix import QMixTrainer
class AvailActionsTestEnv(MultiAgentEnv):
action_space = Discrete(10)
num_actions = 10
action_space = Discrete(num_actions)
observation_space = Dict({
"obs": Discrete(3),
"action_mask": Box(0, 1, (10, )),
"obs": Dict({
"test": Dict({
"a": Discrete(2),
"b": MultiDiscrete([2, 3, 4])
}),
"state": MultiDiscrete([2, 2, 2])
}),
"action_mask": Box(0, 1, (num_actions, )),
})
def __init__(self, env_config):
@ -25,7 +32,7 @@ class AvailActionsTestEnv(MultiAgentEnv):
self.state = 0
return {
"agent_1": {
"obs": self.state,
"obs": self.observation_space["obs"].sample(),
"action_mask": self.action_mask
}
}
@ -36,7 +43,12 @@ class AvailActionsTestEnv(MultiAgentEnv):
"Failed to obey available actions mask!"
self.state += 1
rewards = {"agent_1": 1}
obs = {"agent_1": {"obs": 0, "action_mask": self.action_mask}}
obs = {
"agent_1": {
"obs": self.observation_space["obs"].sample(),
"action_mask": self.action_mask
}
}
dones = {"__all__": self.state > 20}
return obs, rewards, dones, {}