import logging import numpy as np from typing import Optional, Type from ray.rllib.algorithms.algorithm import Algorithm from ray.rllib.algorithms.algorithm_config import AlgorithmConfig from ray.rllib.evaluation.metrics import get_learner_stats from ray.rllib.evaluation.worker_set import WorkerSet from ray.rllib.execution.common import ( STEPS_SAMPLED_COUNTER, STEPS_TRAINED_COUNTER, STEPS_TRAINED_THIS_ITER_COUNTER, _get_shared_metrics, ) from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.execution.metric_ops import CollectMetrics from ray.rllib.evaluation.metrics import collect_metrics from ray.rllib.utils.annotations import override from ray.rllib.utils.deprecation import Deprecated, DEPRECATED_VALUE from ray.rllib.utils.metrics.learner_info import LEARNER_INFO from ray.rllib.utils.sgd import standardized from ray.rllib.utils.typing import AlgorithmConfigDict from ray.util.iter import from_actors, LocalIterator logger = logging.getLogger(__name__) class MAMLConfig(AlgorithmConfig): """Defines a configuration class from which a MAML Algorithm can be built. Example: >>> from ray.rllib.algorithms.maml import MAMLConfig >>> config = MAMLConfig().training(use_gae=False).resources(num_gpus=1) >>> print(config.to_dict()) >>> # Build a Algorithm object from the config and run 1 training iteration. >>> trainer = config.build(env="CartPole-v1") >>> trainer.train() Example: >>> from ray.rllib.algorithms.maml import MAMLConfig >>> from ray import tune >>> config = MAMLConfig() >>> # Print out some default values. >>> print(config.lr) >>> # Update the config object. >>> config.training(grad_clip=tune.grid_search([10.0, 40.0])) >>> # Set the config object's env. >>> config.environment(env="CartPole-v1") >>> # Use to_dict() to get the old-style python config dict >>> # when running with tune. >>> tune.run( ... "MAML", ... stop={"episode_reward_mean": 200}, ... config=config.to_dict(), ... ) """ def __init__(self, algo_class=None): """Initializes a PGConfig instance.""" super().__init__(algo_class=algo_class or MAML) # fmt: off # __sphinx_doc_begin__ # MAML-specific config settings. self.use_gae = True self.lambda_ = 1.0 self.kl_coeff = 0.0005 self.vf_loss_coeff = 0.5 self.entropy_coeff = 0.0 self.clip_param = 0.3 self.vf_clip_param = 10.0 self.grad_clip = None self.kl_target = 0.01 self.inner_adaptation_steps = 1 self.maml_optimizer_steps = 5 self.inner_lr = 0.1 self.use_meta_env = True # Override some of AlgorithmConfig's default values with MAML-specific values. self.rollout_fragment_length = 200 self.create_env_on_local_worker = True self.lr = 1e-3 # Share layers for value function. self.model.update({ "vf_share_layers": False, }) self.batch_mode = "complete_episodes" self._disable_execution_plan_api = False # __sphinx_doc_end__ # fmt: on # Deprecated keys: self.vf_share_layers = DEPRECATED_VALUE def training( self, *, use_gae: Optional[bool] = None, lambda_: Optional[float] = None, kl_coeff: Optional[float] = None, vf_loss_coeff: Optional[float] = None, entropy_coeff: Optional[float] = None, clip_param: Optional[float] = None, vf_clip_param: Optional[float] = None, grad_clip: Optional[float] = None, kl_target: Optional[float] = None, inner_adaptation_steps: Optional[int] = None, maml_optimizer_steps: Optional[int] = None, inner_lr: Optional[float] = None, use_meta_env: Optional[bool] = None, **kwargs, ) -> "MAMLConfig": """Sets the training related configuration. Args: use_gae: If true, use the Generalized Advantage Estimator (GAE) with a value function, see https://arxiv.org/pdf/1506.02438.pdf. lambda_: The GAE (lambda) parameter. kl_coeff: Initial coefficient for KL divergence. vf_loss_coeff: Coefficient of the value function loss. entropy_coeff: Coefficient of the entropy regularizer. clip_param: PPO clip parameter. vf_clip_param: Clip param for the value function. Note that this is sensitive to the scale of the rewards. If your expected V is large, increase this. grad_clip: If specified, clip the global norm of gradients by this amount. kl_target: Target value for KL divergence. inner_adaptation_steps: Number of Inner adaptation steps for the MAML algorithm. maml_optimizer_steps: Number of MAML steps per meta-update iteration (PPO steps). inner_lr: Inner Adaptation Step size. use_meta_env: Use Meta Env Template. Returns: This updated AlgorithmConfig object. """ # Pass kwargs onto super's `training()` method. super().training(**kwargs) if use_gae is not None: self.use_gae = use_gae if lambda_ is not None: self.lambda_ = lambda_ if kl_coeff is not None: self.kl_coeff = kl_coeff if vf_loss_coeff is not None: self.vf_loss_coeff = vf_loss_coeff if entropy_coeff is not None: self.entropy_coeff = entropy_coeff if clip_param is not None: self.clip_param = clip_param if vf_clip_param is not None: self.vf_clip_param = vf_clip_param if grad_clip is not None: self.grad_clip = grad_clip if kl_target is not None: self.kl_target = kl_target if inner_adaptation_steps is not None: self.inner_adaptation_steps = inner_adaptation_steps if maml_optimizer_steps is not None: self.maml_optimizer_steps = maml_optimizer_steps if inner_lr is not None: self.inner_lr = inner_lr if use_meta_env is not None: self.use_meta_env = use_meta_env return self # @mluo: TODO def set_worker_tasks(workers, use_meta_env): if use_meta_env: n_tasks = len(workers.remote_workers()) tasks = workers.local_worker().foreach_env(lambda x: x)[0].sample_tasks(n_tasks) for i, worker in enumerate(workers.remote_workers()): worker.foreach_env.remote(lambda env: env.set_task(tasks[i])) class MetaUpdate: def __init__(self, workers, maml_steps, metric_gen, use_meta_env): self.workers = workers self.maml_optimizer_steps = maml_steps self.metric_gen = metric_gen self.use_meta_env = use_meta_env def __call__(self, data_tuple): # Metaupdate Step samples = data_tuple[0] adapt_metrics_dict = data_tuple[1] # Metric Updating metrics = _get_shared_metrics() metrics.counters[STEPS_SAMPLED_COUNTER] += samples.count fetches = None for i in range(self.maml_optimizer_steps): fetches = self.workers.local_worker().learn_on_batch(samples) learner_stats = get_learner_stats(fetches) # Sync workers with meta policy self.workers.sync_weights() # Set worker tasks set_worker_tasks(self.workers, self.use_meta_env) # Update KLS def update(pi, pi_id): assert "inner_kl" not in learner_stats, ( "inner_kl should be nested under policy id key", learner_stats, ) if pi_id in learner_stats: assert "inner_kl" in learner_stats[pi_id], (learner_stats, pi_id) pi.update_kls(learner_stats[pi_id]["inner_kl"]) else: logger.warning("No data for {}, not updating kl".format(pi_id)) self.workers.local_worker().foreach_policy_to_train(update) # Modify Reporting Metrics metrics = _get_shared_metrics() metrics.info[LEARNER_INFO] = fetches metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = samples.count metrics.counters[STEPS_TRAINED_COUNTER] += samples.count res = self.metric_gen.__call__(None) res.update(adapt_metrics_dict) return res def post_process_metrics(adapt_iter, workers, metrics): # Obtain Current Dataset Metrics and filter out name = "_adapt_" + str(adapt_iter) if adapt_iter > 0 else "" # Only workers are collecting data res = collect_metrics(remote_workers=workers.remote_workers()) metrics["episode_reward_max" + str(name)] = res["episode_reward_max"] metrics["episode_reward_mean" + str(name)] = res["episode_reward_mean"] metrics["episode_reward_min" + str(name)] = res["episode_reward_min"] return metrics def inner_adaptation(workers, samples): # Each worker performs one gradient descent for i, e in enumerate(workers.remote_workers()): e.learn_on_batch.remote(samples[i]) class MAML(Algorithm): @classmethod @override(Algorithm) def get_default_config(cls) -> AlgorithmConfigDict: return MAMLConfig().to_dict() @override(Algorithm) def validate_config(self, config: AlgorithmConfigDict) -> None: # Call super's validation method. super().validate_config(config) if config["num_gpus"] > 1: raise ValueError("`num_gpus` > 1 not yet supported for MAML!") if config["inner_adaptation_steps"] <= 0: raise ValueError("Inner Adaptation Steps must be >=1!") if config["maml_optimizer_steps"] <= 0: raise ValueError("PPO steps for meta-update needs to be >=0!") if config["entropy_coeff"] < 0: raise ValueError("`entropy_coeff` must be >=0.0!") if config["batch_mode"] != "complete_episodes": raise ValueError("`batch_mode`=truncate_episodes not supported!") if config["num_workers"] <= 0: raise ValueError("Must have at least 1 worker/task!") if config["create_env_on_driver"] is False: raise ValueError( "Must have an actual Env created on the driver " "(local) worker! Set `create_env_on_driver` to True." ) @override(Algorithm) def get_default_policy_class(self, config: AlgorithmConfigDict) -> Type[Policy]: if config["framework"] == "torch": from ray.rllib.algorithms.maml.maml_torch_policy import MAMLTorchPolicy return MAMLTorchPolicy elif config["framework"] == "tf": from ray.rllib.algorithms.maml.maml_tf_policy import MAMLTF1Policy return MAMLTF1Policy else: from ray.rllib.algorithms.maml.maml_tf_policy import MAMLTF2Policy return MAMLTF2Policy @staticmethod @override(Algorithm) def execution_plan( workers: WorkerSet, config: AlgorithmConfigDict, **kwargs ) -> LocalIterator[dict]: assert ( len(kwargs) == 0 ), "MAML execution_plan does NOT take any additional parameters" # Sync workers with meta policy workers.sync_weights() # Samples and sets worker tasks use_meta_env = config["use_meta_env"] set_worker_tasks(workers, use_meta_env) # Metric Collector metric_collect = CollectMetrics( workers, min_history=config["metrics_num_episodes_for_smoothing"], timeout_seconds=config["metrics_episode_collection_timeout_s"], ) # Iterator for Inner Adaptation Data gathering (from pre->post # adaptation) inner_steps = config["inner_adaptation_steps"] def inner_adaptation_steps(itr): buf = [] split = [] metrics = {} for samples in itr: # Processing Samples (Standardize Advantages) split_lst = [] for sample in samples: sample["advantages"] = standardized(sample["advantages"]) split_lst.append(sample.count) buf.extend(samples) split.append(split_lst) adapt_iter = len(split) - 1 metrics = post_process_metrics(adapt_iter, workers, metrics) if len(split) > inner_steps: out = SampleBatch.concat_samples(buf) out["split"] = np.array(split) buf = [] split = [] # Reporting Adaptation Rew Diff ep_rew_pre = metrics["episode_reward_mean"] ep_rew_post = metrics[ "episode_reward_mean_adapt_" + str(inner_steps) ] metrics["adaptation_delta"] = ep_rew_post - ep_rew_pre yield out, metrics metrics = {} else: inner_adaptation(workers, samples) rollouts = from_actors(workers.remote_workers()) rollouts = rollouts.batch_across_shards() rollouts = rollouts.transform(inner_adaptation_steps) # Metaupdate Step train_op = rollouts.for_each( MetaUpdate( workers, config["maml_optimizer_steps"], metric_collect, use_meta_env ) ) return train_op # Deprecated: Use ray.rllib.algorithms.qmix.qmix.QMixConfig instead! class _deprecated_default_config(dict): def __init__(self): super().__init__(MAMLConfig().to_dict()) @Deprecated( old="ray.rllib.algorithms.maml.maml.DEFAULT_CONFIG", new="ray.rllib.algorithms.maml.maml.MAMLConfig(...)", error=False, ) def __getitem__(self, item): return super().__getitem__(item) DEFAULT_CONFIG = _deprecated_default_config()