diff --git a/python/ray/tune/logger.py b/python/ray/tune/logger.py index fdc73ed2b..bdec2a8bf 100644 --- a/python/ray/tune/logger.py +++ b/python/ray/tune/logger.py @@ -144,9 +144,9 @@ def tf2_compat_logger(config, logdir): else: import tensorflow as tf use_tf2_api = (distutils.version.LooseVersion(tf.__version__) >= - distutils.version.LooseVersion("1.14.0")) + distutils.version.LooseVersion("2.0.0")) if use_tf2_api: - tf = tf.compat.v2 # setting this for 1.14 + tf = tf.compat.v2 # setting this for TF2.0 return TF2Logger(config, logdir) else: return TFLogger(config, logdir) @@ -154,11 +154,13 @@ def tf2_compat_logger(config, logdir): class TF2Logger(Logger): def _init(self): + from tensorflow.python.eager import context + self._context = context self._file_writer = tf.summary.create_file_writer(self.logdir) def on_result(self, result): - with tf.device("/CPU:0"): - with self._file_writer.as_default(): + with tf.device("/CPU:0"), self._context.eager_mode(): + with tf.summary.record_if(True), self._file_writer.as_default(): step = result.get( TIMESTEPS_TOTAL) or result[TRAINING_ITERATION] @@ -197,10 +199,8 @@ def to_tf_values(result, path): class TFLogger(Logger): def _init(self): - logger.info( - "Initializing TFLogger instead of TF2Logger. We recommend " - "migrating to TF2.0. This class will be removed in the future.") - self._file_writer = tf.summary.FileWriter(self.logdir) + logger.info("Initializing TFLogger instead of TF2Logger.") + self._file_writer = tf.compat.v1.summary.FileWriter(self.logdir) def on_result(self, result): tmp = result.copy()