mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Issue 17706: AttributeError: 'numpy.ndarray' object has no attribute 'items'" on certain turn-based MultiAgentEnvs with Dict obs space. (#17735)
This commit is contained in:
parent
4176e43ef2
commit
29f20cccb6
1 changed files with 3 additions and 1 deletions
|
@ -5,6 +5,7 @@ import numpy as np
|
|||
import queue
|
||||
import threading
|
||||
import time
|
||||
import tree # pip install dm_tree
|
||||
from typing import Any, Callable, Dict, List, Iterable, Optional, Set, Tuple,\
|
||||
Type, TYPE_CHECKING, Union
|
||||
|
||||
|
@ -778,7 +779,8 @@ def _process_observations(
|
|||
obs_sp = worker.policy_map[episode.policy_for(
|
||||
ag_id)].observation_space
|
||||
obs_sp = getattr(obs_sp, "original_space", obs_sp)
|
||||
all_agents_obs[ag_id] = np.zeros_like(obs_sp.sample())
|
||||
all_agents_obs[ag_id] = tree.map_structure(
|
||||
np.zeros_like, obs_sp.sample())
|
||||
else:
|
||||
hit_horizon = False
|
||||
all_agents_done = False
|
||||
|
|
Loading…
Add table
Reference in a new issue