mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
158 lines
5.6 KiB
Python
158 lines
5.6 KiB
Python
import numpy as np
|
|
|
|
from ray.rllib.algorithms.alpha_zero.mcts import Node, RootParentNode
|
|
from ray.rllib.policy.policy import Policy
|
|
from ray.rllib.policy.torch_policy import TorchPolicy
|
|
from ray.rllib.utils.annotations import override
|
|
from ray.rllib.utils.framework import try_import_torch
|
|
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
|
|
|
|
torch, _ = try_import_torch()
|
|
|
|
|
|
class AlphaZeroPolicy(TorchPolicy):
|
|
def __init__(
|
|
self,
|
|
observation_space,
|
|
action_space,
|
|
config,
|
|
model,
|
|
loss,
|
|
action_distribution_class,
|
|
mcts_creator,
|
|
env_creator,
|
|
**kwargs
|
|
):
|
|
super().__init__(
|
|
observation_space,
|
|
action_space,
|
|
config,
|
|
model=model,
|
|
loss=loss,
|
|
action_distribution_class=action_distribution_class,
|
|
)
|
|
# we maintain an env copy in the policy that is used during mcts
|
|
# simulations
|
|
self.env_creator = env_creator
|
|
self.mcts = mcts_creator()
|
|
self.env = self.env_creator()
|
|
self.env.reset()
|
|
self.obs_space = observation_space
|
|
|
|
@override(TorchPolicy)
|
|
def compute_actions(
|
|
self,
|
|
obs_batch,
|
|
state_batches=None,
|
|
prev_action_batch=None,
|
|
prev_reward_batch=None,
|
|
info_batch=None,
|
|
episodes=None,
|
|
**kwargs
|
|
):
|
|
|
|
input_dict = {"obs": obs_batch}
|
|
if prev_action_batch is not None:
|
|
input_dict["prev_actions"] = prev_action_batch
|
|
if prev_reward_batch is not None:
|
|
input_dict["prev_rewards"] = prev_reward_batch
|
|
|
|
return self.compute_actions_from_input_dict(
|
|
input_dict=input_dict,
|
|
episodes=episodes,
|
|
state_batches=state_batches,
|
|
)
|
|
|
|
@override(Policy)
|
|
def compute_actions_from_input_dict(
|
|
self, input_dict, explore=None, timestep=None, episodes=None, **kwargs
|
|
):
|
|
with torch.no_grad():
|
|
actions = []
|
|
for i, episode in enumerate(episodes):
|
|
if episode.length == 0:
|
|
# if first time step of episode, get initial env state
|
|
env_state = episode.user_data["initial_state"]
|
|
# verify if env has been wrapped for ranked rewards
|
|
if self.env.__class__.__name__ == "RankedRewardsEnvWrapper":
|
|
# r2 env state contains also the rewards buffer state
|
|
env_state = {"env_state": env_state, "buffer_state": None}
|
|
# create tree root node
|
|
obs = self.env.set_state(env_state)
|
|
tree_node = Node(
|
|
state=env_state,
|
|
obs=obs,
|
|
reward=0,
|
|
done=False,
|
|
action=None,
|
|
parent=RootParentNode(env=self.env),
|
|
mcts=self.mcts,
|
|
)
|
|
else:
|
|
# otherwise get last root node from previous time step
|
|
tree_node = episode.user_data["tree_node"]
|
|
|
|
# run monte carlo simulations to compute the actions
|
|
# and record the tree
|
|
mcts_policy, action, tree_node = self.mcts.compute_action(tree_node)
|
|
# record action
|
|
actions.append(action)
|
|
# store new node
|
|
episode.user_data["tree_node"] = tree_node
|
|
|
|
# store mcts policies vectors and current tree root node
|
|
if episode.length == 0:
|
|
episode.user_data["mcts_policies"] = [mcts_policy]
|
|
else:
|
|
episode.user_data["mcts_policies"].append(mcts_policy)
|
|
|
|
return (
|
|
np.array(actions),
|
|
[],
|
|
self.extra_action_out(
|
|
input_dict, kwargs.get("state_batches", []), self.model, None
|
|
),
|
|
)
|
|
|
|
@override(Policy)
|
|
def postprocess_trajectory(
|
|
self, sample_batch, other_agent_batches=None, episode=None
|
|
):
|
|
# add mcts policies to sample batch
|
|
sample_batch["mcts_policies"] = np.array(episode.user_data["mcts_policies"])[
|
|
sample_batch["t"]
|
|
]
|
|
# final episode reward corresponds to the value (if not discounted)
|
|
# for all transitions in episode
|
|
final_reward = sample_batch["rewards"][-1]
|
|
# if r2 is enabled, then add the reward to the buffer and normalize it
|
|
if self.env.__class__.__name__ == "RankedRewardsEnvWrapper":
|
|
self.env.r2_buffer.add_reward(final_reward)
|
|
final_reward = self.env.r2_buffer.normalize(final_reward)
|
|
sample_batch["value_label"] = final_reward * np.ones_like(sample_batch["t"])
|
|
return sample_batch
|
|
|
|
@override(TorchPolicy)
|
|
def learn_on_batch(self, postprocessed_batch):
|
|
train_batch = self._lazy_tensor_dict(postprocessed_batch)
|
|
|
|
loss_out, policy_loss, value_loss = self._loss(
|
|
self, self.model, self.dist_class, train_batch
|
|
)
|
|
self._optimizers[0].zero_grad()
|
|
loss_out.backward()
|
|
|
|
grad_process_info = self.extra_grad_process(self._optimizers[0], loss_out)
|
|
self._optimizers[0].step()
|
|
|
|
grad_info = self.extra_grad_info(train_batch)
|
|
grad_info.update(grad_process_info)
|
|
grad_info.update(
|
|
{
|
|
"total_loss": loss_out.detach().cpu().numpy(),
|
|
"policy_loss": policy_loss.detach().cpu().numpy(),
|
|
"value_loss": value_loss.detach().cpu().numpy(),
|
|
}
|
|
)
|
|
|
|
return {LEARNER_STATS_KEY: grad_info}
|