########## # Contribution by the Center on Long-Term Risk: # https://github.com/longtermrisk/marltoolbox ########## from abc import ABC from ray.rllib.examples.env.utils.interfaces import InfoAccumulationInterface class TwoPlayersTwoActionsInfoMixin(InfoAccumulationInterface, ABC): """ Mixin class to add logging capability in a two player discrete game. Logs the frequency of each state. """ def _init_info(self): self.cc_count = [] self.dd_count = [] self.cd_count = [] self.dc_count = [] def _reset_info(self): self.cc_count.clear() self.dd_count.clear() self.cd_count.clear() self.dc_count.clear() def _get_episode_info(self): return { "CC": sum(self.cc_count) / len(self.cc_count), "DD": sum(self.dd_count) / len(self.dd_count), "CD": sum(self.cd_count) / len(self.cd_count), "DC": sum(self.dc_count) / len(self.dc_count), } def _accumulate_info(self, ac0, ac1): self.cc_count.append(ac0 == 0 and ac1 == 0) self.cd_count.append(ac0 == 0 and ac1 == 1) self.dc_count.append(ac0 == 1 and ac1 == 0) self.dd_count.append(ac0 == 1 and ac1 == 1) class NPlayersNDiscreteActionsInfoMixin(InfoAccumulationInterface, ABC): """ Mixin class to add logging capability in N player games with discrete actions. Logs the frequency of action profiles used (action profile: the set of actions used during one step by all players). """ def _init_info(self): self.info_counters = {"n_steps_accumulated": 0} def _reset_info(self): self.info_counters = {"n_steps_accumulated": 0} def _get_episode_info(self): info = {} if self.info_counters["n_steps_accumulated"] > 0: for k, v in self.info_counters.items(): if k != "n_steps_accumulated": info[k] = v / self.info_counters["n_steps_accumulated"] return info def _accumulate_info(self, *actions): id = "_".join([str(a) for a in actions]) if id not in self.info_counters: self.info_counters[id] = 0 self.info_counters[id] += 1 self.info_counters["n_steps_accumulated"] += 1