[RLlib] Issue 17706: AttributeError: 'numpy.ndarray' object has no attribute 'items'" on certain turn-based MultiAgentEnvs with Dict obs space. (#17735)

This commit is contained in:
Sven Mika 2021-08-11 12:33:35 +02:00 committed by GitHub
parent 4176e43ef2
commit 29f20cccb6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -5,6 +5,7 @@ import numpy as np
import queue
import threading
import time
import tree # pip install dm_tree
from typing import Any, Callable, Dict, List, Iterable, Optional, Set, Tuple,\
Type, TYPE_CHECKING, Union
@ -778,7 +779,8 @@ def _process_observations(
obs_sp = worker.policy_map[episode.policy_for(
ag_id)].observation_space
obs_sp = getattr(obs_sp, "original_space", obs_sp)
all_agents_obs[ag_id] = np.zeros_like(obs_sp.sample())
all_agents_obs[ag_id] = tree.map_structure(
np.zeros_like, obs_sp.sample())
else:
hit_horizon = False
all_agents_done = False