mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[RLlib] AlphaStar config objects. (#24576)
This commit is contained in:
parent
d9b54d8bfa
commit
6d94b2acbe
5 changed files with 176 additions and 79 deletions
|
@ -1,6 +1,11 @@
|
||||||
from ray.rllib.agents.alpha_star.alpha_star import DEFAULT_CONFIG, AlphaStarTrainer
|
from ray.rllib.agents.alpha_star.alpha_star import (
|
||||||
|
AlphaStarConfig,
|
||||||
|
AlphaStarTrainer,
|
||||||
|
DEFAULT_CONFIG,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DEFAULT_CONFIG",
|
"AlphaStarConfig",
|
||||||
"AlphaStarTrainer",
|
"AlphaStarTrainer",
|
||||||
|
"DEFAULT_CONFIG",
|
||||||
]
|
]
|
||||||
|
|
|
@ -3,8 +3,8 @@ A multi-agent, distributed multi-GPU, league-capable asynch. PPO
|
||||||
================================================================
|
================================================================
|
||||||
"""
|
"""
|
||||||
import gym
|
import gym
|
||||||
from typing import Optional, Type
|
|
||||||
import tree
|
import tree
|
||||||
|
from typing import Any, Dict, Optional, Type
|
||||||
|
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
|
@ -19,6 +19,7 @@ from ray.rllib.execution.buffers.mixin_replay_buffer import MixInMultiAgentRepla
|
||||||
from ray.rllib.policy.policy import Policy, PolicySpec
|
from ray.rllib.policy.policy import Policy, PolicySpec
|
||||||
from ray.rllib.policy.sample_batch import MultiAgentBatch
|
from ray.rllib.policy.sample_batch import MultiAgentBatch
|
||||||
from ray.rllib.utils.annotations import override
|
from ray.rllib.utils.annotations import override
|
||||||
|
from ray.rllib.utils.deprecation import Deprecated
|
||||||
from ray.rllib.utils.from_config import from_config
|
from ray.rllib.utils.from_config import from_config
|
||||||
from ray.rllib.utils.metrics import (
|
from ray.rllib.utils.metrics import (
|
||||||
LAST_TARGET_UPDATE_TS,
|
LAST_TARGET_UPDATE_TS,
|
||||||
|
@ -42,36 +43,60 @@ from ray.rllib.utils.typing import (
|
||||||
from ray.tune.utils.placement_groups import PlacementGroupFactory
|
from ray.tune.utils.placement_groups import PlacementGroupFactory
|
||||||
from ray.util.timer import _Timer
|
from ray.util.timer import _Timer
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
# __sphinx_doc_begin__
|
|
||||||
|
|
||||||
# Adds the following updates to the `IMPALATrainer` config in
|
class AlphaStarConfig(appo.APPOConfig):
|
||||||
# rllib/agents/impala/impala.py.
|
"""Defines a configuration class from which an AlphaStarTrainer can be built.
|
||||||
DEFAULT_CONFIG = Trainer.merge_trainer_configs(
|
|
||||||
appo.DEFAULT_CONFIG, # See keys in appo.py, which are also supported.
|
|
||||||
{
|
|
||||||
# TODO: Unify the buffer API, then clean up our existing
|
|
||||||
# implementations of different buffers.
|
|
||||||
# This is num batches held at any time for each policy.
|
|
||||||
"replay_buffer_capacity": 20,
|
|
||||||
# e.g. ratio=0.2 -> 20% of samples in each train batch are
|
|
||||||
# old (replayed) ones.
|
|
||||||
"replay_buffer_replay_ratio": 0.5,
|
|
||||||
|
|
||||||
# Timeout to use for `ray.wait()` when waiting for samplers to have placed
|
Example:
|
||||||
# new data into the buffers. If no samples are ready within the timeout,
|
>>> from ray.rllib.agents.alpha_star import AlphaStarConfig
|
||||||
# the buffers used for mixin-sampling will return only older samples.
|
>>> config = AlphaStarConfig().training(lr=0.0003, train_batch_size=512)\
|
||||||
"sample_wait_timeout": 0.01,
|
... .resources(num_gpus=4)\
|
||||||
# Timeout to use for `ray.wait()` when waiting for the policy learner actors
|
... .rollouts(num_rollout_workers=64)
|
||||||
# to have performed an update and returned learning stats. If no learner
|
>>> print(config.to_dict())
|
||||||
# actors have produced any learning results in the meantime, their
|
>>> # Build a Trainer object from the config and run 1 training iteration.
|
||||||
# learner-stats in the results will be empty for that iteration.
|
>>> trainer = config.build(env="CartPole-v1")
|
||||||
"learn_wait_timeout": 0.1,
|
>>> trainer.train()
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from ray.rllib.agents.alpha_star import AlphaStarConfig
|
||||||
|
>>> from ray import tune
|
||||||
|
>>> config = AlphaStarConfig()
|
||||||
|
>>> # Print out some default values.
|
||||||
|
>>> print(config.vtrace)
|
||||||
|
>>> # Update the config object.
|
||||||
|
>>> config.training(lr=tune.grid_search([0.0001, 0.0003]), grad_clip=20.0)
|
||||||
|
>>> # Set the config object's env.
|
||||||
|
>>> config.environment(env="CartPole-v1")
|
||||||
|
>>> # Use to_dict() to get the old-style python config dict
|
||||||
|
>>> # when running with tune.
|
||||||
|
>>> tune.run(
|
||||||
|
... "AlphaStar",
|
||||||
|
... stop={"episode_reward_mean": 200},
|
||||||
|
... config=config.to_dict(),
|
||||||
|
... )
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, trainer_class=None):
|
||||||
|
"""Initializes a AlphaStarConfig instance."""
|
||||||
|
super().__init__(trainer_class=trainer_class or AlphaStarTrainer)
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
# __sphinx_doc_begin__
|
||||||
|
|
||||||
|
# AlphaStar specific settings:
|
||||||
|
self.replay_buffer_capacity = 20
|
||||||
|
self.replay_buffer_replay_ratio = 0.5
|
||||||
|
self.sample_wait_timeout = 0.01
|
||||||
|
self.learn_wait_timeout = 0.1
|
||||||
|
|
||||||
# League-building parameters.
|
# League-building parameters.
|
||||||
# The LeagueBuilder class to be used for league building logic.
|
# The LeagueBuilder class to be used for league building logic.
|
||||||
"league_builder_config": {
|
self.league_builder_config = {
|
||||||
|
# Specify the sub-class of the `LeagueBuilder` API to use.
|
||||||
"type": AlphaStarLeagueBuilder,
|
"type": AlphaStarLeagueBuilder,
|
||||||
|
|
||||||
|
# Any any number of constructor kwargs to pass to this class:
|
||||||
|
|
||||||
# The number of random policies to add to the league. This must be an
|
# The number of random policies to add to the league. This must be an
|
||||||
# even number (including 0) as these will be evenly distributed
|
# even number (including 0) as these will be evenly distributed
|
||||||
# amongst league- and main- exploiters.
|
# amongst league- and main- exploiters.
|
||||||
|
@ -100,28 +125,80 @@ DEFAULT_CONFIG = Trainer.merge_trainer_configs(
|
||||||
# Only for ME matches: Prob to play against learning
|
# Only for ME matches: Prob to play against learning
|
||||||
# main (vs a snapshot main).
|
# main (vs a snapshot main).
|
||||||
"prob_main_exploiter_playing_against_learning_main": 0.5,
|
"prob_main_exploiter_playing_against_learning_main": 0.5,
|
||||||
},
|
}
|
||||||
|
self.max_num_policies_to_train = None
|
||||||
|
|
||||||
# The maximum number of trainable policies for this Trainer.
|
# Override some of APPOConfig's default values with AlphaStar-specific
|
||||||
# Each trainable policy will exist as a independent remote actor, co-located
|
# values.
|
||||||
# with a replay buffer. This is besides its existence inside
|
self.vtrace_drop_last_ts = False
|
||||||
# the RolloutWorkers for training and evaluation.
|
self.min_time_s_per_reporting = 2
|
||||||
# Set to None for automatically inferring this value from the number of
|
# __sphinx_doc_end__
|
||||||
# trainable policies found in the `multiagent` config.
|
# fmt: on
|
||||||
"max_num_policies_to_train": None,
|
|
||||||
|
|
||||||
# By default, don't drop last timestep.
|
@override(appo.APPOConfig)
|
||||||
# TODO: We should do the same for IMPALA and APPO at some point.
|
def training(
|
||||||
"vtrace_drop_last_ts": False,
|
self,
|
||||||
|
*,
|
||||||
|
replay_buffer_capacity: Optional[int] = None,
|
||||||
|
replay_buffer_replay_ratio: Optional[float] = None,
|
||||||
|
sample_wait_timeout: Optional[float] = None,
|
||||||
|
learn_wait_timeout: Optional[float] = None,
|
||||||
|
league_builder_config: Optional[Dict[str, Any]] = None,
|
||||||
|
max_num_policies_to_train: Optional[int] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> "AlphaStarConfig":
|
||||||
|
"""Sets the training related configuration.
|
||||||
|
|
||||||
# Reporting interval.
|
Args:
|
||||||
"min_time_s_per_reporting": 2,
|
replay_buffer_capacity: This is num batches held at any time for each
|
||||||
},
|
policy.
|
||||||
_allow_unknown_configs=True,
|
replay_buffer_replay_ratio: For example, ratio=0.2 -> 20% of samples in
|
||||||
)
|
each train batch are old (replayed) ones.
|
||||||
|
sample_wait_timeout: Timeout to use for `ray.wait()` when waiting for
|
||||||
|
samplers to have placed new data into the buffers. If no samples are
|
||||||
|
ready within the timeout, the buffers used for mixin-sampling will
|
||||||
|
return only older samples.
|
||||||
|
learn_wait_timeout: Timeout to use for `ray.wait()` when waiting for the
|
||||||
|
policy learner actors to have performed an update and returned learning
|
||||||
|
stats. If no learner actors have produced any learning results in the
|
||||||
|
meantime, their learner-stats in the results will be empty for that
|
||||||
|
iteration.
|
||||||
|
league_builder_config: League-building config dict.
|
||||||
|
The dict Must contain a `type` key indicating the LeagueBuilder class
|
||||||
|
to be used for league building logic. All other keys (that are not
|
||||||
|
`type`) will be used as constructor kwargs on the given class to
|
||||||
|
construct the LeagueBuilder instance. See the
|
||||||
|
`ray.rllib.agents.alpha_star.league_builder::AlphaStarLeagueBuilder`
|
||||||
|
(used by default by this algo) as an example.
|
||||||
|
max_num_policies_to_train: The maximum number of trainable policies for this
|
||||||
|
Trainer. Each trainable policy will exist as a independent remote actor,
|
||||||
|
co-located with a replay buffer. This is besides its existence inside
|
||||||
|
the RolloutWorkers for training and evaluation. Set to None for
|
||||||
|
automatically inferring this value from the number of trainable
|
||||||
|
policies found in the `multiagent` config.
|
||||||
|
|
||||||
# __sphinx_doc_end__
|
Returns:
|
||||||
# fmt: on
|
This updated TrainerConfig object.
|
||||||
|
"""
|
||||||
|
# Pass kwargs onto super's `training()` method.
|
||||||
|
super().training(**kwargs)
|
||||||
|
|
||||||
|
# TODO: Unify the buffer API, then clean up our existing
|
||||||
|
# implementations of different buffers.
|
||||||
|
if replay_buffer_capacity is not None:
|
||||||
|
self.replay_buffer_capacity = replay_buffer_capacity
|
||||||
|
if replay_buffer_replay_ratio is not None:
|
||||||
|
self.replay_buffer_replay_ratio = replay_buffer_replay_ratio
|
||||||
|
if sample_wait_timeout is not None:
|
||||||
|
self.sample_wait_timeout = sample_wait_timeout
|
||||||
|
if learn_wait_timeout is not None:
|
||||||
|
self.learn_wait_timeout = learn_wait_timeout
|
||||||
|
if league_builder_config is not None:
|
||||||
|
self.league_builder_config = league_builder_config
|
||||||
|
if max_num_policies_to_train is not None:
|
||||||
|
self.max_num_policies_to_train = max_num_policies_to_train
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class AlphaStarTrainer(appo.APPOTrainer):
|
class AlphaStarTrainer(appo.APPOTrainer):
|
||||||
|
@ -204,7 +281,7 @@ class AlphaStarTrainer(appo.APPOTrainer):
|
||||||
@classmethod
|
@classmethod
|
||||||
@override(appo.APPOTrainer)
|
@override(appo.APPOTrainer)
|
||||||
def get_default_config(cls) -> TrainerConfigDict:
|
def get_default_config(cls) -> TrainerConfigDict:
|
||||||
return DEFAULT_CONFIG
|
return AlphaStarConfig().to_dict()
|
||||||
|
|
||||||
@override(appo.APPOTrainer)
|
@override(appo.APPOTrainer)
|
||||||
def validate_config(self, config: TrainerConfigDict):
|
def validate_config(self, config: TrainerConfigDict):
|
||||||
|
@ -499,3 +576,20 @@ class AlphaStarTrainer(appo.APPOTrainer):
|
||||||
state_copy = state.copy()
|
state_copy = state.copy()
|
||||||
self.league_builder.__setstate__(state.pop("league_builder", {}))
|
self.league_builder.__setstate__(state.pop("league_builder", {}))
|
||||||
super().__setstate__(state_copy)
|
super().__setstate__(state_copy)
|
||||||
|
|
||||||
|
|
||||||
|
# Deprecated: Use ray.rllib.agents.ppo.PPOConfig instead!
|
||||||
|
class _deprecated_default_config(dict):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(AlphaStarConfig().to_dict())
|
||||||
|
|
||||||
|
@Deprecated(
|
||||||
|
old="ray.rllib.agents.alpha_star.alpha_star.DEFAULT_CONFIG",
|
||||||
|
new="ray.rllib.agents.alpha_star.alpha_star.AlphaStarConfig(...)",
|
||||||
|
error=False,
|
||||||
|
)
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return super().__getitem__(item)
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_CONFIG = _deprecated_default_config()
|
||||||
|
|
|
@ -80,9 +80,15 @@ class AlphaStarLeagueBuilder(LeagueBuilder):
|
||||||
):
|
):
|
||||||
"""Initializes a AlphaStarLeagueBuilder instance.
|
"""Initializes a AlphaStarLeagueBuilder instance.
|
||||||
|
|
||||||
|
The following match types are possible:
|
||||||
|
LE: A learning (not snapshot) league_exploiter vs any snapshot policy.
|
||||||
|
ME: A learning (not snapshot) main exploiter vs any main.
|
||||||
|
M: Main self-play (main vs main).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
trainer: The Trainer object by which this league builder is used.
|
trainer: The Trainer object by which this league builder is used.
|
||||||
Trainer calls `build_league()` after each training step.
|
Trainer calls `build_league()` after each training step to reconfigure
|
||||||
|
the league structure (e.g. to add/remove policies).
|
||||||
trainer_config: The (not yet validated) config dict to be
|
trainer_config: The (not yet validated) config dict to be
|
||||||
used on the Trainer. Child classes of `LeagueBuilder`
|
used on the Trainer. Child classes of `LeagueBuilder`
|
||||||
should preprocess this to add e.g. multiagent settings
|
should preprocess this to add e.g. multiagent settings
|
||||||
|
|
|
@ -26,38 +26,33 @@ class TestAlphaStar(unittest.TestCase):
|
||||||
|
|
||||||
def test_alpha_star_compilation(self):
|
def test_alpha_star_compilation(self):
|
||||||
"""Test whether a AlphaStarTrainer can be built with all frameworks."""
|
"""Test whether a AlphaStarTrainer can be built with all frameworks."""
|
||||||
|
config = (
|
||||||
config = {
|
alpha_star.AlphaStarConfig()
|
||||||
"env": "connect_four",
|
.environment(env="connect_four")
|
||||||
"gamma": 1.0,
|
.training(
|
||||||
"num_workers": 4,
|
gamma=1.0,
|
||||||
"num_envs_per_worker": 5,
|
model={"fcnet_hiddens": [256, 256, 256]},
|
||||||
"model": {
|
vf_loss_coeff=0.01,
|
||||||
"fcnet_hiddens": [256, 256, 256],
|
entropy_coeff=0.004,
|
||||||
},
|
league_builder_config={
|
||||||
"vf_loss_coeff": 0.01,
|
"win_rate_threshold_for_new_snapshot": 0.8,
|
||||||
"entropy_coeff": 0.004,
|
"num_random_policies": 2,
|
||||||
"league_builder_config": {
|
"num_learning_league_exploiters": 1,
|
||||||
"win_rate_threshold_for_new_snapshot": 0.8,
|
"num_learning_main_exploiters": 1,
|
||||||
"num_random_policies": 2,
|
},
|
||||||
"num_learning_league_exploiters": 1,
|
grad_clip=10.0,
|
||||||
"num_learning_main_exploiters": 1,
|
replay_buffer_capacity=10,
|
||||||
},
|
replay_buffer_replay_ratio=0.0,
|
||||||
"grad_clip": 10.0,
|
use_kl_loss=True,
|
||||||
"replay_buffer_capacity": 10,
|
)
|
||||||
"replay_buffer_replay_ratio": 0.0,
|
.rollouts(num_rollout_workers=4, num_envs_per_worker=5)
|
||||||
# Two GPUs -> 2 policies per GPU.
|
.resources(num_gpus=4, _fake_gpus=True)
|
||||||
"num_gpus": 4,
|
)
|
||||||
"_fake_gpus": True,
|
|
||||||
# Test with KL loss, just to cover that extra code.
|
|
||||||
"use_kl_loss": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
num_iterations = 2
|
num_iterations = 2
|
||||||
|
|
||||||
for _ in framework_iterator(config, with_eager_tracing=True):
|
for _ in framework_iterator(config, with_eager_tracing=True):
|
||||||
_config = config.copy()
|
trainer = config.build()
|
||||||
trainer = alpha_star.AlphaStarTrainer(config=_config)
|
|
||||||
for i in range(num_iterations):
|
for i in range(num_iterations):
|
||||||
results = trainer.train()
|
results = trainer.train()
|
||||||
print(results)
|
print(results)
|
||||||
|
|
|
@ -59,7 +59,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ImpalaConfig(TrainerConfig):
|
class ImpalaConfig(TrainerConfig):
|
||||||
"""Defines an ARSTrainer configuration class from which an ImpalaTrainer can be built.
|
"""Defines a configuration class from which an ImpalaTrainer can be built.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> from ray.rllib.agents.impala import ImpalaConfig
|
>>> from ray.rllib.agents.impala import ImpalaConfig
|
||||||
|
@ -136,13 +136,10 @@ class ImpalaConfig(TrainerConfig):
|
||||||
self.num_gpus = 1
|
self.num_gpus = 1
|
||||||
self.lr = 0.0005
|
self.lr = 0.0005
|
||||||
self.min_time_s_per_reporting = 10
|
self.min_time_s_per_reporting = 10
|
||||||
# IMPALA and APPO are not on the new training_iteration API yet.
|
|
||||||
self._disable_execution_plan_api = False
|
|
||||||
# __sphinx_doc_end__
|
# __sphinx_doc_end__
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
# Deprecated value.
|
# Deprecated value.
|
||||||
self._disable_execution_plan_api = True
|
|
||||||
self.num_data_loader_buffers = DEPRECATED_VALUE
|
self.num_data_loader_buffers = DEPRECATED_VALUE
|
||||||
|
|
||||||
@override(TrainerConfig)
|
@override(TrainerConfig)
|
||||||
|
|
Loading…
Add table
Reference in a new issue