mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
123 lines
4.9 KiB
Python
123 lines
4.9 KiB
Python
import gym
|
|
from gym.spaces import Box, Discrete
|
|
import logging
|
|
from typing import Tuple, Type
|
|
|
|
import ray
|
|
from ray.rllib.agents.a3c.a3c_torch_policy import vf_preds_fetches
|
|
from ray.rllib.agents.maml.maml_torch_policy import setup_mixins, \
|
|
maml_loss, maml_stats, maml_optimizer_fn, KLCoeffMixin
|
|
from ray.rllib.agents.ppo.ppo_tf_policy import setup_config
|
|
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch
|
|
from ray.rllib.models.catalog import ModelCatalog
|
|
from ray.rllib.models.modelv2 import ModelV2
|
|
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
|
|
from ray.rllib.policy.policy import Policy
|
|
from ray.rllib.policy.policy_template import build_policy_class
|
|
from ray.rllib.utils.error import UnsupportedSpaceException
|
|
from ray.rllib.utils.framework import try_import_torch
|
|
from ray.rllib.utils.torch_utils import apply_grad_clipping
|
|
from ray.rllib.utils.typing import TrainerConfigDict
|
|
|
|
torch, nn = try_import_torch()
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def validate_spaces(policy: Policy, observation_space: gym.spaces.Space,
|
|
action_space: gym.spaces.Space,
|
|
config: TrainerConfigDict) -> None:
|
|
"""Validates the observation- and action spaces used for the Policy.
|
|
|
|
Args:
|
|
policy (Policy): The policy, whose spaces are being validated.
|
|
observation_space (gym.spaces.Space): The observation space to
|
|
validate.
|
|
action_space (gym.spaces.Space): The action space to validate.
|
|
config (TrainerConfigDict): The Policy's config dict.
|
|
|
|
Raises:
|
|
UnsupportedSpaceException: If one of the spaces is not supported.
|
|
"""
|
|
# Only support single Box or single Discrete spaces.
|
|
if not isinstance(action_space, (Box, Discrete)):
|
|
raise UnsupportedSpaceException(
|
|
"Action space ({}) of {} is not supported for "
|
|
"MB-MPO. Must be [Box|Discrete].".format(action_space, policy))
|
|
# If Box, make sure it's a 1D vector space.
|
|
elif isinstance(action_space, Box) and len(action_space.shape) > 1:
|
|
raise UnsupportedSpaceException(
|
|
"Action space ({}) of {} has multiple dimensions "
|
|
"{}. ".format(action_space, policy, action_space.shape) +
|
|
"Consider reshaping this into a single dimension Box space "
|
|
"or using the multi-agent API.")
|
|
|
|
|
|
def make_model_and_action_dist(
|
|
policy: Policy,
|
|
obs_space: gym.spaces.Space,
|
|
action_space: gym.spaces.Space,
|
|
config: TrainerConfigDict) -> \
|
|
Tuple[ModelV2, Type[TorchDistributionWrapper]]:
|
|
"""Constructs the necessary ModelV2 and action dist class for the Policy.
|
|
|
|
Args:
|
|
policy (Policy): The TFPolicy that will use the models.
|
|
obs_space (gym.spaces.Space): The observation space.
|
|
action_space (gym.spaces.Space): The action space.
|
|
config (TrainerConfigDict): The SAC trainer's config dict.
|
|
|
|
Returns:
|
|
ModelV2: The ModelV2 to be used by the Policy. Note: An additional
|
|
target model will be created in this function and assigned to
|
|
`policy.target_model`.
|
|
"""
|
|
# Get the output distribution class for predicting rewards and next-obs.
|
|
policy.distr_cls_next_obs, num_outputs = ModelCatalog.get_action_dist(
|
|
obs_space, config, dist_type="deterministic", framework="torch")
|
|
|
|
# Build one dynamics model if we are a Worker.
|
|
# If we are the main MAML learner, build n (num_workers) dynamics Models
|
|
# for being able to create checkpoints for the current state of training.
|
|
device = (torch.device("cuda")
|
|
if torch.cuda.is_available() else torch.device("cpu"))
|
|
policy.dynamics_model = ModelCatalog.get_model_v2(
|
|
obs_space,
|
|
action_space,
|
|
num_outputs=num_outputs,
|
|
model_config=config["dynamics_model"],
|
|
framework="torch",
|
|
name="dynamics_ensemble",
|
|
).to(device)
|
|
|
|
action_dist, num_outputs = ModelCatalog.get_action_dist(
|
|
action_space, config, framework="torch")
|
|
# Create the pi-model and register it with the Policy.
|
|
policy.pi = ModelCatalog.get_model_v2(
|
|
obs_space,
|
|
action_space,
|
|
num_outputs=num_outputs,
|
|
model_config=config["model"],
|
|
framework="torch",
|
|
name="policy_model",
|
|
)
|
|
|
|
return policy.pi, action_dist
|
|
|
|
|
|
# Build a child class of `TorchPolicy`, given the custom functions defined
|
|
# above.
|
|
MBMPOTorchPolicy = build_policy_class(
|
|
name="MBMPOTorchPolicy",
|
|
framework="torch",
|
|
get_default_config=lambda: ray.rllib.agents.mbmpo.mbmpo.DEFAULT_CONFIG,
|
|
make_model_and_action_dist=make_model_and_action_dist,
|
|
loss_fn=maml_loss,
|
|
stats_fn=maml_stats,
|
|
optimizer_fn=maml_optimizer_fn,
|
|
extra_action_out_fn=vf_preds_fetches,
|
|
postprocess_fn=compute_gae_for_sample_batch,
|
|
extra_grad_process_fn=apply_grad_clipping,
|
|
before_init=setup_config,
|
|
after_init=setup_mixins,
|
|
mixins=[KLCoeffMixin])
|