mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[RLlib] Fix adding policies to RolloutWorkers with complex and discrete observation spaces. (#28133)
This commit is contained in:
parent
51d16b8ff9
commit
250a73a756
1 changed files with 9 additions and 2 deletions
|
@ -1241,10 +1241,17 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
if policy_state:
|
||||
new_policy.set_state(policy_state)
|
||||
|
||||
self.filters[policy_id] = get_filter(
|
||||
self.observation_filter, new_policy.observation_space.shape
|
||||
filter_shape = tree.map_structure(
|
||||
lambda s: (
|
||||
None
|
||||
if isinstance(s, (Discrete, MultiDiscrete)) # noqa
|
||||
else np.array(s.shape)
|
||||
),
|
||||
new_policy.observation_space_struct,
|
||||
)
|
||||
|
||||
self.filters[policy_id] = get_filter(self.observation_filter, filter_shape)
|
||||
|
||||
self.set_policy_mapping_fn(policy_mapping_fn)
|
||||
if policies_to_train is not None:
|
||||
self.set_is_policy_to_train(policies_to_train)
|
||||
|
|
Loading…
Add table
Reference in a new issue