2019-12-30 15:27:32 -05:00
|
|
|
import logging
|
2020-10-12 22:49:11 +02:00
|
|
|
import numpy as np
|
2019-12-30 15:27:32 -05:00
|
|
|
import os
|
2020-04-15 13:25:16 +02:00
|
|
|
import sys
|
2020-07-09 10:44:10 +02:00
|
|
|
from typing import Any, Optional
|
|
|
|
|
2020-08-15 13:24:22 +02:00
|
|
|
from ray.rllib.utils.typing import TensorStructType, TensorShape, TensorType
|
2019-12-30 15:27:32 -05:00
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
2020-03-04 13:00:37 -08:00
|
|
|
# Represents a generic tensor type.
|
2020-07-09 10:44:10 +02:00
|
|
|
TensorType = TensorType
|
2020-03-04 13:00:37 -08:00
|
|
|
|
2020-06-03 12:47:35 -07:00
|
|
|
# Either a plain tensor, or a dict or tuple of tensors (or StructTensors).
|
2020-07-09 10:44:10 +02:00
|
|
|
TensorStructType = TensorStructType
|
2020-06-03 12:47:35 -07:00
|
|
|
|
2019-12-30 15:27:32 -05:00
|
|
|
|
2020-01-28 20:07:55 +01:00
|
|
|
def try_import_tf(error=False):
|
2020-05-27 16:19:13 +02:00
|
|
|
"""Tries importing tf and returns the module (or None).
|
|
|
|
|
2020-01-28 20:07:55 +01:00
|
|
|
Args:
|
|
|
|
error (bool): Whether to raise an error if tf cannot be imported.
|
|
|
|
|
2019-12-30 15:27:32 -05:00
|
|
|
Returns:
|
2020-06-30 10:13:20 +02:00
|
|
|
Tuple:
|
|
|
|
- tf1.x module (either from tf2.x.compat.v1 OR as tf1.x).
|
|
|
|
- tf module (resulting from `import tensorflow`).
|
|
|
|
Either tf1.x or 2.x.
|
|
|
|
- The actually installed tf version as int: 1 or 2.
|
2020-05-27 16:19:13 +02:00
|
|
|
|
|
|
|
Raises:
|
|
|
|
ImportError: If error=True and tf is not installed.
|
2019-12-30 15:27:32 -05:00
|
|
|
"""
|
2020-04-15 13:25:16 +02:00
|
|
|
# Make sure, these are reset after each test case
|
|
|
|
# that uses them: del os.environ["RLLIB_TEST_NO_TF_IMPORT"]
|
2019-12-30 15:27:32 -05:00
|
|
|
if "RLLIB_TEST_NO_TF_IMPORT" in os.environ:
|
|
|
|
logger.warning("Not importing TensorFlow for test purposes")
|
2020-06-30 10:13:20 +02:00
|
|
|
return None, None, None
|
2019-12-30 15:27:32 -05:00
|
|
|
|
2020-04-15 13:25:16 +02:00
|
|
|
if "TF_CPP_MIN_LOG_LEVEL" not in os.environ:
|
|
|
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
|
|
|
|
|
|
|
# Try to reuse already imported tf module. This will avoid going through
|
|
|
|
# the initial import steps below and thereby switching off v2_behavior
|
|
|
|
# (switching off v2 behavior twice breaks all-framework tests for eager).
|
2020-06-30 10:13:20 +02:00
|
|
|
was_imported = False
|
2020-04-15 13:25:16 +02:00
|
|
|
if "tensorflow" in sys.modules:
|
|
|
|
tf_module = sys.modules["tensorflow"]
|
2020-06-30 10:13:20 +02:00
|
|
|
was_imported = True
|
2020-04-15 13:25:16 +02:00
|
|
|
|
2020-06-30 10:13:20 +02:00
|
|
|
else:
|
2019-12-30 15:27:32 -05:00
|
|
|
try:
|
2020-06-30 10:13:20 +02:00
|
|
|
import tensorflow as tf_module
|
2020-01-28 20:07:55 +01:00
|
|
|
except ImportError as e:
|
|
|
|
if error:
|
|
|
|
raise e
|
2020-06-30 10:13:20 +02:00
|
|
|
return None, None, None
|
|
|
|
|
|
|
|
# Try "reducing" tf to tf.compat.v1.
|
|
|
|
try:
|
|
|
|
tf1_module = tf_module.compat.v1
|
|
|
|
if not was_imported:
|
|
|
|
tf1_module.disable_v2_behavior()
|
|
|
|
# No compat.v1 -> return tf as is.
|
|
|
|
except AttributeError:
|
|
|
|
tf1_module = tf_module
|
|
|
|
|
2020-07-02 14:39:40 -07:00
|
|
|
if not hasattr(tf_module, "__version__"):
|
|
|
|
version = 1 # sphinx doc gen
|
|
|
|
else:
|
|
|
|
version = 2 if "2." in tf_module.__version__[:2] else 1
|
2020-06-30 10:13:20 +02:00
|
|
|
|
|
|
|
return tf1_module, tf_module, version
|
2019-12-30 15:27:32 -05:00
|
|
|
|
|
|
|
|
2020-02-19 21:18:45 +01:00
|
|
|
def tf_function(tf_module):
|
|
|
|
"""Conditional decorator for @tf.function.
|
|
|
|
|
|
|
|
Use @tf_function(tf) instead to avoid errors if tf is not installed."""
|
|
|
|
|
|
|
|
# The actual decorator to use (pass in `tf` (which could be None)).
|
|
|
|
def decorator(func):
|
|
|
|
# If tf not installed -> return function as is (won't be used anyways).
|
2020-02-22 23:19:49 +01:00
|
|
|
if tf_module is None or tf_module.executing_eagerly():
|
2020-02-19 21:18:45 +01:00
|
|
|
return func
|
|
|
|
# If tf installed, return @tf.function-decorated function.
|
|
|
|
return tf_module.function(func)
|
|
|
|
|
|
|
|
return decorator
|
|
|
|
|
|
|
|
|
2020-01-28 20:07:55 +01:00
|
|
|
def try_import_tfp(error=False):
|
2020-05-27 16:19:13 +02:00
|
|
|
"""Tries importing tfp and returns the module (or None).
|
|
|
|
|
2020-01-28 20:07:55 +01:00
|
|
|
Args:
|
|
|
|
error (bool): Whether to raise an error if tfp cannot be imported.
|
|
|
|
|
2019-12-30 15:27:32 -05:00
|
|
|
Returns:
|
|
|
|
The tfp module.
|
2020-05-27 16:19:13 +02:00
|
|
|
|
|
|
|
Raises:
|
|
|
|
ImportError: If error=True and tfp is not installed.
|
2019-12-30 15:27:32 -05:00
|
|
|
"""
|
|
|
|
if "RLLIB_TEST_NO_TF_IMPORT" in os.environ:
|
|
|
|
logger.warning("Not importing TensorFlow Probability for test "
|
|
|
|
"purposes.")
|
|
|
|
return None
|
|
|
|
|
|
|
|
try:
|
|
|
|
import tensorflow_probability as tfp
|
|
|
|
return tfp
|
2020-01-28 20:07:55 +01:00
|
|
|
except ImportError as e:
|
|
|
|
if error:
|
|
|
|
raise e
|
2019-12-30 15:27:32 -05:00
|
|
|
return None
|
|
|
|
|
|
|
|
|
2020-03-28 19:08:31 -07:00
|
|
|
# Fake module for torch.nn.
|
|
|
|
class NNStub:
|
2020-04-07 18:07:39 -07:00
|
|
|
def __init__(self, *a, **kw):
|
|
|
|
# Fake nn.functional module within torch.nn.
|
|
|
|
self.functional = None
|
|
|
|
self.Module = ModuleStub
|
2020-03-28 19:08:31 -07:00
|
|
|
|
|
|
|
|
|
|
|
# Fake class for torch.nn.Module to allow it to be inherited from.
|
|
|
|
class ModuleStub:
|
|
|
|
def __init__(self, *a, **kw):
|
|
|
|
raise ImportError("Could not import `torch`.")
|
|
|
|
|
|
|
|
|
2020-01-28 20:07:55 +01:00
|
|
|
def try_import_torch(error=False):
|
2020-05-27 16:19:13 +02:00
|
|
|
"""Tries importing torch and returns the module (or None).
|
|
|
|
|
2020-01-28 20:07:55 +01:00
|
|
|
Args:
|
|
|
|
error (bool): Whether to raise an error if torch cannot be imported.
|
|
|
|
|
2019-12-30 15:27:32 -05:00
|
|
|
Returns:
|
|
|
|
tuple: torch AND torch.nn modules.
|
2020-05-27 16:19:13 +02:00
|
|
|
|
|
|
|
Raises:
|
|
|
|
ImportError: If error=True and PyTorch is not installed.
|
2019-12-30 15:27:32 -05:00
|
|
|
"""
|
|
|
|
if "RLLIB_TEST_NO_TORCH_IMPORT" in os.environ:
|
2020-05-27 16:19:13 +02:00
|
|
|
logger.warning("Not importing PyTorch for test purposes.")
|
2020-04-07 18:07:39 -07:00
|
|
|
return _torch_stubs()
|
2019-12-30 15:27:32 -05:00
|
|
|
|
|
|
|
try:
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
return torch, nn
|
2020-01-28 20:07:55 +01:00
|
|
|
except ImportError as e:
|
|
|
|
if error:
|
|
|
|
raise e
|
2020-04-07 18:07:39 -07:00
|
|
|
return _torch_stubs()
|
|
|
|
|
2020-03-28 19:08:31 -07:00
|
|
|
|
2020-04-07 18:07:39 -07:00
|
|
|
def _torch_stubs():
|
|
|
|
nn = NNStub()
|
|
|
|
return None, nn
|
2020-02-11 00:22:07 +01:00
|
|
|
|
|
|
|
|
2020-04-06 20:56:16 +02:00
|
|
|
def get_variable(value,
|
2020-07-09 10:44:10 +02:00
|
|
|
framework: str = "tf",
|
|
|
|
trainable: bool = False,
|
|
|
|
tf_name: str = "unnamed-variable",
|
|
|
|
torch_tensor: bool = False,
|
|
|
|
device: Optional[str] = None,
|
|
|
|
shape: Optional[TensorShape] = None,
|
|
|
|
dtype: Optional[Any] = None):
|
2020-02-11 00:22:07 +01:00
|
|
|
"""
|
|
|
|
Args:
|
|
|
|
value (any): The initial value to use. In the non-tf case, this will
|
2020-07-09 10:44:10 +02:00
|
|
|
be returned as is. In the tf case, this could be a tf-Initializer
|
|
|
|
object.
|
2020-02-11 00:22:07 +01:00
|
|
|
framework (str): One of "tf", "torch", or None.
|
2020-04-06 20:56:16 +02:00
|
|
|
trainable (bool): Whether the generated variable should be
|
|
|
|
trainable (tf)/require_grad (torch) or not (default: False).
|
|
|
|
tf_name (str): For framework="tf": An optional name for the
|
|
|
|
tf.Variable.
|
|
|
|
torch_tensor (bool): For framework="torch": Whether to actually create
|
|
|
|
a torch.tensor, or just a python value (default).
|
2020-07-09 10:44:10 +02:00
|
|
|
device (Optional[torch.Device]): An optional torch device to use for
|
|
|
|
the created torch tensor.
|
|
|
|
shape (Optional[TensorShape]): An optional shape to use iff `value`
|
|
|
|
does not have any (e.g. if it's an initializer w/o explicit value).
|
|
|
|
dtype (Optional[TensorType]): An optional dtype to use iff `value` does
|
|
|
|
not have any (e.g. if it's an initializer w/o explicit value).
|
2020-10-12 22:49:11 +02:00
|
|
|
This should always be a numpy dtype (e.g. np.float32, np.int64).
|
2020-02-11 00:22:07 +01:00
|
|
|
|
|
|
|
Returns:
|
2020-04-06 20:56:16 +02:00
|
|
|
any: A framework-specific variable (tf.Variable, torch.tensor, or
|
|
|
|
python primitive).
|
2020-02-11 00:22:07 +01:00
|
|
|
"""
|
2020-10-02 23:07:44 +02:00
|
|
|
if framework in ["tf2", "tf", "tfe"]:
|
2020-02-11 00:22:07 +01:00
|
|
|
import tensorflow as tf
|
2020-07-09 10:44:10 +02:00
|
|
|
dtype = dtype or getattr(
|
2020-03-29 00:16:30 +01:00
|
|
|
value, "dtype", tf.float32
|
|
|
|
if isinstance(value, float) else tf.int32
|
|
|
|
if isinstance(value, int) else None)
|
|
|
|
return tf.compat.v1.get_variable(
|
2020-08-07 16:49:49 -07:00
|
|
|
tf_name,
|
|
|
|
initializer=value,
|
|
|
|
dtype=dtype,
|
|
|
|
trainable=trainable,
|
|
|
|
**({} if shape is None else {
|
|
|
|
"shape": shape
|
|
|
|
}))
|
2020-04-06 20:56:16 +02:00
|
|
|
elif framework == "torch" and torch_tensor is True:
|
2020-04-07 18:07:39 -07:00
|
|
|
torch, _ = try_import_torch()
|
2020-04-16 10:20:01 +02:00
|
|
|
var_ = torch.from_numpy(value)
|
2020-10-12 22:49:11 +02:00
|
|
|
if dtype in [torch.float32, np.float32]:
|
2020-07-25 09:29:24 +02:00
|
|
|
var_ = var_.float()
|
2020-10-12 22:49:11 +02:00
|
|
|
elif dtype in [torch.int32, np.int32]:
|
2020-07-25 09:29:24 +02:00
|
|
|
var_ = var_.int()
|
2020-10-12 22:49:11 +02:00
|
|
|
elif dtype in [torch.float64, np.float64]:
|
|
|
|
var_ = var_.double()
|
|
|
|
|
2020-04-16 10:20:01 +02:00
|
|
|
if device:
|
|
|
|
var_ = var_.to(device)
|
2020-04-06 20:56:16 +02:00
|
|
|
var_.requires_grad = trainable
|
|
|
|
return var_
|
2020-02-11 00:22:07 +01:00
|
|
|
# torch or None: Return python primitive.
|
|
|
|
return value
|
|
|
|
|
|
|
|
|
2020-10-12 15:00:42 +02:00
|
|
|
# TODO: (sven) move to models/utils.py
|
2020-11-29 12:31:24 +01:00
|
|
|
def get_activation_fn(name: Optional[str] = None, framework: str = "tf"):
|
2020-05-27 16:19:13 +02:00
|
|
|
"""Returns a framework specific activation function, given a name string.
|
2020-04-15 13:25:16 +02:00
|
|
|
|
|
|
|
Args:
|
2020-11-29 12:31:24 +01:00
|
|
|
name (Optional[str]): One of "relu" (default), "tanh", "swish", or
|
|
|
|
"linear" or None.
|
2020-04-15 13:25:16 +02:00
|
|
|
framework (str): One of "tf" or "torch".
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A framework-specific activtion function. e.g. tf.nn.tanh or
|
2020-05-27 16:19:13 +02:00
|
|
|
torch.nn.ReLU. None if name in ["linear", None].
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
ValueError: If name is an unknown activation function.
|
2020-04-15 13:25:16 +02:00
|
|
|
"""
|
|
|
|
if framework == "torch":
|
2020-05-27 16:19:13 +02:00
|
|
|
if name in ["linear", None]:
|
2020-04-15 13:25:16 +02:00
|
|
|
return None
|
2020-09-05 13:14:24 +02:00
|
|
|
if name == "swish":
|
|
|
|
from ray.rllib.utils.torch_ops import Swish
|
|
|
|
return Swish
|
2020-04-20 21:47:28 +02:00
|
|
|
_, nn = try_import_torch()
|
|
|
|
if name == "relu":
|
2020-04-15 13:25:16 +02:00
|
|
|
return nn.ReLU
|
|
|
|
elif name == "tanh":
|
|
|
|
return nn.Tanh
|
|
|
|
else:
|
2020-05-27 16:19:13 +02:00
|
|
|
if name in ["linear", None]:
|
2020-04-15 13:25:16 +02:00
|
|
|
return None
|
2020-06-30 10:13:20 +02:00
|
|
|
tf1, tf, tfv = try_import_tf()
|
2020-04-15 13:25:16 +02:00
|
|
|
fn = getattr(tf.nn, name, None)
|
|
|
|
if fn is not None:
|
|
|
|
return fn
|
|
|
|
|
|
|
|
raise ValueError("Unknown activation ({}) for framework={}!".format(
|
|
|
|
name, framework))
|