ray/rllib/utils/tf_run_builder.py
Balaji Veeramani 7f1bacc7dc
[CI] Format Python code with Black (#21975)
See #21316 and #21311 for the motivation behind these changes.
2022-01-29 18:41:57 -08:00

103 lines
3.2 KiB
Python

import logging
import os
import time
from ray.util.debug import log_once
from ray.rllib.utils.framework import try_import_tf
tf1, tf, tfv = try_import_tf()
logger = logging.getLogger(__name__)
class TFRunBuilder:
"""Used to incrementally build up a TensorFlow run.
This is particularly useful for batching ops from multiple different
policies in the multi-agent setting.
"""
def __init__(self, session, debug_name):
self.session = session
self.debug_name = debug_name
self.feed_dict = {}
self.fetches = []
self._executed = None
def add_feed_dict(self, feed_dict):
assert not self._executed
for k in feed_dict:
if k in self.feed_dict:
raise ValueError("Key added twice: {}".format(k))
self.feed_dict.update(feed_dict)
def add_fetches(self, fetches):
assert not self._executed
base_index = len(self.fetches)
self.fetches.extend(fetches)
return list(range(base_index, len(self.fetches)))
def get(self, to_fetch):
if self._executed is None:
try:
self._executed = run_timeline(
self.session,
self.fetches,
self.debug_name,
self.feed_dict,
os.environ.get("TF_TIMELINE_DIR"),
)
except Exception as e:
logger.exception(
"Error fetching: {}, feed_dict={}".format(
self.fetches, self.feed_dict
)
)
raise e
if isinstance(to_fetch, int):
return self._executed[to_fetch]
elif isinstance(to_fetch, list):
return [self.get(x) for x in to_fetch]
elif isinstance(to_fetch, tuple):
return tuple(self.get(x) for x in to_fetch)
else:
raise ValueError("Unsupported fetch type: {}".format(to_fetch))
_count = 0
def run_timeline(sess, ops, debug_name, feed_dict=None, timeline_dir=None):
if feed_dict is None:
feed_dict = {}
if timeline_dir:
from tensorflow.python.client import timeline
run_options = tf1.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf1.RunMetadata()
start = time.time()
fetches = sess.run(
ops, options=run_options, run_metadata=run_metadata, feed_dict=feed_dict
)
trace = timeline.Timeline(step_stats=run_metadata.step_stats)
global _count
outf = os.path.join(
timeline_dir,
"timeline-{}-{}-{}.json".format(debug_name, os.getpid(), _count % 10),
)
_count += 1
trace_file = open(outf, "w")
logger.info(
"Wrote tf timeline ({} s) to {}".format(
time.time() - start, os.path.abspath(outf)
)
)
trace_file.write(trace.generate_chrome_trace_format())
else:
if log_once("tf_timeline"):
logger.info(
"Executing TF run without tracing. To dump TF timeline traces "
"to disk, set the TF_TIMELINE_DIR environment variable."
)
fetches = sess.run(ops, feed_dict=feed_dict)
return fetches