mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[rllib] Remove dependency on TensorFlow (#4764)
* remove hard tf dep * add test * comment fix * fix test
This commit is contained in:
parent
ccc540adf1
commit
351753aae5
35 changed files with 189 additions and 63 deletions
|
@ -289,6 +289,9 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
|||
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
||||
/ray/ci/suppress_output python /ray/python/ray/rllib/tests/test_local.py
|
||||
|
||||
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
||||
/ray/ci/suppress_output python /ray/python/ray/rllib/tests/test_dependency.py
|
||||
|
||||
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
||||
/ray/ci/suppress_output python /ray/python/ray/rllib/tests/test_legacy.py
|
||||
|
||||
|
|
|
@ -5,7 +5,9 @@ from __future__ import print_function
|
|||
from collections import deque, OrderedDict
|
||||
import numpy as np
|
||||
|
||||
import tensorflow as tf
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
def unflatten(vector, shapes):
|
||||
|
|
|
@ -4,7 +4,6 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
import gym
|
||||
|
||||
import ray
|
||||
|
@ -19,6 +18,9 @@ from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
|
|||
LearningRateSchedule
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
class A3CLoss(object):
|
||||
|
|
|
@ -7,13 +7,15 @@ from __future__ import print_function
|
|||
|
||||
import gym
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
import ray
|
||||
import ray.experimental.tf_utils
|
||||
from ray.rllib.evaluation.sampler import _unbatch_tuple_actions
|
||||
from ray.rllib.utils.filter import get_filter
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
def rollout(policy, env, timestep_limit=None, add_noise=False, offset=0):
|
||||
|
|
|
@ -6,7 +6,9 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
def compute_ranks(x):
|
||||
|
|
|
@ -4,8 +4,6 @@ from __future__ import print_function
|
|||
|
||||
from gym.spaces import Box
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import tensorflow.contrib.layers as layers
|
||||
|
||||
import ray
|
||||
import ray.experimental.tf_utils
|
||||
|
@ -18,6 +16,9 @@ from ray.rllib.utils.annotations import override
|
|||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
ACTION_SCOPE = "action"
|
||||
POLICY_SCOPE = "policy"
|
||||
|
@ -397,6 +398,8 @@ class DDPGPolicyGraph(DDPGPostprocessing, TFPolicyGraph):
|
|||
self.set_pure_exploration_phase(state[2])
|
||||
|
||||
def _build_q_network(self, obs, obs_space, action_space, actions):
|
||||
import tensorflow.contrib.layers as layers
|
||||
|
||||
if self.config["use_state_preprocessor"]:
|
||||
q_model = ModelCatalog.get_model({
|
||||
"obs": obs,
|
||||
|
@ -417,6 +420,8 @@ class DDPGPolicyGraph(DDPGPostprocessing, TFPolicyGraph):
|
|||
return q_values, q_model
|
||||
|
||||
def _build_policy_network(self, obs, obs_space, action_space):
|
||||
import tensorflow.contrib.layers as layers
|
||||
|
||||
if self.config["use_state_preprocessor"]:
|
||||
model = ModelCatalog.get_model({
|
||||
"obs": obs,
|
||||
|
|
|
@ -5,8 +5,6 @@ from __future__ import print_function
|
|||
from gym.spaces import Discrete
|
||||
import numpy as np
|
||||
from scipy.stats import entropy
|
||||
import tensorflow as tf
|
||||
import tensorflow.contrib.layers as layers
|
||||
|
||||
import ray
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
|
@ -17,6 +15,9 @@ from ray.rllib.utils.error import UnsupportedSpaceException
|
|||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
|
||||
LearningRateSchedule
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
Q_SCOPE = "q_func"
|
||||
Q_TARGET_SCOPE = "target_q_func"
|
||||
|
@ -153,6 +154,8 @@ class QNetwork(object):
|
|||
v_max=10.0,
|
||||
sigma0=0.5,
|
||||
parameter_noise=False):
|
||||
import tensorflow.contrib.layers as layers
|
||||
|
||||
self.model = model
|
||||
with tf.variable_scope("action_value"):
|
||||
if hiddens:
|
||||
|
@ -263,6 +266,8 @@ class QNetwork(object):
|
|||
distributions and \sigma are trainable variables which are expected to
|
||||
vanish along the training procedure
|
||||
"""
|
||||
import tensorflow.contrib.layers as layers
|
||||
|
||||
in_size = int(action_in.shape[1])
|
||||
|
||||
epsilon_in = tf.random_normal(shape=[in_size])
|
||||
|
|
|
@ -7,13 +7,15 @@ from __future__ import print_function
|
|||
|
||||
import gym
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
import ray
|
||||
import ray.experimental.tf_utils
|
||||
from ray.rllib.evaluation.sampler import _unbatch_tuple_actions
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.utils.filter import get_filter
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
def rollout(policy, env, timestep_limit=None, add_noise=False):
|
||||
|
|
|
@ -6,7 +6,9 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
def compute_ranks(x):
|
||||
|
|
|
@ -34,9 +34,11 @@ from __future__ import print_function
|
|||
|
||||
import collections
|
||||
|
||||
import tensorflow as tf
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
nest = tf.contrib.framework.nest
|
||||
tf = try_import_tf()
|
||||
if tf:
|
||||
nest = tf.contrib.framework.nest
|
||||
|
||||
VTraceFromLogitsReturns = collections.namedtuple("VTraceFromLogitsReturns", [
|
||||
"vs", "pg_advantages", "log_rhos", "behaviour_action_log_probs",
|
||||
|
|
|
@ -9,7 +9,6 @@ from __future__ import print_function
|
|||
import gym
|
||||
import ray
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from ray.rllib.agents.impala import vtrace
|
||||
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
|
@ -21,6 +20,9 @@ from ray.rllib.models.catalog import ModelCatalog
|
|||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.utils.explained_variance import explained_variance
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
# Frozen logits of the policy that computed the action
|
||||
BEHAVIOUR_LOGITS = "behaviour_logits"
|
||||
|
|
|
@ -2,8 +2,6 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
import ray
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages, \
|
||||
|
@ -15,6 +13,9 @@ from ray.rllib.evaluation.policy_graph import PolicyGraph
|
|||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
|
||||
from ray.rllib.agents.dqn.dqn_policy_graph import _scope_vars
|
||||
from ray.rllib.utils.explained_variance import explained_variance
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
POLICY_SCOPE = "p_func"
|
||||
VALUE_SCOPE = "v_func"
|
||||
|
|
|
@ -2,8 +2,6 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
import ray
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages, \
|
||||
|
@ -12,6 +10,9 @@ from ray.rllib.evaluation.policy_graph import PolicyGraph
|
|||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
class PGLoss(object):
|
||||
|
|
|
@ -7,7 +7,6 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import logging
|
||||
import gym
|
||||
|
||||
|
@ -23,6 +22,9 @@ from ray.rllib.utils.error import UnsupportedSpaceException
|
|||
from ray.rllib.utils.explained_variance import explained_variance
|
||||
from ray.rllib.models.action_dist import MultiCategorical
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -3,7 +3,6 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import logging
|
||||
import tensorflow as tf
|
||||
|
||||
import ray
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages, \
|
||||
|
@ -16,6 +15,9 @@ from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
|
|||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.explained_variance import explained_variance
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -10,7 +10,6 @@ import pickle
|
|||
import six
|
||||
import time
|
||||
import tempfile
|
||||
import tensorflow as tf
|
||||
from types import FunctionType
|
||||
|
||||
import ray
|
||||
|
@ -26,12 +25,15 @@ from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
|||
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
|
||||
from ray.rllib.utils import FilterManager, deep_update, merge_dicts
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.tune.registry import ENV_CREATOR, register_env, _global_registry
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.trial import Resources, ExportFormat
|
||||
from ray.tune.logger import UnifiedLogger
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Max number of times to retry a worker failure. We shouldn't try too many
|
||||
|
@ -412,8 +414,13 @@ class Trainer(Trainable):
|
|||
if self.config.get("log_level"):
|
||||
logging.getLogger("ray.rllib").setLevel(self.config["log_level"])
|
||||
|
||||
# TODO(ekl) setting the graph is unnecessary for PyTorch agents
|
||||
with tf.Graph().as_default():
|
||||
def get_scope():
|
||||
if tf:
|
||||
return tf.Graph().as_default()
|
||||
else:
|
||||
return open("/dev/null") # fake a no-op scope
|
||||
|
||||
with get_scope():
|
||||
self._init(self.config, self.env_creator)
|
||||
|
||||
# Evaluation related
|
||||
|
|
|
@ -59,18 +59,23 @@ def collect_episodes(local_evaluator=None,
|
|||
timeout_seconds=180):
|
||||
"""Gathers new episodes metrics tuples from the given evaluators."""
|
||||
|
||||
if remote_evaluators:
|
||||
pending = [
|
||||
a.apply.remote(lambda ev: ev.get_metrics()) for a in remote_evaluators
|
||||
a.apply.remote(lambda ev: ev.get_metrics())
|
||||
for a in remote_evaluators
|
||||
]
|
||||
collected, _ = ray.wait(
|
||||
pending, num_returns=len(pending), timeout=timeout_seconds * 1.0)
|
||||
num_metric_batches_dropped = len(pending) - len(collected)
|
||||
if pending and len(collected) == 0:
|
||||
raise ValueError(
|
||||
"Timed out waiting for metrics from workers. You can configure "
|
||||
"this timeout with `collect_metrics_timeout`.")
|
||||
|
||||
"Timed out waiting for metrics from workers. You can "
|
||||
"configure this timeout with `collect_metrics_timeout`.")
|
||||
metric_lists = ray_get_and_free(collected)
|
||||
else:
|
||||
metric_lists = []
|
||||
num_metric_batches_dropped = 0
|
||||
|
||||
if local_evaluator:
|
||||
metric_lists.append(local_evaluator.get_metrics())
|
||||
episodes = []
|
||||
|
|
|
@ -5,7 +5,6 @@ from __future__ import print_function
|
|||
import gym
|
||||
import logging
|
||||
import pickle
|
||||
import tensorflow as tf
|
||||
|
||||
import ray
|
||||
from ray.rllib.env.atari_wrappers import wrap_deepmind, is_atari
|
||||
|
@ -32,7 +31,9 @@ from ray.rllib.utils.debug import disable_log_once_globally, log_once, \
|
|||
summarize, enable_periodic_logging
|
||||
from ray.rllib.utils.filter import get_filter
|
||||
from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Handle to the current evaluator, which will be set to the most recently
|
||||
|
@ -722,8 +723,11 @@ class PolicyEvaluator(EvaluatorInterface):
|
|||
"Found raw Tuple|Dict space as input to policy graph. "
|
||||
"Please preprocess these observations with a "
|
||||
"Tuple|DictFlatteningPreprocessor.")
|
||||
if tf:
|
||||
with tf.variable_scope(name):
|
||||
policy_map[name] = cls(obs_space, act_space, merged_conf)
|
||||
else:
|
||||
policy_map[name] = cls(obs_space, act_space, merged_conf)
|
||||
if self.worker_index == 0:
|
||||
logger.info("Built policy map: {}".format(policy_map))
|
||||
logger.info("Built preprocessor map: {}".format(preprocessors))
|
||||
|
|
|
@ -5,7 +5,6 @@ from __future__ import print_function
|
|||
import os
|
||||
import errno
|
||||
import logging
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
|
@ -18,7 +17,9 @@ from ray.rllib.utils.annotations import override, DeveloperAPI
|
|||
from ray.rllib.utils.debug import log_once, summarize
|
||||
from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule
|
||||
from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
|
@ -4,13 +4,18 @@ from __future__ import print_function
|
|||
|
||||
from collections import namedtuple
|
||||
import distutils.version
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
use_tf150_api = (distutils.version.LooseVersion(tf.VERSION) >=
|
||||
tf = try_import_tf()
|
||||
|
||||
if tf:
|
||||
use_tf150_api = (distutils.version.LooseVersion(tf.VERSION) >=
|
||||
distutils.version.LooseVersion("1.5.0"))
|
||||
else:
|
||||
use_tf150_api = False
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
|
|
|
@ -5,7 +5,6 @@ from __future__ import print_function
|
|||
import gym
|
||||
import logging
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from functools import partial
|
||||
|
||||
from ray.tune.registry import RLLIB_MODEL, RLLIB_PREPROCESSOR, \
|
||||
|
@ -22,6 +21,9 @@ from ray.rllib.models.fcnet import FullyConnectedNetwork
|
|||
from ray.rllib.models.visionnet import VisionNetwork
|
||||
from ray.rllib.models.lstm import LSTM
|
||||
from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -2,12 +2,12 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow.contrib.slim as slim
|
||||
|
||||
from ray.rllib.models.model import Model
|
||||
from ray.rllib.models.misc import normc_initializer, get_activation_fn
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
class FullyConnectedNetwork(Model):
|
||||
|
@ -21,6 +21,8 @@ class FullyConnectedNetwork(Model):
|
|||
model that processes the components separately, use _build_layers_v2().
|
||||
"""
|
||||
|
||||
import tensorflow.contrib.slim as slim
|
||||
|
||||
hiddens = options.get("fcnet_hiddens")
|
||||
activation = get_activation_fn(options.get("fcnet_activation"))
|
||||
|
||||
|
|
|
@ -18,12 +18,13 @@ more info.
|
|||
"""
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import tensorflow.contrib.rnn as rnn
|
||||
|
||||
from ray.rllib.models.misc import linear, normc_initializer
|
||||
from ray.rllib.models.model import Model
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI, PublicAPI
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
class LSTM(Model):
|
||||
|
@ -37,6 +38,8 @@ class LSTM(Model):
|
|||
|
||||
@override(Model)
|
||||
def _build_layers_v2(self, input_dict, num_outputs, options):
|
||||
import tensorflow.contrib.rnn as rnn
|
||||
|
||||
cell_size = options.get("lstm_cell_size")
|
||||
if options.get("lstm_use_prev_action_reward"):
|
||||
action_dim = int(
|
||||
|
|
|
@ -2,8 +2,10 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
def normc_initializer(std=1.0):
|
||||
|
@ -25,8 +27,11 @@ def conv2d(x,
|
|||
filter_size=(3, 3),
|
||||
stride=(1, 1),
|
||||
pad="SAME",
|
||||
dtype=tf.float32,
|
||||
dtype=None,
|
||||
collections=None):
|
||||
if dtype is None:
|
||||
dtype = tf.float32
|
||||
|
||||
with tf.variable_scope(name):
|
||||
stride_shape = [1, stride[0], stride[1], 1]
|
||||
filter_shape = [
|
||||
|
|
|
@ -5,11 +5,13 @@ from __future__ import print_function
|
|||
from collections import OrderedDict
|
||||
|
||||
import gym
|
||||
import tensorflow as tf
|
||||
|
||||
from ray.rllib.models.misc import linear, normc_initializer
|
||||
from ray.rllib.models.preprocessors import get_preprocessor
|
||||
from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
@PublicAPI
|
||||
|
|
|
@ -2,12 +2,12 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow.contrib.slim as slim
|
||||
|
||||
from ray.rllib.models.model import Model
|
||||
from ray.rllib.models.misc import get_activation_fn, flatten
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
class VisionNetwork(Model):
|
||||
|
@ -15,6 +15,8 @@ class VisionNetwork(Model):
|
|||
|
||||
@override(Model)
|
||||
def _build_layers_v2(self, input_dict, num_outputs, options):
|
||||
import tensorflow.contrib.slim as slim
|
||||
|
||||
inputs = input_dict["obs"]
|
||||
filters = options.get("conv_filters")
|
||||
if not filters:
|
||||
|
|
|
@ -4,11 +4,13 @@ from __future__ import print_function
|
|||
|
||||
import logging
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import threading
|
||||
|
||||
from ray.rllib.evaluation.sample_batch import MultiAgentBatch
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -4,9 +4,11 @@ from __future__ import print_function
|
|||
|
||||
from collections import namedtuple
|
||||
import logging
|
||||
import tensorflow as tf
|
||||
|
||||
from ray.rllib.utils.debug import log_once, summarize
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
# Variable scope in which created variables will be placed under
|
||||
TOWER_SCOPE_NAME = "tower"
|
||||
|
|
|
@ -6,7 +6,6 @@ import logging
|
|||
import math
|
||||
import numpy as np
|
||||
from collections import defaultdict
|
||||
import tensorflow as tf
|
||||
|
||||
import ray
|
||||
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
|
||||
|
@ -19,6 +18,9 @@ from ray.rllib.utils.annotations import override
|
|||
from ray.rllib.utils.timer import TimerStat
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
|
||||
MultiAgentBatch
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
24
python/ray/rllib/tests/test_dependency.py
Normal file
24
python/ray/rllib/tests/test_dependency.py
Normal file
|
@ -0,0 +1,24 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
os.environ["RLLIB_TEST_NO_TF_IMPORT"] = "1"
|
||||
|
||||
if __name__ == "__main__":
|
||||
from ray.rllib.agents.a3c import A2CTrainer
|
||||
assert "tensorflow" not in sys.modules, "TF initially present"
|
||||
|
||||
# note: no ray.init(), to test it works without Ray
|
||||
trainer = A2CTrainer(
|
||||
env="CartPole-v0", config={
|
||||
"use_pytorch": True,
|
||||
"num_workers": 0
|
||||
})
|
||||
trainer.train()
|
||||
|
||||
assert "tensorflow" not in sys.modules, "TF should not be imported"
|
|
@ -1,4 +1,5 @@
|
|||
import logging
|
||||
import os
|
||||
|
||||
from ray.rllib.utils.filter_manager import FilterManager
|
||||
from ray.rllib.utils.filter import Filter
|
||||
|
@ -26,6 +27,18 @@ def renamed_class(cls):
|
|||
return DeprecationWrapper
|
||||
|
||||
|
||||
def try_import_tf():
|
||||
if "RLLIB_TEST_NO_TF_IMPORT" in os.environ:
|
||||
logger.warning("Not importing TensorFlow for test purposes")
|
||||
return None
|
||||
|
||||
try:
|
||||
import tensorflow as tf
|
||||
return tf
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Filter",
|
||||
"FilterManager",
|
||||
|
@ -34,4 +47,5 @@ __all__ = [
|
|||
"merge_dicts",
|
||||
"deep_update",
|
||||
"renamed_class",
|
||||
"try_import_tf",
|
||||
]
|
||||
|
|
|
@ -2,7 +2,9 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
def explained_variance(y, pred):
|
||||
|
|
|
@ -4,7 +4,9 @@ from __future__ import print_function
|
|||
|
||||
import numpy as np
|
||||
import random
|
||||
import tensorflow as tf
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
def seed(np_seed=0, random_seed=0, tf_seed=0):
|
||||
|
|
|
@ -6,11 +6,10 @@ import logging
|
|||
import os
|
||||
import time
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.client import timeline
|
||||
|
||||
from ray.rllib.utils.debug import log_once
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -65,6 +64,8 @@ _count = 0
|
|||
|
||||
def run_timeline(sess, ops, debug_name, feed_dict={}, timeline_dir=None):
|
||||
if timeline_dir:
|
||||
from tensorflow.python.client import timeline
|
||||
|
||||
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
|
||||
run_metadata = tf.RunMetadata()
|
||||
start = time.time()
|
||||
|
|
|
@ -118,6 +118,10 @@ class TFLogger(Logger):
|
|||
def _init(self):
|
||||
try:
|
||||
global tf, use_tf150_api
|
||||
if "RLLIB_TEST_NO_TF_IMPORT" in os.environ:
|
||||
logger.warning("Not importing TensorFlow for test purposes")
|
||||
tf = None
|
||||
else:
|
||||
import tensorflow
|
||||
tf = tensorflow
|
||||
use_tf150_api = (distutils.version.LooseVersion(tf.VERSION) >=
|
||||
|
|
Loading…
Add table
Reference in a new issue