[RLlib] Fix adding policies to RolloutWorkers with complex and discrete observation spaces. (#28133)

This commit is contained in:
Artur Niederfahrenhorst 2022-08-29 17:44:48 +02:00 committed by GitHub
parent 51d16b8ff9
commit 250a73a756
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1241,10 +1241,17 @@ class RolloutWorker(ParallelIteratorWorker):
if policy_state:
new_policy.set_state(policy_state)
self.filters[policy_id] = get_filter(
self.observation_filter, new_policy.observation_space.shape
filter_shape = tree.map_structure(
lambda s: (
None
if isinstance(s, (Discrete, MultiDiscrete)) # noqa
else np.array(s.shape)
),
new_policy.observation_space_struct,
)
self.filters[policy_id] = get_filter(self.observation_filter, filter_shape)
self.set_policy_mapping_fn(policy_mapping_fn)
if policies_to_train is not None:
self.set_is_policy_to_train(policies_to_train)