ray/rllib/algorithms/alpha_zero/alpha_zero.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

438 lines
17 KiB
Python
Raw Normal View History

import logging
from typing import List, Optional, Type, Union
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.replay_ops import (
SimpleReplayBuffer,
Replay,
StoreToReplayBuffer,
WaitUntilTimestepsElapsed,
)
from ray.rllib.execution.rollout_ops import (
ParallelRollouts,
ConcatBatches,
synchronous_parallel_sample,
)
from ray.rllib.execution.concurrency_ops import Concurrently
from ray.rllib.execution.train_ops import (
multi_gpu_train_one_step,
train_one_step,
TrainOneStep,
)
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.modelv2 import restore_original_dimensions
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import concat_samples
from ray.rllib.utils.annotations import Deprecated, override
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.metrics import (
NUM_AGENT_STEPS_SAMPLED,
NUM_ENV_STEPS_SAMPLED,
SYNCH_WORKER_WEIGHTS_TIMER,
)
from ray.rllib.utils.replay_buffers.utils import validate_buffer_config
from ray.rllib.utils.typing import ResultDict, AlgorithmConfigDict
from ray.util.iter import LocalIterator
from ray.rllib.algorithms.alpha_zero.alpha_zero_policy import AlphaZeroPolicy
from ray.rllib.algorithms.alpha_zero.mcts import MCTS
from ray.rllib.algorithms.alpha_zero.ranked_rewards import get_r2_env_wrapper
torch, nn = try_import_torch()
logger = logging.getLogger(__name__)
class AlphaZeroDefaultCallbacks(DefaultCallbacks):
"""AlphaZero callbacks.
If you use custom callbacks, you must extend this class and call super()
for on_episode_start.
"""
def on_episode_start(self, worker, base_env, policies, episode, **kwargs):
# save env state when an episode starts
env = base_env.get_sub_environments()[0]
state = env.get_state()
episode.user_data["initial_state"] = state
class AlphaZeroConfig(AlgorithmConfig):
"""Defines a configuration class from which an AlphaZero Algorithm can be built.
Example:
>>> from ray.rllib.algorithms.alpha_zero import AlphaZeroConfig
>>> config = AlphaZeroConfig().training(sgd_minibatch_size=256)\
... .resources(num_gpus=0)\
... .rollouts(num_workers=4)
>>> print(config.to_dict())
>>> # Build a Algorithm object from the config and run 1 training iteration.
>>> trainer = config.build(env="CartPole-v1")
>>> trainer.train()
Example:
>>> from ray.rllib.algorithms.alpha_zero import AlphaZeroConfig
>>> from ray import tune
>>> config = AlphaZeroConfig()
>>> # Print out some default values.
>>> print(config.shuffle_sequences)
>>> # Update the config object.
>>> config.training(lr=tune.grid_search([0.001, 0.0001]))
>>> # Set the config object's env.
>>> config.environment(env="CartPole-v1")
>>> # Use to_dict() to get the old-style python config dict
>>> # when running with tune.
>>> tune.run(
... "AlphaZero",
... stop={"episode_reward_mean": 200},
... config=config.to_dict(),
... )
"""
def __init__(self, algo_class=None):
"""Initializes a PPOConfig instance."""
super().__init__(algo_class=algo_class or AlphaZero)
# fmt: off
# __sphinx_doc_begin__
# AlphaZero specific config settings:
self.sgd_minibatch_size = 128
self.shuffle_sequences = True
self.num_sgd_iter = 30
self.learning_starts = 1000
self.replay_buffer_config = {
"type": "ReplayBuffer",
# Size of the replay buffer in batches (not timesteps!).
"capacity": 1000,
# When to start returning samples (in batches, not timesteps!).
"learning_starts": 500,
# Choosing `fragments` here makes it so that the buffer stores entire
# batches, instead of sequences, episodes or timesteps.
"storage_unit": "fragments",
}
self.lr_schedule = None
self.vf_share_layers = False
self.mcts_config = {
"puct_coefficient": 1.0,
"num_simulations": 30,
"temperature": 1.5,
"dirichlet_epsilon": 0.25,
"dirichlet_noise": 0.03,
"argmax_tree_policy": False,
"add_dirichlet_noise": True,
}
self.ranked_rewards = {
"enable": True,
"percentile": 75,
"buffer_max_length": 1000,
# add rewards obtained from random policy to
# "warm start" the buffer
"initialize_buffer": True,
"num_init_rewards": 100,
}
# Override some of AlgorithmConfig's default values with AlphaZero-specific
# values.
self.framework_str = "torch"
self.callbacks_class = AlphaZeroDefaultCallbacks
self.lr = 5e-5
self.rollout_fragment_length = 200
self.train_batch_size = 4000
self.batch_mode = "complete_episodes"
# Extra configuration that disables exploration.
self.evaluation_config = {
"mcts_config": {
"argmax_tree_policy": True,
"add_dirichlet_noise": False,
},
}
# __sphinx_doc_end__
# fmt: on
self.buffer_size = DEPRECATED_VALUE
@override(AlgorithmConfig)
def training(
self,
*,
sgd_minibatch_size: Optional[int] = None,
shuffle_sequences: Optional[bool] = None,
num_sgd_iter: Optional[int] = None,
replay_buffer_config: Optional[dict] = None,
lr_schedule: Optional[List[List[Union[int, float]]]] = None,
vf_share_layers: Optional[bool] = None,
mcts_config: Optional[dict] = None,
ranked_rewards: Optional[dict] = None,
**kwargs,
) -> "AlphaZeroConfig":
"""Sets the training related configuration.
Args:
sgd_minibatch_size: Total SGD batch size across all devices for SGD.
shuffle_sequences: Whether to shuffle sequences in the batch when training
(recommended).
num_sgd_iter: Number of SGD iterations in each outer loop.
replay_buffer_config: Replay buffer config.
Examples:
{
"_enable_replay_buffer_api": True,
"type": "MultiAgentReplayBuffer",
"learning_starts": 1000,
"capacity": 50000,
"replay_sequence_length": 1,
}
- OR -
{
"_enable_replay_buffer_api": True,
"type": "MultiAgentPrioritizedReplayBuffer",
"capacity": 50000,
"prioritized_replay_alpha": 0.6,
"prioritized_replay_beta": 0.4,
"prioritized_replay_eps": 1e-6,
"replay_sequence_length": 1,
}
- Where -
prioritized_replay_alpha: Alpha parameter controls the degree of
prioritization in the buffer. In other words, when a buffer sample has
a higher temporal-difference error, with how much more probability
should it drawn to use to update the parametrized Q-network. 0.0
corresponds to uniform probability. Setting much above 1.0 may quickly
result as the sampling distribution could become heavily pointy with
low entropy.
prioritized_replay_beta: Beta parameter controls the degree of
importance sampling which suppresses the influence of gradient updates
from samples that have higher probability of being sampled via alpha
parameter and the temporal-difference error.
prioritized_replay_eps: Epsilon parameter sets the baseline probability
for sampling so that when the temporal-difference error of a sample is
zero, there is still a chance of drawing the sample.
lr_schedule: Learning rate schedule. In the format of
[[timestep, lr-value], [timestep, lr-value], ...]
Intermediary timesteps will be assigned to interpolated learning rate
values. A schedule should normally start from timestep 0.
vf_share_layers: Share layers for value function. If you set this to True,
it's important to tune vf_loss_coeff.
mcts_config: MCTS specific settings.
ranked_rewards: Settings for the ranked reward (r2) algorithm
from: https://arxiv.org/pdf/1807.01672.pdf
Returns:
This updated AlgorithmConfig object.
"""
# Pass kwargs onto super's `training()` method.
super().training(**kwargs)
if sgd_minibatch_size is not None:
self.sgd_minibatch_size = sgd_minibatch_size
if shuffle_sequences is not None:
self.shuffle_sequences = shuffle_sequences
if num_sgd_iter is not None:
self.num_sgd_iter = num_sgd_iter
if replay_buffer_config is not None:
self.replay_buffer_config = replay_buffer_config
if lr_schedule is not None:
self.lr_schedule = lr_schedule
if vf_share_layers is not None:
self.vf_share_layers = vf_share_layers
if mcts_config is not None:
self.mcts_config = mcts_config
if ranked_rewards is not None:
self.ranked_rewards = ranked_rewards
return self
def alpha_zero_loss(policy, model, dist_class, train_batch):
# get inputs unflattened inputs
input_dict = restore_original_dimensions(
train_batch["obs"], policy.observation_space, "torch"
)
# forward pass in model
model_out = model.forward(input_dict, None, [1])
logits, _ = model_out
values = model.value_function()
logits, values = torch.squeeze(logits), torch.squeeze(values)
priors = nn.Softmax(dim=-1)(logits)
# compute actor and critic losses
policy_loss = torch.mean(
-torch.sum(train_batch["mcts_policies"] * torch.log(priors), dim=-1)
)
value_loss = torch.mean(torch.pow(values - train_batch["value_label"], 2))
# compute total loss
total_loss = (policy_loss + value_loss) / 2
return total_loss, policy_loss, value_loss
class AlphaZeroPolicyWrapperClass(AlphaZeroPolicy):
def __init__(self, obs_space, action_space, config):
model = ModelCatalog.get_model_v2(
obs_space, action_space, action_space.n, config["model"], "torch"
)
_, env_creator = Algorithm._get_env_id_and_creator(config["env"], config)
if config["ranked_rewards"]["enable"]:
# if r2 is enabled, tne env is wrapped to include a rewards buffer
# used to normalize rewards
env_cls = get_r2_env_wrapper(env_creator, config["ranked_rewards"])
# the wrapped env is used only in the mcts, not in the
# rollout workers
def _env_creator():
return env_cls(config["env_config"])
else:
def _env_creator():
return env_creator(config["env_config"])
def mcts_creator():
return MCTS(model, config["mcts_config"])
super().__init__(
obs_space,
action_space,
config,
model,
alpha_zero_loss,
TorchCategorical,
mcts_creator,
_env_creator,
)
class AlphaZero(Algorithm):
@classmethod
@override(Algorithm)
def get_default_config(cls) -> AlgorithmConfigDict:
return AlphaZeroConfig().to_dict()
def validate_config(self, config: AlgorithmConfigDict) -> None:
"""Checks and updates the config based on settings."""
# Call super's validation method.
super().validate_config(config)
validate_buffer_config(config)
@override(Algorithm)
def get_default_policy_class(self, config: AlgorithmConfigDict) -> Type[Policy]:
return AlphaZeroPolicyWrapperClass
@override(Algorithm)
def training_step(self) -> ResultDict:
"""TODO:
Returns:
The results dict from executing the training iteration.
"""
# Sample n MultiAgentBatches from n workers.
new_sample_batches = synchronous_parallel_sample(
worker_set=self.workers, concat=False
)
for batch in new_sample_batches:
# Update sampling step counters.
self._counters[NUM_ENV_STEPS_SAMPLED] += batch.env_steps()
self._counters[NUM_AGENT_STEPS_SAMPLED] += batch.agent_steps()
# Store new samples in the replay buffer
# Use deprecated add_batch() to support old replay buffers for now
if self.local_replay_buffer is not None:
self.local_replay_buffer.add(batch)
if self.local_replay_buffer is not None:
train_batch = self.local_replay_buffer.sample(
self.config["train_batch_size"]
)
else:
train_batch = concat_samples(new_sample_batches)
# Learn on the training batch.
# Use simple optimizer (only for multi-agent or tf-eager; all other
# cases should use the multi-GPU optimizer, even if only using 1 GPU)
train_results = {}
if train_batch is not None:
if self.config.get("simple_optimizer") is True:
train_results = train_one_step(self, train_batch)
else:
train_results = multi_gpu_train_one_step(self, train_batch)
# TODO: Move training steps counter update outside of `train_one_step()` method.
# # Update train step counters.
# self._counters[NUM_ENV_STEPS_TRAINED] += train_batch.env_steps()
# self._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps()
# Update weights and global_vars - after learning on the local worker - on all
# remote workers.
global_vars = {
"timestep": self._counters[NUM_ENV_STEPS_SAMPLED],
}
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
self.workers.sync_weights(global_vars=global_vars)
# Return all collected metrics for the iteration.
return train_results
@staticmethod
@override(Algorithm)
def execution_plan(
workers: WorkerSet, config: AlgorithmConfigDict, **kwargs
) -> LocalIterator[dict]:
assert (
len(kwargs) == 0
), "Alpha zero execution_plan does NOT take any additional parameters"
rollouts = ParallelRollouts(workers, mode="bulk_sync")
if config["simple_optimizer"]:
train_op = rollouts.combine(
ConcatBatches(
min_batch_size=config["train_batch_size"],
count_steps_by=config["multiagent"]["count_steps_by"],
)
).for_each(TrainOneStep(workers, num_sgd_iter=config["num_sgd_iter"]))
else:
replay_buffer = SimpleReplayBuffer(config["buffer_size"])
store_op = rollouts.for_each(
StoreToReplayBuffer(local_buffer=replay_buffer)
)
replay_op = (
Replay(local_buffer=replay_buffer)
.filter(WaitUntilTimestepsElapsed(config["learning_starts"]))
.combine(
ConcatBatches(
min_batch_size=config["train_batch_size"],
count_steps_by=config["multiagent"]["count_steps_by"],
)
)
.for_each(TrainOneStep(workers, num_sgd_iter=config["num_sgd_iter"]))
)
train_op = Concurrently(
[store_op, replay_op], mode="round_robin", output_indexes=[1]
)
return StandardMetricsReporting(train_op, workers, config)
# Deprecated: Use ray.rllib.algorithms.alpha_zero.AlphaZeroConfig instead!
class _deprecated_default_config(dict):
def __init__(self):
super().__init__(AlphaZeroConfig().to_dict())
@Deprecated(
old="ray.rllib.algorithms.alpha_zero.alpha_zero.DEFAULT_CONFIG",
new="ray.rllib.algorithms.alpha_zero.alpha_zero.AlphaZeroConfig(...)",
error=False,
)
def __getitem__(self, item):
return super().__getitem__(item)
DEFAULT_CONFIG = _deprecated_default_config()