2020-10-01 16:57:10 +02:00
|
|
|
import gym
|
|
|
|
import numpy as np
|
2020-05-08 08:20:18 +02:00
|
|
|
import random
|
|
|
|
from ray.rllib.policy.policy import Policy
|
2020-10-01 16:57:10 +02:00
|
|
|
from ray.rllib.policy.view_requirement import ViewRequirement
|
2020-05-08 08:20:18 +02:00
|
|
|
|
2021-07-22 07:55:07 -07:00
|
|
|
ROCK = 0
|
|
|
|
PAPER = 1
|
|
|
|
SCISSORS = 2
|
|
|
|
|
2020-05-08 08:20:18 +02:00
|
|
|
|
|
|
|
class AlwaysSameHeuristic(Policy):
|
|
|
|
"""Pick a random move and stick with it for the entire episode."""
|
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
self.exploration = self._create_exploration()
|
2022-01-29 18:41:57 -08:00
|
|
|
self.view_requirements.update(
|
|
|
|
{
|
|
|
|
"state_in_0": ViewRequirement(
|
|
|
|
"state_out_0",
|
|
|
|
shift=-1,
|
|
|
|
space=gym.spaces.Box(0, 100, shape=(), dtype=np.int32),
|
|
|
|
)
|
|
|
|
}
|
|
|
|
)
|
2020-05-08 08:20:18 +02:00
|
|
|
|
|
|
|
def get_initial_state(self):
|
2021-07-22 07:55:07 -07:00
|
|
|
return [random.choice([ROCK, PAPER, SCISSORS])]
|
2020-05-08 08:20:18 +02:00
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
def compute_actions(
|
|
|
|
self,
|
|
|
|
obs_batch,
|
|
|
|
state_batches=None,
|
|
|
|
prev_action_batch=None,
|
|
|
|
prev_reward_batch=None,
|
|
|
|
info_batch=None,
|
|
|
|
episodes=None,
|
|
|
|
**kwargs
|
|
|
|
):
|
2020-11-12 16:27:34 +01:00
|
|
|
return state_batches[0], state_batches, {}
|
2020-05-08 08:20:18 +02:00
|
|
|
|
|
|
|
|
|
|
|
class BeatLastHeuristic(Policy):
|
|
|
|
"""Play the move that would beat the last move of the opponent."""
|
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
self.exploration = self._create_exploration()
|
|
|
|
|
2022-01-29 18:41:57 -08:00
|
|
|
def compute_actions(
|
|
|
|
self,
|
|
|
|
obs_batch,
|
|
|
|
state_batches=None,
|
|
|
|
prev_action_batch=None,
|
|
|
|
prev_reward_batch=None,
|
|
|
|
info_batch=None,
|
|
|
|
episodes=None,
|
|
|
|
**kwargs
|
|
|
|
):
|
2020-05-08 08:20:18 +02:00
|
|
|
def successor(x):
|
2021-12-13 12:04:23 +01:00
|
|
|
# Make this also work w/o one-hot preprocessing.
|
|
|
|
if isinstance(self.observation_space, gym.spaces.Discrete):
|
|
|
|
if x == ROCK:
|
|
|
|
return PAPER
|
|
|
|
elif x == PAPER:
|
|
|
|
return SCISSORS
|
|
|
|
elif x == SCISSORS:
|
|
|
|
return ROCK
|
|
|
|
else:
|
|
|
|
return random.choice([ROCK, PAPER, SCISSORS])
|
|
|
|
# One-hot (auto-preprocessed) inputs.
|
|
|
|
else:
|
|
|
|
if x[ROCK] == 1:
|
|
|
|
return PAPER
|
|
|
|
elif x[PAPER] == 1:
|
|
|
|
return SCISSORS
|
|
|
|
elif x[SCISSORS] == 1:
|
|
|
|
return ROCK
|
|
|
|
elif x[-1] == 1:
|
|
|
|
return random.choice([ROCK, PAPER, SCISSORS])
|
2020-05-08 08:20:18 +02:00
|
|
|
|
|
|
|
return [successor(x) for x in obs_batch], [], {}
|
|
|
|
|
|
|
|
def learn_on_batch(self, samples):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def get_weights(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def set_weights(self, weights):
|
|
|
|
pass
|