[RLlib] AlphaStar config objects. (#24576)

This commit is contained in:
Sven Mika 2022-05-10 14:01:00 +02:00 committed by GitHub
parent d9b54d8bfa
commit 6d94b2acbe
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 176 additions and 79 deletions

View file

@ -1,6 +1,11 @@
from ray.rllib.agents.alpha_star.alpha_star import DEFAULT_CONFIG, AlphaStarTrainer
from ray.rllib.agents.alpha_star.alpha_star import (
AlphaStarConfig,
AlphaStarTrainer,
DEFAULT_CONFIG,
)
__all__ = [
"DEFAULT_CONFIG",
"AlphaStarConfig",
"AlphaStarTrainer",
"DEFAULT_CONFIG",
]

View file

@ -3,8 +3,8 @@ A multi-agent, distributed multi-GPU, league-capable asynch. PPO
================================================================
"""
import gym
from typing import Optional, Type
import tree
from typing import Any, Dict, Optional, Type
import ray
@ -19,6 +19,7 @@ from ray.rllib.execution.buffers.mixin_replay_buffer import MixInMultiAgentRepla
from ray.rllib.policy.policy import Policy, PolicySpec
from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.from_config import from_config
from ray.rllib.utils.metrics import (
LAST_TARGET_UPDATE_TS,
@ -42,36 +43,60 @@ from ray.rllib.utils.typing import (
from ray.tune.utils.placement_groups import PlacementGroupFactory
from ray.util.timer import _Timer
# fmt: off
# __sphinx_doc_begin__
# Adds the following updates to the `IMPALATrainer` config in
# rllib/agents/impala/impala.py.
DEFAULT_CONFIG = Trainer.merge_trainer_configs(
appo.DEFAULT_CONFIG, # See keys in appo.py, which are also supported.
{
# TODO: Unify the buffer API, then clean up our existing
# implementations of different buffers.
# This is num batches held at any time for each policy.
"replay_buffer_capacity": 20,
# e.g. ratio=0.2 -> 20% of samples in each train batch are
# old (replayed) ones.
"replay_buffer_replay_ratio": 0.5,
class AlphaStarConfig(appo.APPOConfig):
"""Defines a configuration class from which an AlphaStarTrainer can be built.
# Timeout to use for `ray.wait()` when waiting for samplers to have placed
# new data into the buffers. If no samples are ready within the timeout,
# the buffers used for mixin-sampling will return only older samples.
"sample_wait_timeout": 0.01,
# Timeout to use for `ray.wait()` when waiting for the policy learner actors
# to have performed an update and returned learning stats. If no learner
# actors have produced any learning results in the meantime, their
# learner-stats in the results will be empty for that iteration.
"learn_wait_timeout": 0.1,
Example:
>>> from ray.rllib.agents.alpha_star import AlphaStarConfig
>>> config = AlphaStarConfig().training(lr=0.0003, train_batch_size=512)\
... .resources(num_gpus=4)\
... .rollouts(num_rollout_workers=64)
>>> print(config.to_dict())
>>> # Build a Trainer object from the config and run 1 training iteration.
>>> trainer = config.build(env="CartPole-v1")
>>> trainer.train()
Example:
>>> from ray.rllib.agents.alpha_star import AlphaStarConfig
>>> from ray import tune
>>> config = AlphaStarConfig()
>>> # Print out some default values.
>>> print(config.vtrace)
>>> # Update the config object.
>>> config.training(lr=tune.grid_search([0.0001, 0.0003]), grad_clip=20.0)
>>> # Set the config object's env.
>>> config.environment(env="CartPole-v1")
>>> # Use to_dict() to get the old-style python config dict
>>> # when running with tune.
>>> tune.run(
... "AlphaStar",
... stop={"episode_reward_mean": 200},
... config=config.to_dict(),
... )
"""
def __init__(self, trainer_class=None):
"""Initializes a AlphaStarConfig instance."""
super().__init__(trainer_class=trainer_class or AlphaStarTrainer)
# fmt: off
# __sphinx_doc_begin__
# AlphaStar specific settings:
self.replay_buffer_capacity = 20
self.replay_buffer_replay_ratio = 0.5
self.sample_wait_timeout = 0.01
self.learn_wait_timeout = 0.1
# League-building parameters.
# The LeagueBuilder class to be used for league building logic.
"league_builder_config": {
self.league_builder_config = {
# Specify the sub-class of the `LeagueBuilder` API to use.
"type": AlphaStarLeagueBuilder,
# Any any number of constructor kwargs to pass to this class:
# The number of random policies to add to the league. This must be an
# even number (including 0) as these will be evenly distributed
# amongst league- and main- exploiters.
@ -100,28 +125,80 @@ DEFAULT_CONFIG = Trainer.merge_trainer_configs(
# Only for ME matches: Prob to play against learning
# main (vs a snapshot main).
"prob_main_exploiter_playing_against_learning_main": 0.5,
},
}
self.max_num_policies_to_train = None
# The maximum number of trainable policies for this Trainer.
# Each trainable policy will exist as a independent remote actor, co-located
# with a replay buffer. This is besides its existence inside
# the RolloutWorkers for training and evaluation.
# Set to None for automatically inferring this value from the number of
# trainable policies found in the `multiagent` config.
"max_num_policies_to_train": None,
# Override some of APPOConfig's default values with AlphaStar-specific
# values.
self.vtrace_drop_last_ts = False
self.min_time_s_per_reporting = 2
# __sphinx_doc_end__
# fmt: on
# By default, don't drop last timestep.
# TODO: We should do the same for IMPALA and APPO at some point.
"vtrace_drop_last_ts": False,
@override(appo.APPOConfig)
def training(
self,
*,
replay_buffer_capacity: Optional[int] = None,
replay_buffer_replay_ratio: Optional[float] = None,
sample_wait_timeout: Optional[float] = None,
learn_wait_timeout: Optional[float] = None,
league_builder_config: Optional[Dict[str, Any]] = None,
max_num_policies_to_train: Optional[int] = None,
**kwargs,
) -> "AlphaStarConfig":
"""Sets the training related configuration.
# Reporting interval.
"min_time_s_per_reporting": 2,
},
_allow_unknown_configs=True,
)
Args:
replay_buffer_capacity: This is num batches held at any time for each
policy.
replay_buffer_replay_ratio: For example, ratio=0.2 -> 20% of samples in
each train batch are old (replayed) ones.
sample_wait_timeout: Timeout to use for `ray.wait()` when waiting for
samplers to have placed new data into the buffers. If no samples are
ready within the timeout, the buffers used for mixin-sampling will
return only older samples.
learn_wait_timeout: Timeout to use for `ray.wait()` when waiting for the
policy learner actors to have performed an update and returned learning
stats. If no learner actors have produced any learning results in the
meantime, their learner-stats in the results will be empty for that
iteration.
league_builder_config: League-building config dict.
The dict Must contain a `type` key indicating the LeagueBuilder class
to be used for league building logic. All other keys (that are not
`type`) will be used as constructor kwargs on the given class to
construct the LeagueBuilder instance. See the
`ray.rllib.agents.alpha_star.league_builder::AlphaStarLeagueBuilder`
(used by default by this algo) as an example.
max_num_policies_to_train: The maximum number of trainable policies for this
Trainer. Each trainable policy will exist as a independent remote actor,
co-located with a replay buffer. This is besides its existence inside
the RolloutWorkers for training and evaluation. Set to None for
automatically inferring this value from the number of trainable
policies found in the `multiagent` config.
# __sphinx_doc_end__
# fmt: on
Returns:
This updated TrainerConfig object.
"""
# Pass kwargs onto super's `training()` method.
super().training(**kwargs)
# TODO: Unify the buffer API, then clean up our existing
# implementations of different buffers.
if replay_buffer_capacity is not None:
self.replay_buffer_capacity = replay_buffer_capacity
if replay_buffer_replay_ratio is not None:
self.replay_buffer_replay_ratio = replay_buffer_replay_ratio
if sample_wait_timeout is not None:
self.sample_wait_timeout = sample_wait_timeout
if learn_wait_timeout is not None:
self.learn_wait_timeout = learn_wait_timeout
if league_builder_config is not None:
self.league_builder_config = league_builder_config
if max_num_policies_to_train is not None:
self.max_num_policies_to_train = max_num_policies_to_train
return self
class AlphaStarTrainer(appo.APPOTrainer):
@ -204,7 +281,7 @@ class AlphaStarTrainer(appo.APPOTrainer):
@classmethod
@override(appo.APPOTrainer)
def get_default_config(cls) -> TrainerConfigDict:
return DEFAULT_CONFIG
return AlphaStarConfig().to_dict()
@override(appo.APPOTrainer)
def validate_config(self, config: TrainerConfigDict):
@ -499,3 +576,20 @@ class AlphaStarTrainer(appo.APPOTrainer):
state_copy = state.copy()
self.league_builder.__setstate__(state.pop("league_builder", {}))
super().__setstate__(state_copy)
# Deprecated: Use ray.rllib.agents.ppo.PPOConfig instead!
class _deprecated_default_config(dict):
def __init__(self):
super().__init__(AlphaStarConfig().to_dict())
@Deprecated(
old="ray.rllib.agents.alpha_star.alpha_star.DEFAULT_CONFIG",
new="ray.rllib.agents.alpha_star.alpha_star.AlphaStarConfig(...)",
error=False,
)
def __getitem__(self, item):
return super().__getitem__(item)
DEFAULT_CONFIG = _deprecated_default_config()

View file

@ -80,9 +80,15 @@ class AlphaStarLeagueBuilder(LeagueBuilder):
):
"""Initializes a AlphaStarLeagueBuilder instance.
The following match types are possible:
LE: A learning (not snapshot) league_exploiter vs any snapshot policy.
ME: A learning (not snapshot) main exploiter vs any main.
M: Main self-play (main vs main).
Args:
trainer: The Trainer object by which this league builder is used.
Trainer calls `build_league()` after each training step.
Trainer calls `build_league()` after each training step to reconfigure
the league structure (e.g. to add/remove policies).
trainer_config: The (not yet validated) config dict to be
used on the Trainer. Child classes of `LeagueBuilder`
should preprocess this to add e.g. multiagent settings

