ray/rllib/agents/maml/maml.py
2021-03-08 15:41:27 +01:00

245 lines
8.6 KiB
Python

import logging
import numpy as np
from ray.rllib.utils.sgd import standardized
from ray.rllib.agents import with_common_config
from ray.rllib.agents.maml.maml_tf_policy import MAMLTFPolicy
from ray.rllib.agents.maml.maml_torch_policy import MAMLTorchPolicy
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.evaluation.metrics import get_learner_stats
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \
STEPS_TRAINED_COUNTER, LEARNER_INFO, _get_shared_metrics
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.execution.metric_ops import CollectMetrics
from ray.rllib.evaluation.metrics import collect_metrics
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.util.iter import from_actors
logger = logging.getLogger(__name__)
# yapf: disable
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
# If true, use the Generalized Advantage Estimator (GAE)
# with a value function, see https://arxiv.org/pdf/1506.02438.pdf.
"use_gae": True,
# GAE(lambda) parameter
"lambda": 1.0,
# Initial coefficient for KL divergence
"kl_coeff": 0.0005,
# Size of batches collected from each worker
"rollout_fragment_length": 200,
# Do create an actual env on the local worker (worker-idx=0).
"create_env_on_driver": True,
# Stepsize of SGD
"lr": 1e-3,
"model": {
# Share layers for value function.
"vf_share_layers": False,
},
# Coefficient of the value function loss
"vf_loss_coeff": 0.5,
# Coefficient of the entropy regularizer
"entropy_coeff": 0.0,
# PPO clip parameter
"clip_param": 0.3,
# Clip param for the value function. Note that this is sensitive to the
# scale of the rewards. If your expected V is large, increase this.
"vf_clip_param": 10.0,
# If specified, clip the global norm of gradients by this amount
"grad_clip": None,
# Target value for KL divergence
"kl_target": 0.01,
# Whether to rollout "complete_episodes" or "truncate_episodes"
"batch_mode": "complete_episodes",
# Which observation filter to apply to the observation
"observation_filter": "NoFilter",
# Number of Inner adaptation steps for the MAML algorithm
"inner_adaptation_steps": 1,
# Number of MAML steps per meta-update iteration (PPO steps)
"maml_optimizer_steps": 5,
# Inner Adaptation Step size
"inner_lr": 0.1,
# Use Meta Env Template
"use_meta_env": True,
# Deprecated keys:
# Share layers for value function. If you set this to True, it's important
# to tune vf_loss_coeff.
# Use config.model.vf_share_layers instead.
"vf_share_layers": DEPRECATED_VALUE,
})
# __sphinx_doc_end__
# yapf: enable
# @mluo: TODO
def set_worker_tasks(workers, use_meta_env):
if use_meta_env:
n_tasks = len(workers.remote_workers())
tasks = workers.local_worker().foreach_env(lambda x: x)[
0].sample_tasks(n_tasks)
for i, worker in enumerate(workers.remote_workers()):
worker.foreach_env.remote(lambda env: env.set_task(tasks[i]))
class MetaUpdate:
def __init__(self, workers, maml_steps, metric_gen, use_meta_env):
self.workers = workers
self.maml_optimizer_steps = maml_steps
self.metric_gen = metric_gen
self.use_meta_env = use_meta_env
def __call__(self, data_tuple):
# Metaupdate Step
samples = data_tuple[0]
adapt_metrics_dict = data_tuple[1]
# Metric Updating
metrics = _get_shared_metrics()
metrics.counters[STEPS_SAMPLED_COUNTER] += samples.count
for i in range(self.maml_optimizer_steps):
fetches = self.workers.local_worker().learn_on_batch(samples)
fetches = get_learner_stats(fetches)
# Sync workers with meta policy
self.workers.sync_weights()
# Set worker tasks
set_worker_tasks(self.workers, self.use_meta_env)
# Update KLS
def update(pi, pi_id):
assert "inner_kl" not in fetches, (
"inner_kl should be nested under policy id key", fetches)
if pi_id in fetches:
assert "inner_kl" in fetches[pi_id], (fetches, pi_id)
pi.update_kls(fetches[pi_id]["inner_kl"])
else:
logger.warning("No data for {}, not updating kl".format(pi_id))
self.workers.local_worker().foreach_trainable_policy(update)
# Modify Reporting Metrics
metrics = _get_shared_metrics()
metrics.info[LEARNER_INFO] = fetches
metrics.counters[STEPS_TRAINED_COUNTER] += samples.count
res = self.metric_gen.__call__(None)
res.update(adapt_metrics_dict)
return res
def post_process_metrics(adapt_iter, workers, metrics):
# Obtain Current Dataset Metrics and filter out
name = "_adapt_" + str(adapt_iter) if adapt_iter > 0 else ""
# Only workers are collecting data
res = collect_metrics(remote_workers=workers.remote_workers())
metrics["episode_reward_max" + str(name)] = res["episode_reward_max"]
metrics["episode_reward_mean" + str(name)] = res["episode_reward_mean"]
metrics["episode_reward_min" + str(name)] = res["episode_reward_min"]
return metrics
def inner_adaptation(workers, samples):
# Each worker performs one gradient descent
for i, e in enumerate(workers.remote_workers()):
e.learn_on_batch.remote(samples[i])
def execution_plan(workers, config):
# Sync workers with meta policy
workers.sync_weights()
# Samples and sets worker tasks
use_meta_env = config["use_meta_env"]
set_worker_tasks(workers, use_meta_env)
# Metric Collector
metric_collect = CollectMetrics(
workers,
min_history=config["metrics_smoothing_episodes"],
timeout_seconds=config["collect_metrics_timeout"])
# Iterator for Inner Adaptation Data gathering (from pre->post adaptation)
inner_steps = config["inner_adaptation_steps"]
def inner_adaptation_steps(itr):
buf = []
split = []
metrics = {}
for samples in itr:
# Processing Samples (Standardize Advantages)
split_lst = []
for sample in samples:
sample["advantages"] = standardized(sample["advantages"])
split_lst.append(sample.count)
buf.extend(samples)
split.append(split_lst)
adapt_iter = len(split) - 1
metrics = post_process_metrics(adapt_iter, workers, metrics)
if len(split) > inner_steps:
out = SampleBatch.concat_samples(buf)
out["split"] = np.array(split)
buf = []
split = []
# Reporting Adaptation Rew Diff
ep_rew_pre = metrics["episode_reward_mean"]
ep_rew_post = metrics["episode_reward_mean_adapt_" +
str(inner_steps)]
metrics["adaptation_delta"] = ep_rew_post - ep_rew_pre
yield out, metrics
metrics = {}
else:
inner_adaptation(workers, samples)
rollouts = from_actors(workers.remote_workers())
rollouts = rollouts.batch_across_shards()
rollouts = rollouts.transform(inner_adaptation_steps)
# Metaupdate Step
train_op = rollouts.for_each(
MetaUpdate(workers, config["maml_optimizer_steps"], metric_collect,
use_meta_env))
return train_op
def get_policy_class(config):
if config["framework"] == "torch":
return MAMLTorchPolicy
return MAMLTFPolicy
def validate_config(config):
if config["num_gpus"] > 1:
raise ValueError("`num_gpus` > 1 not yet supported for MAML!")
if config["inner_adaptation_steps"] <= 0:
raise ValueError("Inner Adaptation Steps must be >=1!")
if config["maml_optimizer_steps"] <= 0:
raise ValueError("PPO steps for meta-update needs to be >=0!")
if config["entropy_coeff"] < 0:
raise ValueError("`entropy_coeff` must be >=0.0!")
if config["batch_mode"] != "complete_episodes":
raise ValueError("`batch_mode`=truncate_episodes not supported!")
if config["num_workers"] <= 0:
raise ValueError("Must have at least 1 worker/task!")
if config["create_env_on_driver"] is False:
raise ValueError("Must have an actual Env created on the driver "
"(local) worker! Set `create_env_on_driver` to True.")
MAMLTrainer = build_trainer(
name="MAML",
default_config=DEFAULT_CONFIG,
default_policy=MAMLTFPolicy,
get_policy_class=get_policy_class,
execution_plan=execution_plan,
validate_config=validate_config)