mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
261 lines
10 KiB
Python
261 lines
10 KiB
Python
import logging
|
|
from typing import Any, Dict, List, Optional, Type, Union
|
|
|
|
from ray.actor import ActorHandle
|
|
from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy
|
|
from ray.rllib.agents.trainer import Trainer
|
|
from ray.rllib.agents.trainer_config import TrainerConfig
|
|
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
|
from ray.rllib.execution.parallel_requests import asynchronous_parallel_requests
|
|
from ray.rllib.policy.policy import Policy
|
|
from ray.rllib.utils.annotations import override
|
|
from ray.rllib.utils.deprecation import Deprecated
|
|
from ray.rllib.utils.metrics import (
|
|
APPLY_GRADS_TIMER,
|
|
GRAD_WAIT_TIMER,
|
|
NUM_AGENT_STEPS_SAMPLED,
|
|
NUM_AGENT_STEPS_TRAINED,
|
|
NUM_ENV_STEPS_SAMPLED,
|
|
NUM_ENV_STEPS_TRAINED,
|
|
SYNCH_WORKER_WEIGHTS_TIMER,
|
|
)
|
|
from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder
|
|
from ray.rllib.utils.typing import ResultDict, TrainerConfigDict
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class A3CConfig(TrainerConfig):
|
|
"""Defines a PPOTrainer configuration class from which a PPOTrainer can be built.
|
|
|
|
Example:
|
|
>>> from ray import tune
|
|
>>> config = A3CConfig().training(lr=0.01, grad_clip=30.0)\
|
|
... .resources(num_gpus=0)\
|
|
... .rollouts(num_rollout_workers=4)
|
|
>>> print(config.to_dict())
|
|
>>> # Build a Trainer object from the config and run 1 training iteration.
|
|
>>> trainer = config.build(env="CartPole-v1")
|
|
>>> trainer.train()
|
|
|
|
Example:
|
|
>>> config = A3CConfig()
|
|
>>> # Print out some default values.
|
|
>>> print(config.sample_async)
|
|
>>> # Update the config object.
|
|
>>> config.training(lr=tune.grid_search([0.001, 0.0001]), use_critic=False)
|
|
>>> # Set the config object's env.
|
|
>>> config.environment(env="CartPole-v1")
|
|
>>> # Use to_dict() to get the old-style python config dict
|
|
>>> # when running with tune.
|
|
>>> tune.run(
|
|
... "A3C",
|
|
... stop={"episode_reward_mean": 200},
|
|
... config=config.to_dict(),
|
|
... )
|
|
"""
|
|
|
|
def __init__(self, trainer_class=None):
|
|
"""Initializes a A3CConfig instance."""
|
|
super().__init__(trainer_class=trainer_class or A3CTrainer)
|
|
|
|
# fmt: off
|
|
# __sphinx_doc_begin__
|
|
#
|
|
# A3C specific settings.
|
|
self.use_critic = True
|
|
self.use_gae = True
|
|
self.lambda_ = 1.0
|
|
self.grad_clip = 40.0
|
|
self.lr_schedule = None
|
|
self.vf_loss_coeff = 0.5
|
|
self.entropy_coeff = 0.01
|
|
self.entropy_coeff_schedule = None
|
|
self.sample_async = True
|
|
|
|
# Override some of TrainerConfig's default values with PPO-specific values.
|
|
self.rollout_fragment_length = 10
|
|
self.lr = 0.0001
|
|
# Min time (in seconds) per reporting.
|
|
# This causes not every call to `training_iteration` to be reported,
|
|
# but to wait until n seconds have passed and then to summarize the
|
|
# thus far collected results.
|
|
self.min_time_s_per_reporting = 5
|
|
# __sphinx_doc_end__
|
|
# fmt: on
|
|
|
|
@override(TrainerConfig)
|
|
def training(
|
|
self,
|
|
*,
|
|
lr_schedule: Optional[List[List[Union[int, float]]]] = None,
|
|
use_critic: Optional[bool] = None,
|
|
use_gae: Optional[bool] = None,
|
|
lambda_: Optional[float] = None,
|
|
grad_clip: Optional[float] = None,
|
|
vf_loss_coeff: Optional[float] = None,
|
|
entropy_coeff: Optional[float] = None,
|
|
entropy_coeff_schedule: Optional[List[List[Union[int, float]]]] = None,
|
|
sample_async: Optional[bool] = None,
|
|
**kwargs,
|
|
) -> "A3CConfig":
|
|
"""Sets the training related configuration.
|
|
|
|
Args:
|
|
lr_schedule: Learning rate schedule. In the format of
|
|
[[timestep, lr-value], [timestep, lr-value], ...]
|
|
Intermediary timesteps will be assigned to interpolated learning rate
|
|
values. A schedule should normally start from timestep 0.
|
|
use_critic: Should use a critic as a baseline (otherwise don't use value
|
|
baseline; required for using GAE).
|
|
use_gae: If true, use the Generalized Advantage Estimator (GAE)
|
|
with a value function, see https://arxiv.org/pdf/1506.02438.pdf.
|
|
lambda_: GAE(gamma) parameter.
|
|
grad_clip: Max global norm for each gradient calculated by worker.
|
|
vf_loss_coeff: Value Function Loss coefficient.
|
|
entropy_coeff: Coefficient of the entropy regularizer.
|
|
entropy_coeff_schedule: Decay schedule for the entropy regularizer.
|
|
sample_async: Whether workers should sample async. Note that this
|
|
increases the effective rollout_fragment_length by up to 5x due
|
|
to async buffering of batches.
|
|
|
|
Returns:
|
|
This updated TrainerConfig object.
|
|
"""
|
|
# Pass kwargs onto super's `training()` method.
|
|
super().training(**kwargs)
|
|
|
|
if lr_schedule is not None:
|
|
self.lr_schedule = lr_schedule
|
|
if use_critic is not None:
|
|
self.lr_schedule = use_critic
|
|
if use_gae is not None:
|
|
self.use_gae = use_gae
|
|
if lambda_ is not None:
|
|
self.lambda_ = lambda_
|
|
if grad_clip is not None:
|
|
self.grad_clip = grad_clip
|
|
if vf_loss_coeff is not None:
|
|
self.vf_loss_coeff = vf_loss_coeff
|
|
if entropy_coeff is not None:
|
|
self.entropy_coeff = entropy_coeff
|
|
if entropy_coeff_schedule is not None:
|
|
self.entropy_coeff_schedule = entropy_coeff_schedule
|
|
if sample_async is not None:
|
|
self.sample_async = sample_async
|
|
|
|
return self
|
|
|
|
|
|
class A3CTrainer(Trainer):
|
|
@classmethod
|
|
@override(Trainer)
|
|
def get_default_config(cls) -> TrainerConfigDict:
|
|
return A3CConfig().to_dict()
|
|
|
|
@override(Trainer)
|
|
def validate_config(self, config: TrainerConfigDict) -> None:
|
|
# Call super's validation method.
|
|
super().validate_config(config)
|
|
|
|
if config["entropy_coeff"] < 0:
|
|
raise ValueError("`entropy_coeff` must be >= 0.0!")
|
|
if config["num_workers"] <= 0 and config["sample_async"]:
|
|
raise ValueError("`num_workers` for A3C must be >= 1!")
|
|
|
|
@override(Trainer)
|
|
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
|
|
if config["framework"] == "torch":
|
|
from ray.rllib.agents.a3c.a3c_torch_policy import A3CTorchPolicy
|
|
|
|
return A3CTorchPolicy
|
|
else:
|
|
return A3CTFPolicy
|
|
|
|
def training_iteration(self) -> ResultDict:
|
|
# Shortcut.
|
|
local_worker = self.workers.local_worker()
|
|
|
|
# Define the function executed in parallel by all RolloutWorkers to collect
|
|
# samples + compute and return gradients (and other information).
|
|
|
|
def sample_and_compute_grads(worker: RolloutWorker) -> Dict[str, Any]:
|
|
"""Call sample() and compute_gradients() remotely on workers."""
|
|
samples = worker.sample()
|
|
grads, infos = worker.compute_gradients(samples)
|
|
return {
|
|
"grads": grads,
|
|
"infos": infos,
|
|
"agent_steps": samples.agent_steps(),
|
|
"env_steps": samples.env_steps(),
|
|
}
|
|
|
|
# Perform rollouts and gradient calculations asynchronously.
|
|
with self._timers[GRAD_WAIT_TIMER]:
|
|
# Results are a mapping from ActorHandle (RolloutWorker) to their
|
|
# returned gradient calculation results.
|
|
async_results: Dict[ActorHandle, Dict] = asynchronous_parallel_requests(
|
|
remote_requests_in_flight=self.remote_requests_in_flight,
|
|
actors=self.workers.remote_workers(),
|
|
ray_wait_timeout_s=0.0,
|
|
max_remote_requests_in_flight_per_actor=1,
|
|
remote_fn=sample_and_compute_grads,
|
|
)
|
|
|
|
# Loop through all fetched worker-computed gradients (if any)
|
|
# and apply them - one by one - to the local worker's model.
|
|
# After each apply step (one step per worker that returned some gradients),
|
|
# update that particular worker's weights.
|
|
global_vars = None
|
|
learner_info_builder = LearnerInfoBuilder(num_devices=1)
|
|
for worker, results in async_results.items():
|
|
for result in results:
|
|
# Apply gradients to local worker.
|
|
with self._timers[APPLY_GRADS_TIMER]:
|
|
local_worker.apply_gradients(result["grads"])
|
|
self._timers[APPLY_GRADS_TIMER].push_units_processed(
|
|
result["agent_steps"]
|
|
)
|
|
|
|
# Update all step counters.
|
|
self._counters[NUM_AGENT_STEPS_SAMPLED] += result["agent_steps"]
|
|
self._counters[NUM_ENV_STEPS_SAMPLED] += result["env_steps"]
|
|
self._counters[NUM_AGENT_STEPS_TRAINED] += result["agent_steps"]
|
|
self._counters[NUM_ENV_STEPS_TRAINED] += result["env_steps"]
|
|
|
|
learner_info_builder.add_learn_on_batch_results_multi_agent(
|
|
result["infos"]
|
|
)
|
|
|
|
# Create current global vars.
|
|
global_vars = {
|
|
"timestep": self._counters[NUM_AGENT_STEPS_SAMPLED],
|
|
}
|
|
|
|
# Synch updated weights back to the particular worker.
|
|
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
|
|
weights = local_worker.get_weights(local_worker.get_policies_to_train())
|
|
worker.set_weights.remote(weights, global_vars)
|
|
|
|
# Update global vars of the local worker.
|
|
if global_vars:
|
|
local_worker.set_global_vars(global_vars)
|
|
|
|
return learner_info_builder.finalize()
|
|
|
|
|
|
# Deprecated: Use ray.rllib.agents.a3c.A3CConfig instead!
|
|
class _deprecated_default_config(dict):
|
|
def __init__(self):
|
|
super().__init__(A3CConfig().to_dict())
|
|
|
|
@Deprecated(
|
|
old="ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG",
|
|
new="ray.rllib.agents.ppo.ppo.PPOConfig(...)",
|
|
error=False,
|
|
)
|
|
def __getitem__(self, item):
|
|
return super().__getitem__(item)
|
|
|
|
|
|
DEFAULT_CONFIG = _deprecated_default_config()
|