diff --git a/rllib/agents/sac/sac_tf_model.py b/rllib/agents/sac/sac_tf_model.py index b457f1e94..546de04ab 100644 --- a/rllib/agents/sac/sac_tf_model.py +++ b/rllib/agents/sac/sac_tf_model.py @@ -170,7 +170,7 @@ class SACTFModel(TFModelV2): self.concat_obs_and_actions = True else: if isinstance(orig_space, gym.spaces.Tuple): - spaces = orig_space.spaces + spaces = list(orig_space.spaces) elif isinstance(orig_space, gym.spaces.Dict): spaces = list(orig_space.spaces.values()) else: diff --git a/rllib/agents/sac/sac_torch_model.py b/rllib/agents/sac/sac_torch_model.py index 1288d20da..b1ba49199 100644 --- a/rllib/agents/sac/sac_torch_model.py +++ b/rllib/agents/sac/sac_torch_model.py @@ -176,7 +176,7 @@ class SACTorchModel(TorchModelV2, nn.Module): self.concat_obs_and_actions = True else: if isinstance(orig_space, gym.spaces.Tuple): - spaces = orig_space.spaces + spaces = list(orig_space.spaces) elif isinstance(orig_space, gym.spaces.Dict): spaces = list(orig_space.spaces.values()) else: diff --git a/rllib/agents/sac/tests/test_sac.py b/rllib/agents/sac/tests/test_sac.py index 33e816ee9..77923db83 100644 --- a/rllib/agents/sac/tests/test_sac.py +++ b/rllib/agents/sac/tests/test_sac.py @@ -100,9 +100,8 @@ class TestSAC(unittest.TestCase): print("Env={}".format(env)) if env == RandomEnv: config["env_config"] = { - "observation_space": Tuple( - [simple_space, - Discrete(2), image_space]), + "observation_space": Tuple((simple_space, Discrete(2), + image_space)), "action_space": Box(-1.0, 1.0, shape=(1, )), } else: