ray/rllib/examples/hierarchical_training.py
Sven 60d4d5e1aa Remove future imports (#6724)
* Remove all __future__ imports from RLlib.

* Remove (object) again from tf_run_builder.py::TFRunBuilder.

* Fix 2xLINT warnings.

* Fix broken appo_policy import (must be appo_tf_policy)

* Remove future imports from all other ray files (not just RLlib).

* Remove future imports from all other ray files (not just RLlib).

* Remove future import blocks that contain `unicode_literals` as well.
Revert appo_tf_policy.py to appo_policy.py (belongs to another PR).

* Add two empty lines before Schedule class.

* Put back __future__ imports into determine_tests_to_run.py. Fails otherwise on a py2/print related error.
2020-01-09 00:15:48 -08:00

224 lines
7.8 KiB
Python

"""Example of hierarchical training using the multi-agent API.
The example env is that of a "windy maze". The agent observes the current wind
direction and can either choose to stand still, or move in that direction.
You can try out the env directly with:
$ python hierarchical_training.py --flat
A simple hierarchical formulation involves a high-level agent that issues goals
(i.e., go north / south / east / west), and a low-level agent that executes
these goals over a number of time-steps. This can be implemented as a
multi-agent environment with a top-level agent and low-level agents spawned
for each higher-level action. The lower level agent is rewarded for moving
in the right direction.
You can try this formulation with:
$ python hierarchical_training.py # gets ~100 rew after ~100k timesteps
Note that the hierarchical formulation actually converges slightly slower than
using --flat in this example.
"""
import argparse
import random
import gym
from gym.spaces import Box, Discrete, Tuple
import logging
import ray
from ray import tune
from ray.tune import function
from ray.rllib.env import MultiAgentEnv
parser = argparse.ArgumentParser()
parser.add_argument("--flat", action="store_true")
# Agent has to traverse the maze from the starting position S -> F
# Observation space [x_pos, y_pos, wind_direction]
# Action space: stay still OR move in current wind direction
MAP_DATA = """
#########
#S #
####### #
# #
# #
####### #
#F #
#########"""
logger = logging.getLogger(__name__)
class WindyMazeEnv(gym.Env):
def __init__(self, env_config):
self.map = [m for m in MAP_DATA.split("\n") if m]
self.x_dim = len(self.map)
self.y_dim = len(self.map[0])
logger.info("Loaded map {} {}".format(self.x_dim, self.y_dim))
for x in range(self.x_dim):
for y in range(self.y_dim):
if self.map[x][y] == "S":
self.start_pos = (x, y)
elif self.map[x][y] == "F":
self.end_pos = (x, y)
logger.info("Start pos {} end pos {}".format(self.start_pos,
self.end_pos))
self.observation_space = Tuple([
Box(0, 100, shape=(2, )), # (x, y)
Discrete(4), # wind direction (N, E, S, W)
])
self.action_space = Discrete(2) # whether to move or not
def reset(self):
self.wind_direction = random.choice([0, 1, 2, 3])
self.pos = self.start_pos
self.num_steps = 0
return [[self.pos[0], self.pos[1]], self.wind_direction]
def step(self, action):
if action == 1:
self.pos = self._get_new_pos(self.pos, self.wind_direction)
self.num_steps += 1
self.wind_direction = random.choice([0, 1, 2, 3])
at_goal = self.pos == self.end_pos
done = at_goal or self.num_steps >= 200
return ([[self.pos[0], self.pos[1]], self.wind_direction],
100 * int(at_goal), done, {})
def _get_new_pos(self, pos, direction):
if direction == 0:
new_pos = (pos[0] - 1, pos[1])
elif direction == 1:
new_pos = (pos[0], pos[1] + 1)
elif direction == 2:
new_pos = (pos[0] + 1, pos[1])
elif direction == 3:
new_pos = (pos[0], pos[1] - 1)
if (new_pos[0] >= 0 and new_pos[0] < self.x_dim and new_pos[1] >= 0
and new_pos[1] < self.y_dim
and self.map[new_pos[0]][new_pos[1]] != "#"):
return new_pos
else:
return pos # did not move
class HierarchicalWindyMazeEnv(MultiAgentEnv):
def __init__(self, env_config):
self.flat_env = WindyMazeEnv(env_config)
def reset(self):
self.cur_obs = self.flat_env.reset()
self.current_goal = None
self.steps_remaining_at_level = None
self.num_high_level_steps = 0
# current low level agent id. This must be unique for each high level
# step since agent ids cannot be reused.
self.low_level_agent_id = "low_level_{}".format(
self.num_high_level_steps)
return {
"high_level_agent": self.cur_obs,
}
def step(self, action_dict):
assert len(action_dict) == 1, action_dict
if "high_level_agent" in action_dict:
return self._high_level_step(action_dict["high_level_agent"])
else:
return self._low_level_step(list(action_dict.values())[0])
def _high_level_step(self, action):
logger.debug("High level agent sets goal".format(action))
self.current_goal = action
self.steps_remaining_at_level = 25
self.num_high_level_steps += 1
self.low_level_agent_id = "low_level_{}".format(
self.num_high_level_steps)
obs = {self.low_level_agent_id: [self.cur_obs, self.current_goal]}
rew = {self.low_level_agent_id: 0}
done = {"__all__": False}
return obs, rew, done, {}
def _low_level_step(self, action):
logger.debug("Low level agent step {}".format(action))
self.steps_remaining_at_level -= 1
cur_pos = tuple(self.cur_obs[0])
goal_pos = self.flat_env._get_new_pos(cur_pos, self.current_goal)
# Step in the actual env
f_obs, f_rew, f_done, _ = self.flat_env.step(action)
new_pos = tuple(f_obs[0])
self.cur_obs = f_obs
# Calculate low-level agent observation and reward
obs = {self.low_level_agent_id: [f_obs, self.current_goal]}
if new_pos != cur_pos:
if new_pos == goal_pos:
rew = {self.low_level_agent_id: 1}
else:
rew = {self.low_level_agent_id: -1}
else:
rew = {self.low_level_agent_id: 0}
# Handle env termination & transitions back to higher level
done = {"__all__": False}
if f_done:
done["__all__"] = True
logger.debug("high level final reward {}".format(f_rew))
rew["high_level_agent"] = f_rew
obs["high_level_agent"] = f_obs
elif self.steps_remaining_at_level == 0:
done[self.low_level_agent_id] = True
rew["high_level_agent"] = 0
obs["high_level_agent"] = f_obs
return obs, rew, done, {}
if __name__ == "__main__":
args = parser.parse_args()
ray.init()
if args.flat:
tune.run(
"PPO",
config={
"env": WindyMazeEnv,
"num_workers": 0,
},
)
else:
maze = WindyMazeEnv(None)
def policy_mapping_fn(agent_id):
if agent_id.startswith("low_level_"):
return "low_level_policy"
else:
return "high_level_policy"
tune.run(
"PPO",
config={
"env": HierarchicalWindyMazeEnv,
"num_workers": 0,
"log_level": "INFO",
"entropy_coeff": 0.01,
"multiagent": {
"policies": {
"high_level_policy": (None, maze.observation_space,
Discrete(4), {
"gamma": 0.9
}),
"low_level_policy": (None,
Tuple([
maze.observation_space,
Discrete(4)
]), maze.action_space, {
"gamma": 0.0
}),
},
"policy_mapping_fn": function(policy_mapping_fn),
},
},
)