2022-01-05 11:29:44 +01:00
|
|
|
import gym
|
2022-01-29 18:41:57 -08:00
|
|
|
from typing import (
|
|
|
|
Any,
|
|
|
|
Callable,
|
|
|
|
Dict,
|
|
|
|
List,
|
|
|
|
Optional,
|
|
|
|
Tuple,
|
|
|
|
Union,
|
|
|
|
TypeVar,
|
|
|
|
TYPE_CHECKING,
|
|
|
|
)
|
2020-06-19 13:09:05 -07:00
|
|
|
|
2021-05-03 14:23:28 -07:00
|
|
|
if TYPE_CHECKING:
|
2022-01-25 14:16:58 +01:00
|
|
|
from ray.rllib.env.env_context import EnvContext
|
2021-07-15 05:51:24 -04:00
|
|
|
from ray.rllib.policy.policy import PolicySpec
|
2021-05-03 14:23:28 -07:00
|
|
|
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
|
|
|
|
from ray.rllib.policy.view_requirement import ViewRequirement
|
2022-01-25 14:16:58 +01:00
|
|
|
from ray.rllib.utils import try_import_tf, try_import_torch
|
2022-01-29 18:41:57 -08:00
|
|
|
|
2022-01-25 14:16:58 +01:00
|
|
|
_, tf, _ = try_import_tf()
|
|
|
|
torch, _ = try_import_torch()
|
|
|
|
|
|
|
|
# Represents a generic tensor type.
|
|
|
|
# This could be an np.ndarray, tf.Tensor, or a torch.Tensor.
|
|
|
|
TensorType = Any
|
|
|
|
|
|
|
|
# Either a plain tensor, or a dict or tuple of tensors (or StructTensors).
|
|
|
|
TensorStructType = Union[TensorType, dict, tuple]
|
|
|
|
|
|
|
|
# A shape of a tensor.
|
|
|
|
TensorShape = Union[Tuple[int], List[int]]
|
2021-05-03 14:23:28 -07:00
|
|
|
|
2020-06-19 13:09:05 -07:00
|
|
|
# Represents a fully filled out config of a Trainer class.
|
2020-07-05 13:09:51 +02:00
|
|
|
# Note: Policy config dicts are usually the same as TrainerConfigDict, but
|
|
|
|
# parts of it may sometimes be altered in e.g. a multi-agent setup,
|
|
|
|
# where we have >1 Policies in the same Trainer.
|
2020-06-19 13:09:05 -07:00
|
|
|
TrainerConfigDict = dict
|
|
|
|
|
|
|
|
# A trainer config dict that only has overrides. It needs to be combined with
|
|
|
|
# the default trainer config to be used.
|
|
|
|
PartialTrainerConfigDict = dict
|
|
|
|
|
|
|
|
# Represents the model config sub-dict of the trainer config that is passed to
|
|
|
|
# the model catalog.
|
|
|
|
ModelConfigDict = dict
|
|
|
|
|
2020-08-19 17:49:50 +02:00
|
|
|
# Objects that can be created through the `from_config()` util method
|
|
|
|
# need a config dict with a "type" key, a class path (str), or a type directly.
|
|
|
|
FromConfigSpec = Union[Dict[str, Any], type, str]
|
|
|
|
|
2022-01-25 14:16:58 +01:00
|
|
|
# Represents the env_config sub-dict of the trainer config that is passed to
|
|
|
|
# the env constructor.
|
|
|
|
EnvConfigDict = dict
|
|
|
|
|
|
|
|
# Represents an environment id. These could be:
|
|
|
|
# - An int index for a sub-env within a vectorized env.
|
|
|
|
# - An external env ID (str), which changes(!) each episode.
|
|
|
|
EnvID = Union[int, str]
|
|
|
|
|
2020-06-19 13:09:05 -07:00
|
|
|
# Represents a BaseEnv, MultiAgentEnv, ExternalEnv, ExternalMultiAgentEnv,
|
2021-11-17 21:40:16 +01:00
|
|
|
# VectorEnv, gym.Env, or ActorHandle.
|
2020-06-19 13:09:05 -07:00
|
|
|
EnvType = Any
|
|
|
|
|
2022-01-25 14:16:58 +01:00
|
|
|
# A callable, taking a EnvContext object
|
|
|
|
# (config dict + properties: `worker_index`, `vector_index`, `num_workers`,
|
|
|
|
# and `remote`) and returning an env object (or None if no env is used).
|
|
|
|
EnvCreator = Callable[["EnvContext"], Optional[EnvType]]
|
|
|
|
|
2020-06-19 13:09:05 -07:00
|
|
|
# Represents a generic identifier for an agent (e.g., "agent1").
|
|
|
|
AgentID = Any
|
|
|
|
|
|
|
|
# Represents a generic identifier for a policy (e.g., "pol1").
|
|
|
|
PolicyID = str
|
|
|
|
|
|
|
|
# Type of the config["multiagent"]["policies"] dict for multi-agent training.
|
2021-07-15 05:51:24 -04:00
|
|
|
MultiAgentPolicyConfigDict = Dict[PolicyID, "PolicySpec"]
|
2020-06-19 13:09:05 -07:00
|
|
|
|
2022-01-25 14:16:58 +01:00
|
|
|
# State dict of a Policy, mapping strings (e.g. "weights") to some state
|
|
|
|
# data (TensorStructType).
|
|
|
|
PolicyState = Dict[str, TensorStructType]
|
2020-06-19 13:09:05 -07:00
|
|
|
|
2020-07-29 21:15:09 +02:00
|
|
|
# Represents an episode id.
|
|
|
|
EpisodeID = int
|
|
|
|
|
2020-08-21 12:35:16 +02:00
|
|
|
# Represents an "unroll" (maybe across different sub-envs in a vector env).
|
|
|
|
UnrollID = int
|
|
|
|
|
2020-06-19 13:09:05 -07:00
|
|
|
# A dict keyed by agent ids, e.g. {"agent-1": value}.
|
|
|
|
MultiAgentDict = Dict[AgentID, Any]
|
|
|
|
|
|
|
|
# A dict keyed by env ids that contain further nested dictionaries keyed by
|
|
|
|
# agent ids. e.g., {"env-1": {"agent-1": value}}.
|
|
|
|
MultiEnvDict = Dict[EnvID, MultiAgentDict]
|
|
|
|
|
|
|
|
# Represents an observation returned from the env.
|
|
|
|
EnvObsType = Any
|
|
|
|
|
|
|
|
# Represents an action passed to the env.
|
|
|
|
EnvActionType = Any
|
|
|
|
|
|
|
|
# Info dictionary returned by calling step() on gym envs. Commonly empty dict.
|
|
|
|
EnvInfoDict = dict
|
|
|
|
|
2020-07-27 14:01:17 -07:00
|
|
|
# Represents a File object
|
|
|
|
FileType = Any
|
|
|
|
|
2020-12-07 13:08:17 +01:00
|
|
|
# Represents a ViewRequirements dict mapping column names (str) to
|
|
|
|
# ViewRequirement objects.
|
|
|
|
ViewRequirementsDict = Dict[str, "ViewRequirement"]
|
|
|
|
|
2020-06-19 13:09:05 -07:00
|
|
|
# Represents the result dict returned by Trainer.train().
|
|
|
|
ResultDict = dict
|
|
|
|
|
2020-08-19 17:49:50 +02:00
|
|
|
# A tf or torch local optimizer object.
|
2022-01-29 18:41:57 -08:00
|
|
|
LocalOptimizer = Union["tf.keras.optimizers.Optimizer", "torch.optim.Optimizer"]
|
2020-08-19 17:49:50 +02:00
|
|
|
|
2020-06-19 13:09:05 -07:00
|
|
|
# Dict of tensors returned by compute gradients on the policy, e.g.,
|
|
|
|
# {"td_error": [...], "learner_stats": {"vf_loss": ..., ...}}, for multi-agent,
|
|
|
|
# {"policy1": {"learner_stats": ..., }, "policy2": ...}.
|
|
|
|
GradInfoDict = dict
|
|
|
|
|
|
|
|
# Dict of learner stats returned by compute gradients on the policy, e.g.,
|
|
|
|
# {"vf_loss": ..., ...}. This will always be nested under the "learner_stats"
|
|
|
|
# key(s) of a GradInfoDict. In the multi-agent case, this will be keyed by
|
|
|
|
# policy id.
|
|
|
|
LearnerStatsDict = dict
|
|
|
|
|
2020-07-05 13:09:51 +02:00
|
|
|
# List of grads+var tuples (tf) or list of gradient tensors (torch)
|
|
|
|
# representing model gradients and returned by compute_gradients().
|
|
|
|
ModelGradients = Union[List[Tuple[TensorType, TensorType]], List[TensorType]]
|
2020-06-19 13:09:05 -07:00
|
|
|
|
|
|
|
# Type of dict returned by get_weights() representing model weights.
|
|
|
|
ModelWeights = dict
|
|
|
|
|
2021-10-25 15:00:00 +02:00
|
|
|
# An input dict used for direct ModelV2 calls.
|
2020-12-21 02:22:32 +01:00
|
|
|
ModelInputDict = Dict[str, TensorType]
|
|
|
|
|
2020-06-19 13:09:05 -07:00
|
|
|
# Some kind of sample batch.
|
|
|
|
SampleBatchType = Union["SampleBatch", "MultiAgentBatch"]
|
|
|
|
|
2022-01-05 11:29:44 +01:00
|
|
|
# A (possibly nested) space struct: Either a gym.spaces.Space or a
|
|
|
|
# (possibly nested) dict|tuple of gym.space.Spaces.
|
|
|
|
SpaceStruct = Union[gym.spaces.Space, dict, tuple]
|
2022-01-10 11:22:55 +01:00
|
|
|
|
|
|
|
# Generic type var.
|
|
|
|
T = TypeVar("T")
|