ray/rllib/examples/sumo_env_local.py
2021-03-08 15:41:27 +01:00

154 lines
5 KiB
Python

#!/usr/bin/env python3
""" Example Trainer for RLLIB + SUMO Utlis
Author: Lara CODECA lara.codeca@gmail.com
See:
https://github.com/lcodeca/rllibsumoutils
https://github.com/lcodeca/rllibsumodocker
for further details.
"""
import argparse
from copy import deepcopy
import logging
import os
import pathlib
from pprint import pformat
import ray
from ray import tune
from ray.rllib.agents.ppo import ppo
from ray.rllib.examples.simulators.sumo import marlenvironment
from ray.rllib.utils.test_utils import check_learning_achieved
logging.basicConfig(level=logging.WARN)
logger = logging.getLogger("ppotrain")
parser = argparse.ArgumentParser()
parser.add_argument(
"--sumo-connect-lib",
type=str,
default="libsumo",
choices=["libsumo", "traci"],
help="The SUMO connector to import. "
"Requires the env variable SUMO_HOME set.")
parser.add_argument(
"--sumo-gui",
action="store_true",
help="Enables the SUMO GUI. Possible only with TraCI connector.")
parser.add_argument(
"--sumo-config-file",
type=str,
default=None,
help="The SUMO configuration file for the scenario.")
parser.add_argument(
"--from-checkpoint",
type=str,
default=None,
help="Full path to a checkpoint file for restoring a previously saved "
"Trainer state.")
parser.add_argument("--num-workers", type=int, default=0)
parser.add_argument("--as-test", action="store_true")
parser.add_argument("--stop-iters", type=int, default=10)
parser.add_argument("--stop-reward", type=float, default=30000.0)
parser.add_argument("--stop-timesteps", type=int, default=10000000)
if __name__ == "__main__":
args = parser.parse_args()
ray.init()
tune.register_env("sumo_test_env", marlenvironment.env_creator)
# Algorithm.
policy_class = ppo.PPOTFPolicy
config = ppo.DEFAULT_CONFIG
config["framework"] = "tf"
config["gamma"] = 0.99
config["lambda"] = 0.95
config["log_level"] = "WARN"
config["lr"] = 0.001
config["min_iter_time_s"] = 5
config["num_gpus"] = int(os.environ.get("RLLIB_NUM_GPUS", "0"))
config["num_workers"] = args.num_workers
config["rollout_fragment_length"] = 200
config["sgd_minibatch_size"] = 256
config["train_batch_size"] = 4000
config["batch_mode"] = "complete_episodes"
config["no_done_at_end"] = True
# Load default Scenario configuration for the LEARNING ENVIRONMENT
scenario_config = deepcopy(marlenvironment.DEFAULT_SCENARIO_CONFING)
scenario_config["seed"] = 42
scenario_config["log_level"] = "INFO"
scenario_config["sumo_config"]["sumo_connector"] = args.sumo_connect_lib
scenario_config["sumo_config"]["sumo_gui"] = args.sumo_gui
if args.sumo_config_file is not None:
scenario_config["sumo_config"]["sumo_cfg"] = args.sumo_config_file
else:
filename = "{}/simulators/sumo/scenario/sumo.cfg.xml".format(
pathlib.Path(__file__).parent.absolute())
scenario_config["sumo_config"]["sumo_cfg"] = filename
scenario_config["sumo_config"]["sumo_params"] = [
"--collision.action", "warn"
]
scenario_config["sumo_config"]["trace_file"] = True
scenario_config["sumo_config"]["end_of_sim"] = 3600 # [s]
scenario_config["sumo_config"][
"update_freq"] = 10 # number of traci.simulationStep()
# for each learning step.
scenario_config["sumo_config"]["log_level"] = "INFO"
logger.info("Scenario Configuration: \n %s", pformat(scenario_config))
# Associate the agents with their configuration.
agent_init = {
"agent_0": deepcopy(marlenvironment.DEFAULT_AGENT_CONFING),
"agent_1": deepcopy(marlenvironment.DEFAULT_AGENT_CONFING),
}
logger.info("Agents Configuration: \n %s", pformat(agent_init))
# MARL Environment Init
env_config = {
"agent_init": agent_init,
"scenario_config": scenario_config,
}
marl_env = marlenvironment.SUMOTestMultiAgentEnv(env_config)
# Config for the PPO trainer from the MARLEnv
policies = {}
for agent in marl_env.get_agents():
agent_policy_params = {}
policies[agent] = (policy_class, marl_env.get_obs_space(agent),
marl_env.get_action_space(agent),
agent_policy_params)
config["multiagent"]["policies"] = policies
config["multiagent"]["policy_mapping_fn"] = lambda agent_id: agent_id
config["multiagent"]["policies_to_train"] = ["ppo_policy"]
config["env"] = "sumo_test_env"
config["env_config"] = env_config
logger.info("PPO Configuration: \n %s", pformat(config))
stop = {
"training_iteration": args.stop_iters,
"timesteps_total": args.stop_timesteps,
"episode_reward_mean": args.stop_reward,
}
# Run the experiment.
results = tune.run(
"PPO",
config=config,
stop=stop,
verbose=1,
checkpoint_freq=10,
restore=args.from_checkpoint)
# And check the results.
if args.as_test:
check_learning_achieved(results, args.stop_reward)
ray.shutdown()