2020-10-06 20:28:16 +02:00
|
|
|
import logging
|
2020-08-02 09:12:09 -07:00
|
|
|
import numpy as np
|
2022-05-30 17:33:01 +02:00
|
|
|
from typing import List, Optional, Type
|
2020-10-06 20:28:16 +02:00
|
|
|
|
2020-08-02 09:12:09 -07:00
|
|
|
import ray
|
2022-05-16 00:45:32 -07:00
|
|
|
from ray.rllib.algorithms.mbmpo.model_ensemble import DynamicsEnsembleCustomModel
|
|
|
|
from ray.rllib.algorithms.mbmpo.utils import calculate_gae_advantages, MBMPOExploration
|
2022-06-11 15:10:39 +02:00
|
|
|
from ray.rllib.algorithms.algorithm import Algorithm
|
|
|
|
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
|
2020-10-06 20:28:16 +02:00
|
|
|
from ray.rllib.env.env_context import EnvContext
|
2021-01-19 10:09:39 +01:00
|
|
|
from ray.rllib.env.wrappers.model_vector_env import model_vector_env
|
2020-10-06 20:28:16 +02:00
|
|
|
from ray.rllib.evaluation.metrics import (
|
|
|
|
collect_episodes,
|
|
|
|
collect_metrics,
|
|
|
|
get_learner_stats,
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-10-06 20:28:16 +02:00
|
|
|
from ray.rllib.evaluation.worker_set import WorkerSet
|
2020-08-02 09:12:09 -07:00
|
|
|
from ray.rllib.execution.common import (
|
|
|
|
STEPS_SAMPLED_COUNTER,
|
2021-10-12 07:03:41 -07:00
|
|
|
STEPS_TRAINED_COUNTER,
|
|
|
|
STEPS_TRAINED_THIS_ITER_COUNTER,
|
|
|
|
_get_shared_metrics,
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-08-02 09:12:09 -07:00
|
|
|
from ray.rllib.execution.metric_ops import CollectMetrics
|
2021-12-02 13:17:10 +01:00
|
|
|
from ray.rllib.policy.policy import Policy
|
2020-10-06 20:28:16 +02:00
|
|
|
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
|
2022-05-30 17:33:01 +02:00
|
|
|
from ray.rllib.utils.annotations import Deprecated, override
|
2021-01-19 09:51:35 +01:00
|
|
|
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
2021-09-30 16:39:05 +02:00
|
|
|
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
|
2020-10-06 20:28:16 +02:00
|
|
|
from ray.rllib.utils.sgd import standardized
|
2021-11-03 10:00:46 +01:00
|
|
|
from ray.rllib.utils.torch_utils import convert_to_torch_tensor
|
2022-06-11 15:10:39 +02:00
|
|
|
from ray.rllib.utils.typing import EnvType, AlgorithmConfigDict
|
2020-10-06 20:28:16 +02:00
|
|
|
from ray.util.iter import from_actors, LocalIterator
|
2020-08-02 09:12:09 -07:00
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
2022-05-30 17:33:01 +02:00
|
|
|
|
2022-06-11 15:10:39 +02:00
|
|
|
class MBMPOConfig(AlgorithmConfig):
|
|
|
|
"""Defines a configuration class from which an MBMPO Algorithm can be built.
|
2022-05-30 17:33:01 +02:00
|
|
|
|
|
|
|
Example:
|
|
|
|
>>> from ray.rllib.algorithms.mbmpo import MBMPOConfig
|
|
|
|
>>> config = MBMPOConfig().training(lr=0.0003, train_batch_size=512)\
|
|
|
|
... .resources(num_gpus=4)\
|
|
|
|
... .rollouts(num_rollout_workers=64)
|
|
|
|
>>> print(config.to_dict())
|
2022-06-11 15:10:39 +02:00
|
|
|
>>> # Build a Algorithm object from the config and run 1 training iteration.
|
2022-05-30 17:33:01 +02:00
|
|
|
>>> trainer = config.build(env="CartPole-v1")
|
|
|
|
>>> trainer.train()
|
|
|
|
|
|
|
|
Example:
|
|
|
|
>>> from ray.rllib.algorithms.mbmpo import MBMPOConfig
|
|
|
|
>>> from ray import tune
|
|
|
|
>>> config = MBMPOConfig()
|
|
|
|
>>> # Print out some default values.
|
|
|
|
>>> print(config.vtrace)
|
|
|
|
>>> # Update the config object.
|
|
|
|
>>> config.training(lr=tune.grid_search([0.0001, 0.0003]), grad_clip=20.0)
|
|
|
|
>>> # Set the config object's env.
|
|
|
|
>>> config.environment(env="CartPole-v1")
|
|
|
|
>>> # Use to_dict() to get the old-style python config dict
|
|
|
|
>>> # when running with tune.
|
|
|
|
>>> tune.run(
|
|
|
|
... "AlphaStar",
|
|
|
|
... stop={"episode_reward_mean": 200},
|
|
|
|
... config=config.to_dict(),
|
|
|
|
... )
|
|
|
|
"""
|
|
|
|
|
2022-06-11 15:10:39 +02:00
|
|
|
def __init__(self, algo_class=None):
|
2022-05-30 17:33:01 +02:00
|
|
|
"""Initializes a MBMPOConfig instance."""
|
2022-06-11 15:10:39 +02:00
|
|
|
super().__init__(algo_class=algo_class or MBMPO)
|
2022-05-30 17:33:01 +02:00
|
|
|
|
|
|
|
# fmt: off
|
|
|
|
# __sphinx_doc_begin__
|
|
|
|
|
|
|
|
# MBMPO specific config settings:
|
|
|
|
# If true, use the Generalized Advantage Estimator (GAE)
|
|
|
|
# with a value function, see https://arxiv.org/pdf/1506.02438.pdf.
|
|
|
|
self.use_gae = True
|
|
|
|
# GAE(lambda) parameter.
|
|
|
|
self.lambda_ = 1.0
|
|
|
|
# Initial coefficient for KL divergence.
|
|
|
|
self.kl_coeff = 0.0005
|
|
|
|
|
|
|
|
# Coefficient of the value function loss.
|
|
|
|
self.vf_loss_coeff = 0.5
|
|
|
|
# Coefficient of the entropy regularizer.
|
|
|
|
self.entropy_coeff = 0.0
|
|
|
|
# PPO clip parameter.
|
|
|
|
self.clip_param = 0.5
|
|
|
|
# Clip param for the value function. Note that this is sensitive to the
|
|
|
|
# scale of the rewards. If your expected V is large, increase this.
|
|
|
|
self.vf_clip_param = 10.0
|
|
|
|
# If specified, clip the global norm of gradients by this amount.
|
|
|
|
self.grad_clip = None
|
|
|
|
# Target value for KL divergence.
|
|
|
|
self.kl_target = 0.01
|
|
|
|
# Number of Inner adaptation steps for the MAML algorithm.
|
|
|
|
self.inner_adaptation_steps = 1
|
|
|
|
# Number of MAML steps per meta-update iteration (PPO steps).
|
|
|
|
self.maml_optimizer_steps = 8
|
|
|
|
# Inner adaptation step size.
|
|
|
|
self.inner_lr = 1e-3
|
|
|
|
# Horizon of the environment (200 in MB-MPO paper).
|
|
|
|
self.horizon = 200
|
|
|
|
# Dynamics ensemble hyperparameters.
|
|
|
|
self.dynamics_model = {
|
|
|
|
"custom_model": DynamicsEnsembleCustomModel,
|
|
|
|
# Number of Transition-Dynamics (TD) models in the ensemble.
|
|
|
|
"ensemble_size": 5,
|
|
|
|
# Hidden layers for each model in the TD-model ensemble.
|
|
|
|
"fcnet_hiddens": [512, 512, 512],
|
|
|
|
# Model learning rate.
|
|
|
|
"lr": 1e-3,
|
|
|
|
# Max number of training epochs per MBMPO iter.
|
|
|
|
"train_epochs": 500,
|
|
|
|
# Model batch size.
|
|
|
|
"batch_size": 500,
|
|
|
|
# Training/validation split.
|
|
|
|
"valid_split_ratio": 0.2,
|
|
|
|
# Normalize data (obs, action, and deltas).
|
|
|
|
"normalize_data": True,
|
|
|
|
}
|
|
|
|
# Workers sample from dynamics models, not from actual envs.
|
|
|
|
self.custom_vector_env = model_vector_env
|
|
|
|
# How many iterations through MAML per MBMPO iteration.
|
|
|
|
self.num_maml_steps = 10
|
|
|
|
|
2022-06-11 15:10:39 +02:00
|
|
|
# Override some of AlgorithmConfig's default values with MBMPO-specific
|
2022-05-30 17:33:01 +02:00
|
|
|
# values.
|
|
|
|
self.batch_mode = "complete_episodes"
|
|
|
|
# Size of batches collected from each worker.
|
|
|
|
self.rollout_fragment_length = 200
|
|
|
|
# Do create an actual env on the local worker (worker-idx=0).
|
|
|
|
self.create_env_on_local_worker = True
|
|
|
|
# Step size of SGD.
|
|
|
|
self.lr = 1e-3
|
|
|
|
# Exploration for MB-MPO is based on StochasticSampling, but uses 8000
|
|
|
|
# random timesteps up-front for worker=0.
|
|
|
|
self.exploration_config = {
|
|
|
|
"type": MBMPOExploration,
|
|
|
|
"random_timesteps": 8000,
|
|
|
|
}
|
|
|
|
|
|
|
|
# __sphinx_doc_end__
|
|
|
|
# fmt: on
|
|
|
|
|
|
|
|
self.vf_share_layers = DEPRECATED_VALUE
|
|
|
|
self._disable_execution_plan_api = False
|
|
|
|
|
2022-06-11 15:10:39 +02:00
|
|
|
@override(AlgorithmConfig)
|
2022-05-30 17:33:01 +02:00
|
|
|
def training(
|
|
|
|
self,
|
|
|
|
*,
|
|
|
|
use_gae: Optional[float] = None,
|
|
|
|
lambda_: Optional[float] = None,
|
|
|
|
kl_coeff: Optional[float] = None,
|
|
|
|
vf_loss_coeff: Optional[float] = None,
|
|
|
|
entropy_coeff: Optional[float] = None,
|
|
|
|
clip_param: Optional[float] = None,
|
|
|
|
vf_clip_param: Optional[float] = None,
|
|
|
|
grad_clip: Optional[float] = None,
|
|
|
|
kl_target: Optional[float] = None,
|
|
|
|
inner_adaptation_steps: Optional[int] = None,
|
|
|
|
maml_optimizer_steps: Optional[int] = None,
|
|
|
|
inner_lr: Optional[float] = None,
|
|
|
|
horizon: Optional[int] = None,
|
|
|
|
dynamics_model: Optional[dict] = None,
|
|
|
|
custom_vector_env: Optional[type] = None,
|
|
|
|
num_maml_steps: Optional[int] = None,
|
|
|
|
**kwargs,
|
|
|
|
) -> "MBMPOConfig":
|
|
|
|
"""Sets the training related configuration.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
use_gae: If true, use the Generalized Advantage Estimator (GAE)
|
|
|
|
with a value function, see https://arxiv.org/pdf/1506.02438.pdf.
|
|
|
|
lambda_: The GAE (lambda) parameter.
|
|
|
|
kl_coeff: Initial coefficient for KL divergence.
|
|
|
|
vf_loss_coeff: Coefficient of the value function loss.
|
|
|
|
entropy_coeff: Coefficient of the entropy regularizer.
|
|
|
|
clip_param: PPO clip parameter.
|
|
|
|
vf_clip_param: Clip param for the value function. Note that this is
|
|
|
|
sensitive to the scale of the rewards. If your expected V is large,
|
|
|
|
increase this.
|
|
|
|
grad_clip: If specified, clip the global norm of gradients by this amount.
|
|
|
|
kl_target: Target value for KL divergence.
|
|
|
|
inner_adaptation_steps: Number of Inner adaptation steps for the MAML
|
|
|
|
algorithm.
|
|
|
|
maml_optimizer_steps: Number of MAML steps per meta-update iteration
|
|
|
|
(PPO steps).
|
|
|
|
inner_lr: Inner adaptation step size.
|
|
|
|
horizon: Horizon of the environment (200 in MB-MPO paper).
|
|
|
|
dynamics_model: Dynamics ensemble hyperparameters.
|
|
|
|
custom_vector_env: Workers sample from dynamics models, not from actual
|
|
|
|
envs.
|
|
|
|
num_maml_steps: How many iterations through MAML per MBMPO iteration.
|
|
|
|
|
|
|
|
Returns:
|
2022-06-11 15:10:39 +02:00
|
|
|
This updated AlgorithmConfig object.
|
2022-05-30 17:33:01 +02:00
|
|
|
"""
|
|
|
|
# Pass kwargs onto super's `training()` method.
|
|
|
|
super().training(**kwargs)
|
|
|
|
|
|
|
|
if use_gae is not None:
|
|
|
|
self.use_gae = use_gae
|
|
|
|
if lambda_ is not None:
|
|
|
|
self.lambda_ = lambda_
|
|
|
|
if kl_coeff is not None:
|
|
|
|
self.kl_coeff = kl_coeff
|
|
|
|
if vf_loss_coeff is not None:
|
|
|
|
self.vf_loss_coeff = vf_loss_coeff
|
|
|
|
if entropy_coeff is not None:
|
|
|
|
self.entropy_coeff = entropy_coeff
|
|
|
|
if clip_param is not None:
|
|
|
|
self.clip_param = clip_param
|
|
|
|
if vf_clip_param is not None:
|
|
|
|
self.vf_clip_param = vf_clip_param
|
|
|
|
if grad_clip is not None:
|
|
|
|
self.grad_clip = grad_clip
|
|
|
|
if kl_target is not None:
|
|
|
|
self.kl_target = kl_target
|
|
|
|
if inner_adaptation_steps is not None:
|
|
|
|
self.inner_adaptation_steps = inner_adaptation_steps
|
|
|
|
if maml_optimizer_steps is not None:
|
|
|
|
self.maml_optimizer_steps = maml_optimizer_steps
|
|
|
|
if inner_lr is not None:
|
|
|
|
self.inner_lr = inner_lr
|
|
|
|
if horizon is not None:
|
|
|
|
self.horizon = horizon
|
|
|
|
if dynamics_model is not None:
|
|
|
|
self.dynamics_model = dynamics_model
|
|
|
|
if custom_vector_env is not None:
|
|
|
|
self.custom_vector_env = custom_vector_env
|
|
|
|
if num_maml_steps is not None:
|
|
|
|
self.num_maml_steps = num_maml_steps
|
|
|
|
|
|
|
|
return self
|
|
|
|
|
2020-08-02 09:12:09 -07:00
|
|
|
|
|
|
|
# Select Metric Keys for MAML Stats Tracing
|
|
|
|
METRICS_KEYS = ["episode_reward_mean", "episode_reward_min", "episode_reward_max"]
|
|
|
|
|
|
|
|
|
|
|
|
class MetaUpdate:
|
|
|
|
def __init__(self, workers, num_steps, maml_steps, metric_gen):
|
2020-10-06 20:28:16 +02:00
|
|
|
"""Computes the MetaUpdate step in MAML.
|
|
|
|
|
|
|
|
Adapted for MBMPO for multiple MAML Iterations.
|
2020-08-02 09:12:09 -07:00
|
|
|
|
2020-09-20 11:27:02 +02:00
|
|
|
Args:
|
2022-06-01 11:27:54 -07:00
|
|
|
workers: Set of Workers
|
|
|
|
num_steps: Number of meta-update steps per MAML Iteration
|
|
|
|
maml_steps: MAML Iterations per MBMPO Iteration
|
|
|
|
metric_gen: Generates metrics dictionary
|
2020-08-02 09:12:09 -07:00
|
|
|
|
|
|
|
Returns:
|
2022-06-01 11:27:54 -07:00
|
|
|
metrics: MBMPO metrics for logging.
|
2020-08-02 09:12:09 -07:00
|
|
|
"""
|
|
|
|
self.workers = workers
|
|
|
|
self.num_steps = num_steps
|
|
|
|
self.step_counter = 0
|
|
|
|
self.maml_optimizer_steps = maml_steps
|
|
|
|
self.metric_gen = metric_gen
|
|
|
|
self.metrics = {}
|
|
|
|
|
|
|
|
def __call__(self, data_tuple):
|
2020-09-20 11:27:02 +02:00
|
|
|
"""Args:
|
2022-06-01 11:27:54 -07:00
|
|
|
data_tuple: 1st element is samples collected from MAML
|
2020-08-02 09:12:09 -07:00
|
|
|
Inner adaptation steps and 2nd element is accumulated metrics
|
|
|
|
"""
|
2020-10-06 20:28:16 +02:00
|
|
|
# Metaupdate Step.
|
2020-08-02 09:12:09 -07:00
|
|
|
print("Meta-Update Step")
|
|
|
|
samples = data_tuple[0]
|
|
|
|
adapt_metrics_dict = data_tuple[1]
|
|
|
|
self.postprocess_metrics(
|
|
|
|
adapt_metrics_dict, prefix="MAMLIter{}".format(self.step_counter)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-08-02 09:12:09 -07:00
|
|
|
|
2020-10-06 20:28:16 +02:00
|
|
|
# MAML Meta-update.
|
2021-09-30 16:39:05 +02:00
|
|
|
fetches = None
|
2020-08-02 09:12:09 -07:00
|
|
|
for i in range(self.maml_optimizer_steps):
|
|
|
|
fetches = self.workers.local_worker().learn_on_batch(samples)
|
2021-09-30 16:39:05 +02:00
|
|
|
learner_stats = get_learner_stats(fetches)
|
2020-08-02 09:12:09 -07:00
|
|
|
|
2020-10-06 20:28:16 +02:00
|
|
|
# Update KLs.
|
2020-08-02 09:12:09 -07:00
|
|
|
def update(pi, pi_id):
|
2021-09-30 16:39:05 +02:00
|
|
|
assert "inner_kl" not in learner_stats, (
|
|
|
|
"inner_kl should be nested under policy id key",
|
|
|
|
learner_stats,
|
|
|
|
)
|
|
|
|
if pi_id in learner_stats:
|
|
|
|
assert "inner_kl" in learner_stats[pi_id], (learner_stats, pi_id)
|
|
|
|
pi.update_kls(learner_stats[pi_id]["inner_kl"])
|
2020-08-02 09:12:09 -07:00
|
|
|
else:
|
|
|
|
logger.warning("No data for {}, not updating kl".format(pi_id))
|
|
|
|
|
2022-01-27 12:17:34 +01:00
|
|
|
self.workers.local_worker().foreach_policy_to_train(update)
|
2020-08-02 09:12:09 -07:00
|
|
|
|
2020-10-06 20:28:16 +02:00
|
|
|
# Modify Reporting Metrics.
|
2020-08-02 09:12:09 -07:00
|
|
|
metrics = _get_shared_metrics()
|
|
|
|
metrics.info[LEARNER_INFO] = fetches
|
2021-10-12 07:03:41 -07:00
|
|
|
metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = samples.count
|
2020-08-02 09:12:09 -07:00
|
|
|
metrics.counters[STEPS_TRAINED_COUNTER] += samples.count
|
|
|
|
|
2020-09-09 00:34:34 -07:00
|
|
|
if self.step_counter == self.num_steps - 1:
|
2020-08-02 09:12:09 -07:00
|
|
|
td_metric = self.workers.local_worker().foreach_policy(fit_dynamics)[0]
|
|
|
|
|
2020-10-06 20:28:16 +02:00
|
|
|
# Sync workers with meta policy.
|
2020-08-02 09:12:09 -07:00
|
|
|
self.workers.sync_weights()
|
|
|
|
|
2020-10-06 20:28:16 +02:00
|
|
|
# Sync TD Models with workers.
|
2020-08-02 09:12:09 -07:00
|
|
|
sync_ensemble(self.workers)
|
|
|
|
sync_stats(self.workers)
|
|
|
|
|
|
|
|
metrics.counters[STEPS_SAMPLED_COUNTER] = td_metric[STEPS_SAMPLED_COUNTER]
|
|
|
|
|
2020-10-06 20:28:16 +02:00
|
|
|
# Modify to CollectMetrics.
|
2020-08-02 09:12:09 -07:00
|
|
|
res = self.metric_gen.__call__(None)
|
|
|
|
res.update(self.metrics)
|
|
|
|
self.step_counter = 0
|
|
|
|
print("MB-MPO Iteration Completed")
|
|
|
|
return [res]
|
|
|
|
else:
|
|
|
|
print("MAML Iteration {} Completed".format(self.step_counter))
|
|
|
|
self.step_counter += 1
|
|
|
|
|
|
|
|
# Sync workers with meta policy
|
|
|
|
print("Syncing Weights with Workers")
|
|
|
|
self.workers.sync_weights()
|
|
|
|
return []
|
|
|
|
|
|
|
|
def postprocess_metrics(self, metrics, prefix=""):
|
|
|
|
"""Appends prefix to current metrics
|
|
|
|
|
2020-09-20 11:27:02 +02:00
|
|
|
Args:
|
2022-06-01 11:27:54 -07:00
|
|
|
metrics: Dictionary of current metrics
|
|
|
|
prefix: Prefix string to be appended
|
2020-08-02 09:12:09 -07:00
|
|
|
"""
|
|
|
|
for key in metrics.keys():
|
|
|
|
self.metrics[prefix + "_" + key] = metrics[key]
|
|
|
|
|
|
|
|
|
|
|
|
def post_process_metrics(prefix, workers, metrics):
|
2020-10-06 20:28:16 +02:00
|
|
|
"""Update current dataset metrics and filter out specific keys.
|
2020-08-02 09:12:09 -07:00
|
|
|
|
2020-09-20 11:27:02 +02:00
|
|
|
Args:
|
2022-06-01 11:27:54 -07:00
|
|
|
prefix: Prefix string to be appended
|
|
|
|
workers: Set of workers
|
|
|
|
metrics: Current metrics dictionary
|
2020-08-02 09:12:09 -07:00
|
|
|
"""
|
|
|
|
res = collect_metrics(remote_workers=workers.remote_workers())
|
|
|
|
for key in METRICS_KEYS:
|
|
|
|
metrics[prefix + "_" + key] = res[key]
|
|
|
|
return metrics
|
|
|
|
|
|
|
|
|
2020-10-06 20:28:16 +02:00
|
|
|
def inner_adaptation(workers: WorkerSet, samples: List[SampleBatch]):
|
|
|
|
"""Performs one gradient descend step on each remote worker.
|
|
|
|
|
|
|
|
Args:
|
2022-06-11 15:10:39 +02:00
|
|
|
workers: The WorkerSet of the Algorithm.
|
2020-10-06 20:28:16 +02:00
|
|
|
samples (List[SampleBatch]): The list of SampleBatches to perform
|
|
|
|
a training step on (one for each remote worker).
|
|
|
|
"""
|
|
|
|
|
2020-08-02 09:12:09 -07:00
|
|
|
for i, e in enumerate(workers.remote_workers()):
|
|
|
|
e.learn_on_batch.remote(samples[i])
|
|
|
|
|
|
|
|
|
|
|
|
def fit_dynamics(policy, pid):
|
|
|
|
return policy.dynamics_model.fit()
|
|
|
|
|
|
|
|
|
2020-10-06 20:28:16 +02:00
|
|
|
def sync_ensemble(workers: WorkerSet) -> None:
|
|
|
|
"""Syncs dynamics ensemble weights from driver (main) to workers.
|
2020-08-02 09:12:09 -07:00
|
|
|
|
2020-09-20 11:27:02 +02:00
|
|
|
Args:
|
2022-06-01 11:27:54 -07:00
|
|
|
workers: Set of workers, including driver (main).
|
2020-08-02 09:12:09 -07:00
|
|
|
"""
|
|
|
|
|
|
|
|
def get_ensemble_weights(worker):
|
|
|
|
policy_map = worker.policy_map
|
|
|
|
policies = policy_map.keys()
|
|
|
|
|
|
|
|
def policy_ensemble_weights(policy):
|
|
|
|
model = policy.dynamics_model
|
|
|
|
return {k: v.cpu().detach().numpy() for k, v in model.state_dict().items()}
|
|
|
|
|
|
|
|
return {
|
|
|
|
pid: policy_ensemble_weights(policy)
|
|
|
|
for pid, policy in policy_map.items()
|
|
|
|
if pid in policies
|
|
|
|
}
|
|
|
|
|
|
|
|
def set_ensemble_weights(policy, pid, weights):
|
|
|
|
weights = weights[pid]
|
|
|
|
weights = convert_to_torch_tensor(weights, device=policy.device)
|
|
|
|
model = policy.dynamics_model
|
|
|
|
model.load_state_dict(weights)
|
|
|
|
|
|
|
|
if workers.remote_workers():
|
|
|
|
weights = ray.put(get_ensemble_weights(workers.local_worker()))
|
|
|
|
set_func = ray.put(set_ensemble_weights)
|
|
|
|
for e in workers.remote_workers():
|
|
|
|
e.foreach_policy.remote(set_func, weights=weights)
|
|
|
|
|
|
|
|
|
2020-10-06 20:28:16 +02:00
|
|
|
def sync_stats(workers: WorkerSet) -> None:
|
2020-08-02 09:12:09 -07:00
|
|
|
def get_normalizations(worker):
|
|
|
|
policy = worker.policy_map[DEFAULT_POLICY_ID]
|
|
|
|
return policy.dynamics_model.normalizations
|
|
|
|
|
|
|
|
def set_normalizations(policy, pid, normalizations):
|
|
|
|
policy.dynamics_model.set_norms(normalizations)
|
|
|
|
|
|
|
|
if workers.remote_workers():
|
|
|
|
normalization_dict = ray.put(get_normalizations(workers.local_worker()))
|
|
|
|
set_func = ray.put(set_normalizations)
|
|
|
|
for e in workers.remote_workers():
|
|
|
|
e.foreach_policy.remote(set_func, normalizations=normalization_dict)
|
|
|
|
|
|
|
|
|
2022-06-11 15:10:39 +02:00
|
|
|
def post_process_samples(samples, config: AlgorithmConfigDict):
|
2020-08-02 09:12:09 -07:00
|
|
|
# Instead of using NN for value function, we use regression
|
|
|
|
split_lst = []
|
|
|
|
for sample in samples:
|
|
|
|
indexes = np.asarray(sample["dones"]).nonzero()[0]
|
|
|
|
indexes = indexes + 1
|
|
|
|
|
|
|
|
reward_list = np.split(sample["rewards"], indexes)[:-1]
|
|
|
|
observation_list = np.split(sample["obs"], indexes)[:-1]
|
|
|
|
|
|
|
|
paths = []
|
|
|
|
for i in range(0, len(reward_list)):
|
|
|
|
paths.append(
|
|
|
|
{"rewards": reward_list[i], "observations": observation_list[i]}
|
|
|
|
)
|
|
|
|
|
|
|
|
paths = calculate_gae_advantages(paths, config["gamma"], config["lambda"])
|
|
|
|
|
|
|
|
advantages = np.concatenate([path["advantages"] for path in paths])
|
|
|
|
sample["advantages"] = standardized(advantages)
|
|
|
|
split_lst.append(sample.count)
|
|
|
|
return samples, split_lst
|
|
|
|
|
|
|
|
|
2022-06-11 15:10:39 +02:00
|
|
|
class MBMPO(Algorithm):
|
|
|
|
"""Model-Based Meta Policy Optimization (MB-MPO) Algorithm.
|
2020-10-06 20:28:16 +02:00
|
|
|
|
2022-06-11 15:10:39 +02:00
|
|
|
This file defines the distributed Algorithm class for model-based meta
|
2021-12-02 13:17:10 +01:00
|
|
|
policy optimization.
|
|
|
|
See `mbmpo_[tf|torch]_policy.py` for the definition of the policy loss.
|
2020-10-06 20:28:16 +02:00
|
|
|
|
2021-12-02 13:17:10 +01:00
|
|
|
Detailed documentation:
|
|
|
|
https://docs.ray.io/en/master/rllib-algorithms.html#mbmpo
|
2020-10-06 20:28:16 +02:00
|
|
|
"""
|
2020-08-02 09:12:09 -07:00
|
|
|
|
2021-12-02 13:17:10 +01:00
|
|
|
@classmethod
|
2022-06-11 15:10:39 +02:00
|
|
|
@override(Algorithm)
|
|
|
|
def get_default_config(cls) -> AlgorithmConfigDict:
|
2021-12-02 13:17:10 +01:00
|
|
|
return DEFAULT_CONFIG
|
|
|
|
|
2022-06-11 15:10:39 +02:00
|
|
|
@override(Algorithm)
|
|
|
|
def validate_config(self, config: AlgorithmConfigDict) -> None:
|
2022-01-10 11:19:40 +01:00
|
|
|
# Call super's validation method.
|
|
|
|
super().validate_config(config)
|
|
|
|
|
2021-12-02 13:17:10 +01:00
|
|
|
if config["num_gpus"] > 1:
|
|
|
|
raise ValueError("`num_gpus` > 1 not yet supported for MB-MPO!")
|
|
|
|
if config["framework"] != "torch":
|
|
|
|
logger.warning(
|
|
|
|
"MB-MPO only supported in PyTorch so far! Switching to "
|
|
|
|
"`framework=torch`."
|
|
|
|
)
|
|
|
|
config["framework"] = "torch"
|
|
|
|
if config["inner_adaptation_steps"] <= 0:
|
|
|
|
raise ValueError("Inner adaptation steps must be >=1!")
|
|
|
|
if config["maml_optimizer_steps"] <= 0:
|
|
|
|
raise ValueError("PPO steps for meta-update needs to be >=0!")
|
|
|
|
if config["entropy_coeff"] < 0:
|
|
|
|
raise ValueError("`entropy_coeff` must be >=0.0!")
|
|
|
|
if config["batch_mode"] != "complete_episodes":
|
|
|
|
raise ValueError("`batch_mode=truncate_episodes` not supported!")
|
|
|
|
if config["num_workers"] <= 0:
|
|
|
|
raise ValueError("Must have at least 1 worker/task.")
|
|
|
|
if config["create_env_on_driver"] is False:
|
|
|
|
raise ValueError(
|
|
|
|
"Must have an actual Env created on the driver "
|
|
|
|
"(local) worker! Set `create_env_on_driver` to True."
|
|
|
|
)
|
|
|
|
|
2022-06-11 15:10:39 +02:00
|
|
|
@override(Algorithm)
|
|
|
|
def get_default_policy_class(self, config: AlgorithmConfigDict) -> Type[Policy]:
|
2022-05-20 05:10:59 -07:00
|
|
|
from ray.rllib.algorithms.mbmpo.mbmpo_torch_policy import MBMPOTorchPolicy
|
|
|
|
|
2021-12-02 13:17:10 +01:00
|
|
|
return MBMPOTorchPolicy
|
|
|
|
|
|
|
|
@staticmethod
|
2022-06-11 15:10:39 +02:00
|
|
|
@override(Algorithm)
|
2021-12-04 22:05:26 +01:00
|
|
|
def execution_plan(
|
2022-06-11 15:10:39 +02:00
|
|
|
workers: WorkerSet, config: AlgorithmConfigDict, **kwargs
|
2021-12-04 22:05:26 +01:00
|
|
|
) -> LocalIterator[dict]:
|
2021-12-02 13:17:10 +01:00
|
|
|
assert (
|
|
|
|
len(kwargs) == 0
|
|
|
|
), "MBMPO execution_plan does NOT take any additional parameters"
|
|
|
|
|
|
|
|
# Train TD Models on the driver.
|
|
|
|
workers.local_worker().foreach_policy(fit_dynamics)
|
|
|
|
|
|
|
|
# Sync driver's policy with workers.
|
|
|
|
workers.sync_weights()
|
|
|
|
|
|
|
|
# Sync TD Models and normalization stats with workers
|
|
|
|
sync_ensemble(workers)
|
|
|
|
sync_stats(workers)
|
|
|
|
|
|
|
|
# Dropping metrics from the first iteration
|
|
|
|
_, _ = collect_episodes(
|
|
|
|
workers.local_worker(), workers.remote_workers(), [], timeout_seconds=9999
|
|
|
|
)
|
|
|
|
|
|
|
|
# Metrics Collector.
|
|
|
|
metric_collect = CollectMetrics(
|
|
|
|
workers,
|
|
|
|
min_history=0,
|
2021-12-21 08:39:05 +01:00
|
|
|
timeout_seconds=config["metrics_episode_collection_timeout_s"],
|
|
|
|
)
|
2021-12-02 13:17:10 +01:00
|
|
|
|
|
|
|
num_inner_steps = config["inner_adaptation_steps"]
|
|
|
|
|
|
|
|
def inner_adaptation_steps(itr):
|
|
|
|
buf = []
|
|
|
|
split = []
|
|
|
|
metrics = {}
|
|
|
|
for samples in itr:
|
|
|
|
print("Collecting Samples, Inner Adaptation {}".format(len(split)))
|
|
|
|
# Processing Samples (Standardize Advantages)
|
|
|
|
samples, split_lst = post_process_samples(samples, config)
|
|
|
|
|
|
|
|
buf.extend(samples)
|
|
|
|
split.append(split_lst)
|
|
|
|
|
|
|
|
adapt_iter = len(split) - 1
|
|
|
|
prefix = "DynaTrajInner_" + str(adapt_iter)
|
|
|
|
metrics = post_process_metrics(prefix, workers, metrics)
|
|
|
|
|
|
|
|
if len(split) > num_inner_steps:
|
|
|
|
out = SampleBatch.concat_samples(buf)
|
|
|
|
out["split"] = np.array(split)
|
|
|
|
buf = []
|
|
|
|
split = []
|
|
|
|
|
|
|
|
yield out, metrics
|
|
|
|
metrics = {}
|
|
|
|
else:
|
|
|
|
inner_adaptation(workers, samples)
|
|
|
|
|
|
|
|
# Iterator for Inner Adaptation Data gathering (from pre->post
|
|
|
|
# adaptation).
|
|
|
|
rollouts = from_actors(workers.remote_workers())
|
|
|
|
rollouts = rollouts.batch_across_shards()
|
|
|
|
rollouts = rollouts.transform(inner_adaptation_steps)
|
|
|
|
|
|
|
|
# Meta update step with outer combine loop for multiple MAML
|
|
|
|
# iterations.
|
|
|
|
train_op = rollouts.combine(
|
|
|
|
MetaUpdate(
|
|
|
|
workers,
|
|
|
|
config["num_maml_steps"],
|
|
|
|
config["maml_optimizer_steps"],
|
|
|
|
metric_collect,
|
|
|
|
)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2021-12-02 13:17:10 +01:00
|
|
|
return train_op
|
|
|
|
|
|
|
|
@staticmethod
|
2022-06-11 15:10:39 +02:00
|
|
|
@override(Algorithm)
|
2021-12-02 13:17:10 +01:00
|
|
|
def validate_env(env: EnvType, env_context: EnvContext) -> None:
|
|
|
|
"""Validates the local_worker's env object (after creation).
|
2020-08-02 09:12:09 -07:00
|
|
|
|
2021-12-02 13:17:10 +01:00
|
|
|
Args:
|
|
|
|
env: The env object to check (for worker=0 only).
|
|
|
|
env_context: The env context used for the instantiation of
|
|
|
|
the local worker's env (worker=0).
|
2020-10-06 20:28:16 +02:00
|
|
|
|
2021-12-02 13:17:10 +01:00
|
|
|
Raises:
|
|
|
|
ValueError: In case something is wrong with the config.
|
|
|
|
"""
|
|
|
|
if not hasattr(env, "reward") or not callable(env.reward):
|
|
|
|
raise ValueError(
|
|
|
|
f"Env {env} doest not have a `reward()` method, needed for "
|
|
|
|
"MB-MPO! This `reward()` method should return "
|
|
|
|
)
|
2022-05-30 17:33:01 +02:00
|
|
|
|
|
|
|
|
|
|
|
# Deprecated: Use ray.rllib.algorithms.mbmpo.MBMPOConfig instead!
|
|
|
|
class _deprecated_default_config(dict):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__(MBMPOConfig().to_dict())
|
|
|
|
|
|
|
|
@Deprecated(
|
|
|
|
old="ray.rllib.algorithms.mbmpo.mbmpo.DEFAULT_CONFIG",
|
|
|
|
new="ray.rllib.algorithms.mbmpo.mbmpo.MBMPOConfig(...)",
|
|
|
|
error=False,
|
|
|
|
)
|
|
|
|
def __getitem__(self, item):
|
|
|
|
return super().__getitem__(item)
|
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_CONFIG = _deprecated_default_config()
|