ray/rllib/examples/env/windy_maze_env.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

156 lines
5.2 KiB
Python
Raw Normal View History

import gym
from gym.spaces import Box, Discrete, Tuple
import logging
import random
from ray.rllib.env import MultiAgentEnv
logger = logging.getLogger(__name__)
# Agent has to traverse the maze from the starting position S -> F
# Observation space [x_pos, y_pos, wind_direction]
# Action space: stay still OR move in current wind direction
MAP_DATA = """
#########
#S #
####### #
# #
# #
####### #
#F #
#########"""
class WindyMazeEnv(gym.Env):
def __init__(self, env_config):
self.map = [m for m in MAP_DATA.split("\n") if m]
self.x_dim = len(self.map)
self.y_dim = len(self.map[0])
logger.info("Loaded map {} {}".format(self.x_dim, self.y_dim))
for x in range(self.x_dim):
for y in range(self.y_dim):
if self.map[x][y] == "S":
self.start_pos = (x, y)
elif self.map[x][y] == "F":
self.end_pos = (x, y)
logger.info("Start pos {} end pos {}".format(self.start_pos, self.end_pos))
self.observation_space = Tuple(
[
Box(0, 100, shape=(2,)), # (x, y)
Discrete(4), # wind direction (N, E, S, W)
]
)
self.action_space = Discrete(2) # whether to move or not
def reset(self):
self.wind_direction = random.choice([0, 1, 2, 3])
self.pos = self.start_pos
self.num_steps = 0
return [[self.pos[0], self.pos[1]], self.wind_direction]
def step(self, action):
if action == 1:
self.pos = self._get_new_pos(self.pos, self.wind_direction)
self.num_steps += 1
self.wind_direction = random.choice([0, 1, 2, 3])
at_goal = self.pos == self.end_pos
done = at_goal or self.num_steps >= 200
return (
[[self.pos[0], self.pos[1]], self.wind_direction],
100 * int(at_goal),
done,
{},
)
def _get_new_pos(self, pos, direction):
if direction == 0:
new_pos = (pos[0] - 1, pos[1])
elif direction == 1:
new_pos = (pos[0], pos[1] + 1)
elif direction == 2:
new_pos = (pos[0] + 1, pos[1])
elif direction == 3:
new_pos = (pos[0], pos[1] - 1)
if (
new_pos[0] >= 0
and new_pos[0] < self.x_dim
and new_pos[1] >= 0
and new_pos[1] < self.y_dim
and self.map[new_pos[0]][new_pos[1]] != "#"
):
return new_pos
else:
return pos # did not move
class HierarchicalWindyMazeEnv(MultiAgentEnv):
def __init__(self, env_config):
2022-01-18 07:34:06 -08:00
super().__init__()
self._skip_env_checking = True
self.flat_env = WindyMazeEnv(env_config)
def reset(self):
self.cur_obs = self.flat_env.reset()
self.current_goal = None
self.steps_remaining_at_level = None
self.num_high_level_steps = 0
# current low level agent id. This must be unique for each high level
# step since agent ids cannot be reused.
self.low_level_agent_id = "low_level_{}".format(self.num_high_level_steps)
return {
"high_level_agent": self.cur_obs,
}
def step(self, action_dict):
assert len(action_dict) == 1, action_dict
if "high_level_agent" in action_dict:
return self._high_level_step(action_dict["high_level_agent"])
else:
return self._low_level_step(list(action_dict.values())[0])
def _high_level_step(self, action):
logger.debug("High level agent sets goal")
self.current_goal = action
self.steps_remaining_at_level = 25
self.num_high_level_steps += 1
self.low_level_agent_id = "low_level_{}".format(self.num_high_level_steps)
obs = {self.low_level_agent_id: [self.cur_obs, self.current_goal]}
rew = {self.low_level_agent_id: 0}
done = {"__all__": False}
return obs, rew, done, {}
def _low_level_step(self, action):
logger.debug("Low level agent step {}".format(action))
self.steps_remaining_at_level -= 1
cur_pos = tuple(self.cur_obs[0])
goal_pos = self.flat_env._get_new_pos(cur_pos, self.current_goal)
# Step in the actual env
f_obs, f_rew, f_done, _ = self.flat_env.step(action)
new_pos = tuple(f_obs[0])
self.cur_obs = f_obs
# Calculate low-level agent observation and reward
obs = {self.low_level_agent_id: [f_obs, self.current_goal]}
if new_pos != cur_pos:
if new_pos == goal_pos:
rew = {self.low_level_agent_id: 1}
else:
rew = {self.low_level_agent_id: -1}
else:
rew = {self.low_level_agent_id: 0}
# Handle env termination & transitions back to higher level
done = {"__all__": False}
if f_done:
done["__all__"] = True
logger.debug("high level final reward {}".format(f_rew))
rew["high_level_agent"] = f_rew
obs["high_level_agent"] = f_obs
elif self.steps_remaining_at_level == 0:
done[self.low_level_agent_id] = True
rew["high_level_agent"] = 0
obs["high_level_agent"] = f_obs
return obs, rew, done, {}