ray/rllib/algorithms/alpha_zero/alpha_zero_policy.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

159 lines
5.6 KiB
Python
Raw Normal View History

import numpy as np
from ray.rllib.algorithms.alpha_zero.mcts import Node, RootParentNode
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.torch_policy import TorchPolicy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
torch, _ = try_import_torch()
class AlphaZeroPolicy(TorchPolicy):
def __init__(
self,
observation_space,
action_space,
config,
model,
loss,
action_distribution_class,
mcts_creator,
env_creator,
**kwargs
):
super().__init__(
observation_space,
action_space,
config,
model=model,
loss=loss,
action_distribution_class=action_distribution_class,
)
# we maintain an env copy in the policy that is used during mcts
# simulations
self.env_creator = env_creator
self.mcts = mcts_creator()
self.env = self.env_creator()
self.env.reset()
self.obs_space = observation_space
@override(TorchPolicy)
def compute_actions(
self,
obs_batch,
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None,
info_batch=None,
episodes=None,
**kwargs
):
input_dict = {"obs": obs_batch}
if prev_action_batch is not None:
input_dict["prev_actions"] = prev_action_batch
if prev_reward_batch is not None:
input_dict["prev_rewards"] = prev_reward_batch
return self.compute_actions_from_input_dict(
input_dict=input_dict,
episodes=episodes,
state_batches=state_batches,
)
@override(Policy)
def compute_actions_from_input_dict(
self, input_dict, explore=None, timestep=None, episodes=None, **kwargs
):
with torch.no_grad():
actions = []
for i, episode in enumerate(episodes):
if episode.length == 0:
# if first time step of episode, get initial env state
env_state = episode.user_data["initial_state"]
# verify if env has been wrapped for ranked rewards
if self.env.__class__.__name__ == "RankedRewardsEnvWrapper":
# r2 env state contains also the rewards buffer state
env_state = {"env_state": env_state, "buffer_state": None}
# create tree root node
obs = self.env.set_state(env_state)
tree_node = Node(
state=env_state,
obs=obs,
reward=0,
done=False,
action=None,
parent=RootParentNode(env=self.env),
mcts=self.mcts,
)
else:
# otherwise get last root node from previous time step
tree_node = episode.user_data["tree_node"]
# run monte carlo simulations to compute the actions
# and record the tree
mcts_policy, action, tree_node = self.mcts.compute_action(tree_node)
# record action
actions.append(action)
# store new node
episode.user_data["tree_node"] = tree_node
# store mcts policies vectors and current tree root node
if episode.length == 0:
episode.user_data["mcts_policies"] = [mcts_policy]
else:
episode.user_data["mcts_policies"].append(mcts_policy)
return (
np.array(actions),
[],
self.extra_action_out(
input_dict, kwargs.get("state_batches", []), self.model, None
),
)
@override(Policy)
def postprocess_trajectory(
self, sample_batch, other_agent_batches=None, episode=None
):
# add mcts policies to sample batch
sample_batch["mcts_policies"] = np.array(episode.user_data["mcts_policies"])[
sample_batch["t"]
]
# final episode reward corresponds to the value (if not discounted)
# for all transitions in episode
final_reward = sample_batch["rewards"][-1]
# if r2 is enabled, then add the reward to the buffer and normalize it
if self.env.__class__.__name__ == "RankedRewardsEnvWrapper":
self.env.r2_buffer.add_reward(final_reward)
final_reward = self.env.r2_buffer.normalize(final_reward)
sample_batch["value_label"] = final_reward * np.ones_like(sample_batch["t"])
return sample_batch
@override(TorchPolicy)
def learn_on_batch(self, postprocessed_batch):
train_batch = self._lazy_tensor_dict(postprocessed_batch)
loss_out, policy_loss, value_loss = self._loss(
self, self.model, self.dist_class, train_batch
)
[RLlib] SAC Torch (incl. Atari learning) (#7984) * Policy-classes cleanup and torch/tf unification. - Make Policy abstract. - Add `action_dist` to call to `extra_action_out_fn` (necessary for PPO torch). - Move some methods and vars to base Policy (from TFPolicy): num_state_tensors, ACTION_PROB, ACTION_LOGP and some more. * Fix `clip_action` import from Policy (should probably be moved into utils altogether). * - Move `is_recurrent()` and `num_state_tensors()` into TFPolicy (from DynamicTFPolicy). - Add config to all Policy c'tor calls (as 3rd arg after obs and action spaces). * Add `config` to c'tor call to TFPolicy. * Add missing `config` to c'tor call to TFPolicy in marvil_policy.py. * Fix test_rollout_worker.py::MockPolicy and BadPolicy classes (Policy base class is now abstract). * Fix LINT errors in Policy classes. * Implement StatefulPolicy abstract methods in test cases: test_multi_agent_env.py. * policy.py LINT errors. * Create a simple TestPolicy to sub-class from when testing Policies (reduces code in some test cases). * policy.py - Remove abstractmethod from `apply_gradients` and `compute_gradients` (these are not required iff `learn_on_batch` implemented). - Fix docstring of `num_state_tensors`. * Make QMIX torch Policy a child of TorchPolicy (instead of Policy). * QMixPolicy add empty implementations of abstract Policy methods. * Store Policy's config in self.config in base Policy c'tor. * - Make only compute_actions in base Policy's an abstractmethod and provide pass implementation to all other methods if not defined. - Fix state_batches=None (most Policies don't have internal states). * Cartpole tf learning. * Cartpole tf AND torch learning (in ~ same ts). * Cartpole tf AND torch learning (in ~ same ts). 2 * Cartpole tf (torch syntax-broken) learning (in ~ same ts). 3 * Cartpole tf AND torch learning (in ~ same ts). 4 * Cartpole tf AND torch learning (in ~ same ts). 5 * Cartpole tf AND torch learning (in ~ same ts). 6 * Cartpole tf AND torch learning (in ~ same ts). Pendulum tf learning. * WIP. * WIP. * SAC torch learning Pendulum. * WIP. * SAC torch and tf learning Pendulum and Cartpole after cleanup. * WIP. * LINT. * LINT. * SAC: Move policy.target_model to policy.device as well. * Fixes and cleanup. * Fix data-format of tf keras Conv2d layers (broken for some tf-versions which have data_format="channels_first" as default). * Fixes and LINT. * Fixes and LINT. * Fix and LINT. * WIP. * Test fixes and LINT. * Fixes and LINT. Co-authored-by: Sven Mika <sven@Svens-MacBook-Pro.local>
2020-04-15 13:25:16 +02:00
self._optimizers[0].zero_grad()
loss_out.backward()
[RLlib] SAC Torch (incl. Atari learning) (#7984) * Policy-classes cleanup and torch/tf unification. - Make Policy abstract. - Add `action_dist` to call to `extra_action_out_fn` (necessary for PPO torch). - Move some methods and vars to base Policy (from TFPolicy): num_state_tensors, ACTION_PROB, ACTION_LOGP and some more. * Fix `clip_action` import from Policy (should probably be moved into utils altogether). * - Move `is_recurrent()` and `num_state_tensors()` into TFPolicy (from DynamicTFPolicy). - Add config to all Policy c'tor calls (as 3rd arg after obs and action spaces). * Add `config` to c'tor call to TFPolicy. * Add missing `config` to c'tor call to TFPolicy in marvil_policy.py. * Fix test_rollout_worker.py::MockPolicy and BadPolicy classes (Policy base class is now abstract). * Fix LINT errors in Policy classes. * Implement StatefulPolicy abstract methods in test cases: test_multi_agent_env.py. * policy.py LINT errors. * Create a simple TestPolicy to sub-class from when testing Policies (reduces code in some test cases). * policy.py - Remove abstractmethod from `apply_gradients` and `compute_gradients` (these are not required iff `learn_on_batch` implemented). - Fix docstring of `num_state_tensors`. * Make QMIX torch Policy a child of TorchPolicy (instead of Policy). * QMixPolicy add empty implementations of abstract Policy methods. * Store Policy's config in self.config in base Policy c'tor. * - Make only compute_actions in base Policy's an abstractmethod and provide pass implementation to all other methods if not defined. - Fix state_batches=None (most Policies don't have internal states). * Cartpole tf learning. * Cartpole tf AND torch learning (in ~ same ts). * Cartpole tf AND torch learning (in ~ same ts). 2 * Cartpole tf (torch syntax-broken) learning (in ~ same ts). 3 * Cartpole tf AND torch learning (in ~ same ts). 4 * Cartpole tf AND torch learning (in ~ same ts). 5 * Cartpole tf AND torch learning (in ~ same ts). 6 * Cartpole tf AND torch learning (in ~ same ts). Pendulum tf learning. * WIP. * WIP. * SAC torch learning Pendulum. * WIP. * SAC torch and tf learning Pendulum and Cartpole after cleanup. * WIP. * LINT. * LINT. * SAC: Move policy.target_model to policy.device as well. * Fixes and cleanup. * Fix data-format of tf keras Conv2d layers (broken for some tf-versions which have data_format="channels_first" as default). * Fixes and LINT. * Fixes and LINT. * Fix and LINT. * WIP. * Test fixes and LINT. * Fixes and LINT. Co-authored-by: Sven Mika <sven@Svens-MacBook-Pro.local>
2020-04-15 13:25:16 +02:00
grad_process_info = self.extra_grad_process(self._optimizers[0], loss_out)
self._optimizers[0].step()
grad_info = self.extra_grad_info(train_batch)
grad_info.update(grad_process_info)
grad_info.update(
{
"total_loss": loss_out.detach().cpu().numpy(),
"policy_loss": policy_loss.detach().cpu().numpy(),
"value_loss": value_loss.detach().cpu().numpy(),
}
)
return {LEARNER_STATS_KEY: grad_info}