ray/rllib/contrib/sumo/utils.py
2020-11-03 09:45:03 +01:00

324 lines
12 KiB
Python

""" RLLIB SUMO Utils - SUMO Connector Wrapper
Author: Lara CODECA lara.codeca@gmail.com
See:
https://github.com/lcodeca/rllibsumoutils
https://github.com/lcodeca/rllibsumodocker
for further details.
"""
import collections
from copy import deepcopy
import logging
import os
from pprint import pformat
import sys
from lxml import etree
from ray.rllib.contrib.sumo.connector import SUMOConnector, DEFAULT_CONFIG
# """ Import SUMO library """
if "SUMO_HOME" in os.environ:
sys.path.append(os.path.join(os.environ["SUMO_HOME"], "tools"))
# from traci.exceptions import TraCIException
import traci.constants as tc
else:
sys.exit("please declare environment variable 'SUMO_HOME'")
###############################################################################
logging.basicConfig()
logger = logging.getLogger(__name__)
###############################################################################
def sumo_default_config():
""" Return the default configuration for the SUMO Connector. """
return deepcopy(DEFAULT_CONFIG)
###############################################################################
class SUMOUtils(SUMOConnector):
"""
A wrapper for the interaction with the SUMO simulation that adds
functionalities.
"""
def _initialize_metrics(self):
""" Specific metrics initialization """
# Default TripInfo file metrics
self.tripinfo = collections.defaultdict(dict)
self.personinfo = collections.defaultdict(dict)
###########################################################################
# TRIPINFO FILE
def process_tripinfo_file(self):
"""
Closes the TraCI connections, then reads and process the tripinfo
data. It requires "tripinfo_xml_file" and "tripinfo_xml_schema"
configuration parametes set.
"""
if "tripinfo_keyword" not in self._config:
raise Exception(
"Function process_tripinfo_file requires the parameter "
"'tripinfo_keyword' set.", self._config)
if "tripinfo_xml_schema" not in self._config:
raise Exception(
"Function process_tripinfo_file requires the parameter "
"'tripinfo_xml_schema' set.", self._config)
# Make sure that the simulation is finished and the tripinfo file is
# written.
self.end_simulation()
# Reset the data structures.
self.tripinfo = collections.defaultdict(dict)
self.personinfo = collections.defaultdict(dict)
schema = etree.XMLSchema(file=self._config["tripinfo_xml_schema"])
parser = etree.XMLParser(schema=schema)
tripinfo_file = "{}{}".format(self._sumo_output_prefix,
self._config["tripinfo_keyword"])
tree = etree.parse(tripinfo_file, parser)
logger.info("Processing %s tripinfo file.", tripinfo_file)
for element in tree.getroot():
if element.tag == "tripinfo":
self.tripinfo[element.attrib["id"]] = dict(element.attrib)
elif element.tag == "personinfo":
self.personinfo[element.attrib["id"]] = dict(element.attrib)
stages = []
for stage in element:
stages.append([stage.tag, dict(stage.attrib)])
self.personinfo[element.attrib["id"]]["stages"] = stages
else:
raise Exception("Unrecognized element in the tripinfo file.")
logger.debug("TRIPINFO: \n%s", pformat(self.tripinfo))
logger.debug("PERSONINFO: \n%s", pformat(self.personinfo))
def get_timeloss(self, entity, default=float("NaN")):
""" Returns the timeLoss computed by SUMO for the given entity. """
if entity in self.tripinfo:
logger.debug("TRIPINFO for %s", entity)
if "timeLoss" in self.tripinfo[entity]:
logger.debug("timeLoss %s", self.tripinfo[entity]["timeLoss"])
return float(self.tripinfo[entity]["timeLoss"])
logger.debug("timeLoss not found.")
return default
elif entity in self.personinfo:
logger.debug("PERSONINFO for %s", entity)
logger.debug("%s", pformat(self.personinfo[entity]))
time_loss, ts_found = 0.0, False
for _, stage in self.personinfo[entity]["stages"]:
if "timeLoss" in stage:
logger.debug("timeLoss %s", stage["timeLoss"])
time_loss += float(stage["timeLoss"])
ts_found = True
if not ts_found:
logger.debug("timeLoss not found.")
return default
if time_loss <= 0:
logger.debug("ERROR: timeLoss is %.2f", time_loss)
return default
logger.debug("total timeLoss %.2f", time_loss)
return time_loss
else:
logger.debug("Entity %s not found.", entity)
return default
def get_depart(self, entity, default=float("NaN")):
"""
Returns the departure recorded by SUMO for the given entity.
The functions process_tripinfo_file() needs to be called in advance
to initialize the data structures required.
If the entity does not exist or does not have the value, it returns
the default value.
"""
if entity in self.tripinfo:
logger.debug("TRIPINFO for %s", entity)
if "depart" in self.tripinfo[entity]:
logger.debug("depart %s", self.tripinfo[entity]["depart"])
return float(self.tripinfo[entity]["depart"])
logger.debug("depart not found.")
elif entity in self.personinfo:
logger.debug("PERSONINFO for %s", entity)
logger.debug("%s", pformat(self.personinfo[entity]))
if "depart" in self.personinfo[entity]:
logger.debug("depart %s", self.personinfo[entity]["depart"])
return float(self.personinfo[entity]["depart"])
logger.debug("depart not found.")
else:
logger.debug("Entity %s not found.", entity)
return default
def get_duration(self, entity, default=float("NaN")):
"""
Returns the duration computed by SUMO for the given entity.
The functions process_tripinfo_file() needs to be called in advance
to initialize the data structures required.
If the entity does not exist or does not have the value, it returns
the default value.
"""
if entity in self.tripinfo:
logger.debug("TRIPINFO for %s", entity)
if "duration" in self.tripinfo[entity]:
logger.debug("duration %s", self.tripinfo[entity]["duration"])
return float(self.tripinfo[entity]["duration"])
logger.debug("duration not found.")
elif entity in self.personinfo:
logger.debug("PERSONINFO for %s", entity)
logger.debug("%s", pformat(self.personinfo[entity]))
if "depart" in self.personinfo[entity]:
depart = float(self.personinfo[entity]["depart"])
arrival = depart
for _, stage in self.personinfo[entity]["stages"]:
if "arrival" in stage:
arrival = float(stage["arrival"])
duration = arrival - depart
if duration > 0:
logger.debug("duration %d", duration)
return duration
logger.debug("duration impossible to compute.")
else:
logger.debug("Entity %s not found.", entity)
return default
def get_arrival(self, entity, default=float("NaN")):
"""
Returns the arrival computed by SUMO for the given entity.
The functions process_tripinfo_file() needs to be called in advance
to initialize the data structures required.
If the entity does not exist or does not have the value, it returns
the default value.
"""
if entity in self.tripinfo:
logger.debug("TRIPINFO for %s", entity)
if "arrival" in self.tripinfo[entity]:
logger.debug("arrival %s", self.tripinfo[entity]["arrival"])
return float(self.tripinfo[entity]["arrival"])
logger.debug("arrival not found.")
return default
elif entity in self.personinfo:
logger.debug("PERSONINFO for %s", entity)
arrival, arrival_found = 0.0, False
for _, stage in self.personinfo[entity]["stages"]:
if "arrival" in stage:
logger.debug("arrival %s", stage["arrival"])
arrival = float(stage["arrival"])
arrival_found = True
if not arrival_found:
logger.debug("arrival not found.")
return default
if arrival <= 0:
logger.debug("ERROR: arrival is %.2f", arrival)
return default
logger.debug("total arrival %.2f", arrival)
return arrival
else:
logger.debug("Entity %s not found.", entity)
return default
def get_global_travel_time(self):
"""
Returns the global travel time computed from SUMO tripinfo data.
The functions process_tripinfo_file() needs to be called in advance
to initialize the data structures required.
"""
gtt = 0
for entity in self.tripinfo:
gtt += self.get_duration(entity, default=0.0)
for entity in self.personinfo:
gtt += self.get_duration(entity, default=0.0)
return gtt
###########################################################################
# ROUTING
@staticmethod
def get_mode_parameters(mode):
"""
Return the correst TraCI parameters for the requested mode.
See: https://sumo.dlr.de/docs/TraCI/Simulation_Value_Retrieval.html
#command_0x87_find_intermodal_route
Param: mode, String.
Returns: _mode, _ptype, _vtype
"""
if mode == "public":
return "public", "", ""
if mode == "bicycle":
return "bicycle", "", "bicycle"
if mode == "walk":
return "", "pedestrian", ""
return "car", "", mode # (but car is not always necessary, and it may
# creates unusable alternatives)
def is_valid_route(self, mode, route):
"""
Handle findRoute and findIntermodalRoute results.
Params:
mode, String.
route, return value of findRoute or findIntermodalRoute.
"""
if route is None:
# traci failed
return False
_mode, _ptype, _vtype = self.get_mode_parameters(mode)
if not isinstance(route, (list, tuple)):
# only for findRoute
if len(route.edges) >= 2:
return True
elif _mode == "public":
for stage in route:
if stage.line:
return True
elif _mode in ("car", "bicycle"):
for stage in route:
if stage.type == tc.STAGE_DRIVING and len(stage.edges) >= 2:
return True
else:
for stage in route:
if len(stage.edges) >= 2:
return True
return False
@staticmethod
def cost_from_route(route):
"""
Compute the route cost.
Params:
route, return value of findRoute or findIntermodalRoute.
"""
cost = 0.0
for stage in route:
cost += stage.cost
return cost
@staticmethod
def travel_time_from_route(route):
"""
Compute the route estimated travel time.
Params:
route, return value of findRoute or findIntermodalRoute.
"""
ett = 0.0
for stage in route:
ett += stage.estimatedTime
return ett