mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
Lazily create summary writer for TF2 logger. (#5631)
This commit is contained in:
parent
8236936189
commit
3ea9062419
1 changed files with 9 additions and 5 deletions
|
@ -154,11 +154,13 @@ def tf2_compat_logger(config, logdir):
|
||||||
|
|
||||||
class TF2Logger(Logger):
|
class TF2Logger(Logger):
|
||||||
def _init(self):
|
def _init(self):
|
||||||
from tensorflow.python.eager import context
|
self._file_writer = None
|
||||||
self._context = context
|
|
||||||
self._file_writer = tf.summary.create_file_writer(self.logdir)
|
|
||||||
|
|
||||||
def on_result(self, result):
|
def on_result(self, result):
|
||||||
|
if self._file_writer is None:
|
||||||
|
from tensorflow.python.eager import context
|
||||||
|
self._context = context
|
||||||
|
self._file_writer = tf.summary.create_file_writer(self.logdir)
|
||||||
with tf.device("/CPU:0"), self._context.eager_mode():
|
with tf.device("/CPU:0"), self._context.eager_mode():
|
||||||
with tf.summary.record_if(True), self._file_writer.as_default():
|
with tf.summary.record_if(True), self._file_writer.as_default():
|
||||||
step = result.get(
|
step = result.get(
|
||||||
|
@ -181,10 +183,12 @@ class TF2Logger(Logger):
|
||||||
self._file_writer.flush()
|
self._file_writer.flush()
|
||||||
|
|
||||||
def flush(self):
|
def flush(self):
|
||||||
self._file_writer.flush()
|
if self._file_writer is not None:
|
||||||
|
self._file_writer.flush()
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
self._file_writer.close()
|
if self._file_writer is not None:
|
||||||
|
self._file_writer.close()
|
||||||
|
|
||||||
|
|
||||||
def to_tf_values(result, path):
|
def to_tf_values(result, path):
|
||||||
|
|
Loading…
Add table
Reference in a new issue