mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] Issue 18418: SAC w/ dict space not working. (#19101)
This commit is contained in:
parent
f8a91c7fad
commit
1f0646f658
4 changed files with 34 additions and 19 deletions
|
@ -1,6 +1,7 @@
|
|||
import gym
|
||||
from gym.spaces import Box, Discrete
|
||||
import numpy as np
|
||||
import tree # pip install dm_tree
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
|
@ -267,13 +268,18 @@ class SACTFModel(TFModelV2):
|
|||
Returns:
|
||||
TensorType: Distribution inputs for sampling actions.
|
||||
"""
|
||||
# Model outs may come as original Tuple observations, concat them
|
||||
# Model outs may come as original Tuple/Dict observations, concat them
|
||||
# here if this is the case.
|
||||
if isinstance(self.action_model.obs_space, Box):
|
||||
if isinstance(model_out, (list, tuple)):
|
||||
model_out = tf.concat(model_out, axis=-1)
|
||||
elif isinstance(model_out, dict):
|
||||
model_out = tf.concat(list(model_out.values()), axis=-1)
|
||||
model_out = tf.concat(
|
||||
[
|
||||
tf.expand_dims(val, 1) if len(val.shape) == 1 else val
|
||||
for val in tree.flatten(model_out.values())
|
||||
],
|
||||
axis=-1)
|
||||
out, _ = self.action_model({"obs": model_out}, [], None)
|
||||
return out
|
||||
|
||||
|
|
|
@ -6,7 +6,6 @@ import gym
|
|||
from gym.spaces import Box, Discrete
|
||||
from functools import partial
|
||||
import logging
|
||||
import numpy as np
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import ray
|
||||
|
@ -53,9 +52,6 @@ def build_sac_model(policy: Policy, obs_space: gym.spaces.Space,
|
|||
target model will be created in this function and assigned to
|
||||
`policy.target_model`.
|
||||
"""
|
||||
# With separate state-preprocessor (before obs+action concat).
|
||||
num_outputs = int(np.product(obs_space.shape))
|
||||
|
||||
# Force-ignore any additionally provided hidden layer sizes.
|
||||
# Everything should be configured using SAC's "Q_model" and "policy_model"
|
||||
# settings.
|
||||
|
@ -70,7 +66,7 @@ def build_sac_model(policy: Policy, obs_space: gym.spaces.Space,
|
|||
model = ModelCatalog.get_model_v2(
|
||||
obs_space=obs_space,
|
||||
action_space=action_space,
|
||||
num_outputs=num_outputs,
|
||||
num_outputs=None,
|
||||
model_config=config["model"],
|
||||
framework=config["framework"],
|
||||
default_model=default_model_cls,
|
||||
|
@ -90,7 +86,7 @@ def build_sac_model(policy: Policy, obs_space: gym.spaces.Space,
|
|||
policy.target_model = ModelCatalog.get_model_v2(
|
||||
obs_space=obs_space,
|
||||
action_space=action_space,
|
||||
num_outputs=num_outputs,
|
||||
num_outputs=None,
|
||||
model_config=config["model"],
|
||||
framework=config["framework"],
|
||||
default_model=default_model_cls,
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import gym
|
||||
from gym.spaces import Box, Discrete
|
||||
import numpy as np
|
||||
import tree # pip install dm_tree
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
|
@ -281,7 +282,12 @@ class SACTorchModel(TorchModelV2, nn.Module):
|
|||
if isinstance(model_out, (list, tuple)):
|
||||
model_out = torch.cat(model_out, dim=-1)
|
||||
elif isinstance(model_out, dict):
|
||||
model_out = torch.cat(list(model_out.values()), dim=-1)
|
||||
model_out = torch.cat(
|
||||
[
|
||||
torch.unsqueeze(val, 1) if len(val.shape) == 1 else val
|
||||
for val in tree.flatten(model_out.values())
|
||||
],
|
||||
dim=-1)
|
||||
out, _ = self.action_model({"obs": model_out}, [], None)
|
||||
return out
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from gym import Env
|
||||
from gym.spaces import Box, Discrete, Tuple
|
||||
from gym.spaces import Box, Dict, Discrete, Tuple
|
||||
import numpy as np
|
||||
import re
|
||||
import unittest
|
||||
|
@ -23,6 +23,7 @@ from ray.rllib.utils.spaces.simplex import Simplex
|
|||
from ray.rllib.utils.test_utils import check, check_compute_single_action, \
|
||||
check_train_results, framework_iterator
|
||||
from ray.rllib.utils.torch_ops import convert_to_torch_tensor
|
||||
from ray import tune
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
torch, _ = try_import_torch()
|
||||
|
@ -90,22 +91,28 @@ class TestSAC(unittest.TestCase):
|
|||
image_space = Box(-1.0, 1.0, shape=(84, 84, 3))
|
||||
simple_space = Box(-1.0, 1.0, shape=(3, ))
|
||||
|
||||
tune.register_env(
|
||||
"random_dict_env", lambda _: RandomEnv({
|
||||
"observation_space": Dict({
|
||||
"a": simple_space,
|
||||
"b": Discrete(2),
|
||||
"c": image_space, }),
|
||||
"action_space": Box(-1.0, 1.0, shape=(1, )), }))
|
||||
tune.register_env(
|
||||
"random_tuple_env", lambda _: RandomEnv({
|
||||
"observation_space": Tuple([
|
||||
simple_space, Discrete(2), image_space]),
|
||||
"action_space": Box(-1.0, 1.0, shape=(1, )), }))
|
||||
|
||||
for fw in framework_iterator(config):
|
||||
# Test for different env types (discrete w/ and w/o image, + cont).
|
||||
for env in [
|
||||
RandomEnv,
|
||||
"random_dict_env",
|
||||
"random_tuple_env",
|
||||
"MsPacmanNoFrameskip-v4",
|
||||
"CartPole-v0",
|
||||
]:
|
||||
print("Env={}".format(env))
|
||||
if env == RandomEnv:
|
||||
config["env_config"] = {
|
||||
"observation_space": Tuple((simple_space, Discrete(2),
|
||||
image_space)),
|
||||
"action_space": Box(-1.0, 1.0, shape=(1, )),
|
||||
}
|
||||
else:
|
||||
config["env_config"] = {}
|
||||
# Test making the Q-model a custom one for CartPole, otherwise,
|
||||
# use the default model.
|
||||
config["Q_model"]["custom_model"] = "batch_norm{}".format(
|
||||
|
|
Loading…
Add table
Reference in a new issue