2020-10-01 16:57:10 +02:00
|
|
|
import gym
|
|
|
|
import numpy as np
|
2020-05-08 08:20:18 +02:00
|
|
|
import random
|
2021-07-06 09:43:47 -07:00
|
|
|
|
|
|
|
from ray.rllib.examples.env.rock_paper_scissors import RockPaperScissors
|
2020-05-08 08:20:18 +02:00
|
|
|
from ray.rllib.policy.policy import Policy
|
2020-10-01 16:57:10 +02:00
|
|
|
from ray.rllib.policy.view_requirement import ViewRequirement
|
2020-05-08 08:20:18 +02:00
|
|
|
|
|
|
|
|
|
|
|
class AlwaysSameHeuristic(Policy):
|
|
|
|
"""Pick a random move and stick with it for the entire episode."""
|
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
self.exploration = self._create_exploration()
|
2020-10-01 16:57:10 +02:00
|
|
|
self.view_requirements.update({
|
|
|
|
"state_in_0": ViewRequirement(
|
|
|
|
"state_out_0",
|
2020-12-07 13:08:17 +01:00
|
|
|
shift=-1,
|
2020-10-01 16:57:10 +02:00
|
|
|
space=gym.spaces.Box(0, 100, shape=(), dtype=np.int32))
|
|
|
|
})
|
2020-05-08 08:20:18 +02:00
|
|
|
|
|
|
|
def get_initial_state(self):
|
2021-07-06 09:43:47 -07:00
|
|
|
return [
|
|
|
|
random.choice([
|
|
|
|
RockPaperScissors.ROCK, RockPaperScissors.PAPER,
|
|
|
|
RockPaperScissors.SCISSORS
|
|
|
|
])
|
|
|
|
]
|
2020-05-08 08:20:18 +02:00
|
|
|
|
|
|
|
def compute_actions(self,
|
|
|
|
obs_batch,
|
|
|
|
state_batches=None,
|
|
|
|
prev_action_batch=None,
|
|
|
|
prev_reward_batch=None,
|
|
|
|
info_batch=None,
|
|
|
|
episodes=None,
|
|
|
|
**kwargs):
|
2020-11-12 16:27:34 +01:00
|
|
|
return state_batches[0], state_batches, {}
|
2020-05-08 08:20:18 +02:00
|
|
|
|
|
|
|
|
|
|
|
class BeatLastHeuristic(Policy):
|
|
|
|
"""Play the move that would beat the last move of the opponent."""
|
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
self.exploration = self._create_exploration()
|
|
|
|
|
|
|
|
def compute_actions(self,
|
|
|
|
obs_batch,
|
|
|
|
state_batches=None,
|
|
|
|
prev_action_batch=None,
|
|
|
|
prev_reward_batch=None,
|
|
|
|
info_batch=None,
|
|
|
|
episodes=None,
|
|
|
|
**kwargs):
|
|
|
|
def successor(x):
|
2021-07-06 09:43:47 -07:00
|
|
|
if x[RockPaperScissors.ROCK] == 1:
|
|
|
|
return RockPaperScissors.PAPER
|
|
|
|
elif x[RockPaperScissors.PAPER] == 1:
|
|
|
|
return RockPaperScissors.SCISSORS
|
|
|
|
elif x[RockPaperScissors.SCISSORS] == 1:
|
|
|
|
return RockPaperScissors.ROCK
|
2020-05-08 08:20:18 +02:00
|
|
|
|
|
|
|
return [successor(x) for x in obs_batch], [], {}
|
|
|
|
|
|
|
|
def learn_on_batch(self, samples):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def get_weights(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def set_weights(self, weights):
|
|
|
|
pass
|