2019-12-07 21:08:40 +01:00
Mcts implementation modified from
import collections
import math
import numpy as np
class Node:
def __init__(self, action, obs, done, reward, state, mcts, parent=None):
self.env = parent.env
self.action = action # Action used to go to this state
self.is_expanded = False
self.parent = parent
self.children = {}
self.action_space_size = self.env.action_space.n
self.child_total_value = np.zeros(
2022-01-29 18:41:57 -08:00
[self.action_space_size], dtype=np.float32
) # Q
self.child_priors = np.zeros([self.action_space_size], dtype=np.float32) # P
2019-12-07 21:08:40 +01:00
self.child_number_visits = np.zeros(
2022-01-29 18:41:57 -08:00
[self.action_space_size], dtype=np.float32
) # N
2019-12-07 21:08:40 +01:00
self.valid_actions = obs["action_mask"].astype(np.bool)
self.reward = reward
self.done = done
self.state = state
self.obs = obs
self.mcts = mcts
def number_visits(self):
return self.parent.child_number_visits[self.action]
def number_visits(self, value):
self.parent.child_number_visits[self.action] = value
def total_value(self):
return self.parent.child_total_value[self.action]
def total_value(self, value):
self.parent.child_total_value[self.action] = value
def child_Q(self):
# TODO (weak todo) add "softmax" version of the Q-value
return self.child_total_value / (1 + self.child_number_visits)
def child_U(self):
2022-01-29 18:41:57 -08:00
return (
* self.child_priors
/ (1 + self.child_number_visits)
2019-12-07 21:08:40 +01:00
def best_action(self):
:return: action
child_score = self.child_Q() + self.mcts.c_puct * self.child_U()
masked_child_score = child_score
masked_child_score[~self.valid_actions] = -np.inf
return np.argmax(masked_child_score)
def select(self):
current_node = self
while current_node.is_expanded:
best_action = current_node.best_action()
current_node = current_node.get_child(best_action)
return current_node
def expand(self, child_priors):
self.is_expanded = True
self.child_priors = child_priors
def get_child(self, action):
if action not in self.children:
obs, reward, done, _ = self.env.step(action)
next_state = self.env.get_state()
self.children[action] = Node(
2022-01-29 18:41:57 -08:00
2019-12-07 21:08:40 +01:00
return self.children[action]
def backup(self, value):
current = self
while current.parent is not None:
current.number_visits += 1
current.total_value += value
current = current.parent
2020-01-02 17:42:13 -08:00
class RootParentNode:
2019-12-07 21:08:40 +01:00
def __init__(self, env):
self.parent = None
self.child_total_value = collections.defaultdict(float)
self.child_number_visits = collections.defaultdict(float)
self.env = env
class MCTS:
def __init__(self, model, mcts_param):
self.model = model
self.temperature = mcts_param["temperature"]
self.dir_epsilon = mcts_param["dirichlet_epsilon"]
self.dir_noise = mcts_param["dirichlet_noise"]
self.num_sims = mcts_param["num_simulations"]
self.exploit = mcts_param["argmax_tree_policy"]
self.add_dirichlet_noise = mcts_param["add_dirichlet_noise"]
self.c_puct = mcts_param["puct_coefficient"]
def compute_action(self, node):
for _ in range(self.num_sims):
leaf = node.select()
if leaf.done:
value = leaf.reward
2022-01-29 18:41:57 -08:00
child_priors, value = self.model.compute_priors_and_value(leaf.obs)
2019-12-07 21:08:40 +01:00
if self.add_dirichlet_noise:
child_priors = (1 - self.dir_epsilon) * child_priors
child_priors += self.dir_epsilon * np.random.dirichlet(
2022-01-29 18:41:57 -08:00
[self.dir_noise] * child_priors.size
2019-12-07 21:08:40 +01:00
# Tree policy target (TPT)
tree_policy = node.child_number_visits / node.number_visits
tree_policy = tree_policy / np.max(
2022-01-29 18:41:57 -08:00
) # to avoid overflows when computing softmax
2019-12-07 21:08:40 +01:00
tree_policy = np.power(tree_policy, self.temperature)
tree_policy = tree_policy / np.sum(tree_policy)
if self.exploit:
# if exploit then choose action that has the maximum
# tree policy probability
action = np.argmax(tree_policy)
# otherwise sample an action according to tree policy probabilities
2022-01-29 18:41:57 -08:00
action = np.random.choice(np.arange(node.action_space_size), p=tree_policy)
2019-12-07 21:08:40 +01:00
return tree_policy, action, node.children[action]