mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[sgd/tune][minor] more tf ports (#5953)
This commit is contained in:
parent
235dec8aa3
commit
252a5d13ed
2 changed files with 16 additions and 3 deletions
|
@ -3,7 +3,7 @@ from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
from tensorflow.data import Dataset
|
import tensorflow as tf
|
||||||
from tensorflow.keras.models import Sequential
|
from tensorflow.keras.models import Sequential
|
||||||
from tensorflow.keras.layers import Dense
|
from tensorflow.keras.layers import Dense
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -43,8 +43,8 @@ def simple_dataset(config):
|
||||||
x_train, y_train = linear_dataset(size=NUM_TRAIN_SAMPLES)
|
x_train, y_train = linear_dataset(size=NUM_TRAIN_SAMPLES)
|
||||||
x_test, y_test = linear_dataset(size=NUM_TEST_SAMPLES)
|
x_test, y_test = linear_dataset(size=NUM_TEST_SAMPLES)
|
||||||
|
|
||||||
train_dataset = Dataset.from_tensor_slices((x_train, y_train))
|
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
|
||||||
test_dataset = Dataset.from_tensor_slices((x_test, y_test))
|
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
|
||||||
train_dataset = train_dataset.shuffle(NUM_TRAIN_SAMPLES).repeat().batch(
|
train_dataset = train_dataset.shuffle(NUM_TRAIN_SAMPLES).repeat().batch(
|
||||||
batch_size)
|
batch_size)
|
||||||
test_dataset = test_dataset.repeat().batch(batch_size)
|
test_dataset = test_dataset.repeat().batch(batch_size)
|
||||||
|
|
|
@ -148,6 +148,11 @@ def tf2_compat_logger(config, logdir, trial=None):
|
||||||
use_tf2_api = (distutils.version.LooseVersion(tf.__version__) >=
|
use_tf2_api = (distutils.version.LooseVersion(tf.__version__) >=
|
||||||
distutils.version.LooseVersion("1.15.0"))
|
distutils.version.LooseVersion("1.15.0"))
|
||||||
if use_tf2_api:
|
if use_tf2_api:
|
||||||
|
# This is temporarily for RLlib because it disables v2 behavior...
|
||||||
|
from tensorflow.python import tf2
|
||||||
|
if not tf2.enabled():
|
||||||
|
tf = tf.compat.v1
|
||||||
|
return TFLogger(config, logdir, trial)
|
||||||
tf = tf.compat.v2 # setting this for TF2.0
|
tf = tf.compat.v2 # setting this for TF2.0
|
||||||
return TF2Logger(config, logdir, trial)
|
return TF2Logger(config, logdir, trial)
|
||||||
else:
|
else:
|
||||||
|
@ -166,6 +171,10 @@ class TF2Logger(Logger):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _init(self):
|
def _init(self):
|
||||||
|
global tf
|
||||||
|
if tf is None:
|
||||||
|
import tensorflow as tf
|
||||||
|
tf = tf.compat.v2 # setting this for TF2.0
|
||||||
self._file_writer = None
|
self._file_writer = None
|
||||||
self._hp_logged = False
|
self._hp_logged = False
|
||||||
|
|
||||||
|
@ -237,6 +246,10 @@ class TFLogger(Logger):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _init(self):
|
def _init(self):
|
||||||
|
global tf
|
||||||
|
if tf is None:
|
||||||
|
import tensorflow as tf
|
||||||
|
tf = tf.compat.v1 # setting this for regular TF logger
|
||||||
logger.debug("Initializing TFLogger instead of TF2Logger.")
|
logger.debug("Initializing TFLogger instead of TF2Logger.")
|
||||||
self._file_writer = tf.summary.FileWriter(self.logdir)
|
self._file_writer = tf.summary.FileWriter(self.logdir)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue