""" Mcts implementation modified from https://github.com/brilee/python_uct/blob/master/numpy_impl.py """ import collections import math import numpy as np class Node: def __init__(self, action, obs, done, reward, state, mcts, parent=None): self.env = parent.env self.action = action # Action used to go to this state self.is_expanded = False self.parent = parent self.children = {} self.action_space_size = self.env.action_space.n self.child_total_value = np.zeros( [self.action_space_size], dtype=np.float32 ) # Q self.child_priors = np.zeros([self.action_space_size], dtype=np.float32) # P self.child_number_visits = np.zeros( [self.action_space_size], dtype=np.float32 ) # N self.valid_actions = obs["action_mask"].astype(np.bool) self.reward = reward self.done = done self.state = state self.obs = obs self.mcts = mcts @property def number_visits(self): return self.parent.child_number_visits[self.action] @number_visits.setter def number_visits(self, value): self.parent.child_number_visits[self.action] = value @property def total_value(self): return self.parent.child_total_value[self.action] @total_value.setter def total_value(self, value): self.parent.child_total_value[self.action] = value def child_Q(self): # TODO (weak todo) add "softmax" version of the Q-value return self.child_total_value / (1 + self.child_number_visits) def child_U(self): return ( math.sqrt(self.number_visits) * self.child_priors / (1 + self.child_number_visits) ) def best_action(self): """ :return: action """ child_score = self.child_Q() + self.mcts.c_puct * self.child_U() masked_child_score = child_score masked_child_score[~self.valid_actions] = -np.inf return np.argmax(masked_child_score) def select(self): current_node = self while current_node.is_expanded: best_action = current_node.best_action() current_node = current_node.get_child(best_action) return current_node def expand(self, child_priors): self.is_expanded = True self.child_priors = child_priors def get_child(self, action): if action not in self.children: self.env.set_state(self.state) obs, reward, done, _ = self.env.step(action) next_state = self.env.get_state() self.children[action] = Node( state=next_state, action=action, parent=self, reward=reward, done=done, obs=obs, mcts=self.mcts, ) return self.children[action] def backup(self, value): current = self while current.parent is not None: current.number_visits += 1 current.total_value += value current = current.parent class RootParentNode: def __init__(self, env): self.parent = None self.child_total_value = collections.defaultdict(float) self.child_number_visits = collections.defaultdict(float) self.env = env class MCTS: def __init__(self, model, mcts_param): self.model = model self.temperature = mcts_param["temperature"] self.dir_epsilon = mcts_param["dirichlet_epsilon"] self.dir_noise = mcts_param["dirichlet_noise"] self.num_sims = mcts_param["num_simulations"] self.exploit = mcts_param["argmax_tree_policy"] self.add_dirichlet_noise = mcts_param["add_dirichlet_noise"] self.c_puct = mcts_param["puct_coefficient"] def compute_action(self, node): for _ in range(self.num_sims): leaf = node.select() if leaf.done: value = leaf.reward else: child_priors, value = self.model.compute_priors_and_value(leaf.obs) if self.add_dirichlet_noise: child_priors = (1 - self.dir_epsilon) * child_priors child_priors += self.dir_epsilon * np.random.dirichlet( [self.dir_noise] * child_priors.size ) leaf.expand(child_priors) leaf.backup(value) # Tree policy target (TPT) tree_policy = node.child_number_visits / node.number_visits tree_policy = tree_policy / np.max( tree_policy ) # to avoid overflows when computing softmax tree_policy = np.power(tree_policy, self.temperature) tree_policy = tree_policy / np.sum(tree_policy) if self.exploit: # if exploit then choose action that has the maximum # tree policy probability action = np.argmax(tree_policy) else: # otherwise sample an action according to tree policy probabilities action = np.random.choice(np.arange(node.action_space_size), p=tree_policy) return tree_policy, action, node.children[action]