mirror of
https://github.com/vale981/ray
synced 2025-03-09 04:46:38 -04:00
71 lines
2.2 KiB
Python
71 lines
2.2 KiB
Python
##########
|
|
# Contribution by the Center on Long-Term Risk:
|
|
# https://github.com/longtermrisk/marltoolbox
|
|
##########
|
|
from abc import ABC
|
|
|
|
from ray.rllib.examples.env.utils.interfaces import InfoAccumulationInterface
|
|
|
|
|
|
class TwoPlayersTwoActionsInfoMixin(InfoAccumulationInterface, ABC):
|
|
"""
|
|
Mixin class to add logging capability in a two player discrete game.
|
|
Logs the frequency of each state.
|
|
"""
|
|
|
|
def _init_info(self):
|
|
self.cc_count = []
|
|
self.dd_count = []
|
|
self.cd_count = []
|
|
self.dc_count = []
|
|
|
|
def _reset_info(self):
|
|
self.cc_count.clear()
|
|
self.dd_count.clear()
|
|
self.cd_count.clear()
|
|
self.dc_count.clear()
|
|
|
|
def _get_episode_info(self):
|
|
return {
|
|
"CC": sum(self.cc_count) / len(self.cc_count),
|
|
"DD": sum(self.dd_count) / len(self.dd_count),
|
|
"CD": sum(self.cd_count) / len(self.cd_count),
|
|
"DC": sum(self.dc_count) / len(self.dc_count),
|
|
}
|
|
|
|
def _accumulate_info(self, ac0, ac1):
|
|
self.cc_count.append(ac0 == 0 and ac1 == 0)
|
|
self.cd_count.append(ac0 == 0 and ac1 == 1)
|
|
self.dc_count.append(ac0 == 1 and ac1 == 0)
|
|
self.dd_count.append(ac0 == 1 and ac1 == 1)
|
|
|
|
|
|
class NPlayersNDiscreteActionsInfoMixin(InfoAccumulationInterface, ABC):
|
|
"""
|
|
Mixin class to add logging capability in N player games with
|
|
discrete actions.
|
|
Logs the frequency of action profiles used
|
|
(action profile: the set of actions used during one step by all players).
|
|
"""
|
|
|
|
def _init_info(self):
|
|
self.info_counters = {"n_steps_accumulated": 0}
|
|
|
|
def _reset_info(self):
|
|
self.info_counters = {"n_steps_accumulated": 0}
|
|
|
|
def _get_episode_info(self):
|
|
info = {}
|
|
if self.info_counters["n_steps_accumulated"] > 0:
|
|
for k, v in self.info_counters.items():
|
|
if k != "n_steps_accumulated":
|
|
info[k] = v / self.info_counters["n_steps_accumulated"]
|
|
|
|
return info
|
|
|
|
def _accumulate_info(self, *actions):
|
|
id = "_".join([str(a) for a in actions])
|
|
if id not in self.info_counters:
|
|
self.info_counters[id] = 0
|
|
self.info_counters[id] += 1
|
|
self.info_counters["n_steps_accumulated"] += 1
|