mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
324 lines
12 KiB
Python
324 lines
12 KiB
Python
""" RLLIB SUMO Utils - SUMO Connector Wrapper
|
|
|
|
Author: Lara CODECA lara.codeca@gmail.com
|
|
|
|
See:
|
|
https://github.com/lcodeca/rllibsumoutils
|
|
https://github.com/lcodeca/rllibsumodocker
|
|
for further details.
|
|
"""
|
|
|
|
import collections
|
|
from copy import deepcopy
|
|
import logging
|
|
import os
|
|
from pprint import pformat
|
|
import sys
|
|
|
|
from lxml import etree
|
|
|
|
from ray.rllib.contrib.sumo.connector import SUMOConnector, DEFAULT_CONFIG
|
|
|
|
# """ Import SUMO library """
|
|
if "SUMO_HOME" in os.environ:
|
|
sys.path.append(os.path.join(os.environ["SUMO_HOME"], "tools"))
|
|
# from traci.exceptions import TraCIException
|
|
import traci.constants as tc
|
|
else:
|
|
sys.exit("please declare environment variable 'SUMO_HOME'")
|
|
|
|
###############################################################################
|
|
|
|
logging.basicConfig()
|
|
logger = logging.getLogger(__name__)
|
|
|
|
###############################################################################
|
|
|
|
|
|
def sumo_default_config():
|
|
""" Return the default configuration for the SUMO Connector. """
|
|
return deepcopy(DEFAULT_CONFIG)
|
|
|
|
|
|
###############################################################################
|
|
|
|
|
|
class SUMOUtils(SUMOConnector):
|
|
"""
|
|
A wrapper for the interaction with the SUMO simulation that adds
|
|
functionalities.
|
|
"""
|
|
|
|
def _initialize_metrics(self):
|
|
""" Specific metrics initialization """
|
|
# Default TripInfo file metrics
|
|
self.tripinfo = collections.defaultdict(dict)
|
|
self.personinfo = collections.defaultdict(dict)
|
|
|
|
###########################################################################
|
|
# TRIPINFO FILE
|
|
|
|
def process_tripinfo_file(self):
|
|
"""
|
|
Closes the TraCI connections, then reads and process the tripinfo
|
|
data. It requires "tripinfo_xml_file" and "tripinfo_xml_schema"
|
|
configuration parametes set.
|
|
"""
|
|
|
|
if "tripinfo_keyword" not in self._config:
|
|
raise Exception(
|
|
"Function process_tripinfo_file requires the parameter "
|
|
"'tripinfo_keyword' set.", self._config)
|
|
|
|
if "tripinfo_xml_schema" not in self._config:
|
|
raise Exception(
|
|
"Function process_tripinfo_file requires the parameter "
|
|
"'tripinfo_xml_schema' set.", self._config)
|
|
|
|
# Make sure that the simulation is finished and the tripinfo file is
|
|
# written.
|
|
self.end_simulation()
|
|
|
|
# Reset the data structures.
|
|
self.tripinfo = collections.defaultdict(dict)
|
|
self.personinfo = collections.defaultdict(dict)
|
|
|
|
schema = etree.XMLSchema(file=self._config["tripinfo_xml_schema"])
|
|
parser = etree.XMLParser(schema=schema)
|
|
tripinfo_file = "{}{}".format(self._sumo_output_prefix,
|
|
self._config["tripinfo_keyword"])
|
|
tree = etree.parse(tripinfo_file, parser)
|
|
|
|
logger.info("Processing %s tripinfo file.", tripinfo_file)
|
|
for element in tree.getroot():
|
|
if element.tag == "tripinfo":
|
|
self.tripinfo[element.attrib["id"]] = dict(element.attrib)
|
|
elif element.tag == "personinfo":
|
|
self.personinfo[element.attrib["id"]] = dict(element.attrib)
|
|
stages = []
|
|
for stage in element:
|
|
stages.append([stage.tag, dict(stage.attrib)])
|
|
self.personinfo[element.attrib["id"]]["stages"] = stages
|
|
else:
|
|
raise Exception("Unrecognized element in the tripinfo file.")
|
|
logger.debug("TRIPINFO: \n%s", pformat(self.tripinfo))
|
|
logger.debug("PERSONINFO: \n%s", pformat(self.personinfo))
|
|
|
|
def get_timeloss(self, entity, default=float("NaN")):
|
|
""" Returns the timeLoss computed by SUMO for the given entity. """
|
|
|
|
if entity in self.tripinfo:
|
|
logger.debug("TRIPINFO for %s", entity)
|
|
if "timeLoss" in self.tripinfo[entity]:
|
|
logger.debug("timeLoss %s", self.tripinfo[entity]["timeLoss"])
|
|
return float(self.tripinfo[entity]["timeLoss"])
|
|
logger.debug("timeLoss not found.")
|
|
return default
|
|
elif entity in self.personinfo:
|
|
logger.debug("PERSONINFO for %s", entity)
|
|
logger.debug("%s", pformat(self.personinfo[entity]))
|
|
time_loss, ts_found = 0.0, False
|
|
for _, stage in self.personinfo[entity]["stages"]:
|
|
if "timeLoss" in stage:
|
|
logger.debug("timeLoss %s", stage["timeLoss"])
|
|
time_loss += float(stage["timeLoss"])
|
|
ts_found = True
|
|
if not ts_found:
|
|
logger.debug("timeLoss not found.")
|
|
return default
|
|
if time_loss <= 0:
|
|
logger.debug("ERROR: timeLoss is %.2f", time_loss)
|
|
return default
|
|
logger.debug("total timeLoss %.2f", time_loss)
|
|
return time_loss
|
|
else:
|
|
logger.debug("Entity %s not found.", entity)
|
|
return default
|
|
|
|
def get_depart(self, entity, default=float("NaN")):
|
|
"""
|
|
Returns the departure recorded by SUMO for the given entity.
|
|
|
|
The functions process_tripinfo_file() needs to be called in advance
|
|
to initialize the data structures required.
|
|
|
|
If the entity does not exist or does not have the value, it returns
|
|
the default value.
|
|
"""
|
|
if entity in self.tripinfo:
|
|
logger.debug("TRIPINFO for %s", entity)
|
|
if "depart" in self.tripinfo[entity]:
|
|
logger.debug("depart %s", self.tripinfo[entity]["depart"])
|
|
return float(self.tripinfo[entity]["depart"])
|
|
logger.debug("depart not found.")
|
|
elif entity in self.personinfo:
|
|
logger.debug("PERSONINFO for %s", entity)
|
|
logger.debug("%s", pformat(self.personinfo[entity]))
|
|
if "depart" in self.personinfo[entity]:
|
|
logger.debug("depart %s", self.personinfo[entity]["depart"])
|
|
return float(self.personinfo[entity]["depart"])
|
|
logger.debug("depart not found.")
|
|
else:
|
|
logger.debug("Entity %s not found.", entity)
|
|
return default
|
|
|
|
def get_duration(self, entity, default=float("NaN")):
|
|
"""
|
|
Returns the duration computed by SUMO for the given entity.
|
|
|
|
The functions process_tripinfo_file() needs to be called in advance
|
|
to initialize the data structures required.
|
|
|
|
If the entity does not exist or does not have the value, it returns
|
|
the default value.
|
|
"""
|
|
if entity in self.tripinfo:
|
|
logger.debug("TRIPINFO for %s", entity)
|
|
if "duration" in self.tripinfo[entity]:
|
|
logger.debug("duration %s", self.tripinfo[entity]["duration"])
|
|
return float(self.tripinfo[entity]["duration"])
|
|
logger.debug("duration not found.")
|
|
elif entity in self.personinfo:
|
|
logger.debug("PERSONINFO for %s", entity)
|
|
logger.debug("%s", pformat(self.personinfo[entity]))
|
|
if "depart" in self.personinfo[entity]:
|
|
depart = float(self.personinfo[entity]["depart"])
|
|
arrival = depart
|
|
for _, stage in self.personinfo[entity]["stages"]:
|
|
if "arrival" in stage:
|
|
arrival = float(stage["arrival"])
|
|
duration = arrival - depart
|
|
if duration > 0:
|
|
logger.debug("duration %d", duration)
|
|
return duration
|
|
logger.debug("duration impossible to compute.")
|
|
else:
|
|
logger.debug("Entity %s not found.", entity)
|
|
return default
|
|
|
|
def get_arrival(self, entity, default=float("NaN")):
|
|
"""
|
|
Returns the arrival computed by SUMO for the given entity.
|
|
|
|
The functions process_tripinfo_file() needs to be called in advance
|
|
to initialize the data structures required.
|
|
|
|
If the entity does not exist or does not have the value, it returns
|
|
the default value.
|
|
"""
|
|
if entity in self.tripinfo:
|
|
logger.debug("TRIPINFO for %s", entity)
|
|
if "arrival" in self.tripinfo[entity]:
|
|
logger.debug("arrival %s", self.tripinfo[entity]["arrival"])
|
|
return float(self.tripinfo[entity]["arrival"])
|
|
logger.debug("arrival not found.")
|
|
return default
|
|
elif entity in self.personinfo:
|
|
logger.debug("PERSONINFO for %s", entity)
|
|
arrival, arrival_found = 0.0, False
|
|
for _, stage in self.personinfo[entity]["stages"]:
|
|
if "arrival" in stage:
|
|
logger.debug("arrival %s", stage["arrival"])
|
|
arrival = float(stage["arrival"])
|
|
arrival_found = True
|
|
if not arrival_found:
|
|
logger.debug("arrival not found.")
|
|
return default
|
|
if arrival <= 0:
|
|
logger.debug("ERROR: arrival is %.2f", arrival)
|
|
return default
|
|
logger.debug("total arrival %.2f", arrival)
|
|
return arrival
|
|
else:
|
|
logger.debug("Entity %s not found.", entity)
|
|
return default
|
|
|
|
def get_global_travel_time(self):
|
|
"""
|
|
Returns the global travel time computed from SUMO tripinfo data.
|
|
|
|
The functions process_tripinfo_file() needs to be called in advance
|
|
to initialize the data structures required.
|
|
"""
|
|
gtt = 0
|
|
for entity in self.tripinfo:
|
|
gtt += self.get_duration(entity, default=0.0)
|
|
for entity in self.personinfo:
|
|
gtt += self.get_duration(entity, default=0.0)
|
|
return gtt
|
|
|
|
###########################################################################
|
|
# ROUTING
|
|
|
|
@staticmethod
|
|
def get_mode_parameters(mode):
|
|
"""
|
|
Return the correst TraCI parameters for the requested mode.
|
|
See: https://sumo.dlr.de/docs/TraCI/Simulation_Value_Retrieval.html
|
|
#command_0x87_find_intermodal_route
|
|
|
|
Param: mode, String.
|
|
Returns: _mode, _ptype, _vtype
|
|
"""
|
|
if mode == "public":
|
|
return "public", "", ""
|
|
if mode == "bicycle":
|
|
return "bicycle", "", "bicycle"
|
|
if mode == "walk":
|
|
return "", "pedestrian", ""
|
|
return "car", "", mode # (but car is not always necessary, and it may
|
|
# creates unusable alternatives)
|
|
|
|
def is_valid_route(self, mode, route):
|
|
"""
|
|
Handle findRoute and findIntermodalRoute results.
|
|
|
|
Params:
|
|
mode, String.
|
|
route, return value of findRoute or findIntermodalRoute.
|
|
"""
|
|
if route is None:
|
|
# traci failed
|
|
return False
|
|
_mode, _ptype, _vtype = self.get_mode_parameters(mode)
|
|
if not isinstance(route, (list, tuple)):
|
|
# only for findRoute
|
|
if len(route.edges) >= 2:
|
|
return True
|
|
elif _mode == "public":
|
|
for stage in route:
|
|
if stage.line:
|
|
return True
|
|
elif _mode in ("car", "bicycle"):
|
|
for stage in route:
|
|
if stage.type == tc.STAGE_DRIVING and len(stage.edges) >= 2:
|
|
return True
|
|
else:
|
|
for stage in route:
|
|
if len(stage.edges) >= 2:
|
|
return True
|
|
return False
|
|
|
|
@staticmethod
|
|
def cost_from_route(route):
|
|
"""
|
|
Compute the route cost.
|
|
Params:
|
|
route, return value of findRoute or findIntermodalRoute.
|
|
"""
|
|
cost = 0.0
|
|
for stage in route:
|
|
cost += stage.cost
|
|
return cost
|
|
|
|
@staticmethod
|
|
def travel_time_from_route(route):
|
|
"""
|
|
Compute the route estimated travel time.
|
|
Params:
|
|
route, return value of findRoute or findIntermodalRoute.
|
|
"""
|
|
ett = 0.0
|
|
for stage in route:
|
|
ett += stage.estimatedTime
|
|
return ett
|