mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[sgd/tune][minor] more tf ports (#5953)
This commit is contained in:
parent
235dec8aa3
commit
252a5d13ed
2 changed files with 16 additions and 3 deletions
|
@ -3,7 +3,7 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
from tensorflow.data import Dataset
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.models import Sequential
|
||||
from tensorflow.keras.layers import Dense
|
||||
import numpy as np
|
||||
|
@ -43,8 +43,8 @@ def simple_dataset(config):
|
|||
x_train, y_train = linear_dataset(size=NUM_TRAIN_SAMPLES)
|
||||
x_test, y_test = linear_dataset(size=NUM_TEST_SAMPLES)
|
||||
|
||||
train_dataset = Dataset.from_tensor_slices((x_train, y_train))
|
||||
test_dataset = Dataset.from_tensor_slices((x_test, y_test))
|
||||
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
|
||||
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
|
||||
train_dataset = train_dataset.shuffle(NUM_TRAIN_SAMPLES).repeat().batch(
|
||||
batch_size)
|
||||
test_dataset = test_dataset.repeat().batch(batch_size)
|
||||
|
|
|
@ -148,6 +148,11 @@ def tf2_compat_logger(config, logdir, trial=None):
|
|||
use_tf2_api = (distutils.version.LooseVersion(tf.__version__) >=
|
||||
distutils.version.LooseVersion("1.15.0"))
|
||||
if use_tf2_api:
|
||||
# This is temporarily for RLlib because it disables v2 behavior...
|
||||
from tensorflow.python import tf2
|
||||
if not tf2.enabled():
|
||||
tf = tf.compat.v1
|
||||
return TFLogger(config, logdir, trial)
|
||||
tf = tf.compat.v2 # setting this for TF2.0
|
||||
return TF2Logger(config, logdir, trial)
|
||||
else:
|
||||
|
@ -166,6 +171,10 @@ class TF2Logger(Logger):
|
|||
"""
|
||||
|
||||
def _init(self):
|
||||
global tf
|
||||
if tf is None:
|
||||
import tensorflow as tf
|
||||
tf = tf.compat.v2 # setting this for TF2.0
|
||||
self._file_writer = None
|
||||
self._hp_logged = False
|
||||
|
||||
|
@ -237,6 +246,10 @@ class TFLogger(Logger):
|
|||
"""
|
||||
|
||||
def _init(self):
|
||||
global tf
|
||||
if tf is None:
|
||||
import tensorflow as tf
|
||||
tf = tf.compat.v1 # setting this for regular TF logger
|
||||
logger.debug("Initializing TFLogger instead of TF2Logger.")
|
||||
self._file_writer = tf.summary.FileWriter(self.logdir)
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue