ray/rllib/examples/policy/rock_paper_scissors_dummies.py
Balaji Veeramani 7f1bacc7dc
[CI] Format Python code with Black (#21975)
See #21316 and #21311 for the motivation behind these changes.
2022-01-29 18:41:57 -08:00

92 lines
2.5 KiB
Python

import gym
import numpy as np
import random
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.view_requirement import ViewRequirement
ROCK = 0
PAPER = 1
SCISSORS = 2
class AlwaysSameHeuristic(Policy):
"""Pick a random move and stick with it for the entire episode."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.exploration = self._create_exploration()
self.view_requirements.update(
{
"state_in_0": ViewRequirement(
"state_out_0",
shift=-1,
space=gym.spaces.Box(0, 100, shape=(), dtype=np.int32),
)
}
)
def get_initial_state(self):
return [random.choice([ROCK, PAPER, SCISSORS])]
def compute_actions(
self,
obs_batch,
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None,
info_batch=None,
episodes=None,
**kwargs
):
return state_batches[0], state_batches, {}
class BeatLastHeuristic(Policy):
"""Play the move that would beat the last move of the opponent."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.exploration = self._create_exploration()
def compute_actions(
self,
obs_batch,
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None,
info_batch=None,
episodes=None,
**kwargs
):
def successor(x):
# Make this also work w/o one-hot preprocessing.
if isinstance(self.observation_space, gym.spaces.Discrete):
if x == ROCK:
return PAPER
elif x == PAPER:
return SCISSORS
elif x == SCISSORS:
return ROCK
else:
return random.choice([ROCK, PAPER, SCISSORS])
# One-hot (auto-preprocessed) inputs.
else:
if x[ROCK] == 1:
return PAPER
elif x[PAPER] == 1:
return SCISSORS
elif x[SCISSORS] == 1:
return ROCK
elif x[-1] == 1:
return random.choice([ROCK, PAPER, SCISSORS])
return [successor(x) for x in obs_batch], [], {}
def learn_on_batch(self, samples):
pass
def get_weights(self):
pass
def set_weights(self, weights):
pass