mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
21 lines
670 B
Python
21 lines
670 B
Python
from ray.rllib.utils import try_import_tf, try_import_torch
|
|
|
|
tf = try_import_tf()
|
|
torch, nn = try_import_torch()
|
|
|
|
|
|
def explained_variance(y, pred, framework="tf"):
|
|
if framework == "tf":
|
|
_, y_var = tf.nn.moments(y, axes=[0])
|
|
_, diff_var = tf.nn.moments(y - pred, axes=[0])
|
|
return tf.maximum(-1.0, 1 - (diff_var / y_var))
|
|
else:
|
|
y_var = torch.var(y, dim=[0])
|
|
diff_var = torch.var(y - pred, dim=[0])
|
|
min_ = torch.Tensor([-1.0])
|
|
return torch.max(
|
|
min_.to(
|
|
device=torch.device("cuda")
|
|
) if torch.cuda.is_available() else min_,
|
|
1 - (diff_var / y_var)
|
|
)
|