[rllib] Remove dependency on TensorFlow (#4764)

* remove hard tf dep

* add test

* comment fix

* fix test
This commit is contained in:
Eric Liang 2019-05-10 20:36:18 -07:00 committed by GitHub
parent ccc540adf1
commit 351753aae5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
35 changed files with 189 additions and 63 deletions

View file

@ -289,6 +289,9 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output python /ray/python/ray/rllib/tests/test_local.py
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output python /ray/python/ray/rllib/tests/test_dependency.py
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
/ray/ci/suppress_output python /ray/python/ray/rllib/tests/test_legacy.py

View file

@ -5,7 +5,9 @@ from __future__ import print_function
from collections import deque, OrderedDict
import numpy as np
import tensorflow as tf
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
def unflatten(vector, shapes):

View file

@ -4,7 +4,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import gym
import ray
@ -19,6 +18,9 @@ from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
LearningRateSchedule
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.utils.annotations import override
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
class A3CLoss(object):

View file

@ -7,13 +7,15 @@ from __future__ import print_function
import gym
import numpy as np
import tensorflow as tf
import ray
import ray.experimental.tf_utils
from ray.rllib.evaluation.sampler import _unbatch_tuple_actions
from ray.rllib.utils.filter import get_filter
from ray.rllib.models import ModelCatalog
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
def rollout(policy, env, timestep_limit=None, add_noise=False, offset=0):

View file

@ -6,7 +6,9 @@ from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
def compute_ranks(x):

View file

@ -4,8 +4,6 @@ from __future__ import print_function
from gym.spaces import Box
import numpy as np
import tensorflow as tf
import tensorflow.contrib.layers as layers
import ray
import ray.experimental.tf_utils
@ -18,6 +16,9 @@ from ray.rllib.utils.annotations import override
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
ACTION_SCOPE = "action"
POLICY_SCOPE = "policy"
@ -397,6 +398,8 @@ class DDPGPolicyGraph(DDPGPostprocessing, TFPolicyGraph):
self.set_pure_exploration_phase(state[2])
def _build_q_network(self, obs, obs_space, action_space, actions):
import tensorflow.contrib.layers as layers
if self.config["use_state_preprocessor"]:
q_model = ModelCatalog.get_model({
"obs": obs,
@ -417,6 +420,8 @@ class DDPGPolicyGraph(DDPGPostprocessing, TFPolicyGraph):
return q_values, q_model
def _build_policy_network(self, obs, obs_space, action_space):
import tensorflow.contrib.layers as layers
if self.config["use_state_preprocessor"]:
model = ModelCatalog.get_model({
"obs": obs,

View file

@ -5,8 +5,6 @@ from __future__ import print_function
from gym.spaces import Discrete
import numpy as np
from scipy.stats import entropy
import tensorflow as tf
import tensorflow.contrib.layers as layers
import ray
from ray.rllib.evaluation.sample_batch import SampleBatch
@ -17,6 +15,9 @@ from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
LearningRateSchedule
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
Q_SCOPE = "q_func"
Q_TARGET_SCOPE = "target_q_func"
@ -153,6 +154,8 @@ class QNetwork(object):
v_max=10.0,
sigma0=0.5,
parameter_noise=False):
import tensorflow.contrib.layers as layers
self.model = model
with tf.variable_scope("action_value"):
if hiddens:
@ -263,6 +266,8 @@ class QNetwork(object):
distributions and \sigma are trainable variables which are expected to
vanish along the training procedure
"""
import tensorflow.contrib.layers as layers
in_size = int(action_in.shape[1])
epsilon_in = tf.random_normal(shape=[in_size])

View file

@ -7,13 +7,15 @@ from __future__ import print_function
import gym
import numpy as np
import tensorflow as tf
import ray
import ray.experimental.tf_utils
from ray.rllib.evaluation.sampler import _unbatch_tuple_actions
from ray.rllib.models import ModelCatalog
from ray.rllib.utils.filter import get_filter
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
def rollout(policy, env, timestep_limit=None, add_noise=False):

View file

@ -6,7 +6,9 @@ from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
def compute_ranks(x):

View file

@ -34,9 +34,11 @@ from __future__ import print_function
import collections
import tensorflow as tf
from ray.rllib.utils import try_import_tf
nest = tf.contrib.framework.nest
tf = try_import_tf()
if tf:
nest = tf.contrib.framework.nest
VTraceFromLogitsReturns = collections.namedtuple("VTraceFromLogitsReturns", [
"vs", "pg_advantages", "log_rhos", "behaviour_action_log_probs",

View file

@ -9,7 +9,6 @@ from __future__ import print_function
import gym
import ray
import numpy as np
import tensorflow as tf
from ray.rllib.agents.impala import vtrace
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
from ray.rllib.evaluation.policy_graph import PolicyGraph
@ -21,6 +20,9 @@ from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.utils.annotations import override
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils.explained_variance import explained_variance
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
# Frozen logits of the policy that computed the action
BEHAVIOUR_LOGITS = "behaviour_logits"

View file

@ -2,8 +2,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import ray
from ray.rllib.models import ModelCatalog
from ray.rllib.evaluation.postprocessing import compute_advantages, \
@ -15,6 +13,9 @@ from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
from ray.rllib.agents.dqn.dqn_policy_graph import _scope_vars
from ray.rllib.utils.explained_variance import explained_variance
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
POLICY_SCOPE = "p_func"
VALUE_SCOPE = "v_func"

View file

@ -2,8 +2,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import ray
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.evaluation.postprocessing import compute_advantages, \
@ -12,6 +10,9 @@ from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.evaluation.sample_batch import SampleBatch
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
from ray.rllib.utils.annotations import override
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
class PGLoss(object):

View file

@ -7,7 +7,6 @@ from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
import logging
import gym
@ -23,6 +22,9 @@ from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils.explained_variance import explained_variance
from ray.rllib.models.action_dist import MultiCategorical
from ray.rllib.evaluation.postprocessing import compute_advantages
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
logger = logging.getLogger(__name__)

View file

@ -3,7 +3,6 @@ from __future__ import division
from __future__ import print_function
import logging
import tensorflow as tf
import ray
from ray.rllib.evaluation.postprocessing import compute_advantages, \
@ -16,6 +15,9 @@ from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.utils.annotations import override
from ray.rllib.utils.explained_variance import explained_variance
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
logger = logging.getLogger(__name__)

View file

@ -10,7 +10,6 @@ import pickle
import six
import time
import tempfile
import tensorflow as tf
from types import FunctionType
import ray
@ -26,12 +25,15 @@ from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
from ray.rllib.utils import FilterManager, deep_update, merge_dicts
from ray.rllib.utils.memory import ray_get_and_free
from ray.rllib.utils import try_import_tf
from ray.tune.registry import ENV_CREATOR, register_env, _global_registry
from ray.tune.trainable import Trainable
from ray.tune.trial import Resources, ExportFormat
from ray.tune.logger import UnifiedLogger
from ray.tune.result import DEFAULT_RESULTS_DIR
tf = try_import_tf()
logger = logging.getLogger(__name__)
# Max number of times to retry a worker failure. We shouldn't try too many
@ -412,8 +414,13 @@ class Trainer(Trainable):
if self.config.get("log_level"):
logging.getLogger("ray.rllib").setLevel(self.config["log_level"])
# TODO(ekl) setting the graph is unnecessary for PyTorch agents
with tf.Graph().as_default():
def get_scope():
if tf:
return tf.Graph().as_default()
else:
return open("/dev/null") # fake a no-op scope
with get_scope():
self._init(self.config, self.env_creator)
# Evaluation related

View file

@ -59,18 +59,23 @@ def collect_episodes(local_evaluator=None,
timeout_seconds=180):
"""Gathers new episodes metrics tuples from the given evaluators."""
if remote_evaluators:
pending = [
a.apply.remote(lambda ev: ev.get_metrics()) for a in remote_evaluators
a.apply.remote(lambda ev: ev.get_metrics())
for a in remote_evaluators
]
collected, _ = ray.wait(
pending, num_returns=len(pending), timeout=timeout_seconds * 1.0)
num_metric_batches_dropped = len(pending) - len(collected)
if pending and len(collected) == 0:
raise ValueError(
"Timed out waiting for metrics from workers. You can configure "
"this timeout with `collect_metrics_timeout`.")
"Timed out waiting for metrics from workers. You can "
"configure this timeout with `collect_metrics_timeout`.")
metric_lists = ray_get_and_free(collected)
else:
metric_lists = []
num_metric_batches_dropped = 0
if local_evaluator:
metric_lists.append(local_evaluator.get_metrics())
episodes = []

View file

@ -5,7 +5,6 @@ from __future__ import print_function
import gym
import logging
import pickle
import tensorflow as tf
import ray
from ray.rllib.env.atari_wrappers import wrap_deepmind, is_atari
@ -32,7 +31,9 @@ from ray.rllib.utils.debug import disable_log_once_globally, log_once, \
summarize, enable_periodic_logging
from ray.rllib.utils.filter import get_filter
from ray.rllib.utils.tf_run_builder import TFRunBuilder
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
logger = logging.getLogger(__name__)
# Handle to the current evaluator, which will be set to the most recently
@ -722,8 +723,11 @@ class PolicyEvaluator(EvaluatorInterface):
"Found raw Tuple|Dict space as input to policy graph. "
"Please preprocess these observations with a "
"Tuple|DictFlatteningPreprocessor.")
if tf:
with tf.variable_scope(name):
policy_map[name] = cls(obs_space, act_space, merged_conf)
else:
policy_map[name] = cls(obs_space, act_space, merged_conf)
if self.worker_index == 0:
logger.info("Built policy map: {}".format(policy_map))
logger.info("Built preprocessor map: {}".format(preprocessors))

View file

@ -5,7 +5,6 @@ from __future__ import print_function
import os
import errno
import logging
import tensorflow as tf
import numpy as np
import ray
@ -18,7 +17,9 @@ from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.debug import log_once, summarize
from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule
from ray.rllib.utils.tf_run_builder import TFRunBuilder
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
logger = logging.getLogger(__name__)

View file

@ -4,13 +4,18 @@ from __future__ import print_function
from collections import namedtuple
import distutils.version
import tensorflow as tf
import numpy as np
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils import try_import_tf
use_tf150_api = (distutils.version.LooseVersion(tf.VERSION) >=
tf = try_import_tf()
if tf:
use_tf150_api = (distutils.version.LooseVersion(tf.VERSION) >=
distutils.version.LooseVersion("1.5.0"))
else:
use_tf150_api = False
@DeveloperAPI

View file

@ -5,7 +5,6 @@ from __future__ import print_function
import gym
import logging
import numpy as np
import tensorflow as tf
from functools import partial
from ray.tune.registry import RLLIB_MODEL, RLLIB_PREPROCESSOR, \
@ -22,6 +21,9 @@ from ray.rllib.models.fcnet import FullyConnectedNetwork
from ray.rllib.models.visionnet import VisionNetwork
from ray.rllib.models.lstm import LSTM
from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
logger = logging.getLogger(__name__)

View file

@ -2,12 +2,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import tensorflow.contrib.slim as slim
from ray.rllib.models.model import Model
from ray.rllib.models.misc import normc_initializer, get_activation_fn
from ray.rllib.utils.annotations import override
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
class FullyConnectedNetwork(Model):
@ -21,6 +21,8 @@ class FullyConnectedNetwork(Model):
model that processes the components separately, use _build_layers_v2().
"""
import tensorflow.contrib.slim as slim
hiddens = options.get("fcnet_hiddens")
activation = get_activation_fn(options.get("fcnet_activation"))

View file

@ -18,12 +18,13 @@ more info.
"""
import numpy as np
import tensorflow as tf
import tensorflow.contrib.rnn as rnn
from ray.rllib.models.misc import linear, normc_initializer
from ray.rllib.models.model import Model
from ray.rllib.utils.annotations import override, DeveloperAPI, PublicAPI
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
class LSTM(Model):
@ -37,6 +38,8 @@ class LSTM(Model):
@override(Model)
def _build_layers_v2(self, input_dict, num_outputs, options):
import tensorflow.contrib.rnn as rnn
cell_size = options.get("lstm_cell_size")
if options.get("lstm_use_prev_action_reward"):
action_dim = int(

View file

@ -2,8 +2,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import numpy as np
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
def normc_initializer(std=1.0):
@ -25,8 +27,11 @@ def conv2d(x,
filter_size=(3, 3),
stride=(1, 1),
pad="SAME",
dtype=tf.float32,
dtype=None,
collections=None):
if dtype is None:
dtype = tf.float32
with tf.variable_scope(name):
stride_shape = [1, stride[0], stride[1], 1]
filter_shape = [

View file

@ -5,11 +5,13 @@ from __future__ import print_function
from collections import OrderedDict
import gym
import tensorflow as tf
from ray.rllib.models.misc import linear, normc_initializer
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
@PublicAPI

View file

@ -2,12 +2,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import tensorflow.contrib.slim as slim
from ray.rllib.models.model import Model
from ray.rllib.models.misc import get_activation_fn, flatten
from ray.rllib.utils.annotations import override
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
class VisionNetwork(Model):
@ -15,6 +15,8 @@ class VisionNetwork(Model):
@override(Model)
def _build_layers_v2(self, input_dict, num_outputs, options):
import tensorflow.contrib.slim as slim
inputs = input_dict["obs"]
filters = options.get("conv_filters")
if not filters:

View file

@ -4,11 +4,13 @@ from __future__ import print_function
import logging
import numpy as np
import tensorflow as tf
import threading
from ray.rllib.evaluation.sample_batch import MultiAgentBatch
from ray.rllib.utils.annotations import PublicAPI
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
logger = logging.getLogger(__name__)

View file

@ -4,9 +4,11 @@ from __future__ import print_function
from collections import namedtuple
import logging
import tensorflow as tf
from ray.rllib.utils.debug import log_once, summarize
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
# Variable scope in which created variables will be placed under
TOWER_SCOPE_NAME = "tower"

View file

@ -6,7 +6,6 @@ import logging
import math
import numpy as np
from collections import defaultdict
import tensorflow as tf
import ray
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
@ -19,6 +18,9 @@ from ray.rllib.utils.annotations import override
from ray.rllib.utils.timer import TimerStat
from ray.rllib.evaluation.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
MultiAgentBatch
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
logger = logging.getLogger(__name__)

View file

@ -0,0 +1,24 @@
#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
os.environ["RLLIB_TEST_NO_TF_IMPORT"] = "1"
if __name__ == "__main__":
from ray.rllib.agents.a3c import A2CTrainer
assert "tensorflow" not in sys.modules, "TF initially present"
# note: no ray.init(), to test it works without Ray
trainer = A2CTrainer(
env="CartPole-v0", config={
"use_pytorch": True,
"num_workers": 0
})
trainer.train()
assert "tensorflow" not in sys.modules, "TF should not be imported"

View file

@ -1,4 +1,5 @@
import logging
import os
from ray.rllib.utils.filter_manager import FilterManager
from ray.rllib.utils.filter import Filter
@ -26,6 +27,18 @@ def renamed_class(cls):
return DeprecationWrapper
def try_import_tf():
if "RLLIB_TEST_NO_TF_IMPORT" in os.environ:
logger.warning("Not importing TensorFlow for test purposes")
return None
try:
import tensorflow as tf
return tf
except ImportError:
return None
__all__ = [
"Filter",
"FilterManager",
@ -34,4 +47,5 @@ __all__ = [
"merge_dicts",
"deep_update",
"renamed_class",
"try_import_tf",
]

View file

@ -2,7 +2,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
def explained_variance(y, pred):

View file

@ -4,7 +4,9 @@ from __future__ import print_function
import numpy as np
import random
import tensorflow as tf
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
def seed(np_seed=0, random_seed=0, tf_seed=0):

View file

@ -6,11 +6,10 @@ import logging
import os
import time
import tensorflow as tf
from tensorflow.python.client import timeline
from ray.rllib.utils.debug import log_once
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
logger = logging.getLogger(__name__)
@ -65,6 +64,8 @@ _count = 0
def run_timeline(sess, ops, debug_name, feed_dict={}, timeline_dir=None):
if timeline_dir:
from tensorflow.python.client import timeline
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
start = time.time()

View file

@ -118,6 +118,10 @@ class TFLogger(Logger):
def _init(self):
try:
global tf, use_tf150_api
if "RLLIB_TEST_NO_TF_IMPORT" in os.environ:
logger.warning("Not importing TensorFlow for test purposes")
tf = None
else:
import tensorflow
tf = tensorflow
use_tf150_api = (distutils.version.LooseVersion(tf.VERSION) >=