ray/rllib/algorithms/alpha_zero/ranked_rewards.py

78 lines
2.6 KiB
Python

from copy import deepcopy
import numpy as np
class RankedRewardsBuffer:
def __init__(self, buffer_max_length, percentile):
self.buffer_max_length = buffer_max_length
self.percentile = percentile
self.buffer = []
def add_reward(self, reward):
if len(self.buffer) < self.buffer_max_length:
self.buffer.append(reward)
else:
self.buffer = self.buffer[1:] + [reward]
def normalize(self, reward):
reward_threshold = np.percentile(self.buffer, self.percentile)
if reward < reward_threshold:
return -1.0
else:
return 1.0
def get_state(self):
return np.array(self.buffer)
def set_state(self, state):
if state is not None:
self.buffer = list(state)
def get_r2_env_wrapper(env_creator, r2_config):
class RankedRewardsEnvWrapper:
def __init__(self, env_config):
self.env = env_creator(env_config)
self.action_space = self.env.action_space
self.observation_space = self.env.observation_space
max_buffer_length = r2_config["buffer_max_length"]
percentile = r2_config["percentile"]
self.r2_buffer = RankedRewardsBuffer(max_buffer_length, percentile)
if r2_config["initialize_buffer"]:
self._initialize_buffer(r2_config["num_init_rewards"])
def _initialize_buffer(self, num_init_rewards=100):
# initialize buffer with random policy
for _ in range(num_init_rewards):
obs = self.env.reset()
done = False
while not done:
mask = obs["action_mask"]
probs = mask / mask.sum()
action = np.random.choice(np.arange(mask.shape[0]), p=probs)
obs, reward, done, _ = self.env.step(action)
self.r2_buffer.add_reward(reward)
def step(self, action):
obs, reward, done, info = self.env.step(action)
if done:
reward = self.r2_buffer.normalize(reward)
return obs, reward, done, info
def get_state(self):
state = {
"env_state": self.env.get_state(),
"buffer_state": self.r2_buffer.get_state(),
}
return deepcopy(state)
def reset(self):
return self.env.reset()
def set_state(self, state):
obs = self.env.set_state(state["env_state"])
self.r2_buffer.set_state(state["buffer_state"])
return obs
return RankedRewardsEnvWrapper