ray/rllib/algorithms/mbmpo/mbmpo.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

476 lines
17 KiB
Python
Raw Normal View History

import logging
import numpy as np
from typing import List, Type
import ray
from ray.rllib.agents import with_common_config
from ray.rllib.algorithms.mbmpo.model_ensemble import DynamicsEnsembleCustomModel
from ray.rllib.algorithms.mbmpo.utils import calculate_gae_advantages, MBMPOExploration
from ray.rllib.agents.trainer import Trainer
from ray.rllib.env.env_context import EnvContext
from ray.rllib.env.wrappers.model_vector_env import model_vector_env
from ray.rllib.evaluation.metrics import (
collect_episodes,
collect_metrics,
get_learner_stats,
)
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.common import (
STEPS_SAMPLED_COUNTER,
STEPS_TRAINED_COUNTER,
STEPS_TRAINED_THIS_ITER_COUNTER,
_get_shared_metrics,
)
from ray.rllib.execution.metric_ops import CollectMetrics
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
from ray.rllib.utils.sgd import standardized
from ray.rllib.utils.torch_utils import convert_to_torch_tensor
from ray.rllib.utils.typing import EnvType, TrainerConfigDict
from ray.util.iter import from_actors, LocalIterator
logger = logging.getLogger(__name__)
# fmt: off
# __sphinx_doc_begin__
# Adds the following updates to the (base) `Trainer` config in
# rllib/agents/trainer.py (`COMMON_CONFIG` dict).
DEFAULT_CONFIG = with_common_config({
# If true, use the Generalized Advantage Estimator (GAE)
# with a value function, see https://arxiv.org/pdf/1506.02438.pdf.
"use_gae": True,
# GAE(lambda) parameter.
"lambda": 1.0,
# Initial coefficient for KL divergence.
"kl_coeff": 0.0005,
# Size of batches collected from each worker.
"rollout_fragment_length": 200,
# Do create an actual env on the local worker (worker-idx=0).
"create_env_on_driver": True,
# Step size of SGD.
"lr": 1e-3,
# Coefficient of the value function loss.
"vf_loss_coeff": 0.5,
# Coefficient of the entropy regularizer.
"entropy_coeff": 0.0,
# PPO clip parameter.
"clip_param": 0.5,
# Clip param for the value function. Note that this is sensitive to the
# scale of the rewards. If your expected V is large, increase this.
"vf_clip_param": 10.0,
# If specified, clip the global norm of gradients by this amount.
"grad_clip": None,
# Target value for KL divergence.
"kl_target": 0.01,
# Whether to rollout "complete_episodes" or "truncate_episodes".
"batch_mode": "complete_episodes",
# Which observation filter to apply to the observation.
"observation_filter": "NoFilter",
# Number of Inner adaptation steps for the MAML algorithm.
"inner_adaptation_steps": 1,
# Number of MAML steps per meta-update iteration (PPO steps).
"maml_optimizer_steps": 8,
# Inner adaptation step size.
"inner_lr": 1e-3,
# Horizon of the environment (200 in MB-MPO paper).
"horizon": 200,
# Dynamics ensemble hyperparameters.
"dynamics_model": {
"custom_model": DynamicsEnsembleCustomModel,
# Number of Transition-Dynamics (TD) models in the ensemble.
"ensemble_size": 5,
# Hidden layers for each model in the TD-model ensemble.
2020-09-09 00:34:34 -07:00
"fcnet_hiddens": [512, 512, 512],
# Model learning rate.
"lr": 1e-3,
# Max number of training epochs per MBMPO iter.
"train_epochs": 500,
# Model batch size.
"batch_size": 500,
# Training/validation split.
"valid_split_ratio": 0.2,
# Normalize data (obs, action, and deltas).
"normalize_data": True,
},
# Exploration for MB-MPO is based on StochasticSampling, but uses 8000
# random timesteps up-front for worker=0.
2020-09-09 00:34:34 -07:00
"exploration_config": {
"type": MBMPOExploration,
"random_timesteps": 8000,
2020-09-09 00:34:34 -07:00
},
# Workers sample from dynamics models, not from actual envs.
"custom_vector_env": model_vector_env,
# How many iterations through MAML per MBMPO iteration.
"num_maml_steps": 10,
# Deprecated keys:
# Share layers for value function. If you set this to True, it's important
# to tune vf_loss_coeff.
# Use config.model.vf_share_layers instead.
"vf_share_layers": DEPRECATED_VALUE,
# Use `execution_plan` instead of `training_iteration`.
"_disable_execution_plan_api": False,
})
# __sphinx_doc_end__
# fmt: on
# Select Metric Keys for MAML Stats Tracing
METRICS_KEYS = ["episode_reward_mean", "episode_reward_min", "episode_reward_max"]
class MetaUpdate:
def __init__(self, workers, num_steps, maml_steps, metric_gen):
"""Computes the MetaUpdate step in MAML.
Adapted for MBMPO for multiple MAML Iterations.
2020-09-20 11:27:02 +02:00
Args:
workers (WorkerSet): Set of Workers
num_steps (int): Number of meta-update steps per MAML Iteration
maml_steps (int): MAML Iterations per MBMPO Iteration
metric_gen (Iterator): Generates metrics dictionary
Returns:
metrics (dict): MBMPO metrics for logging.
"""
self.workers = workers
self.num_steps = num_steps
self.step_counter = 0
self.maml_optimizer_steps = maml_steps
self.metric_gen = metric_gen
self.metrics = {}
def __call__(self, data_tuple):
2020-09-20 11:27:02 +02:00
"""Args:
data_tuple (tuple): 1st element is samples collected from MAML
Inner adaptation steps and 2nd element is accumulated metrics
"""
# Metaupdate Step.
print("Meta-Update Step")
samples = data_tuple[0]
adapt_metrics_dict = data_tuple[1]
self.postprocess_metrics(
adapt_metrics_dict, prefix="MAMLIter{}".format(self.step_counter)
)
# MAML Meta-update.
fetches = None
for i in range(self.maml_optimizer_steps):
fetches = self.workers.local_worker().learn_on_batch(samples)
learner_stats = get_learner_stats(fetches)
# Update KLs.
def update(pi, pi_id):
assert "inner_kl" not in learner_stats, (
"inner_kl should be nested under policy id key",
learner_stats,
)
if pi_id in learner_stats:
assert "inner_kl" in learner_stats[pi_id], (learner_stats, pi_id)
pi.update_kls(learner_stats[pi_id]["inner_kl"])
else:
logger.warning("No data for {}, not updating kl".format(pi_id))
self.workers.local_worker().foreach_policy_to_train(update)
# Modify Reporting Metrics.
metrics = _get_shared_metrics()
metrics.info[LEARNER_INFO] = fetches
metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = samples.count
metrics.counters[STEPS_TRAINED_COUNTER] += samples.count
2020-09-09 00:34:34 -07:00
if self.step_counter == self.num_steps - 1:
td_metric = self.workers.local_worker().foreach_policy(fit_dynamics)[0]
# Sync workers with meta policy.
self.workers.sync_weights()
# Sync TD Models with workers.
sync_ensemble(self.workers)
sync_stats(self.workers)
metrics.counters[STEPS_SAMPLED_COUNTER] = td_metric[STEPS_SAMPLED_COUNTER]
# Modify to CollectMetrics.
res = self.metric_gen.__call__(None)
res.update(self.metrics)
self.step_counter = 0
print("MB-MPO Iteration Completed")
return [res]
else:
print("MAML Iteration {} Completed".format(self.step_counter))
self.step_counter += 1
# Sync workers with meta policy
print("Syncing Weights with Workers")
self.workers.sync_weights()
return []
def postprocess_metrics(self, metrics, prefix=""):
"""Appends prefix to current metrics
2020-09-20 11:27:02 +02:00
Args:
metrics (dict): Dictionary of current metrics
prefix (str): Prefix string to be appended
"""
for key in metrics.keys():
self.metrics[prefix + "_" + key] = metrics[key]
def post_process_metrics(prefix, workers, metrics):
"""Update current dataset metrics and filter out specific keys.
2020-09-20 11:27:02 +02:00
Args:
prefix (str): Prefix string to be appended
workers (WorkerSet): Set of workers
metrics (dict): Current metrics dictionary
"""
res = collect_metrics(remote_workers=workers.remote_workers())
for key in METRICS_KEYS:
metrics[prefix + "_" + key] = res[key]
return metrics
def inner_adaptation(workers: WorkerSet, samples: List[SampleBatch]):
"""Performs one gradient descend step on each remote worker.
Args:
workers (WorkerSet): The WorkerSet of the Trainer.
samples (List[SampleBatch]): The list of SampleBatches to perform
a training step on (one for each remote worker).
"""
for i, e in enumerate(workers.remote_workers()):
e.learn_on_batch.remote(samples[i])
def fit_dynamics(policy, pid):
return policy.dynamics_model.fit()
def sync_ensemble(workers: WorkerSet) -> None:
"""Syncs dynamics ensemble weights from driver (main) to workers.
2020-09-20 11:27:02 +02:00
Args:
workers (WorkerSet): Set of workers, including driver (main).
"""
def get_ensemble_weights(worker):
policy_map = worker.policy_map
policies = policy_map.keys()
def policy_ensemble_weights(policy):
model = policy.dynamics_model
return {k: v.cpu().detach().numpy() for k, v in model.state_dict().items()}
return {
pid: policy_ensemble_weights(policy)
for pid, policy in policy_map.items()
if pid in policies
}
def set_ensemble_weights(policy, pid, weights):
weights = weights[pid]
weights = convert_to_torch_tensor(weights, device=policy.device)
model = policy.dynamics_model
model.load_state_dict(weights)
if workers.remote_workers():
weights = ray.put(get_ensemble_weights(workers.local_worker()))
set_func = ray.put(set_ensemble_weights)
for e in workers.remote_workers():
e.foreach_policy.remote(set_func, weights=weights)
def sync_stats(workers: WorkerSet) -> None:
def get_normalizations(worker):
policy = worker.policy_map[DEFAULT_POLICY_ID]
return policy.dynamics_model.normalizations
def set_normalizations(policy, pid, normalizations):
policy.dynamics_model.set_norms(normalizations)
if workers.remote_workers():
normalization_dict = ray.put(get_normalizations(workers.local_worker()))
set_func = ray.put(set_normalizations)
for e in workers.remote_workers():
e.foreach_policy.remote(set_func, normalizations=normalization_dict)
def post_process_samples(samples, config: TrainerConfigDict):
# Instead of using NN for value function, we use regression
split_lst = []
for sample in samples:
indexes = np.asarray(sample["dones"]).nonzero()[0]
indexes = indexes + 1
reward_list = np.split(sample["rewards"], indexes)[:-1]
observation_list = np.split(sample["obs"], indexes)[:-1]
paths = []
for i in range(0, len(reward_list)):
paths.append(
{"rewards": reward_list[i], "observations": observation_list[i]}
)
paths = calculate_gae_advantages(paths, config["gamma"], config["lambda"])
advantages = np.concatenate([path["advantages"] for path in paths])
sample["advantages"] = standardized(advantages)
split_lst.append(sample.count)
return samples, split_lst
class MBMPOTrainer(Trainer):
"""Model-Based Meta Policy Optimization (MB-MPO) Trainer.
This file defines the distributed Trainer class for model-based meta
policy optimization.
See `mbmpo_[tf|torch]_policy.py` for the definition of the policy loss.
Detailed documentation:
https://docs.ray.io/en/master/rllib-algorithms.html#mbmpo
"""
@classmethod
@override(Trainer)
def get_default_config(cls) -> TrainerConfigDict:
return DEFAULT_CONFIG
@override(Trainer)
def validate_config(self, config: TrainerConfigDict) -> None:
# Call super's validation method.
super().validate_config(config)
if config["num_gpus"] > 1:
raise ValueError("`num_gpus` > 1 not yet supported for MB-MPO!")
if config["framework"] != "torch":
logger.warning(
"MB-MPO only supported in PyTorch so far! Switching to "
"`framework=torch`."
)
config["framework"] = "torch"
if config["inner_adaptation_steps"] <= 0:
raise ValueError("Inner adaptation steps must be >=1!")
if config["maml_optimizer_steps"] <= 0:
raise ValueError("PPO steps for meta-update needs to be >=0!")
if config["entropy_coeff"] < 0:
raise ValueError("`entropy_coeff` must be >=0.0!")
if config["batch_mode"] != "complete_episodes":
raise ValueError("`batch_mode=truncate_episodes` not supported!")
if config["num_workers"] <= 0:
raise ValueError("Must have at least 1 worker/task.")
if config["create_env_on_driver"] is False:
raise ValueError(
"Must have an actual Env created on the driver "
"(local) worker! Set `create_env_on_driver` to True."
)
@override(Trainer)
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
from ray.rllib.algorithms.mbmpo.mbmpo_torch_policy import MBMPOTorchPolicy
return MBMPOTorchPolicy
@staticmethod
@override(Trainer)
def execution_plan(
workers: WorkerSet, config: TrainerConfigDict, **kwargs
) -> LocalIterator[dict]:
assert (
len(kwargs) == 0
), "MBMPO execution_plan does NOT take any additional parameters"
# Train TD Models on the driver.
workers.local_worker().foreach_policy(fit_dynamics)
# Sync driver's policy with workers.
workers.sync_weights()
# Sync TD Models and normalization stats with workers
sync_ensemble(workers)
sync_stats(workers)
# Dropping metrics from the first iteration
_, _ = collect_episodes(
workers.local_worker(), workers.remote_workers(), [], timeout_seconds=9999
)
# Metrics Collector.
metric_collect = CollectMetrics(
workers,
min_history=0,
timeout_seconds=config["metrics_episode_collection_timeout_s"],
)
num_inner_steps = config["inner_adaptation_steps"]
def inner_adaptation_steps(itr):
buf = []
split = []
metrics = {}
for samples in itr:
print("Collecting Samples, Inner Adaptation {}".format(len(split)))
# Processing Samples (Standardize Advantages)
samples, split_lst = post_process_samples(samples, config)
buf.extend(samples)
split.append(split_lst)
adapt_iter = len(split) - 1
prefix = "DynaTrajInner_" + str(adapt_iter)
metrics = post_process_metrics(prefix, workers, metrics)
if len(split) > num_inner_steps:
out = SampleBatch.concat_samples(buf)
out["split"] = np.array(split)
buf = []
split = []
yield out, metrics
metrics = {}
else:
inner_adaptation(workers, samples)
# Iterator for Inner Adaptation Data gathering (from pre->post
# adaptation).
rollouts = from_actors(workers.remote_workers())
rollouts = rollouts.batch_across_shards()
rollouts = rollouts.transform(inner_adaptation_steps)
# Meta update step with outer combine loop for multiple MAML
# iterations.
train_op = rollouts.combine(
MetaUpdate(
workers,
config["num_maml_steps"],
config["maml_optimizer_steps"],
metric_collect,
)
)
return train_op
@staticmethod
@override(Trainer)
def validate_env(env: EnvType, env_context: EnvContext) -> None:
"""Validates the local_worker's env object (after creation).
Args:
env: The env object to check (for worker=0 only).
env_context: The env context used for the instantiation of
the local worker's env (worker=0).
Raises:
ValueError: In case something is wrong with the config.
"""
if not hasattr(env, "reward") or not callable(env.reward):
raise ValueError(
f"Env {env} doest not have a `reward()` method, needed for "
"MB-MPO! This `reward()` method should return "
)