ray/rllib/algorithms/alpha_zero/alpha_zero_policy.py

158 lines
5.6 KiB
Python

import numpy as np
from ray.rllib.algorithms.alpha_zero.mcts import Node, RootParentNode
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.torch_policy import TorchPolicy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
torch, _ = try_import_torch()
class AlphaZeroPolicy(TorchPolicy):
def __init__(
self,
observation_space,
action_space,
config,
model,
loss,
action_distribution_class,
mcts_creator,
env_creator,
**kwargs
):
super().__init__(
observation_space,
action_space,
config,
model=model,
loss=loss,
action_distribution_class=action_distribution_class,
)
# we maintain an env copy in the policy that is used during mcts
# simulations
self.env_creator = env_creator
self.mcts = mcts_creator()
self.env = self.env_creator()
self.env.reset()
self.obs_space = observation_space
@override(TorchPolicy)
def compute_actions(
self,
obs_batch,
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None,
info_batch=None,
episodes=None,
**kwargs
):
input_dict = {"obs": obs_batch}
if prev_action_batch is not None:
input_dict["prev_actions"] = prev_action_batch
if prev_reward_batch is not None:
input_dict["prev_rewards"] = prev_reward_batch
return self.compute_actions_from_input_dict(
input_dict=input_dict,
episodes=episodes,
state_batches=state_batches,
)
@override(Policy)
def compute_actions_from_input_dict(
self, input_dict, explore=None, timestep=None, episodes=None, **kwargs
):
with torch.no_grad():
actions = []
for i, episode in enumerate(episodes):
if episode.length == 0:
# if first time step of episode, get initial env state
env_state = episode.user_data["initial_state"]
# verify if env has been wrapped for ranked rewards
if self.env.__class__.__name__ == "RankedRewardsEnvWrapper":
# r2 env state contains also the rewards buffer state
env_state = {"env_state": env_state, "buffer_state": None}
# create tree root node
obs = self.env.set_state(env_state)
tree_node = Node(
state=env_state,
obs=obs,
reward=0,
done=False,
action=None,
parent=RootParentNode(env=self.env),
mcts=self.mcts,
)
else:
# otherwise get last root node from previous time step
tree_node = episode.user_data["tree_node"]
# run monte carlo simulations to compute the actions
# and record the tree
mcts_policy, action, tree_node = self.mcts.compute_action(tree_node)
# record action
actions.append(action)
# store new node
episode.user_data["tree_node"] = tree_node
# store mcts policies vectors and current tree root node
if episode.length == 0:
episode.user_data["mcts_policies"] = [mcts_policy]
else:
episode.user_data["mcts_policies"].append(mcts_policy)
return (
np.array(actions),
[],
self.extra_action_out(
input_dict, kwargs.get("state_batches", []), self.model, None
),
)
@override(Policy)
def postprocess_trajectory(
self, sample_batch, other_agent_batches=None, episode=None
):
# add mcts policies to sample batch
sample_batch["mcts_policies"] = np.array(episode.user_data["mcts_policies"])[
sample_batch["t"]
]
# final episode reward corresponds to the value (if not discounted)
# for all transitions in episode
final_reward = sample_batch["rewards"][-1]
# if r2 is enabled, then add the reward to the buffer and normalize it
if self.env.__class__.__name__ == "RankedRewardsEnvWrapper":
self.env.r2_buffer.add_reward(final_reward)
final_reward = self.env.r2_buffer.normalize(final_reward)
sample_batch["value_label"] = final_reward * np.ones_like(sample_batch["t"])
return sample_batch
@override(TorchPolicy)
def learn_on_batch(self, postprocessed_batch):
train_batch = self._lazy_tensor_dict(postprocessed_batch)
loss_out, policy_loss, value_loss = self._loss(
self, self.model, self.dist_class, train_batch
)
self._optimizers[0].zero_grad()
loss_out.backward()
grad_process_info = self.extra_grad_process(self._optimizers[0], loss_out)
self._optimizers[0].step()
grad_info = self.extra_grad_info(train_batch)
grad_info.update(grad_process_info)
grad_info.update(
{
"total_loss": loss_out.detach().cpu().numpy(),
"policy_loss": policy_loss.detach().cpu().numpy(),
"value_loss": value_loss.detach().cpu().numpy(),
}
)
return {LEARNER_STATS_KEY: grad_info}