2019-12-30 15:27:32 -05:00
|
|
|
import logging
|
|
|
|
import os
|
2020-03-04 13:00:37 -08:00
|
|
|
from typing import Any
|
2019-12-30 15:27:32 -05:00
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
2020-03-04 13:00:37 -08:00
|
|
|
# Represents a generic tensor type.
|
|
|
|
TensorType = Any
|
|
|
|
|
2019-12-30 15:27:32 -05:00
|
|
|
|
2020-01-28 20:07:55 +01:00
|
|
|
def check_framework(framework="tf"):
|
2019-12-30 15:27:32 -05:00
|
|
|
"""
|
2020-01-28 20:07:55 +01:00
|
|
|
Checks, whether the given framework is "valid", meaning, whether all
|
|
|
|
necessary dependencies are installed. Errors otherwise.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
framework (str): Once of "tf", "torch", or None.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
str: The input framework string.
|
|
|
|
"""
|
|
|
|
if framework == "tf":
|
2020-02-11 00:22:07 +01:00
|
|
|
if tf is None:
|
|
|
|
raise ImportError("Could not import tensorflow.")
|
2020-01-28 20:07:55 +01:00
|
|
|
elif framework == "torch":
|
2020-02-11 00:22:07 +01:00
|
|
|
if torch is None:
|
|
|
|
raise ImportError("Could not import torch.")
|
2020-01-28 20:07:55 +01:00
|
|
|
else:
|
|
|
|
assert framework is None
|
|
|
|
return framework
|
|
|
|
|
|
|
|
|
|
|
|
def try_import_tf(error=False):
|
|
|
|
"""
|
|
|
|
Args:
|
|
|
|
error (bool): Whether to raise an error if tf cannot be imported.
|
|
|
|
|
2019-12-30 15:27:32 -05:00
|
|
|
Returns:
|
|
|
|
The tf module (either from tf2.0.compat.v1 OR as tf1.x.
|
|
|
|
"""
|
2020-02-11 00:22:07 +01:00
|
|
|
# TODO(sven): Make sure, these are reset after each test case
|
|
|
|
# that uses them.
|
2019-12-30 15:27:32 -05:00
|
|
|
if "RLLIB_TEST_NO_TF_IMPORT" in os.environ:
|
|
|
|
logger.warning("Not importing TensorFlow for test purposes")
|
|
|
|
return None
|
|
|
|
|
|
|
|
try:
|
|
|
|
if "TF_CPP_MIN_LOG_LEVEL" not in os.environ:
|
|
|
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
|
|
|
import tensorflow.compat.v1 as tf
|
|
|
|
tf.logging.set_verbosity(tf.logging.ERROR)
|
|
|
|
tf.disable_v2_behavior()
|
|
|
|
return tf
|
|
|
|
except ImportError:
|
|
|
|
try:
|
|
|
|
import tensorflow as tf
|
|
|
|
return tf
|
2020-01-28 20:07:55 +01:00
|
|
|
except ImportError as e:
|
|
|
|
if error:
|
|
|
|
raise e
|
2019-12-30 15:27:32 -05:00
|
|
|
return None
|
|
|
|
|
|
|
|
|
2020-02-19 21:18:45 +01:00
|
|
|
def tf_function(tf_module):
|
|
|
|
"""Conditional decorator for @tf.function.
|
|
|
|
|
|
|
|
Use @tf_function(tf) instead to avoid errors if tf is not installed."""
|
|
|
|
|
|
|
|
# The actual decorator to use (pass in `tf` (which could be None)).
|
|
|
|
def decorator(func):
|
|
|
|
# If tf not installed -> return function as is (won't be used anyways).
|
2020-02-22 23:19:49 +01:00
|
|
|
if tf_module is None or tf_module.executing_eagerly():
|
2020-02-19 21:18:45 +01:00
|
|
|
return func
|
|
|
|
# If tf installed, return @tf.function-decorated function.
|
|
|
|
return tf_module.function(func)
|
|
|
|
|
|
|
|
return decorator
|
|
|
|
|
|
|
|
|
2020-01-28 20:07:55 +01:00
|
|
|
def try_import_tfp(error=False):
|
2019-12-30 15:27:32 -05:00
|
|
|
"""
|
2020-01-28 20:07:55 +01:00
|
|
|
Args:
|
|
|
|
error (bool): Whether to raise an error if tfp cannot be imported.
|
|
|
|
|
2019-12-30 15:27:32 -05:00
|
|
|
Returns:
|
|
|
|
The tfp module.
|
|
|
|
"""
|
|
|
|
if "RLLIB_TEST_NO_TF_IMPORT" in os.environ:
|
|
|
|
logger.warning("Not importing TensorFlow Probability for test "
|
|
|
|
"purposes.")
|
|
|
|
return None
|
|
|
|
|
|
|
|
try:
|
|
|
|
import tensorflow_probability as tfp
|
|
|
|
return tfp
|
2020-01-28 20:07:55 +01:00
|
|
|
except ImportError as e:
|
|
|
|
if error:
|
|
|
|
raise e
|
2019-12-30 15:27:32 -05:00
|
|
|
return None
|
|
|
|
|
|
|
|
|
2020-03-28 19:08:31 -07:00
|
|
|
# Fake module for torch.nn.
|
|
|
|
class NNStub:
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
# Fake class for torch.nn.Module to allow it to be inherited from.
|
|
|
|
class ModuleStub:
|
|
|
|
def __init__(self, *a, **kw):
|
|
|
|
raise ImportError("Could not import `torch`.")
|
|
|
|
|
|
|
|
|
2020-01-28 20:07:55 +01:00
|
|
|
def try_import_torch(error=False):
|
2019-12-30 15:27:32 -05:00
|
|
|
"""
|
2020-01-28 20:07:55 +01:00
|
|
|
Args:
|
|
|
|
error (bool): Whether to raise an error if torch cannot be imported.
|
|
|
|
|
2019-12-30 15:27:32 -05:00
|
|
|
Returns:
|
|
|
|
tuple: torch AND torch.nn modules.
|
|
|
|
"""
|
|
|
|
if "RLLIB_TEST_NO_TORCH_IMPORT" in os.environ:
|
|
|
|
logger.warning("Not importing Torch for test purposes.")
|
|
|
|
return None, None
|
|
|
|
|
|
|
|
try:
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
return torch, nn
|
2020-01-28 20:07:55 +01:00
|
|
|
except ImportError as e:
|
|
|
|
if error:
|
|
|
|
raise e
|
2020-03-28 19:08:31 -07:00
|
|
|
|
|
|
|
nn = NNStub()
|
|
|
|
nn.Module = ModuleStub
|
|
|
|
return None, nn
|
2020-02-11 00:22:07 +01:00
|
|
|
|
|
|
|
|
|
|
|
def get_variable(value, framework="tf", tf_name="unnamed-variable"):
|
|
|
|
"""
|
|
|
|
Args:
|
|
|
|
value (any): The initial value to use. In the non-tf case, this will
|
|
|
|
be returned as is.
|
|
|
|
framework (str): One of "tf", "torch", or None.
|
|
|
|
tf_name (str): An optional name for the variable. Only for tf.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
any: A framework-specific variable (tf.Variable or python primitive).
|
|
|
|
"""
|
|
|
|
if framework == "tf":
|
|
|
|
import tensorflow as tf
|
2020-03-29 00:16:30 +01:00
|
|
|
dtype = getattr(
|
|
|
|
value, "dtype", tf.float32
|
|
|
|
if isinstance(value, float) else tf.int32
|
|
|
|
if isinstance(value, int) else None)
|
|
|
|
return tf.compat.v1.get_variable(
|
|
|
|
tf_name, initializer=value, dtype=dtype)
|
2020-02-11 00:22:07 +01:00
|
|
|
# torch or None: Return python primitive.
|
|
|
|
return value
|
|
|
|
|
|
|
|
|
|
|
|
# This call should never happen inside a module's functions/classes
|
|
|
|
# as it would re-disable tf-eager.
|
|
|
|
tf = try_import_tf()
|
|
|
|
torch, _ = try_import_torch()
|