[RLlib] SAC tuple observation space fix (#17356)

This commit is contained in:
Julius Frost 2021-07-28 12:39:28 -04:00 committed by GitHub
parent 2618236167
commit d7a5ec1830
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 4 additions and 5 deletions

View file

@ -170,7 +170,7 @@ class SACTFModel(TFModelV2):
self.concat_obs_and_actions = True
else:
if isinstance(orig_space, gym.spaces.Tuple):
spaces = orig_space.spaces
spaces = list(orig_space.spaces)
elif isinstance(orig_space, gym.spaces.Dict):
spaces = list(orig_space.spaces.values())
else:

View file

@ -176,7 +176,7 @@ class SACTorchModel(TorchModelV2, nn.Module):
self.concat_obs_and_actions = True
else:
if isinstance(orig_space, gym.spaces.Tuple):
spaces = orig_space.spaces
spaces = list(orig_space.spaces)
elif isinstance(orig_space, gym.spaces.Dict):
spaces = list(orig_space.spaces.values())
else:

View file

@ -100,9 +100,8 @@ class TestSAC(unittest.TestCase):
print("Env={}".format(env))
if env == RandomEnv:
config["env_config"] = {
"observation_space": Tuple(
[simple_space,
Discrete(2), image_space]),
"observation_space": Tuple((simple_space, Discrete(2),
image_space)),
"action_space": Box(-1.0, 1.0, shape=(1, )),
}
else: