2019-12-30 15:27:32 -05:00
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
from ray.rllib.utils.framework import try_import_tf
|
|
|
|
|
|
|
|
tf = try_import_tf()
|
|
|
|
|
|
|
|
|
|
|
|
def check(x, y, decimals=5, atol=None, rtol=None, false=False):
|
|
|
|
"""
|
|
|
|
Checks two structures (dict, tuple, list,
|
|
|
|
np.array, float, int, etc..) for (almost) numeric identity.
|
|
|
|
All numbers in the two structures have to match up to `decimal` digits
|
|
|
|
after the floating point. Uses assertions.
|
|
|
|
|
|
|
|
Args:
|
2020-01-30 20:27:57 +01:00
|
|
|
x (any): The value to be compared (to the expectation: `y`). This
|
|
|
|
may be a Tensor.
|
|
|
|
y (any): The expected value to be compared to `x`. This must not
|
|
|
|
be a Tensor.
|
2019-12-30 15:27:32 -05:00
|
|
|
decimals (int): The number of digits after the floating point up to
|
|
|
|
which all numeric values have to match.
|
|
|
|
atol (float): Absolute tolerance of the difference between x and y
|
|
|
|
(overrides `decimals` if given).
|
|
|
|
rtol (float): Relative tolerance of the difference between x and y
|
|
|
|
(overrides `decimals` if given).
|
|
|
|
false (bool): Whether to check that x and y are NOT the same.
|
|
|
|
"""
|
|
|
|
# A dict type.
|
|
|
|
if isinstance(x, dict):
|
|
|
|
assert isinstance(y, dict), \
|
|
|
|
"ERROR: If x is dict, y needs to be a dict as well!"
|
|
|
|
y_keys = set(x.keys())
|
|
|
|
for key, value in x.items():
|
|
|
|
assert key in y, \
|
|
|
|
"ERROR: y does not have x's key='{}'! y={}".format(key, y)
|
2020-01-23 02:02:58 +01:00
|
|
|
check(
|
|
|
|
value,
|
|
|
|
y[key],
|
|
|
|
decimals=decimals,
|
|
|
|
atol=atol,
|
|
|
|
rtol=rtol,
|
|
|
|
false=false)
|
2019-12-30 15:27:32 -05:00
|
|
|
y_keys.remove(key)
|
|
|
|
assert not y_keys, \
|
|
|
|
"ERROR: y contains keys ({}) that are not in x! y={}".\
|
|
|
|
format(list(y_keys), y)
|
|
|
|
# A tuple type.
|
|
|
|
elif isinstance(x, (tuple, list)):
|
|
|
|
assert isinstance(y, (tuple, list)),\
|
|
|
|
"ERROR: If x is tuple, y needs to be a tuple as well!"
|
|
|
|
assert len(y) == len(x),\
|
|
|
|
"ERROR: y does not have the same length as x ({} vs {})!".\
|
|
|
|
format(len(y), len(x))
|
|
|
|
for i, value in enumerate(x):
|
2020-01-23 02:02:58 +01:00
|
|
|
check(
|
|
|
|
value,
|
|
|
|
y[i],
|
|
|
|
decimals=decimals,
|
|
|
|
atol=atol,
|
|
|
|
rtol=rtol,
|
|
|
|
false=false)
|
2019-12-30 15:27:32 -05:00
|
|
|
# Boolean comparison.
|
|
|
|
elif isinstance(x, (np.bool_, bool)):
|
|
|
|
if false is True:
|
|
|
|
assert bool(x) is not bool(y), \
|
|
|
|
"ERROR: x ({}) is y ({})!".format(x, y)
|
|
|
|
else:
|
|
|
|
assert bool(x) is bool(y), \
|
|
|
|
"ERROR: x ({}) is not y ({})!".format(x, y)
|
2020-01-23 02:02:58 +01:00
|
|
|
# Nones or primitives.
|
2020-01-28 20:07:55 +01:00
|
|
|
elif x is None or y is None or isinstance(x, (str, int)):
|
2019-12-30 15:27:32 -05:00
|
|
|
if false is True:
|
|
|
|
assert x != y, "ERROR: x ({}) is the same as y ({})!".format(x, y)
|
|
|
|
else:
|
|
|
|
assert x == y, \
|
|
|
|
"ERROR: x ({}) is not the same as y ({})!".format(x, y)
|
|
|
|
# String comparison.
|
|
|
|
elif hasattr(x, "dtype") and x.dtype == np.object:
|
|
|
|
try:
|
|
|
|
np.testing.assert_array_equal(x, y)
|
|
|
|
if false is True:
|
|
|
|
assert False, \
|
|
|
|
"ERROR: x ({}) is the same as y ({})!".format(x, y)
|
|
|
|
except AssertionError as e:
|
|
|
|
if false is False:
|
|
|
|
raise e
|
|
|
|
# Everything else (assume numeric).
|
|
|
|
else:
|
2020-01-30 20:27:57 +01:00
|
|
|
if tf is not None:
|
|
|
|
# y should never be a Tensor (y=expected value).
|
|
|
|
if isinstance(y, tf.Tensor):
|
|
|
|
raise ValueError("`y` (expected value) must not be a Tensor. "
|
|
|
|
"Use numpy.ndarray instead")
|
|
|
|
if isinstance(x, tf.Tensor):
|
|
|
|
# In eager mode, numpyize tensors.
|
|
|
|
if tf.executing_eagerly():
|
|
|
|
x = x.numpy()
|
|
|
|
# Otherwise, ???
|
|
|
|
else:
|
|
|
|
with tf.Session() as sess:
|
|
|
|
x = sess.run(x)
|
|
|
|
check(
|
|
|
|
x,
|
|
|
|
y,
|
|
|
|
decimals=decimals,
|
|
|
|
atol=atol,
|
|
|
|
rtol=rtol,
|
|
|
|
false=false)
|
2019-12-30 15:27:32 -05:00
|
|
|
|
|
|
|
# Using decimals.
|
|
|
|
if atol is None and rtol is None:
|
|
|
|
try:
|
|
|
|
np.testing.assert_almost_equal(x, y, decimal=decimals)
|
|
|
|
if false is True:
|
|
|
|
assert False, \
|
|
|
|
"ERROR: x ({}) is the same as y ({})!".format(x, y)
|
|
|
|
except AssertionError as e:
|
|
|
|
if false is False:
|
|
|
|
raise e
|
|
|
|
|
|
|
|
# Using atol/rtol.
|
|
|
|
else:
|
|
|
|
# Provide defaults for either one of atol/rtol.
|
|
|
|
if atol is None:
|
|
|
|
atol = 0
|
|
|
|
if rtol is None:
|
|
|
|
rtol = 1e-7
|
|
|
|
try:
|
|
|
|
np.testing.assert_allclose(x, y, atol=atol, rtol=rtol)
|
|
|
|
if false is True:
|
|
|
|
assert False, \
|
|
|
|
"ERROR: x ({}) is the same as y ({})!".format(x, y)
|
|
|
|
except AssertionError as e:
|
|
|
|
if false is False:
|
|
|
|
raise e
|