2019-01-29 21:06:09 -08:00
|
|
|
"""Example of hierarchical training using the multi-agent API.
|
|
|
|
|
|
|
|
The example env is that of a "windy maze". The agent observes the current wind
|
|
|
|
direction and can either choose to stand still, or move in that direction.
|
|
|
|
|
|
|
|
You can try out the env directly with:
|
|
|
|
|
|
|
|
$ python hierarchical_training.py --flat
|
|
|
|
|
|
|
|
A simple hierarchical formulation involves a high-level agent that issues goals
|
|
|
|
(i.e., go north / south / east / west), and a low-level agent that executes
|
|
|
|
these goals over a number of time-steps. This can be implemented as a
|
|
|
|
multi-agent environment with a top-level agent and low-level agents spawned
|
|
|
|
for each higher-level action. The lower level agent is rewarded for moving
|
|
|
|
in the right direction.
|
|
|
|
|
|
|
|
You can try this formulation with:
|
|
|
|
|
|
|
|
$ python hierarchical_training.py # gets ~100 rew after ~100k timesteps
|
|
|
|
|
|
|
|
Note that the hierarchical formulation actually converges slightly slower than
|
|
|
|
using --flat in this example.
|
|
|
|
"""
|
|
|
|
|
|
|
|
import argparse
|
2020-05-01 22:59:34 +02:00
|
|
|
from gym.spaces import Discrete, Tuple
|
2019-01-29 21:06:09 -08:00
|
|
|
import logging
|
|
|
|
|
|
|
|
import ray
|
2019-03-30 14:07:50 -07:00
|
|
|
from ray import tune
|
2020-05-12 08:23:10 +02:00
|
|
|
from ray.tune import function
|
2020-05-01 22:59:34 +02:00
|
|
|
from ray.rllib.examples.env.windy_maze_env import WindyMazeEnv, \
|
|
|
|
HierarchicalWindyMazeEnv
|
2020-05-12 08:23:10 +02:00
|
|
|
from ray.rllib.utils.test_utils import check_learning_achieved
|
2019-01-29 21:06:09 -08:00
|
|
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument("--flat", action="store_true")
|
2020-05-12 08:23:10 +02:00
|
|
|
parser.add_argument("--as-test", action="store_true")
|
2020-05-01 22:59:34 +02:00
|
|
|
parser.add_argument("--torch", action="store_true")
|
2020-05-12 08:23:10 +02:00
|
|
|
parser.add_argument("--stop-iters", type=int, default=200)
|
2020-05-01 22:59:34 +02:00
|
|
|
parser.add_argument("--stop-reward", type=float, default=0.0)
|
|
|
|
parser.add_argument("--stop-timesteps", type=int, default=100000)
|
2019-01-29 21:06:09 -08:00
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
args = parser.parse_args()
|
|
|
|
ray.init()
|
2020-05-01 22:59:34 +02:00
|
|
|
|
|
|
|
stop = {
|
2020-05-12 08:23:10 +02:00
|
|
|
"training_iteration": args.stop_iters,
|
2020-05-01 22:59:34 +02:00
|
|
|
"timesteps_total": args.stop_timesteps,
|
2020-05-12 08:23:10 +02:00
|
|
|
"episode_reward_mean": args.stop_reward,
|
2020-05-01 22:59:34 +02:00
|
|
|
}
|
|
|
|
|
2019-01-29 21:06:09 -08:00
|
|
|
if args.flat:
|
2020-05-01 22:59:34 +02:00
|
|
|
results = tune.run(
|
2019-03-30 14:07:50 -07:00
|
|
|
"PPO",
|
2020-05-01 22:59:34 +02:00
|
|
|
stop=stop,
|
2019-03-30 14:07:50 -07:00
|
|
|
config={
|
2019-01-29 21:06:09 -08:00
|
|
|
"env": WindyMazeEnv,
|
2019-03-30 14:07:50 -07:00
|
|
|
"num_workers": 0,
|
2020-05-27 16:19:13 +02:00
|
|
|
"framework": "torch" if args.torch else "tf",
|
2019-01-29 21:06:09 -08:00
|
|
|
},
|
2019-03-30 14:07:50 -07:00
|
|
|
)
|
2019-01-29 21:06:09 -08:00
|
|
|
else:
|
|
|
|
maze = WindyMazeEnv(None)
|
|
|
|
|
|
|
|
def policy_mapping_fn(agent_id):
|
|
|
|
if agent_id.startswith("low_level_"):
|
|
|
|
return "low_level_policy"
|
|
|
|
else:
|
|
|
|
return "high_level_policy"
|
|
|
|
|
2020-05-01 22:59:34 +02:00
|
|
|
config = {
|
|
|
|
"env": HierarchicalWindyMazeEnv,
|
|
|
|
"num_workers": 0,
|
|
|
|
"log_level": "INFO",
|
|
|
|
"entropy_coeff": 0.01,
|
|
|
|
"multiagent": {
|
|
|
|
"policies": {
|
|
|
|
"high_level_policy": (None, maze.observation_space,
|
|
|
|
Discrete(4), {
|
|
|
|
"gamma": 0.9
|
|
|
|
}),
|
|
|
|
"low_level_policy": (None,
|
|
|
|
Tuple([
|
|
|
|
maze.observation_space,
|
|
|
|
Discrete(4)
|
|
|
|
]), maze.action_space, {
|
|
|
|
"gamma": 0.0
|
|
|
|
}),
|
2019-01-29 21:06:09 -08:00
|
|
|
},
|
2020-05-01 22:59:34 +02:00
|
|
|
"policy_mapping_fn": function(policy_mapping_fn),
|
2019-01-29 21:06:09 -08:00
|
|
|
},
|
2020-05-27 16:19:13 +02:00
|
|
|
"framework": "torch" if args.torch else "tf",
|
2020-05-01 22:59:34 +02:00
|
|
|
}
|
|
|
|
|
2020-05-12 08:23:10 +02:00
|
|
|
results = tune.run("PPO", stop=stop, config=config)
|
|
|
|
|
|
|
|
if args.as_test:
|
|
|
|
check_learning_achieved(results, args.stop_reward)
|
2020-05-01 22:59:34 +02:00
|
|
|
|
2020-05-12 08:23:10 +02:00
|
|
|
ray.shutdown()
|