ray/rllib/algorithms/alpha_zero/alpha_zero.py
2022-05-23 08:18:44 +02:00

314 lines
11 KiB
Python

import logging
from typing import Type
from ray.rllib.agents import with_common_config
from ray.rllib.agents.callbacks import DefaultCallbacks
from ray.rllib.agents.trainer import Trainer
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.replay_ops import (
SimpleReplayBuffer,
Replay,
StoreToReplayBuffer,
WaitUntilTimestepsElapsed,
)
from ray.rllib.execution.rollout_ops import (
ParallelRollouts,
ConcatBatches,
synchronous_parallel_sample,
)
from ray.rllib.execution.concurrency_ops import Concurrently
from ray.rllib.execution.train_ops import (
multi_gpu_train_one_step,
train_one_step,
TrainOneStep,
)
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.modelv2 import restore_original_dimensions
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.metrics import (
NUM_AGENT_STEPS_SAMPLED,
NUM_ENV_STEPS_SAMPLED,
SYNCH_WORKER_WEIGHTS_TIMER,
)
from ray.rllib.utils.replay_buffers.utils import validate_buffer_config
from ray.rllib.utils.typing import ResultDict, TrainerConfigDict
from ray.util.iter import LocalIterator
from ray.rllib.algorithms.alpha_zero.alpha_zero_policy import AlphaZeroPolicy
from ray.rllib.algorithms.alpha_zero.mcts import MCTS
from ray.rllib.algorithms.alpha_zero.ranked_rewards import get_r2_env_wrapper
torch, nn = try_import_torch()
logger = logging.getLogger(__name__)
class AlphaZeroDefaultCallbacks(DefaultCallbacks):
"""AlphaZero callbacks.
If you use custom callbacks, you must extend this class and call super()
for on_episode_start.
"""
def on_episode_start(self, worker, base_env, policies, episode, **kwargs):
# save env state when an episode starts
env = base_env.get_sub_environments()[0]
state = env.get_state()
episode.user_data["initial_state"] = state
# fmt: off
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
# Size of batches collected from each worker
"rollout_fragment_length": 200,
# Number of timesteps collected for each SGD round
"train_batch_size": 4000,
# Total SGD batch size across all devices for SGD
"sgd_minibatch_size": 128,
# Whether to shuffle sequences in the batch when training (recommended)
"shuffle_sequences": True,
# Number of SGD iterations in each outer loop
"num_sgd_iter": 30,
# In case a buffer optimizer is used
"learning_starts": 1000,
# Size of the replay buffer in batches (not timesteps!).
"buffer_size": DEPRECATED_VALUE,
"replay_buffer_config": {
"_enable_replay_buffer_api": True,
"type": "SimpleReplayBuffer",
# Size of the replay buffer in batches (not timesteps!).
"capacity": 1000,
# When to start returning samples (in batches, not timesteps!).
"learning_starts": 500,
},
# Stepsize of SGD
"lr": 5e-5,
# Learning rate schedule
"lr_schedule": None,
# Share layers for value function. If you set this to True, it"s important
# to tune vf_loss_coeff.
"vf_share_layers": False,
# Whether to rollout "complete_episodes" or "truncate_episodes"
"batch_mode": "complete_episodes",
# Which observation filter to apply to the observation
"observation_filter": "NoFilter",
# === MCTS ===
"mcts_config": {
"puct_coefficient": 1.0,
"num_simulations": 30,
"temperature": 1.5,
"dirichlet_epsilon": 0.25,
"dirichlet_noise": 0.03,
"argmax_tree_policy": False,
"add_dirichlet_noise": True,
},
# === Ranked Rewards ===
# implement the ranked reward (r2) algorithm
# from: https://arxiv.org/pdf/1807.01672.pdf
"ranked_rewards": {
"enable": True,
"percentile": 75,
"buffer_max_length": 1000,
# add rewards obtained from random policy to
# "warm start" the buffer
"initialize_buffer": True,
"num_init_rewards": 100,
},
# === Evaluation ===
# Extra configuration that disables exploration.
"evaluation_config": {
"mcts_config": {
"argmax_tree_policy": True,
"add_dirichlet_noise": False,
},
},
# === Callbacks ===
"callbacks": AlphaZeroDefaultCallbacks,
"framework": "torch", # Only PyTorch supported so far.
})
# __sphinx_doc_end__
# fmt: on
def alpha_zero_loss(policy, model, dist_class, train_batch):
# get inputs unflattened inputs
input_dict = restore_original_dimensions(
train_batch["obs"], policy.observation_space, "torch"
)
# forward pass in model
model_out = model.forward(input_dict, None, [1])
logits, _ = model_out
values = model.value_function()
logits, values = torch.squeeze(logits), torch.squeeze(values)
priors = nn.Softmax(dim=-1)(logits)
# compute actor and critic losses
policy_loss = torch.mean(
-torch.sum(train_batch["mcts_policies"] * torch.log(priors), dim=-1)
)
value_loss = torch.mean(torch.pow(values - train_batch["value_label"], 2))
# compute total loss
total_loss = (policy_loss + value_loss) / 2
return total_loss, policy_loss, value_loss
class AlphaZeroPolicyWrapperClass(AlphaZeroPolicy):
def __init__(self, obs_space, action_space, config):
model = ModelCatalog.get_model_v2(
obs_space, action_space, action_space.n, config["model"], "torch"
)
_, env_creator = Trainer._get_env_id_and_creator(config["env"], config)
if config["ranked_rewards"]["enable"]:
# if r2 is enabled, tne env is wrapped to include a rewards buffer
# used to normalize rewards
env_cls = get_r2_env_wrapper(env_creator, config["ranked_rewards"])
# the wrapped env is used only in the mcts, not in the
# rollout workers
def _env_creator():
return env_cls(config["env_config"])
else:
def _env_creator():
return env_creator(config["env_config"])
def mcts_creator():
return MCTS(model, config["mcts_config"])
super().__init__(
obs_space,
action_space,
config,
model,
alpha_zero_loss,
TorchCategorical,
mcts_creator,
_env_creator,
)
class AlphaZeroTrainer(Trainer):
@classmethod
@override(Trainer)
def get_default_config(cls) -> TrainerConfigDict:
return DEFAULT_CONFIG
def validate_config(self, config: TrainerConfigDict) -> None:
"""Checks and updates the config based on settings."""
# Call super's validation method.
super().validate_config(config)
validate_buffer_config(config)
@override(Trainer)
def get_default_policy_class(self, config: TrainerConfigDict) -> Type[Policy]:
return AlphaZeroPolicyWrapperClass
@override(Trainer)
def training_iteration(self) -> ResultDict:
"""TODO:
Returns:
The results dict from executing the training iteration.
"""
# Sample n MultiAgentBatches from n workers.
new_sample_batches = synchronous_parallel_sample(
worker_set=self.workers, concat=False
)
for batch in new_sample_batches:
# Update sampling step counters.
self._counters[NUM_ENV_STEPS_SAMPLED] += batch.env_steps()
self._counters[NUM_AGENT_STEPS_SAMPLED] += batch.agent_steps()
# Store new samples in the replay buffer
# Use deprecated add_batch() to support old replay buffers for now
if self.local_replay_buffer is not None:
self.local_replay_buffer.add(batch)
if self.local_replay_buffer is not None:
train_batch = self.local_replay_buffer.sample(
self.config["train_batch_size"]
)
else:
train_batch = SampleBatch.concat_samples(new_sample_batches)
# Learn on the training batch.
# Use simple optimizer (only for multi-agent or tf-eager; all other
# cases should use the multi-GPU optimizer, even if only using 1 GPU)
train_results = {}
if train_batch is not None:
if self.config.get("simple_optimizer") is True:
train_results = train_one_step(self, train_batch)
else:
train_results = multi_gpu_train_one_step(self, train_batch)
# TODO: Move training steps counter update outside of `train_one_step()` method.
# # Update train step counters.
# self._counters[NUM_ENV_STEPS_TRAINED] += train_batch.env_steps()
# self._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps()
# Update weights and global_vars - after learning on the local worker - on all
# remote workers.
global_vars = {
"timestep": self._counters[NUM_ENV_STEPS_SAMPLED],
}
with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
self.workers.sync_weights(global_vars=global_vars)
# Return all collected metrics for the iteration.
return train_results
@staticmethod
@override(Trainer)
def execution_plan(
workers: WorkerSet, config: TrainerConfigDict, **kwargs
) -> LocalIterator[dict]:
assert (
len(kwargs) == 0
), "Alpha zero execution_plan does NOT take any additional parameters"
rollouts = ParallelRollouts(workers, mode="bulk_sync")
if config["simple_optimizer"]:
train_op = rollouts.combine(
ConcatBatches(
min_batch_size=config["train_batch_size"],
count_steps_by=config["multiagent"]["count_steps_by"],
)
).for_each(TrainOneStep(workers, num_sgd_iter=config["num_sgd_iter"]))
else:
replay_buffer = SimpleReplayBuffer(config["buffer_size"])
store_op = rollouts.for_each(
StoreToReplayBuffer(local_buffer=replay_buffer)
)
replay_op = (
Replay(local_buffer=replay_buffer)
.filter(WaitUntilTimestepsElapsed(config["learning_starts"]))
.combine(
ConcatBatches(
min_batch_size=config["train_batch_size"],
count_steps_by=config["multiagent"]["count_steps_by"],
)
)
.for_each(TrainOneStep(workers, num_sgd_iter=config["num_sgd_iter"]))
)
train_op = Concurrently(
[store_op, replay_op], mode="round_robin", output_indexes=[1]
)
return StandardMetricsReporting(train_op, workers, config)