mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] SAC tuple observation space fix (#17356)
This commit is contained in:
parent
2618236167
commit
d7a5ec1830
3 changed files with 4 additions and 5 deletions
|
@ -170,7 +170,7 @@ class SACTFModel(TFModelV2):
|
|||
self.concat_obs_and_actions = True
|
||||
else:
|
||||
if isinstance(orig_space, gym.spaces.Tuple):
|
||||
spaces = orig_space.spaces
|
||||
spaces = list(orig_space.spaces)
|
||||
elif isinstance(orig_space, gym.spaces.Dict):
|
||||
spaces = list(orig_space.spaces.values())
|
||||
else:
|
||||
|
|
|
@ -176,7 +176,7 @@ class SACTorchModel(TorchModelV2, nn.Module):
|
|||
self.concat_obs_and_actions = True
|
||||
else:
|
||||
if isinstance(orig_space, gym.spaces.Tuple):
|
||||
spaces = orig_space.spaces
|
||||
spaces = list(orig_space.spaces)
|
||||
elif isinstance(orig_space, gym.spaces.Dict):
|
||||
spaces = list(orig_space.spaces.values())
|
||||
else:
|
||||
|
|
|
@ -100,9 +100,8 @@ class TestSAC(unittest.TestCase):
|
|||
print("Env={}".format(env))
|
||||
if env == RandomEnv:
|
||||
config["env_config"] = {
|
||||
"observation_space": Tuple(
|
||||
[simple_space,
|
||||
Discrete(2), image_space]),
|
||||
"observation_space": Tuple((simple_space, Discrete(2),
|
||||
image_space)),
|
||||
"action_space": Box(-1.0, 1.0, shape=(1, )),
|
||||
}
|
||||
else:
|
||||
|
|
Loading…
Add table
Reference in a new issue