2018-10-21 23:43:57 -07:00
|
|
|
import logging
|
2018-06-25 22:33:57 -07:00
|
|
|
import os
|
|
|
|
import time
|
|
|
|
|
2020-02-27 19:40:44 +01:00
|
|
|
from ray.util.debug import log_once
|
2020-06-16 08:52:20 +02:00
|
|
|
from ray.rllib.utils.framework import try_import_tf
|
2019-03-29 12:44:23 -07:00
|
|
|
|
2020-06-30 10:13:20 +02:00
|
|
|
tf1, tf, tfv = try_import_tf()
|
2018-10-21 23:43:57 -07:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
2018-06-25 22:33:57 -07:00
|
|
|
|
2022-05-24 22:14:25 -07:00
|
|
|
class _TFRunBuilder:
|
2018-06-25 22:33:57 -07:00
|
|
|
"""Used to incrementally build up a TensorFlow run.
|
|
|
|
|
|
|
|
This is particularly useful for batching ops from multiple different
|
|
|
|
policies in the multi-agent setting.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, session, debug_name):
|
|
|
|
self.session = session
|
|
|
|
self.debug_name = debug_name
|
|
|
|
self.feed_dict = {}
|
|
|
|
self.fetches = []
|
|
|
|
self._executed = None
|
|
|
|
|
|
|
|
def add_feed_dict(self, feed_dict):
|
|
|
|
assert not self._executed
|
|
|
|
for k in feed_dict:
|
2018-10-20 15:21:22 -07:00
|
|
|
if k in self.feed_dict:
|
|
|
|
raise ValueError("Key added twice: {}".format(k))
|
2018-06-25 22:33:57 -07:00
|
|
|
self.feed_dict.update(feed_dict)
|
|
|
|
|
|
|
|
def add_fetches(self, fetches):
|
|
|
|
assert not self._executed
|
|
|
|
base_index = len(self.fetches)
|
|
|
|
self.fetches.extend(fetches)
|
|
|
|
return list(range(base_index, len(self.fetches)))
|
|
|
|
|
|
|
|
def get(self, to_fetch):
|
|
|
|
if self._executed is None:
|
|
|
|
try:
|
2022-05-24 22:14:25 -07:00
|
|
|
self._executed = _run_timeline(
|
2018-06-25 22:33:57 -07:00
|
|
|
self.session,
|
|
|
|
self.fetches,
|
|
|
|
self.debug_name,
|
|
|
|
self.feed_dict,
|
|
|
|
os.environ.get("TF_TIMELINE_DIR"),
|
|
|
|
)
|
2020-03-06 19:37:12 +01:00
|
|
|
except Exception as e:
|
2019-05-14 06:39:25 +08:00
|
|
|
logger.exception(
|
|
|
|
"Error fetching: {}, feed_dict={}".format(
|
|
|
|
self.fetches, self.feed_dict
|
|
|
|
)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
2020-03-06 19:37:12 +01:00
|
|
|
raise e
|
2018-06-25 22:33:57 -07:00
|
|
|
if isinstance(to_fetch, int):
|
|
|
|
return self._executed[to_fetch]
|
|
|
|
elif isinstance(to_fetch, list):
|
|
|
|
return [self.get(x) for x in to_fetch]
|
|
|
|
elif isinstance(to_fetch, tuple):
|
|
|
|
return tuple(self.get(x) for x in to_fetch)
|
|
|
|
else:
|
|
|
|
raise ValueError("Unsupported fetch type: {}".format(to_fetch))
|
|
|
|
|
|
|
|
|
|
|
|
_count = 0
|
|
|
|
|
|
|
|
|
2022-05-24 22:14:25 -07:00
|
|
|
def _run_timeline(sess, ops, debug_name, feed_dict=None, timeline_dir=None):
|
2021-10-03 23:24:11 -07:00
|
|
|
if feed_dict is None:
|
|
|
|
feed_dict = {}
|
|
|
|
|
2018-06-25 22:33:57 -07:00
|
|
|
if timeline_dir:
|
2019-05-10 20:36:18 -07:00
|
|
|
from tensorflow.python.client import timeline
|
|
|
|
|
2020-06-30 10:13:20 +02:00
|
|
|
run_options = tf1.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
|
|
|
|
run_metadata = tf1.RunMetadata()
|
2018-06-25 22:33:57 -07:00
|
|
|
start = time.time()
|
|
|
|
fetches = sess.run(
|
2018-07-19 15:30:36 -07:00
|
|
|
ops, options=run_options, run_metadata=run_metadata, feed_dict=feed_dict
|
2018-06-25 22:33:57 -07:00
|
|
|
)
|
|
|
|
trace = timeline.Timeline(step_stats=run_metadata.step_stats)
|
|
|
|
global _count
|
|
|
|
outf = os.path.join(
|
2018-07-19 15:30:36 -07:00
|
|
|
timeline_dir,
|
2019-03-29 12:44:23 -07:00
|
|
|
"timeline-{}-{}-{}.json".format(debug_name, os.getpid(), _count % 10),
|
|
|
|
)
|
2018-06-25 22:33:57 -07:00
|
|
|
_count += 1
|
|
|
|
trace_file = open(outf, "w")
|
2018-10-21 23:43:57 -07:00
|
|
|
logger.info(
|
|
|
|
"Wrote tf timeline ({} s) to {}".format(
|
|
|
|
time.time() - start, os.path.abspath(outf)
|
2022-01-29 18:41:57 -08:00
|
|
|
)
|
|
|
|
)
|
2018-06-25 22:33:57 -07:00
|
|
|
trace_file.write(trace.generate_chrome_trace_format())
|
|
|
|
else:
|
2019-03-29 12:44:23 -07:00
|
|
|
if log_once("tf_timeline"):
|
|
|
|
logger.info(
|
|
|
|
"Executing TF run without tracing. To dump TF timeline traces "
|
|
|
|
"to disk, set the TF_TIMELINE_DIR environment variable."
|
|
|
|
)
|
2018-06-25 22:33:57 -07:00
|
|
|
fetches = sess.run(ops, feed_dict=feed_dict)
|
|
|
|
return fetches
|