mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
316 lines
11 KiB
Python
316 lines
11 KiB
Python
"""
|
|
Asynchronous Proximal Policy Optimization (APPO)
|
|
================================================
|
|
|
|
This file defines the distributed Algorithm class for the asynchronous version
|
|
of proximal policy optimization (APPO).
|
|
See `appo_[tf|torch]_policy.py` for the definition of the policy loss.
|
|
|
|
Detailed documentation:
|
|
https://docs.ray.io/en/master/rllib-algorithms.html#appo
|
|
"""
|
|
from typing import Optional, Type
|
|
import logging
|
|
|
|
from ray.rllib.algorithms.impala.impala import Impala, ImpalaConfig
|
|
from ray.rllib.algorithms.ppo.ppo import UpdateKL
|
|
from ray.rllib.execution.common import _get_shared_metrics, STEPS_SAMPLED_COUNTER
|
|
from ray.rllib.policy.policy import Policy
|
|
from ray.rllib.utils.annotations import override
|
|
from ray.rllib.utils.deprecation import Deprecated
|
|
from ray.rllib.utils.metrics import (
|
|
LAST_TARGET_UPDATE_TS,
|
|
NUM_AGENT_STEPS_SAMPLED,
|
|
NUM_ENV_STEPS_SAMPLED,
|
|
NUM_TARGET_UPDATES,
|
|
)
|
|
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
|
|
from ray.rllib.utils.typing import (
|
|
PartialAlgorithmConfigDict,
|
|
ResultDict,
|
|
AlgorithmConfigDict,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class APPOConfig(ImpalaConfig):
|
|
"""Defines a configuration class from which an APPO Algorithm can be built.
|
|
|
|
Example:
|
|
>>> from ray.rllib.algorithms.appo import APPOConfig
|
|
>>> config = APPOConfig().training(lr=0.01, grad_clip=30.0)\
|
|
... .resources(num_gpus=1)\
|
|
... .rollouts(num_rollout_workers=16)
|
|
>>> print(config.to_dict())
|
|
>>> # Build a Algorithm object from the config and run 1 training iteration.
|
|
>>> trainer = config.build(env="CartPole-v1")
|
|
>>> trainer.train()
|
|
|
|
Example:
|
|
>>> from ray.rllib.algorithms.appo import APPOConfig
|
|
>>> from ray import tune
|
|
>>> config = APPOConfig()
|
|
>>> # Print out some default values.
|
|
>>> print(config.sample_async)
|
|
>>> # Update the config object.
|
|
>>> config.training(lr=tune.grid_search([0.001, 0.0001]))
|
|
>>> # Set the config object's env.
|
|
>>> config.environment(env="CartPole-v1")
|
|
>>> # Use to_dict() to get the old-style python config dict
|
|
>>> # when running with tune.
|
|
>>> tune.run(
|
|
... "APPO",
|
|
... stop={"episode_reward_mean": 200},
|
|
... config=config.to_dict(),
|
|
... )
|
|
"""
|
|
|
|
def __init__(self, algo_class=None):
|
|
"""Initializes a APPOConfig instance."""
|
|
super().__init__(algo_class=algo_class or APPO)
|
|
|
|
# fmt: off
|
|
# __sphinx_doc_begin__
|
|
|
|
# APPO specific settings:
|
|
self.vtrace = True
|
|
self.use_critic = True
|
|
self.use_gae = True
|
|
self.lambda_ = 1.0
|
|
self.clip_param = 0.4
|
|
self.use_kl_loss = False
|
|
self.kl_coeff = 1.0
|
|
self.kl_target = 0.01
|
|
|
|
# Override some of ImpalaConfig's default values with APPO-specific values.
|
|
self.rollout_fragment_length = 50
|
|
self.train_batch_size = 500
|
|
self.min_time_s_per_iteration = 10
|
|
self.num_workers = 2
|
|
self.num_gpus = 0
|
|
self.num_multi_gpu_tower_stacks = 1
|
|
self.minibatch_buffer_size = 1
|
|
self.num_sgd_iter = 1
|
|
self.replay_proportion = 0.0
|
|
self.replay_buffer_num_slots = 100
|
|
self.learner_queue_size = 16
|
|
self.learner_queue_timeout = 300
|
|
self.max_sample_requests_in_flight_per_worker = 2
|
|
self.broadcast_interval = 1
|
|
self.grad_clip = 40.0
|
|
self.opt_type = "adam"
|
|
self.lr = 0.0005
|
|
self.lr_schedule = None
|
|
self.decay = 0.99
|
|
self.momentum = 0.0
|
|
self.epsilon = 0.1
|
|
self.vf_loss_coeff = 0.5
|
|
self.entropy_coeff = 0.01
|
|
self.entropy_coeff_schedule = None
|
|
# __sphinx_doc_end__
|
|
# fmt: on
|
|
|
|
@override(ImpalaConfig)
|
|
def training(
|
|
self,
|
|
*,
|
|
vtrace: Optional[bool] = None,
|
|
use_critic: Optional[bool] = None,
|
|
use_gae: Optional[bool] = None,
|
|
lambda_: Optional[float] = None,
|
|
clip_param: Optional[float] = None,
|
|
use_kl_loss: Optional[bool] = None,
|
|
kl_coeff: Optional[float] = None,
|
|
kl_target: Optional[float] = None,
|
|
**kwargs,
|
|
) -> "APPOConfig":
|
|
"""Sets the training related configuration.
|
|
|
|
Args:
|
|
vtrace: Whether to use V-trace weighted advantages. If false, PPO GAE
|
|
advantages will be used instead.
|
|
use_critic: Should use a critic as a baseline (otherwise don't use value
|
|
baseline; required for using GAE). Only applies if vtrace=False.
|
|
use_gae: If true, use the Generalized Advantage Estimator (GAE)
|
|
with a value function, see https://arxiv.org/pdf/1506.02438.pdf.
|
|
Only applies if vtrace=False.
|
|
lambda_: GAE (lambda) parameter.
|
|
clip_param: PPO surrogate slipping parameter.
|
|
use_kl_loss: Whether to use the KL-term in the loss function.
|
|
kl_coeff: Coefficient for weighting the KL-loss term.
|
|
kl_target: Target term for the KL-term to reach (via adjusting the
|
|
`kl_coeff` automatically).
|
|
|
|
Returns:
|
|
This updated AlgorithmConfig object.
|
|
"""
|
|
# Pass kwargs onto super's `training()` method.
|
|
super().training(**kwargs)
|
|
|
|
if vtrace is not None:
|
|
self.vtrace = vtrace
|
|
if use_critic is not None:
|
|
self.use_critic = use_critic
|
|
if use_gae is not None:
|
|
self.use_gae = use_gae
|
|
if lambda_ is not None:
|
|
self.lambda_ = lambda_
|
|
if clip_param is not None:
|
|
self.clip_param = clip_param
|
|
if use_kl_loss is not None:
|
|
self.use_kl_loss = use_kl_loss
|
|
if kl_coeff is not None:
|
|
self.kl_coeff = kl_coeff
|
|
if kl_target is not None:
|
|
self.kl_target = kl_target
|
|
|
|
return self
|
|
|
|
|
|
class UpdateTargetAndKL:
|
|
def __init__(self, workers, config):
|
|
self.workers = workers
|
|
self.config = config
|
|
self.update_kl = UpdateKL(workers)
|
|
self.target_update_freq = (
|
|
config["num_sgd_iter"] * config["minibatch_buffer_size"]
|
|
)
|
|
|
|
def __call__(self, fetches):
|
|
metrics = _get_shared_metrics()
|
|
cur_ts = metrics.counters[STEPS_SAMPLED_COUNTER]
|
|
last_update = metrics.counters[LAST_TARGET_UPDATE_TS]
|
|
if cur_ts - last_update > self.target_update_freq:
|
|
metrics.counters[NUM_TARGET_UPDATES] += 1
|
|
metrics.counters[LAST_TARGET_UPDATE_TS] = cur_ts
|
|
# Update Target Network
|
|
self.workers.local_worker().foreach_policy_to_train(
|
|
lambda p, _: p.update_target()
|
|
)
|
|
# Also update KL Coeff
|
|
if self.config["use_kl_loss"]:
|
|
self.update_kl(fetches)
|
|
|
|
|
|
class APPO(Impala):
|
|
def __init__(self, config, *args, **kwargs):
|
|
"""Initializes a DDPPO instance."""
|
|
super().__init__(config, *args, **kwargs)
|
|
|
|
# After init: Initialize target net.
|
|
self.workers.local_worker().foreach_policy_to_train(
|
|
lambda p, _: p.update_target()
|
|
)
|
|
|
|
@override(Impala)
|
|
def setup(self, config: PartialAlgorithmConfigDict):
|
|
# Before init: Add the update target and kl hook.
|
|
# This hook is called explicitly after each learner step in the
|
|
# execution setup for IMPALA.
|
|
if config.get("_disable_execution_plan_api", False) is False:
|
|
config["after_train_step"] = UpdateTargetAndKL
|
|
|
|
super().setup(config)
|
|
|
|
if self.config["_disable_execution_plan_api"] is True:
|
|
self.update_kl = UpdateKL(self.workers)
|
|
|
|
def after_train_step(self, train_results: ResultDict) -> None:
|
|
"""Updates the target network and the KL coefficient for the APPO-loss.
|
|
|
|
This method is called from within the `training_iteration` method after each
|
|
train update.
|
|
|
|
The target network update frequency is calculated automatically by the product
|
|
of `num_sgd_iter` setting (usually 1 for APPO) and `minibatch_buffer_size`.
|
|
|
|
Args:
|
|
train_results: The results dict collected during the most recent
|
|
training step.
|
|
"""
|
|
cur_ts = self._counters[
|
|
NUM_AGENT_STEPS_SAMPLED if self._by_agent_steps else NUM_ENV_STEPS_SAMPLED
|
|
]
|
|
last_update = self._counters[LAST_TARGET_UPDATE_TS]
|
|
target_update_freq = (
|
|
self.config["num_sgd_iter"] * self.config["minibatch_buffer_size"]
|
|
)
|
|
if cur_ts - last_update > target_update_freq:
|
|
self._counters[NUM_TARGET_UPDATES] += 1
|
|
self._counters[LAST_TARGET_UPDATE_TS] = cur_ts
|
|
|
|
# Update our target network.
|
|
self.workers.local_worker().foreach_policy_to_train(
|
|
lambda p, _: p.update_target()
|
|
)
|
|
|
|
# Also update the KL-coefficient for the APPO loss, if necessary.
|
|
if self.config["use_kl_loss"]:
|
|
|
|
def update(pi, pi_id):
|
|
assert LEARNER_STATS_KEY not in train_results, (
|
|
"{} should be nested under policy id key".format(
|
|
LEARNER_STATS_KEY
|
|
),
|
|
train_results,
|
|
)
|
|
if pi_id in train_results:
|
|
kl = train_results[pi_id][LEARNER_STATS_KEY].get("kl")
|
|
assert kl is not None, (train_results, pi_id)
|
|
# Make the actual `Policy.update_kl()` call.
|
|
pi.update_kl(kl)
|
|
else:
|
|
logger.warning("No data for {}, not updating kl".format(pi_id))
|
|
|
|
# Update KL on all trainable policies within the local (trainer)
|
|
# Worker.
|
|
self.workers.local_worker().foreach_policy_to_train(update)
|
|
|
|
@override(Impala)
|
|
def training_step(self) -> ResultDict:
|
|
train_results = super().training_step()
|
|
|
|
# Update KL, target network periodically.
|
|
self.after_train_step(train_results)
|
|
|
|
return train_results
|
|
|
|
@classmethod
|
|
@override(Impala)
|
|
def get_default_config(cls) -> AlgorithmConfigDict:
|
|
return APPOConfig().to_dict()
|
|
|
|
@override(Impala)
|
|
def get_default_policy_class(
|
|
self, config: PartialAlgorithmConfigDict
|
|
) -> Optional[Type[Policy]]:
|
|
if config["framework"] == "torch":
|
|
from ray.rllib.algorithms.appo.appo_torch_policy import APPOTorchPolicy
|
|
|
|
return APPOTorchPolicy
|
|
elif config["framework"] == "tf":
|
|
from ray.rllib.algorithms.appo.appo_tf_policy import APPOTF1Policy
|
|
|
|
return APPOTF1Policy
|
|
else:
|
|
from ray.rllib.algorithms.appo.appo_tf_policy import APPOTF2Policy
|
|
|
|
return APPOTF2Policy
|
|
|
|
|
|
# Deprecated: Use ray.rllib.algorithms.appo.APPOConfig instead!
|
|
class _deprecated_default_config(dict):
|
|
def __init__(self):
|
|
super().__init__(APPOConfig().to_dict())
|
|
|
|
@Deprecated(
|
|
old="ray.rllib.agents.ppo.appo::DEFAULT_CONFIG",
|
|
new="ray.rllib.algorithms.appo.appo::APPOConfig(...)",
|
|
error=False,
|
|
)
|
|
def __getitem__(self, item):
|
|
return super().__getitem__(item)
|
|
|
|
|
|
DEFAULT_CONFIG = _deprecated_default_config()
|