import logging from typing import Any, Dict, List, Optional, Type, Union from ray.actor import ActorHandle from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy from ray.rllib.agents.trainer import Trainer from ray.rllib.agents.trainer_config import TrainerConfig from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.execution.parallel_requests import asynchronous_parallel_requests from ray.rllib.policy.policy import Policy from ray.rllib.utils.annotations import override from ray.rllib.utils.deprecation import Deprecated from ray.rllib.utils.metrics import ( APPLY_GRADS_TIMER, GRAD_WAIT_TIMER, NUM_AGENT_STEPS_SAMPLED, NUM_AGENT_STEPS_TRAINED, NUM_ENV_STEPS_SAMPLED, NUM_ENV_STEPS_TRAINED, SYNCH_WORKER_WEIGHTS_TIMER, ) from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder from ray.rllib.utils.typing import ResultDict, TrainerConfigDict logger = logging.getLogger(__name__) class A3CConfig(TrainerConfig): """Defines a PPOTrainer configuration class from which a PPOTrainer can be built. Example: >>> from ray import tune >>> config = A3CConfig().training(lr=0.01, grad_clip=30.0)\ ... .resources(num_gpus=0)\ ... .rollouts(num_rollout_workers=4) >>> print(config.to_dict()) >>> # Build a Trainer object from the config and run 1 training iteration. >>> trainer = config.build(env="CartPole-v1") >>> trainer.train() Example: >>> config = A3CConfig() >>> # Print out some default values. >>> print(config.sample_async) >>> # Update the config object. >>> config.training(lr=tune.grid_search([0.001, 0.0001]), use_critic=False) >>> # Set the config object's env. >>> config.environment(env="CartPole-v1") >>> # Use to_dict() to get the old-style python config dict >>> # when running with tune. >>> tune.run( ... "A3C", ... stop={"episode_reward_mean": 200}, ... config=config.to_dict(), ... ) """ def __init__(self, trainer_class=None): """Initializes a A3CConfig instance.""" super().__init__(trainer_class=trainer_class or A3CTrainer) # fmt: off # __sphinx_doc_begin__ # # A3C specific settings. self.use_critic = True self.use_gae = True self.lambda_ = 1.0 self.grad_clip = 40.0 self.lr_schedule = None self.vf_loss_coeff = 0.5 self.entropy_coeff = 0.01 self.entropy_coeff_schedule = None self.sample_async = True # Override some of TrainerConfig's default values with PPO-specific values. self.rollout_fragment_length = 10 self.lr = 0.0001 # Min time (in seconds) per reporting. # This causes not every call to `training_iteration` to be reported, # but to wait until n seconds have passed and then to summarize the # thus far collected results. self.min_time_s_per_reporting = 5 # __sphinx_doc_end__ # fmt: on @override(TrainerConfig) def training( self, *, lr_schedule: Optional[List[List[Union[int, float]]]] = None, use_critic: Optional[bool] = None, use_gae: Optional[bool] = None, lambda_: Optional[float] = None, grad_clip: Optional[float] = None, vf_loss_coeff: Optional[float] = None, entropy_coeff: Optional[float] = None, entropy_coeff_schedule: Optional[List[List[Union[int, float]]]] = None, sample_async: Optional[bool] = None, **kwargs, ) -> "A3CConfig": """Sets the training related configuration. Args: lr_schedule: Learning rate schedule. In the format of [[timestep, lr-value], [timestep, lr-value], ...] Intermediary timesteps will be assigned to interpolated learning rate values. A schedule should normally start from timestep 0. use_critic: Should use a critic as a baseline (otherwise don't use value baseline; required for using GAE). use_gae: If true, use the Generalized Advantage Estimator (GAE) with a value function, see https://arxiv.org/pdf/1506.02438.pdf. lambda_: GAE(gamma) parameter. grad_clip: Max global norm for each gradient calculated by worker. vf_loss_coeff: Value Function Loss coefficient. entropy_coeff: Coefficient of the entropy regularizer. entropy_coeff_schedule: Decay schedule for the entropy regularizer. sample_async: Whether workers should sample async. Note that this increases the effective rollout_fragment_length by up to 5x due to async buffering of batches. Returns: This updated TrainerConfig object. """ # Pass kwargs onto super's `training()` method. super().training(**kwargs) if lr_schedule is not None: self.lr_schedule = lr_schedule if use_critic is not None: self.lr_schedule = use_critic if use_gae is not None: self.use_gae = use_gae if lambda_ is not None: self.lambda_ = lambda_ if grad_clip is not None: self.grad_clip = grad_clip if vf_loss_coeff is not None: self.vf_loss_coeff = vf_loss_coeff if entropy_coeff is not None: self.entropy_coeff = entropy_coeff if entropy_coeff_schedule is not None: self.entropy_coeff_schedule = entropy_coeff_schedule if sample_async is not None: self.sample_async = sample_async return self class A3CTrainer(Trainer): @classmethod @override(Trainer) def get_default_config(cls) -> TrainerConfigDict: return A3CConfig().to_dict() @override(Trainer) def validate_config(self, config: TrainerConfigDict) -> None: # Call super's validation method. super().validate_config(config) if config["entropy_coeff"] < 0: raise ValueError("`entropy_coeff` must be >= 0.0!") if config["num_workers"] <= 0 and config["sample_async"]: raise ValueError("`num_workers` for A3C must be >= 1!") @override(Trainer) def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]: if config["framework"] == "torch": from ray.rllib.agents.a3c.a3c_torch_policy import A3CTorchPolicy return A3CTorchPolicy else: return A3CTFPolicy def training_iteration(self) -> ResultDict: # Shortcut. local_worker = self.workers.local_worker() # Define the function executed in parallel by all RolloutWorkers to collect # samples + compute and return gradients (and other information). def sample_and_compute_grads(worker: RolloutWorker) -> Dict[str, Any]: """Call sample() and compute_gradients() remotely on workers.""" samples = worker.sample() grads, infos = worker.compute_gradients(samples) return { "grads": grads, "infos": infos, "agent_steps": samples.agent_steps(), "env_steps": samples.env_steps(), } # Perform rollouts and gradient calculations asynchronously. with self._timers[GRAD_WAIT_TIMER]: # Results are a mapping from ActorHandle (RolloutWorker) to their # returned gradient calculation results. async_results: Dict[ActorHandle, Dict] = asynchronous_parallel_requests( remote_requests_in_flight=self.remote_requests_in_flight, actors=self.workers.remote_workers(), ray_wait_timeout_s=0.0, max_remote_requests_in_flight_per_actor=1, remote_fn=sample_and_compute_grads, ) # Loop through all fetched worker-computed gradients (if any) # and apply them - one by one - to the local worker's model. # After each apply step (one step per worker that returned some gradients), # update that particular worker's weights. global_vars = None learner_info_builder = LearnerInfoBuilder(num_devices=1) for worker, results in async_results.items(): for result in results: # Apply gradients to local worker. with self._timers[APPLY_GRADS_TIMER]: local_worker.apply_gradients(result["grads"]) self._timers[APPLY_GRADS_TIMER].push_units_processed( result["agent_steps"] ) # Update all step counters. self._counters[NUM_AGENT_STEPS_SAMPLED] += result["agent_steps"] self._counters[NUM_ENV_STEPS_SAMPLED] += result["env_steps"] self._counters[NUM_AGENT_STEPS_TRAINED] += result["agent_steps"] self._counters[NUM_ENV_STEPS_TRAINED] += result["env_steps"] learner_info_builder.add_learn_on_batch_results_multi_agent( result["infos"] ) # Create current global vars. global_vars = { "timestep": self._counters[NUM_AGENT_STEPS_SAMPLED], } # Synch updated weights back to the particular worker. with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: weights = local_worker.get_weights(local_worker.get_policies_to_train()) worker.set_weights.remote(weights, global_vars) # Update global vars of the local worker. if global_vars: local_worker.set_global_vars(global_vars) return learner_info_builder.finalize() # Deprecated: Use ray.rllib.agents.a3c.A3CConfig instead! class _deprecated_default_config(dict): def __init__(self): super().__init__(A3CConfig().to_dict()) @Deprecated( old="ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG", new="ray.rllib.agents.ppo.ppo.PPOConfig(...)", error=False, ) def __getitem__(self, item): return super().__getitem__(item) DEFAULT_CONFIG = _deprecated_default_config()