mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] MAML config objects. (#25066)
This commit is contained in:
parent
baf8c2fa1e
commit
dea9b86a16
3 changed files with 168 additions and 64 deletions
|
@ -1,6 +1,7 @@
|
|||
from ray.rllib.algorithms.maml.maml import MAMLTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.algorithms.maml.maml import MAMLConfig, MAMLTrainer, DEFAULT_CONFIG
|
||||
|
||||
__all__ = [
|
||||
"MAMLConfig",
|
||||
"MAMLTrainer",
|
||||
"DEFAULT_CONFIG",
|
||||
]
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
import logging
|
||||
import numpy as np
|
||||
from typing import Type
|
||||
from typing import Optional, Type
|
||||
|
||||
from ray.rllib.utils.sgd import standardized
|
||||
from ray.rllib.agents import with_common_config
|
||||
from ray.rllib.agents.trainer import Trainer
|
||||
from ray.rllib.agents.trainer_config import TrainerConfig
|
||||
from ray.rllib.evaluation.metrics import get_learner_stats
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.execution.common import (
|
||||
|
@ -18,70 +17,158 @@ from ray.rllib.policy.sample_batch import SampleBatch
|
|||
from ray.rllib.execution.metric_ops import CollectMetrics
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
||||
from ray.rllib.utils.deprecation import Deprecated, DEPRECATED_VALUE
|
||||
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
|
||||
from ray.rllib.utils.sgd import standardized
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
from ray.util.iter import from_actors, LocalIterator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
# If true, use the Generalized Advantage Estimator (GAE)
|
||||
# with a value function, see https://arxiv.org/pdf/1506.02438.pdf.
|
||||
"use_gae": True,
|
||||
# GAE(lambda) parameter
|
||||
"lambda": 1.0,
|
||||
# Initial coefficient for KL divergence
|
||||
"kl_coeff": 0.0005,
|
||||
# Size of batches collected from each worker
|
||||
"rollout_fragment_length": 200,
|
||||
# Do create an actual env on the local worker (worker-idx=0).
|
||||
"create_env_on_driver": True,
|
||||
# Stepsize of SGD
|
||||
"lr": 1e-3,
|
||||
"model": {
|
||||
|
||||
class MAMLConfig(TrainerConfig):
|
||||
"""Defines a configuration class from which a MAMLTrainer can be built.
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.algorithms.maml import MAMLConfig
|
||||
>>> config = MAMLConfig().training(use_gae=False).resources(num_gpus=1)
|
||||
>>> print(config.to_dict())
|
||||
>>> # Build a Trainer object from the config and run 1 training iteration.
|
||||
>>> trainer = config.build(env="CartPole-v1")
|
||||
>>> trainer.train()
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.algorithms.maml import MAMLConfig
|
||||
>>> from ray import tune
|
||||
>>> config = MAMLConfig()
|
||||
>>> # Print out some default values.
|
||||
>>> print(config.lr)
|
||||
>>> # Update the config object.
|
||||
>>> config.training(grad_clip=tune.grid_search([10.0, 40.0]))
|
||||
>>> # Set the config object's env.
|
||||
>>> config.environment(env="CartPole-v1")
|
||||
>>> # Use to_dict() to get the old-style python config dict
|
||||
>>> # when running with tune.
|
||||
>>> tune.run(
|
||||
... "MAML",
|
||||
... stop={"episode_reward_mean": 200},
|
||||
... config=config.to_dict(),
|
||||
... )
|
||||
"""
|
||||
|
||||
def __init__(self, trainer_class=None):
|
||||
"""Initializes a PGConfig instance."""
|
||||
super().__init__(trainer_class=trainer_class or MAMLTrainer)
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
# MAML-specific config settings.
|
||||
self.use_gae = True
|
||||
self.lambda_ = 1.0
|
||||
self.kl_coeff = 0.0005
|
||||
self.vf_loss_coeff = 0.5
|
||||
self.entropy_coeff = 0.0
|
||||
self.clip_param = 0.3
|
||||
self.vf_clip_param = 10.0
|
||||
self.grad_clip = None
|
||||
self.kl_target = 0.01
|
||||
self.inner_adaptation_steps = 1
|
||||
self.maml_optimizer_steps = 5
|
||||
self.inner_lr = 0.1
|
||||
self.use_meta_env = True
|
||||
|
||||
# Override some of TrainerConfig's default values with MAML-specific values.
|
||||
self.rollout_fragment_length = 200
|
||||
self.create_env_on_local_worker = True
|
||||
self.lr = 1e-3
|
||||
|
||||
# Share layers for value function.
|
||||
"vf_share_layers": False,
|
||||
},
|
||||
# Coefficient of the value function loss
|
||||
"vf_loss_coeff": 0.5,
|
||||
# Coefficient of the entropy regularizer
|
||||
"entropy_coeff": 0.0,
|
||||
# PPO clip parameter
|
||||
"clip_param": 0.3,
|
||||
# Clip param for the value function. Note that this is sensitive to the
|
||||
# scale of the rewards. If your expected V is large, increase this.
|
||||
"vf_clip_param": 10.0,
|
||||
# If specified, clip the global norm of gradients by this amount
|
||||
"grad_clip": None,
|
||||
# Target value for KL divergence
|
||||
"kl_target": 0.01,
|
||||
# Whether to rollout "complete_episodes" or "truncate_episodes"
|
||||
"batch_mode": "complete_episodes",
|
||||
# Which observation filter to apply to the observation
|
||||
"observation_filter": "NoFilter",
|
||||
# Number of Inner adaptation steps for the MAML algorithm
|
||||
"inner_adaptation_steps": 1,
|
||||
# Number of MAML steps per meta-update iteration (PPO steps)
|
||||
"maml_optimizer_steps": 5,
|
||||
# Inner Adaptation Step size
|
||||
"inner_lr": 0.1,
|
||||
# Use Meta Env Template
|
||||
"use_meta_env": True,
|
||||
self.model.update({
|
||||
"vf_share_layers": False,
|
||||
})
|
||||
|
||||
# Deprecated keys:
|
||||
# Share layers for value function. If you set this to True, it's important
|
||||
# to tune vf_loss_coeff.
|
||||
# Use config.model.vf_share_layers instead.
|
||||
"vf_share_layers": DEPRECATED_VALUE,
|
||||
self.batch_mode = "complete_episodes"
|
||||
self._disable_execution_plan_api = False
|
||||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
|
||||
# Use `execution_plan` instead of `training_iteration`.
|
||||
"_disable_execution_plan_api": False,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
# Deprecated keys:
|
||||
self.vf_share_layers = DEPRECATED_VALUE
|
||||
|
||||
def training(
|
||||
self,
|
||||
*,
|
||||
use_gae: Optional[bool] = None,
|
||||
lambda_: Optional[float] = None,
|
||||
kl_coeff: Optional[float] = None,
|
||||
vf_loss_coeff: Optional[float] = None,
|
||||
entropy_coeff: Optional[float] = None,
|
||||
clip_param: Optional[float] = None,
|
||||
vf_clip_param: Optional[float] = None,
|
||||
grad_clip: Optional[float] = None,
|
||||
kl_target: Optional[float] = None,
|
||||
inner_adaptation_steps: Optional[int] = None,
|
||||
maml_optimizer_steps: Optional[int] = None,
|
||||
inner_lr: Optional[float] = None,
|
||||
use_meta_env: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> "MAMLConfig":
|
||||
"""Sets the training related configuration.
|
||||
|
||||
Args:
|
||||
use_gae: If true, use the Generalized Advantage Estimator (GAE)
|
||||
with a value function, see https://arxiv.org/pdf/1506.02438.pdf.
|
||||
lambda_: The GAE (lambda) parameter.
|
||||
kl_coeff: Initial coefficient for KL divergence.
|
||||
vf_loss_coeff: Coefficient of the value function loss.
|
||||
entropy_coeff: Coefficient of the entropy regularizer.
|
||||
clip_param: PPO clip parameter.
|
||||
vf_clip_param: Clip param for the value function. Note that this is
|
||||
sensitive to the scale of the rewards. If your expected V is large,
|
||||
increase this.
|
||||
grad_clip: If specified, clip the global norm of gradients by this amount.
|
||||
kl_target: Target value for KL divergence.
|
||||
inner_adaptation_steps: Number of Inner adaptation steps for the MAML
|
||||
algorithm.
|
||||
maml_optimizer_steps: Number of MAML steps per meta-update iteration
|
||||
(PPO steps).
|
||||
inner_lr: Inner Adaptation Step size.
|
||||
use_meta_env: Use Meta Env Template.
|
||||
|
||||
Returns:
|
||||
This updated TrainerConfig object.
|
||||
"""
|
||||
# Pass kwargs onto super's `training()` method.
|
||||
super().training(**kwargs)
|
||||
|
||||
if use_gae is not None:
|
||||
self.use_gae = use_gae
|
||||
if lambda_ is not None:
|
||||
self.lambda_ = lambda_
|
||||
if kl_coeff is not None:
|
||||
self.kl_coeff = kl_coeff
|
||||
if vf_loss_coeff is not None:
|
||||
self.vf_loss_coeff = vf_loss_coeff
|
||||
if entropy_coeff is not None:
|
||||
self.entropy_coeff = entropy_coeff
|
||||
if clip_param is not None:
|
||||
self.clip_param = clip_param
|
||||
if vf_clip_param is not None:
|
||||
self.vf_clip_param = vf_clip_param
|
||||
if grad_clip is not None:
|
||||
self.grad_clip = grad_clip
|
||||
if kl_target is not None:
|
||||
self.kl_target = kl_target
|
||||
if inner_adaptation_steps is not None:
|
||||
self.inner_adaptation_steps = inner_adaptation_steps
|
||||
if maml_optimizer_steps is not None:
|
||||
self.maml_optimizer_steps = maml_optimizer_steps
|
||||
if inner_lr is not None:
|
||||
self.inner_lr = inner_lr
|
||||
if use_meta_env is not None:
|
||||
self.use_meta_env = use_meta_env
|
||||
|
||||
return self
|
||||
|
||||
|
||||
# @mluo: TODO
|
||||
|
@ -169,7 +256,7 @@ class MAMLTrainer(Trainer):
|
|||
@classmethod
|
||||
@override(Trainer)
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
return DEFAULT_CONFIG
|
||||
return MAMLConfig().to_dict()
|
||||
|
||||
@override(Trainer)
|
||||
def validate_config(self, config: TrainerConfigDict) -> None:
|
||||
|
@ -281,3 +368,20 @@ class MAMLTrainer(Trainer):
|
|||
)
|
||||
)
|
||||
return train_op
|
||||
|
||||
|
||||
# Deprecated: Use ray.rllib.algorithms.qmix.qmix.QMixConfig instead!
|
||||
class _deprecated_default_config(dict):
|
||||
def __init__(self):
|
||||
super().__init__(MAMLConfig().to_dict())
|
||||
|
||||
@Deprecated(
|
||||
old="ray.rllib.algorithms.maml.maml.DEFAULT_CONFIG",
|
||||
new="ray.rllib.algorithms.maml.maml.MAMLConfig(...)",
|
||||
error=False,
|
||||
)
|
||||
def __getitem__(self, item):
|
||||
return super().__getitem__(item)
|
||||
|
||||
|
||||
DEFAULT_CONFIG = _deprecated_default_config()
|
||||
|
|
|
@ -20,9 +20,8 @@ class TestMAML(unittest.TestCase):
|
|||
|
||||
def test_maml_compilation(self):
|
||||
"""Test whether a MAMLTrainer can be built with all frameworks."""
|
||||
config = maml.DEFAULT_CONFIG.copy()
|
||||
config["num_workers"] = 1
|
||||
config["horizon"] = 200
|
||||
config = maml.MAMLConfig().rollouts(num_rollout_workers=1, horizon=200)
|
||||
|
||||
num_iterations = 1
|
||||
|
||||
# Test for tf framework (torch not implemented yet).
|
||||
|
@ -35,7 +34,7 @@ class TestMAML(unittest.TestCase):
|
|||
continue
|
||||
print("env={}".format(env))
|
||||
env_ = "ray.rllib.examples.env.{}".format(env)
|
||||
trainer = maml.MAMLTrainer(config=config, env=env_)
|
||||
trainer = config.build(env=env_)
|
||||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
check_train_results(results)
|
||||
|
|
Loading…
Add table
Reference in a new issue