[RLlib] Initial code/comment cleanups in preparation for decentralized multi-agent learner. (#21420)

This commit is contained in:
Sven Mika 2022-01-10 11:22:55 +01:00 committed by GitHub
parent 4eaf70942d
commit 92f030331e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 128 additions and 55 deletions

View file

@ -366,7 +366,7 @@ class ImpalaTrainer(Trainer):
{ {
# Evaluation (remote) workers. # Evaluation (remote) workers.
# Note: The local eval worker is located on the driver # Note: The local eval worker is located on the driver
# CPU. # CPU or not even created iff >0 eval workers.
"CPU": eval_config.get("num_cpus_per_worker", "CPU": eval_config.get("num_cpus_per_worker",
cf["num_cpus_per_worker"]), cf["num_cpus_per_worker"]),
"GPU": eval_config.get("num_gpus_per_worker", "GPU": eval_config.get("num_gpus_per_worker",

View file

@ -120,12 +120,6 @@ class APPOTrainer(impala.ImpalaTrainer):
self.workers.local_worker().foreach_trainable_policy( self.workers.local_worker().foreach_trainable_policy(
lambda p, _: p.update_target()) lambda p, _: p.update_target())
# TODO: Remove this once ImpalaTrainer directly inherits from Trainer
# (instead of being created by `build_trainer()` utility).
@override(impala.ImpalaTrainer)
def _init(self, *args, **kwargs):
raise NotImplementedError
@classmethod @classmethod
@override(Trainer) @override(Trainer)
def get_default_config(cls) -> TrainerConfigDict: def get_default_config(cls) -> TrainerConfigDict:

View file

@ -29,6 +29,7 @@ from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.metric_ops import StandardMetricsReporting from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.buffers.multi_agent_replay_buffer import \ from ray.rllib.execution.buffers.multi_agent_replay_buffer import \
MultiAgentReplayBuffer MultiAgentReplayBuffer
from ray.rllib.execution.common import WORKER_UPDATE_TIMER
from ray.rllib.execution.rollout_ops import ConcatBatches, ParallelRollouts, \ from ray.rllib.execution.rollout_ops import ConcatBatches, ParallelRollouts, \
synchronous_parallel_sample synchronous_parallel_sample
from ray.rllib.execution.train_ops import TrainOneStep, MultiGPUTrainOneStep, \ from ray.rllib.execution.train_ops import TrainOneStep, MultiGPUTrainOneStep, \
@ -775,8 +776,8 @@ class Trainer(Trainable):
# Set Trainer's seed after we have - if necessary - enabled # Set Trainer's seed after we have - if necessary - enabled
# tf eager-execution. # tf eager-execution.
update_global_seed_if_necessary( update_global_seed_if_necessary(self.config["framework"],
config.get("framework"), config.get("seed")) self.config["seed"])
self.validate_config(self.config) self.validate_config(self.config)
if not callable(self.config["callbacks"]): if not callable(self.config["callbacks"]):
@ -844,6 +845,14 @@ class Trainer(Trainable):
self.workers, self.config, self.workers, self.config,
**self._kwargs_for_execution_plan()) **self._kwargs_for_execution_plan())
# TODO: Now that workers have been created, update our policy
# specs in the config[multiagent] dict with the correct spaces.
# However, this leads to a problem with the evaluation
# workers' observation one-hot preprocessor in
# `examples/documentation/rllib_in_6sec.py` script.
# self.config["multiagent"]["policies"] = \
# self.workers.local_worker().policy_map.policy_specs
# Evaluation WorkerSet setup. # Evaluation WorkerSet setup.
# User would like to setup a separate evaluation worker set. # User would like to setup a separate evaluation worker set.
@ -1295,6 +1304,12 @@ class Trainer(Trainable):
else: else:
train_results = multi_gpu_train_one_step(self, train_batch) train_results = multi_gpu_train_one_step(self, train_batch)
# Update weights - after learning on the local worker - on all remote
# workers.
if self.workers.remote_workers():
with self._timers[WORKER_UPDATE_TIMER]:
self.workers.sync_weights()
return train_results return train_results
@DeveloperAPI @DeveloperAPI
@ -1976,6 +1991,22 @@ class Trainer(Trainable):
config2: PartialTrainerConfigDict, config2: PartialTrainerConfigDict,
_allow_unknown_configs: Optional[bool] = None _allow_unknown_configs: Optional[bool] = None
) -> TrainerConfigDict: ) -> TrainerConfigDict:
"""Merges a complete Trainer config with a partial override dict.
Respects nested structures within the config dicts. The values in the
partial override dict take priority.
Args:
config1: The complete Trainer's dict to be merged (overridden)
with `config2`.
config2: The partial override config dict to merge on top of
`config1`.
_allow_unknown_configs: If True, keys in `config2` that don't exist
in `config1` are allowed and will be added to the final config.
Returns:
The merged full trainer config dict.
"""
config1 = copy.deepcopy(config1) config1 = copy.deepcopy(config1)
if "callbacks" in config2 and type(config2["callbacks"]) is dict: if "callbacks" in config2 and type(config2["callbacks"]) is dict:
legacy_callbacks_dict = config2["callbacks"] legacy_callbacks_dict = config2["callbacks"]

View file

@ -6,7 +6,7 @@ import numpy as np
import platform import platform
import os import os
import tree # pip install dm_tree import tree # pip install dm_tree
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, \ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, \
TYPE_CHECKING, Union TYPE_CHECKING, Union
import ray import ray
@ -46,7 +46,7 @@ from ray.rllib.utils.tf_run_builder import TFRunBuilder
from ray.rllib.utils.typing import AgentID, EnvConfigDict, EnvType, \ from ray.rllib.utils.typing import AgentID, EnvConfigDict, EnvType, \
ModelConfigDict, ModelGradients, ModelWeights, \ ModelConfigDict, ModelGradients, ModelWeights, \
MultiAgentPolicyConfigDict, PartialTrainerConfigDict, PolicyID, \ MultiAgentPolicyConfigDict, PartialTrainerConfigDict, PolicyID, \
SampleBatchType SampleBatchType, T
from ray.util.debug import log_once, disable_log_once_globally, \ from ray.util.debug import log_once, disable_log_once_globally, \
enable_periodic_logging enable_periodic_logging
from ray.util.iter import ParallelIteratorWorker from ray.util.iter import ParallelIteratorWorker
@ -56,9 +56,6 @@ if TYPE_CHECKING:
from ray.rllib.evaluation.observation_function import ObservationFunction from ray.rllib.evaluation.observation_function import ObservationFunction
from ray.rllib.agents.callbacks import DefaultCallbacks # noqa from ray.rllib.agents.callbacks import DefaultCallbacks # noqa
# Generic type var for foreach_* methods.
T = TypeVar("T")
tf1, tf, tfv = try_import_tf() tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch() torch, _ = try_import_torch()
@ -1436,19 +1433,26 @@ class RolloutWorker(ParallelIteratorWorker):
sess.close() sess.close()
@DeveloperAPI @DeveloperAPI
def apply(self, func: Callable[["RolloutWorker", Optional[Any]], T], def apply(
*args) -> T: self,
func: Callable[["RolloutWorker", Optional[Any], Optional[Any]], T],
*args, **kwargs) -> T:
"""Calls the given function with this rollout worker instance. """Calls the given function with this rollout worker instance.
Useful for when the RolloutWorker class has been converted into a
ActorHandle and the user needs to execute some functionality (e.g.
add a property) on the underlying policy object.
Args: Args:
func: The function to call with this RolloutWorker as first func: The function to call, with this RolloutWorker as first
argument. argument, followed by args, and kwargs.
args: Optional additional args to pass to the function call. args: Optional additional args to pass to the function call.
kwargs: Optional additional kwargs to pass to the function call.
Returns: Returns:
The return value of the function call. The return value of the function call.
""" """
return func(self, *args) return func(self, *args, **kwargs)
def setup_torch_data_parallel(self, url: str, world_rank: int, def setup_torch_data_parallel(self, url: str, world_rank: int,
world_size: int, backend: str) -> None: world_size: int, backend: str) -> None:

View file

@ -159,13 +159,15 @@ class WorkerSet:
if self.remote_workers() or from_worker is not None: if self.remote_workers() or from_worker is not None:
weights = (from_worker weights = (from_worker
or self.local_worker()).get_weights(policies) or self.local_worker()).get_weights(policies)
# Put weights only once into object store and use same object
# ref to synch to all workers.
weights_ref = ray.put(weights) weights_ref = ray.put(weights)
# Sync to all remote workers in this WorkerSet. # Sync to all remote workers in this WorkerSet.
for to_worker in self.remote_workers(): for to_worker in self.remote_workers():
to_worker.set_weights.remote(weights_ref) to_worker.set_weights.remote(weights_ref)
# If from_worker is provided, also sync to this WorkerSet's local # If `from_worker` is provided, also sync to this WorkerSet's
# worker. # local worker.
if from_worker is not None and self.local_worker() is not None: if from_worker is not None and self.local_worker() is not None:
self.local_worker().set_weights(weights) self.local_worker().set_weights(weights)

View file

@ -1,11 +1,11 @@
import collections import collections
import platform import platform
from typing import Dict, Any from typing import Any, Dict
import numpy as np import numpy as np
import ray import ray
from ray.rllib import SampleBatch from ray.rllib import SampleBatch
from ray.rllib.execution import PrioritizedReplayBuffer from ray.rllib.execution import PrioritizedReplayBuffer, ReplayBuffer
from ray.rllib.execution.buffers.replay_buffer import logger, _ALL_POLICIES from ray.rllib.execution.buffers.replay_buffer import logger, _ALL_POLICIES
from ray.rllib.policy.rnn_sequencing import \ from ray.rllib.policy.rnn_sequencing import \
timeslice_along_seq_lens_with_overlap timeslice_along_seq_lens_with_overlap
@ -54,7 +54,7 @@ class MultiAgentReplayBuffer(ParallelIteratorWorker):
`self.replay_batch_size` will be set to the number of `self.replay_batch_size` will be set to the number of
sequences sampled (B). sequences sampled (B).
prioritized_replay_alpha (float): Alpha parameter for a prioritized prioritized_replay_alpha (float): Alpha parameter for a prioritized
replay buffer. replay buffer. Use 0.0 for no prioritization.
prioritized_replay_beta (float): Beta parameter for a prioritized prioritized_replay_beta (float): Beta parameter for a prioritized
replay buffer. replay buffer.
prioritized_replay_eps (float): Epsilon parameter for a prioritized prioritized_replay_eps (float): Epsilon parameter for a prioritized
@ -108,8 +108,11 @@ class MultiAgentReplayBuffer(ParallelIteratorWorker):
ParallelIteratorWorker.__init__(self, gen_replay, False) ParallelIteratorWorker.__init__(self, gen_replay, False)
def new_buffer(): def new_buffer():
return PrioritizedReplayBuffer( if prioritized_replay_alpha == 0.0:
self.capacity, alpha=prioritized_replay_alpha) return ReplayBuffer(self.capacity)
else:
return PrioritizedReplayBuffer(
self.capacity, alpha=prioritized_replay_alpha)
self.replay_buffers = collections.defaultdict(new_buffer) self.replay_buffers = collections.defaultdict(new_buffer)

View file

@ -52,7 +52,7 @@ class ReplayBuffer:
def __init__(self, def __init__(self,
capacity: int = 10000, capacity: int = 10000,
size: Optional[int] = DEPRECATED_VALUE): size: Optional[int] = DEPRECATED_VALUE):
"""Initializes a Replaybuffer instance. """Initializes a ReplayBuffer instance.
Args: Args:
capacity: Max number of timesteps to store in the FIFO capacity: Max number of timesteps to store in the FIFO
@ -84,6 +84,7 @@ class ReplayBuffer:
self._est_size_bytes = 0 self._est_size_bytes = 0
def __len__(self) -> int: def __len__(self) -> int:
"""Returns the number of items currently stored in this buffer."""
return len(self._storage) return len(self._storage)
@DeveloperAPI @DeveloperAPI
@ -147,7 +148,7 @@ class ReplayBuffer:
"""Returns the stats of this buffer. """Returns the stats of this buffer.
Args: Args:
debug: If true, adds sample eviction statistics to the returned debug: If True, adds sample eviction statistics to the returned
stats dict. stats dict.
Returns: Returns:
@ -253,7 +254,11 @@ class PrioritizedReplayBuffer(ReplayBuffer):
@DeveloperAPI @DeveloperAPI
@override(ReplayBuffer) @override(ReplayBuffer)
def sample(self, num_items: int, beta: float) -> SampleBatchType: def sample(self, num_items: int, beta: float) -> SampleBatchType:
"""Sample a batch of experiences and return priority weights, indices. """Sample `num_items` items from this buffer, including prio. weights.
If less than `num_items` records are in this buffer, some samples in
the results may be repeated to fulfil the batch size (`num_items`)
request.
Args: Args:
num_items: Number of items to sample from this buffer. num_items: Number of items to sample from this buffer.
@ -272,11 +277,11 @@ class PrioritizedReplayBuffer(ReplayBuffer):
weights = [] weights = []
batch_indexes = [] batch_indexes = []
p_min = self._it_min.min() / self._it_sum.sum() p_min = self._it_min.min() / self._it_sum.sum()
max_weight = (p_min * len(self._storage))**(-beta) max_weight = (p_min * len(self))**(-beta)
for idx in idxes: for idx in idxes:
p_sample = self._it_sum[idx] / self._it_sum.sum() p_sample = self._it_sum[idx] / self._it_sum.sum()
weight = (p_sample * len(self._storage))**(-beta) weight = (p_sample * len(self))**(-beta)
count = self._storage[idx].count count = self._storage[idx].count
# If zero-padded, count will not be the actual batch size of the # If zero-padded, count will not be the actual batch size of the
# data. # data.

View file

@ -55,13 +55,6 @@ def train_one_step(trainer, train_batch) -> Dict:
trainer._counters[NUM_ENV_STEPS_TRAINED] += train_batch.count trainer._counters[NUM_ENV_STEPS_TRAINED] += train_batch.count
trainer._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps() trainer._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps()
# Update weights - after learning on the local worker - on all remote
# workers.
if workers.remote_workers():
with trainer._timers[WORKER_UPDATE_TIMER]:
weights = ray.put(workers.local_worker().get_weights(policies))
for e in workers.remote_workers():
e.set_weights.remote(weights)
return info return info

View file

@ -4,8 +4,9 @@ import gym
from gym.spaces import Box from gym.spaces import Box
import logging import logging
import numpy as np import numpy as np
import platform
import tree # pip install dm_tree import tree # pip install dm_tree
from typing import Dict, List, Optional, Type, TYPE_CHECKING from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING
from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.catalog import ModelCatalog
@ -21,7 +22,7 @@ from ray.rllib.utils.from_config import from_config
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space, \ from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space, \
get_dummy_batch_for_space, unbatch get_dummy_batch_for_space, unbatch
from ray.rllib.utils.typing import AgentID, ModelGradients, ModelWeights, \ from ray.rllib.utils.typing import AgentID, ModelGradients, ModelWeights, \
TensorType, TensorStructType, TrainerConfigDict, Tuple, Union T, TensorType, TensorStructType, TrainerConfigDict, Tuple, Union
tf1, tf, tfv = try_import_tf() tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch() torch, _ = try_import_torch()
@ -638,6 +639,27 @@ class Policy(metaclass=ABCMeta):
self.set_weights(state["weights"]) self.set_weights(state["weights"])
self.global_timestep = state["global_timestep"] self.global_timestep = state["global_timestep"]
@ExperimentalAPI
def apply(self,
func: Callable[["Policy", Optional[Any], Optional[Any]], T],
*args, **kwargs) -> T:
"""Calls the given function with this Policy instance.
Useful for when the Policy class has been converted into a ActorHandle
and the user needs to execute some functionality (e.g. add a property)
on the underlying policy object.
Args:
func: The function to call, with this Policy as first
argument, followed by args, and kwargs.
args: Optional additional args to pass to the function call.
kwargs: Optional additional kwargs to pass to the function call.
Returns:
The return value of the function call.
"""
return func(self, *args, **kwargs)
@DeveloperAPI @DeveloperAPI
def on_global_var_update(self, global_vars: Dict[str, TensorType]) -> None: def on_global_var_update(self, global_vars: Dict[str, TensorType]) -> None:
"""Called on an update to global vars. """Called on an update to global vars.
@ -697,6 +719,15 @@ class Policy(metaclass=ABCMeta):
""" """
return None return None
def get_host(self) -> str:
"""Returns the computer's network name.
Returns:
The computer's networks name or an empty string, if the network
name could not be determined.
"""
return platform.node()
def _create_exploration(self) -> Exploration: def _create_exploration(self) -> Exploration:
"""Creates the Policy's Exploration object. """Creates the Policy's Exploration object.

View file

@ -294,7 +294,7 @@ def chop_into_sequences(
f = np.array(f) f = np.array(f)
length = len(seq_lens) * max_seq_len length = len(seq_lens) * max_seq_len
if f.dtype == np.object or f.dtype.type is np.str_: if f.dtype == object or f.dtype.type is np.str_:
f_pad = [None] * length f_pad = [None] * length
else: else:
# Make sure type doesn't change. # Make sure type doesn't change.

View file

@ -620,7 +620,7 @@ class SampleBatch(dict):
or path[0] == SampleBatch.SEQ_LENS: or path[0] == SampleBatch.SEQ_LENS:
return return
# Generate zero-filled primer of len=max_seq_len. # Generate zero-filled primer of len=max_seq_len.
if value.dtype == np.object or value.dtype.type is np.str_: if value.dtype == object or value.dtype.type is np.str_:
f_pad = [None] * length f_pad = [None] * length
else: else:
# Make sure type doesn't change. # Make sure type doesn't change.
@ -651,13 +651,13 @@ class SampleBatch(dict):
return self return self
# Experimental method. @ExperimentalAPI
def to_device(self, device, framework="torch"): def to_device(self, device, framework="torch"):
"""TODO: transfer batch to given device as framework tensor.""" """TODO: transfer batch to given device as framework tensor."""
if framework == "torch": if framework == "torch":
assert torch is not None assert torch is not None
for k, v in self.items(): for k, v in self.items():
if isinstance(v, np.ndarray) and v.dtype != np.object: if isinstance(v, np.ndarray) and v.dtype != object:
self[k] = torch.from_numpy(v).to(device) self[k] = torch.from_numpy(v).to(device)
else: else:
raise NotImplementedError raise NotImplementedError

View file

@ -1170,7 +1170,7 @@ class EntropyCoeffSchedule:
@DeveloperAPI @DeveloperAPI
class DirectStepOptimizer: class DirectStepOptimizer:
"""Typesafe method for indicating apply gradients can directly step the """Typesafe method for indicating `apply_gradients` can directly step the
optimizers with in-place gradients. optimizers with in-place gradients.
""" """
_instance = None _instance = None

View file

@ -42,7 +42,7 @@ def _summarize(obj):
if obj.size == 0: if obj.size == 0:
return _StringValue("np.ndarray({}, dtype={})".format( return _StringValue("np.ndarray({}, dtype={})".format(
obj.shape, obj.dtype)) obj.shape, obj.dtype))
elif obj.dtype == np.object or obj.dtype.type is np.str_: elif obj.dtype == object or obj.dtype.type is np.str_:
return _StringValue("np.ndarray({}, dtype={}, head={})".format( return _StringValue("np.ndarray({}, dtype={}, head={})".format(
obj.shape, obj.dtype, _summarize(obj[0]))) obj.shape, obj.dtype, _summarize(obj[0])))
else: else:

View file

@ -214,7 +214,7 @@ def check(x, y, decimals=5, atol=None, rtol=None, false=False):
"ERROR: x ({}) is not the same as y ({})!".format(x, y) "ERROR: x ({}) is not the same as y ({})!".format(x, y)
# String/byte comparisons. # String/byte comparisons.
elif hasattr(x, "dtype") and \ elif hasattr(x, "dtype") and \
(x.dtype == np.object or str(x.dtype).startswith("<U")): (x.dtype == object or str(x.dtype).startswith("<U")):
try: try:
np.testing.assert_array_equal(x, y) np.testing.assert_array_equal(x, y)
if false is True: if false is True:
@ -307,11 +307,15 @@ def check_compute_single_action(trainer,
ValueError: If anything unexpected happens. ValueError: If anything unexpected happens.
""" """
# Have to import this here to avoid circular dependency. # Have to import this here to avoid circular dependency.
from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
# Some Trainers may not abide to the standard API. # Some Trainers may not abide to the standard API.
pid = DEFAULT_POLICY_ID
try: try:
pol = trainer.get_policy() # Multi-agent: Pick any policy (or DEFAULT_POLICY if it's the only
# one).
pid = next(iter(trainer.workers.local_worker().policy_map))
pol = trainer.get_policy(pid)
except AttributeError: except AttributeError:
pol = trainer.policy pol = trainer.policy
# Get the policy's model. # Get the policy's model.
@ -324,6 +328,7 @@ def check_compute_single_action(trainer,
call_kwargs = {} call_kwargs = {}
if what is trainer: if what is trainer:
call_kwargs["full_fetch"] = full_fetch call_kwargs["full_fetch"] = full_fetch
call_kwargs["policy_id"] = pid
obs = obs_space.sample() obs = obs_space.sample()
if isinstance(obs_space, Box): if isinstance(obs_space, Box):
@ -429,10 +434,10 @@ def check_compute_single_action(trainer,
worker_set = getattr(trainer, "_workers", None) worker_set = getattr(trainer, "_workers", None)
assert worker_set assert worker_set
if isinstance(worker_set, list): if isinstance(worker_set, list):
obs_space = trainer.get_policy().observation_space obs_space = trainer.get_policy(pid).observation_space
else: else:
obs_space = worker_set.local_worker().for_policy( obs_space = worker_set.local_worker().for_policy(
lambda p: p.observation_space) lambda p: p.observation_space, policy_id=pid)
obs_space = getattr(obs_space, "original_space", obs_space) obs_space = getattr(obs_space, "original_space", obs_space)
else: else:
obs_space = pol.observation_space obs_space = pol.observation_space

View file

@ -255,7 +255,8 @@ def get_tf_eager_cls_if_necessary(
cls = orig_cls.as_eager() cls = orig_cls.as_eager()
if config.get("eager_tracing"): if config.get("eager_tracing"):
cls = cls.with_tracing() cls = cls.with_tracing()
# Could be some other type of policy. # Could be some other type of policy or already
# eager-ized.
elif not issubclass(orig_cls, TFPolicy): elif not issubclass(orig_cls, TFPolicy):
pass pass
else: else:

View file

@ -136,8 +136,8 @@ def convert_to_torch_tensor(x: TensorStructType, device: Optional[str] = None):
item.max_len) item.max_len)
# Numpy arrays. # Numpy arrays.
if isinstance(item, np.ndarray): if isinstance(item, np.ndarray):
# np.object_ type (e.g. info dicts in train batch): leave as-is. # Object type (e.g. info dicts in train batch): leave as-is.
if item.dtype == np.object_: if item.dtype == object:
return item return item
# Non-writable numpy-arrays will cause PyTorch warning. # Non-writable numpy-arrays will cause PyTorch warning.
elif item.flags.writeable is False: elif item.flags.writeable is False:

View file

@ -1,5 +1,6 @@
import gym import gym
from typing import Any, Dict, List, Tuple, Union, TYPE_CHECKING from typing import Any, Dict, List, Tuple, Union, TypeVar, \
TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from ray.rllib.utils import try_import_tf, try_import_torch from ray.rllib.utils import try_import_tf, try_import_torch
@ -123,3 +124,6 @@ TensorShape = Union[Tuple[int], List[int]]
# A (possibly nested) space struct: Either a gym.spaces.Space or a # A (possibly nested) space struct: Either a gym.spaces.Space or a
# (possibly nested) dict|tuple of gym.space.Spaces. # (possibly nested) dict|tuple of gym.space.Spaces.
SpaceStruct = Union[gym.spaces.Space, dict, tuple] SpaceStruct = Union[gym.spaces.Space, dict, tuple]
# Generic type var.
T = TypeVar("T")