[RLlib] Issue 18418: SAC w/ dict space not working. (#19101)

This commit is contained in:
Sven Mika 2021-10-06 09:05:50 +02:00 committed by GitHub
parent f8a91c7fad
commit 1f0646f658
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 34 additions and 19 deletions

View file

@ -1,6 +1,7 @@
import gym
from gym.spaces import Box, Discrete
import numpy as np
import tree # pip install dm_tree
from typing import Dict, List, Optional
from ray.rllib.models.catalog import ModelCatalog
@ -267,13 +268,18 @@ class SACTFModel(TFModelV2):
Returns:
TensorType: Distribution inputs for sampling actions.
"""
# Model outs may come as original Tuple observations, concat them
# Model outs may come as original Tuple/Dict observations, concat them
# here if this is the case.
if isinstance(self.action_model.obs_space, Box):
if isinstance(model_out, (list, tuple)):
model_out = tf.concat(model_out, axis=-1)
elif isinstance(model_out, dict):
model_out = tf.concat(list(model_out.values()), axis=-1)
model_out = tf.concat(
[
tf.expand_dims(val, 1) if len(val.shape) == 1 else val
for val in tree.flatten(model_out.values())
],
axis=-1)
out, _ = self.action_model({"obs": model_out}, [], None)
return out

View file

@ -6,7 +6,6 @@ import gym
from gym.spaces import Box, Discrete
from functools import partial
import logging
import numpy as np
from typing import Dict, List, Optional, Tuple, Type, Union
import ray
@ -53,9 +52,6 @@ def build_sac_model(policy: Policy, obs_space: gym.spaces.Space,
target model will be created in this function and assigned to
`policy.target_model`.
"""
# With separate state-preprocessor (before obs+action concat).
num_outputs = int(np.product(obs_space.shape))
# Force-ignore any additionally provided hidden layer sizes.
# Everything should be configured using SAC's "Q_model" and "policy_model"
# settings.
@ -70,7 +66,7 @@ def build_sac_model(policy: Policy, obs_space: gym.spaces.Space,
model = ModelCatalog.get_model_v2(
obs_space=obs_space,
action_space=action_space,
num_outputs=num_outputs,
num_outputs=None,
model_config=config["model"],
framework=config["framework"],
default_model=default_model_cls,
@ -90,7 +86,7 @@ def build_sac_model(policy: Policy, obs_space: gym.spaces.Space,
policy.target_model = ModelCatalog.get_model_v2(
obs_space=obs_space,
action_space=action_space,
num_outputs=num_outputs,
num_outputs=None,
model_config=config["model"],
framework=config["framework"],
default_model=default_model_cls,

View file

@ -1,6 +1,7 @@
import gym
from gym.spaces import Box, Discrete
import numpy as np
import tree # pip install dm_tree
from typing import Dict, List, Optional
from ray.rllib.models.catalog import ModelCatalog
@ -281,7 +282,12 @@ class SACTorchModel(TorchModelV2, nn.Module):
if isinstance(model_out, (list, tuple)):
model_out = torch.cat(model_out, dim=-1)
elif isinstance(model_out, dict):
model_out = torch.cat(list(model_out.values()), dim=-1)
model_out = torch.cat(
[
torch.unsqueeze(val, 1) if len(val.shape) == 1 else val
for val in tree.flatten(model_out.values())
],
dim=-1)
out, _ = self.action_model({"obs": model_out}, [], None)
return out

View file

@ -1,5 +1,5 @@
from gym import Env
from gym.spaces import Box, Discrete, Tuple
from gym.spaces import Box, Dict, Discrete, Tuple
import numpy as np
import re
import unittest
@ -23,6 +23,7 @@ from ray.rllib.utils.spaces.simplex import Simplex
from ray.rllib.utils.test_utils import check, check_compute_single_action, \
check_train_results, framework_iterator
from ray.rllib.utils.torch_ops import convert_to_torch_tensor
from ray import tune
tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()
@ -90,22 +91,28 @@ class TestSAC(unittest.TestCase):
image_space = Box(-1.0, 1.0, shape=(84, 84, 3))
simple_space = Box(-1.0, 1.0, shape=(3, ))
tune.register_env(
"random_dict_env", lambda _: RandomEnv({
"observation_space": Dict({
"a": simple_space,
"b": Discrete(2),
"c": image_space, }),
"action_space": Box(-1.0, 1.0, shape=(1, )), }))
tune.register_env(
"random_tuple_env", lambda _: RandomEnv({
"observation_space": Tuple([
simple_space, Discrete(2), image_space]),
"action_space": Box(-1.0, 1.0, shape=(1, )), }))
for fw in framework_iterator(config):
# Test for different env types (discrete w/ and w/o image, + cont).
for env in [
RandomEnv,
"random_dict_env",
"random_tuple_env",
"MsPacmanNoFrameskip-v4",
"CartPole-v0",
]:
print("Env={}".format(env))
if env == RandomEnv:
config["env_config"] = {
"observation_space": Tuple((simple_space, Discrete(2),
image_space)),
"action_space": Box(-1.0, 1.0, shape=(1, )),
}
else:
config["env_config"] = {}
# Test making the Q-model a custom one for CartPole, otherwise,
# use the default model.
config["Q_model"]["custom_model"] = "batch_norm{}".format(