View file

@ -26,38 +26,33 @@ class TestAlphaStar(unittest.TestCase):
def test_alpha_star_compilation(self):
"""Test whether a AlphaStarTrainer can be built with all frameworks."""
config = {
"env": "connect_four",
"gamma": 1.0,
"num_workers": 4,
"num_envs_per_worker": 5,
"model": {
"fcnet_hiddens": [256, 256, 256],
},
"vf_loss_coeff": 0.01,
"entropy_coeff": 0.004,
"league_builder_config": {
config = (
alpha_star.AlphaStarConfig()
.environment(env="connect_four")
.training(
gamma=1.0,
model={"fcnet_hiddens": [256, 256, 256]},
vf_loss_coeff=0.01,
entropy_coeff=0.004,
league_builder_config={
"win_rate_threshold_for_new_snapshot": 0.8,
"num_random_policies": 2,
"num_learning_league_exploiters": 1,
"num_learning_main_exploiters": 1,
},
"grad_clip": 10.0,
"replay_buffer_capacity": 10,
"replay_buffer_replay_ratio": 0.0,
# Two GPUs -> 2 policies per GPU.
"num_gpus": 4,
"_fake_gpus": True,
# Test with KL loss, just to cover that extra code.
"use_kl_loss": True,
}
grad_clip=10.0,
replay_buffer_capacity=10,
replay_buffer_replay_ratio=0.0,
use_kl_loss=True,
)
.rollouts(num_rollout_workers=4, num_envs_per_worker=5)
.resources(num_gpus=4, _fake_gpus=True)
)
num_iterations = 2
for _ in framework_iterator(config, with_eager_tracing=True):
_config = config.copy()
trainer = alpha_star.AlphaStarTrainer(config=_config)
trainer = config.build()
for i in range(num_iterations):
results = trainer.train()
print(results)

View file

@ -59,7 +59,7 @@ logger = logging.getLogger(__name__)
class ImpalaConfig(TrainerConfig):
"""Defines an ARSTrainer configuration class from which an ImpalaTrainer can be built.
"""Defines a configuration class from which an ImpalaTrainer can be built.
Example:
>>> from ray.rllib.agents.impala import ImpalaConfig
@ -136,13 +136,10 @@ class ImpalaConfig(TrainerConfig):
self.num_gpus = 1
self.lr = 0.0005
self.min_time_s_per_reporting = 10
# IMPALA and APPO are not on the new training_iteration API yet.
self._disable_execution_plan_api = False
# __sphinx_doc_end__
# fmt: on
# Deprecated value.
self._disable_execution_plan_api = True
self.num_data_loader_buffers = DEPRECATED_VALUE
@override(TrainerConfig)