[sgd/tune][minor] more tf ports (#5953)

This commit is contained in:
Richard Liaw 2019-10-21 16:46:16 -07:00 committed by GitHub
parent 235dec8aa3
commit 252a5d13ed
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 3 deletions

View file

@ -3,7 +3,7 @@ from __future__ import division
from __future__ import print_function
import argparse
from tensorflow.data import Dataset
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
import numpy as np
@ -43,8 +43,8 @@ def simple_dataset(config):
x_train, y_train = linear_dataset(size=NUM_TRAIN_SAMPLES)
x_test, y_test = linear_dataset(size=NUM_TEST_SAMPLES)
train_dataset = Dataset.from_tensor_slices((x_train, y_train))
test_dataset = Dataset.from_tensor_slices((x_test, y_test))
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
train_dataset = train_dataset.shuffle(NUM_TRAIN_SAMPLES).repeat().batch(
batch_size)
test_dataset = test_dataset.repeat().batch(batch_size)

View file

@ -148,6 +148,11 @@ def tf2_compat_logger(config, logdir, trial=None):
use_tf2_api = (distutils.version.LooseVersion(tf.__version__) >=
distutils.version.LooseVersion("1.15.0"))
if use_tf2_api:
# This is temporarily for RLlib because it disables v2 behavior...
from tensorflow.python import tf2
if not tf2.enabled():
tf = tf.compat.v1
return TFLogger(config, logdir, trial)
tf = tf.compat.v2 # setting this for TF2.0
return TF2Logger(config, logdir, trial)
else:
@ -166,6 +171,10 @@ class TF2Logger(Logger):
"""
def _init(self):
global tf
if tf is None:
import tensorflow as tf
tf = tf.compat.v2 # setting this for TF2.0
self._file_writer = None
self._hp_logged = False
@ -237,6 +246,10 @@ class TFLogger(Logger):
"""
def _init(self):
global tf
if tf is None:
import tensorflow as tf
tf = tf.compat.v1 # setting this for regular TF logger
logger.debug("Initializing TFLogger instead of TF2Logger.")
self._file_writer = tf.summary.FileWriter(self.logdir)