[sgd/tune][minor] more tf ports (#5953)

This commit is contained in:
Richard Liaw 2019-10-21 16:46:16 -07:00 committed by GitHub
parent 235dec8aa3
commit 252a5d13ed
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 3 deletions

View file

@ -3,7 +3,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse import argparse
from tensorflow.data import Dataset import tensorflow as tf
from tensorflow.keras.models import Sequential from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense from tensorflow.keras.layers import Dense
import numpy as np import numpy as np
@ -43,8 +43,8 @@ def simple_dataset(config):
x_train, y_train = linear_dataset(size=NUM_TRAIN_SAMPLES) x_train, y_train = linear_dataset(size=NUM_TRAIN_SAMPLES)
x_test, y_test = linear_dataset(size=NUM_TEST_SAMPLES) x_test, y_test = linear_dataset(size=NUM_TEST_SAMPLES)
train_dataset = Dataset.from_tensor_slices((x_train, y_train)) train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
test_dataset = Dataset.from_tensor_slices((x_test, y_test)) test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
train_dataset = train_dataset.shuffle(NUM_TRAIN_SAMPLES).repeat().batch( train_dataset = train_dataset.shuffle(NUM_TRAIN_SAMPLES).repeat().batch(
batch_size) batch_size)
test_dataset = test_dataset.repeat().batch(batch_size) test_dataset = test_dataset.repeat().batch(batch_size)

View file

@ -148,6 +148,11 @@ def tf2_compat_logger(config, logdir, trial=None):
use_tf2_api = (distutils.version.LooseVersion(tf.__version__) >= use_tf2_api = (distutils.version.LooseVersion(tf.__version__) >=
distutils.version.LooseVersion("1.15.0")) distutils.version.LooseVersion("1.15.0"))
if use_tf2_api: if use_tf2_api:
# This is temporarily for RLlib because it disables v2 behavior...
from tensorflow.python import tf2
if not tf2.enabled():
tf = tf.compat.v1
return TFLogger(config, logdir, trial)
tf = tf.compat.v2 # setting this for TF2.0 tf = tf.compat.v2 # setting this for TF2.0
return TF2Logger(config, logdir, trial) return TF2Logger(config, logdir, trial)
else: else:
@ -166,6 +171,10 @@ class TF2Logger(Logger):
""" """
def _init(self): def _init(self):
global tf
if tf is None:
import tensorflow as tf
tf = tf.compat.v2 # setting this for TF2.0
self._file_writer = None self._file_writer = None
self._hp_logged = False self._hp_logged = False
@ -237,6 +246,10 @@ class TFLogger(Logger):
""" """
def _init(self): def _init(self):
global tf
if tf is None:
import tensorflow as tf
tf = tf.compat.v1 # setting this for regular TF logger
logger.debug("Initializing TFLogger instead of TF2Logger.") logger.debug("Initializing TFLogger instead of TF2Logger.")
self._file_writer = tf.summary.FileWriter(self.logdir) self._file_writer = tf.summary.FileWriter(self.logdir)