mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] AlphaStar config objects. (#24576)
This commit is contained in:
parent
d9b54d8bfa
commit
6d94b2acbe
5 changed files with 176 additions and 79 deletions
|
@ -1,6 +1,11 @@
|
|||
from ray.rllib.agents.alpha_star.alpha_star import DEFAULT_CONFIG, AlphaStarTrainer
|
||||
from ray.rllib.agents.alpha_star.alpha_star import (
|
||||
AlphaStarConfig,
|
||||
AlphaStarTrainer,
|
||||
DEFAULT_CONFIG,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DEFAULT_CONFIG",
|
||||
"AlphaStarConfig",
|
||||
"AlphaStarTrainer",
|
||||
"DEFAULT_CONFIG",
|
||||
]
|
||||
|
|
|
@ -3,8 +3,8 @@ A multi-agent, distributed multi-GPU, league-capable asynch. PPO
|
|||
================================================================
|
||||
"""
|
||||
import gym
|
||||
from typing import Optional, Type
|
||||
import tree
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
|
||||
import ray
|
||||
|
@ -19,6 +19,7 @@ from ray.rllib.execution.buffers.mixin_replay_buffer import MixInMultiAgentRepla
|
|||
from ray.rllib.policy.policy import Policy, PolicySpec
|
||||
from ray.rllib.policy.sample_batch import MultiAgentBatch
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.deprecation import Deprecated
|
||||
from ray.rllib.utils.from_config import from_config
|
||||
from ray.rllib.utils.metrics import (
|
||||
LAST_TARGET_UPDATE_TS,
|
||||
|
@ -42,36 +43,60 @@ from ray.rllib.utils.typing import (
|
|||
from ray.tune.utils.placement_groups import PlacementGroupFactory
|
||||
from ray.util.timer import _Timer
|
||||
|
||||
|
||||
class AlphaStarConfig(appo.APPOConfig):
|
||||
"""Defines a configuration class from which an AlphaStarTrainer can be built.
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.agents.alpha_star import AlphaStarConfig
|
||||
>>> config = AlphaStarConfig().training(lr=0.0003, train_batch_size=512)\
|
||||
... .resources(num_gpus=4)\
|
||||
... .rollouts(num_rollout_workers=64)
|
||||
>>> print(config.to_dict())
|
||||
>>> # Build a Trainer object from the config and run 1 training iteration.
|
||||
>>> trainer = config.build(env="CartPole-v1")
|
||||
>>> trainer.train()
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.agents.alpha_star import AlphaStarConfig
|
||||
>>> from ray import tune
|
||||
>>> config = AlphaStarConfig()
|
||||
>>> # Print out some default values.
|
||||
>>> print(config.vtrace)
|
||||
>>> # Update the config object.
|
||||
>>> config.training(lr=tune.grid_search([0.0001, 0.0003]), grad_clip=20.0)
|
||||
>>> # Set the config object's env.
|
||||
>>> config.environment(env="CartPole-v1")
|
||||
>>> # Use to_dict() to get the old-style python config dict
|
||||
>>> # when running with tune.
|
||||
>>> tune.run(
|
||||
... "AlphaStar",
|
||||
... stop={"episode_reward_mean": 200},
|
||||
... config=config.to_dict(),
|
||||
... )
|
||||
"""
|
||||
|
||||
def __init__(self, trainer_class=None):
|
||||
"""Initializes a AlphaStarConfig instance."""
|
||||
super().__init__(trainer_class=trainer_class or AlphaStarTrainer)
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
|
||||
# Adds the following updates to the `IMPALATrainer` config in
|
||||
# rllib/agents/impala/impala.py.
|
||||
DEFAULT_CONFIG = Trainer.merge_trainer_configs(
|
||||
appo.DEFAULT_CONFIG, # See keys in appo.py, which are also supported.
|
||||
{
|
||||
# TODO: Unify the buffer API, then clean up our existing
|
||||
# implementations of different buffers.
|
||||
# This is num batches held at any time for each policy.
|
||||
"replay_buffer_capacity": 20,
|
||||
# e.g. ratio=0.2 -> 20% of samples in each train batch are
|
||||
# old (replayed) ones.
|
||||
"replay_buffer_replay_ratio": 0.5,
|
||||
|
||||
# Timeout to use for `ray.wait()` when waiting for samplers to have placed
|
||||
# new data into the buffers. If no samples are ready within the timeout,
|
||||
# the buffers used for mixin-sampling will return only older samples.
|
||||
"sample_wait_timeout": 0.01,
|
||||
# Timeout to use for `ray.wait()` when waiting for the policy learner actors
|
||||
# to have performed an update and returned learning stats. If no learner
|
||||
# actors have produced any learning results in the meantime, their
|
||||
# learner-stats in the results will be empty for that iteration.
|
||||
"learn_wait_timeout": 0.1,
|
||||
# AlphaStar specific settings:
|
||||
self.replay_buffer_capacity = 20
|
||||
self.replay_buffer_replay_ratio = 0.5
|
||||
self.sample_wait_timeout = 0.01
|
||||
self.learn_wait_timeout = 0.1
|
||||
|
||||
# League-building parameters.
|
||||
# The LeagueBuilder class to be used for league building logic.
|
||||
"league_builder_config": {
|
||||
self.league_builder_config = {
|
||||
# Specify the sub-class of the `LeagueBuilder` API to use.
|
||||
"type": AlphaStarLeagueBuilder,
|
||||
|
||||
# Any any number of constructor kwargs to pass to this class:
|
||||
|
||||
# The number of random policies to add to the league. This must be an
|
||||
# even number (including 0) as these will be evenly distributed
|
||||
# amongst league- and main- exploiters.
|
||||
|
@ -100,29 +125,81 @@ DEFAULT_CONFIG = Trainer.merge_trainer_configs(
|
|||
# Only for ME matches: Prob to play against learning
|
||||
# main (vs a snapshot main).
|
||||
"prob_main_exploiter_playing_against_learning_main": 0.5,
|
||||
},
|
||||
|
||||
# The maximum number of trainable policies for this Trainer.
|
||||
# Each trainable policy will exist as a independent remote actor, co-located
|
||||
# with a replay buffer. This is besides its existence inside
|
||||
# the RolloutWorkers for training and evaluation.
|
||||
# Set to None for automatically inferring this value from the number of
|
||||
# trainable policies found in the `multiagent` config.
|
||||
"max_num_policies_to_train": None,
|
||||
|
||||
# By default, don't drop last timestep.
|
||||
# TODO: We should do the same for IMPALA and APPO at some point.
|
||||
"vtrace_drop_last_ts": False,
|
||||
|
||||
# Reporting interval.
|
||||
"min_time_s_per_reporting": 2,
|
||||
},
|
||||
_allow_unknown_configs=True,
|
||||
)
|
||||
}
|
||||
self.max_num_policies_to_train = None
|
||||
|
||||
# Override some of APPOConfig's default values with AlphaStar-specific
|
||||
# values.
|
||||
self.vtrace_drop_last_ts = False
|
||||
self.min_time_s_per_reporting = 2
|
||||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
|
||||
@override(appo.APPOConfig)
|
||||
def training(
|
||||
self,
|
||||
*,
|
||||
replay_buffer_capacity: Optional[int] = None,
|
||||
replay_buffer_replay_ratio: Optional[float] = None,
|
||||
sample_wait_timeout: Optional[float] = None,
|
||||
learn_wait_timeout: Optional[float] = None,
|
||||
league_builder_config: Optional[Dict[str, Any]] = None,
|
||||
max_num_policies_to_train: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> "AlphaStarConfig":
|
||||
"""Sets the training related configuration.
|
||||
|
||||
Args:
|
||||
replay_buffer_capacity: This is num batches held at any time for each
|
||||
policy.
|
||||
replay_buffer_replay_ratio: For example, ratio=0.2 -> 20% of samples in
|
||||
each train batch are old (replayed) ones.
|
||||
sample_wait_timeout: Timeout to use for `ray.wait()` when waiting for
|
||||
samplers to have placed new data into the buffers. If no samples are
|
||||
ready within the timeout, the buffers used for mixin-sampling will
|
||||
return only older samples.
|
||||
learn_wait_timeout: Timeout to use for `ray.wait()` when waiting for the
|
||||
policy learner actors to have performed an update and returned learning
|
||||
stats. If no learner actors have produced any learning results in the
|
||||
meantime, their learner-stats in the results will be empty for that
|
||||
iteration.
|
||||
league_builder_config: League-building config dict.
|
||||
The dict Must contain a `type` key indicating the LeagueBuilder class
|
||||
to be used for league building logic. All other keys (that are not
|
||||
`type`) will be used as constructor kwargs on the given class to
|
||||
construct the LeagueBuilder instance. See the
|
||||
`ray.rllib.agents.alpha_star.league_builder::AlphaStarLeagueBuilder`
|
||||
(used by default by this algo) as an example.
|
||||
max_num_policies_to_train: The maximum number of trainable policies for this
|
||||
Trainer. Each trainable policy will exist as a independent remote actor,
|
||||
co-located with a replay buffer. This is besides its existence inside
|
||||
the RolloutWorkers for training and evaluation. Set to None for
|
||||
automatically inferring this value from the number of trainable
|
||||
policies found in the `multiagent` config.
|
||||
|
||||
Returns:
|
||||
This updated TrainerConfig object.
|
||||
"""
|
||||
# Pass kwargs onto super's `training()` method.
|
||||
super().training(**kwargs)
|
||||
|
||||
# TODO: Unify the buffer API, then clean up our existing
|
||||
# implementations of different buffers.
|
||||
if replay_buffer_capacity is not None:
|
||||
self.replay_buffer_capacity = replay_buffer_capacity
|
||||
if replay_buffer_replay_ratio is not None:
|
||||
self.replay_buffer_replay_ratio = replay_buffer_replay_ratio
|
||||
if sample_wait_timeout is not None:
|
||||
self.sample_wait_timeout = sample_wait_timeout
|
||||
if learn_wait_timeout is not None:
|
||||
self.learn_wait_timeout = learn_wait_timeout
|
||||
if league_builder_config is not None:
|
||||
self.league_builder_config = league_builder_config
|
||||
if max_num_policies_to_train is not None:
|
||||
self.max_num_policies_to_train = max_num_policies_to_train
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class AlphaStarTrainer(appo.APPOTrainer):
|
||||
_allow_unknown_subkeys = appo.APPOTrainer._allow_unknown_subkeys + [
|
||||
|
@ -204,7 +281,7 @@ class AlphaStarTrainer(appo.APPOTrainer):
|
|||
@classmethod
|
||||
@override(appo.APPOTrainer)
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
return DEFAULT_CONFIG
|
||||
return AlphaStarConfig().to_dict()
|
||||
|
||||
@override(appo.APPOTrainer)
|
||||
def validate_config(self, config: TrainerConfigDict):
|
||||
|
@ -499,3 +576,20 @@ class AlphaStarTrainer(appo.APPOTrainer):
|
|||
state_copy = state.copy()
|
||||
self.league_builder.__setstate__(state.pop("league_builder", {}))
|
||||
super().__setstate__(state_copy)
|
||||
|
||||
|
||||
# Deprecated: Use ray.rllib.agents.ppo.PPOConfig instead!
|
||||
class _deprecated_default_config(dict):
|
||||
def __init__(self):
|
||||
super().__init__(AlphaStarConfig().to_dict())
|
||||
|
||||
@Deprecated(
|
||||
old="ray.rllib.agents.alpha_star.alpha_star.DEFAULT_CONFIG",
|
||||
new="ray.rllib.agents.alpha_star.alpha_star.AlphaStarConfig(...)",
|
||||
error=False,
|
||||
)
|
||||
def __getitem__(self, item):
|
||||
return super().__getitem__(item)
|
||||
|
||||
|
||||
DEFAULT_CONFIG = _deprecated_default_config()
|
||||
|
|
|
@ -80,9 +80,15 @@ class AlphaStarLeagueBuilder(LeagueBuilder):
|
|||
):
|
||||
"""Initializes a AlphaStarLeagueBuilder instance.
|
||||
|
||||
The following match types are possible:
|
||||
LE: A learning (not snapshot) league_exploiter vs any snapshot policy.
|
||||
ME: A learning (not snapshot) main exploiter vs any main.
|
||||
M: Main self-play (main vs main).
|
||||
|
||||
Args:
|
||||
trainer: The Trainer object by which this league builder is used.
|
||||
Trainer calls `build_league()` after each training step.
|
||||
Trainer calls `build_league()` after each training step to reconfigure
|
||||
the league structure (e.g. to add/remove policies).
|
||||
trainer_config: The (not yet validated) config dict to be
|
||||
used on the Trainer. Child classes of `LeagueBuilder`
|
||||
should preprocess this to add e.g. multiagent settings
|
||||
|
|
|
@ -26,38 +26,33 @@ class TestAlphaStar(unittest.TestCase):
|
|||
|
||||
def test_alpha_star_compilation(self):
|
||||
"""Test whether a AlphaStarTrainer can be built with all frameworks."""
|
||||
|
||||
config = {
|
||||
"env": "connect_four",
|
||||
"gamma": 1.0,
|
||||
"num_workers": 4,
|
||||
"num_envs_per_worker": 5,
|
||||
"model": {
|
||||
"fcnet_hiddens": [256, 256, 256],
|
||||
},
|
||||
"vf_loss_coeff": 0.01,
|
||||
"entropy_coeff": 0.004,
|
||||
"league_builder_config": {
|
||||
config = (
|
||||
alpha_star.AlphaStarConfig()
|
||||
.environment(env="connect_four")
|
||||
.training(
|
||||
gamma=1.0,
|
||||
model={"fcnet_hiddens": [256, 256, 256]},
|
||||
vf_loss_coeff=0.01,
|
||||
entropy_coeff=0.004,
|
||||
league_builder_config={
|
||||
"win_rate_threshold_for_new_snapshot": 0.8,
|
||||
"num_random_policies": 2,
|
||||
"num_learning_league_exploiters": 1,
|
||||
"num_learning_main_exploiters": 1,
|
||||
},
|
||||
"grad_clip": 10.0,
|
||||
"replay_buffer_capacity": 10,
|
||||
"replay_buffer_replay_ratio": 0.0,
|
||||
# Two GPUs -> 2 policies per GPU.
|
||||
"num_gpus": 4,
|
||||
"_fake_gpus": True,
|
||||
# Test with KL loss, just to cover that extra code.
|
||||
"use_kl_loss": True,
|
||||
}
|
||||
grad_clip=10.0,
|
||||
replay_buffer_capacity=10,
|
||||
replay_buffer_replay_ratio=0.0,
|
||||
use_kl_loss=True,
|
||||
)
|
||||
.rollouts(num_rollout_workers=4, num_envs_per_worker=5)
|
||||
.resources(num_gpus=4, _fake_gpus=True)
|
||||
)
|
||||
|
||||
num_iterations = 2
|
||||
|
||||
for _ in framework_iterator(config, with_eager_tracing=True):
|
||||
_config = config.copy()
|
||||
trainer = alpha_star.AlphaStarTrainer(config=_config)
|
||||
trainer = config.build()
|
||||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
print(results)
|
||||
|
|
|
@ -59,7 +59,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class ImpalaConfig(TrainerConfig):
|
||||
"""Defines an ARSTrainer configuration class from which an ImpalaTrainer can be built.
|
||||
"""Defines a configuration class from which an ImpalaTrainer can be built.
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.agents.impala import ImpalaConfig
|
||||
|
@ -136,13 +136,10 @@ class ImpalaConfig(TrainerConfig):
|
|||
self.num_gpus = 1
|
||||
self.lr = 0.0005
|
||||
self.min_time_s_per_reporting = 10
|
||||
# IMPALA and APPO are not on the new training_iteration API yet.
|
||||
self._disable_execution_plan_api = False
|
||||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
|
||||
# Deprecated value.
|
||||
self._disable_execution_plan_api = True
|
||||
self.num_data_loader_buffers = DEPRECATED_VALUE
|
||||
|
||||
@override(TrainerConfig)
|
||||
|
|
Loading…
Add table
Reference in a new issue