[RLlib] Issue 9071 A3C w/ RNN not working due to VF assuming no RNN. (#13238)

This commit is contained in:
Sven Mika 2021-01-19 14:22:36 +01:00 committed by GitHub
parent e74947cc94
commit 2e3655e8a9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
24 changed files with 251 additions and 255 deletions

View file

@ -1,17 +1,34 @@
"""Note: Keep in sync with changes to VTraceTFPolicy.""" """Note: Keep in sync with changes to VTraceTFPolicy."""
import ray import ray
from ray.rllib.agents.ppo.ppo_tf_policy import ValueNetworkMixin
from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.evaluation.postprocessing import compute_advantages, \ from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
Postprocessing Postprocessing
from ray.rllib.policy.tf_policy_template import build_tf_policy from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.policy.tf_policy import LearningRateSchedule from ray.rllib.policy.tf_policy import LearningRateSchedule
from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.tf_ops import explained_variance, make_tf_callable from ray.rllib.utils.tf_ops import explained_variance
tf1, tf, tfv = try_import_tf() tf1, tf, tfv = try_import_tf()
def postprocess_advantages(policy,
sample_batch,
other_agent_batches=None,
episode=None):
# Stub serving backward compatibility.
deprecation_warning(
old="rllib.agents.a3c.a3c_tf_policy.postprocess_advantages",
new="rllib.evaluation.postprocessing.compute_gae_for_sample_batch",
error=False)
return compute_gae_for_sample_batch(policy, sample_batch,
other_agent_batches, episode)
class A3CLoss: class A3CLoss:
def __init__(self, def __init__(self,
action_dist, action_dist,
@ -45,46 +62,10 @@ def actor_critic_loss(policy, model, dist_class, train_batch):
return policy.loss.total_loss return policy.loss.total_loss
def postprocess_advantages(policy,
sample_batch,
other_agent_batches=None,
episode=None):
completed = sample_batch[SampleBatch.DONES][-1]
if completed:
last_r = 0.0
else:
next_state = []
for i in range(policy.num_state_tensors()):
next_state.append(sample_batch["state_out_{}".format(i)][-1])
last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1],
sample_batch[SampleBatch.ACTIONS][-1],
sample_batch[SampleBatch.REWARDS][-1],
*next_state)
return compute_advantages(
sample_batch, last_r, policy.config["gamma"], policy.config["lambda"],
policy.config["use_gae"], policy.config["use_critic"])
def add_value_function_fetch(policy): def add_value_function_fetch(policy):
return {SampleBatch.VF_PREDS: policy.model.value_function()} return {SampleBatch.VF_PREDS: policy.model.value_function()}
class ValueNetworkMixin:
def __init__(self):
@make_tf_callable(self.get_session())
def value(ob, prev_action, prev_reward, *state):
model_out, _ = self.model({
SampleBatch.CUR_OBS: tf.convert_to_tensor([ob]),
SampleBatch.PREV_ACTIONS: tf.convert_to_tensor([prev_action]),
SampleBatch.PREV_REWARDS: tf.convert_to_tensor([prev_reward]),
"is_training": tf.convert_to_tensor(False),
}, [tf.convert_to_tensor([s]) for s in state],
tf.convert_to_tensor([1]))
return self.model.value_function()[0]
self._value = value
def stats(policy, train_batch): def stats(policy, train_batch):
return { return {
"cur_lr": tf.cast(policy.cur_lr, tf.float64), "cur_lr": tf.cast(policy.cur_lr, tf.float64),
@ -115,7 +96,7 @@ def clip_gradients(policy, optimizer, loss):
def setup_mixins(policy, obs_space, action_space, config): def setup_mixins(policy, obs_space, action_space, config):
ValueNetworkMixin.__init__(policy) ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
@ -126,7 +107,7 @@ A3CTFPolicy = build_tf_policy(
stats_fn=stats, stats_fn=stats,
grad_stats_fn=grad_stats, grad_stats_fn=grad_stats,
gradients_fn=clip_gradients, gradients_fn=clip_gradients,
postprocess_fn=postprocess_advantages, postprocess_fn=compute_gae_for_sample_batch,
extra_action_fetches_fn=add_value_function_fetch, extra_action_fetches_fn=add_value_function_fetch,
before_loss_init=setup_mixins, before_loss_init=setup_mixins,
mixins=[ValueNetworkMixin, LearningRateSchedule]) mixins=[ValueNetworkMixin, LearningRateSchedule])

View file

@ -1,13 +1,35 @@
import gym
import ray import ray
from ray.rllib.evaluation.postprocessing import compute_advantages, \ from ray.rllib.agents.ppo.ppo_torch_policy import ValueNetworkMixin
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
Postprocessing Postprocessing
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_ops import apply_grad_clipping
from ray.rllib.utils.typing import TrainerConfigDict
torch, nn = try_import_torch() torch, nn = try_import_torch()
def add_advantages(policy,
sample_batch,
other_agent_batches=None,
episode=None):
# Stub serving backward compatibility.
deprecation_warning(
old="rllib.agents.a3c.a3c_torch_policy.add_advantages",
new="rllib.evaluation.postprocessing.compute_gae_for_sample_batch",
error=False)
return compute_gae_for_sample_batch(policy, sample_batch,
other_agent_batches, episode)
def actor_critic_loss(policy, model, dist_class, train_batch): def actor_critic_loss(policy, model, dist_class, train_batch):
logits, _ = model.from_batch(train_batch) logits, _ = model.from_batch(train_batch)
values = model.value_function() values = model.value_function()
@ -36,52 +58,27 @@ def loss_and_entropy_stats(policy, train_batch):
} }
def add_advantages(policy,
sample_batch,
other_agent_batches=None,
episode=None):
completed = sample_batch[SampleBatch.DONES][-1]
if completed:
last_r = 0.0
else:
last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1])
return compute_advantages(
sample_batch, last_r, policy.config["gamma"], policy.config["lambda"],
policy.config["use_gae"], policy.config["use_critic"])
def model_value_predictions(policy, input_dict, state_batches, model, def model_value_predictions(policy, input_dict, state_batches, model,
action_dist): action_dist):
return {SampleBatch.VF_PREDS: model.value_function()} return {SampleBatch.VF_PREDS: model.value_function()}
def apply_grad_clipping(policy, optimizer, loss):
info = {}
if policy.config["grad_clip"]:
for param_group in optimizer.param_groups:
# Make sure we only pass params with grad != None into torch
# clip_grad_norm_. Would fail otherwise.
params = list(
filter(lambda p: p.grad is not None, param_group["params"]))
if params:
grad_gnorm = nn.utils.clip_grad_norm_(
params, policy.config["grad_clip"])
if isinstance(grad_gnorm, torch.Tensor):
grad_gnorm = grad_gnorm.cpu().numpy()
info["grad_gnorm"] = grad_gnorm
return info
def torch_optimizer(policy, config): def torch_optimizer(policy, config):
return torch.optim.Adam(policy.model.parameters(), lr=config["lr"]) return torch.optim.Adam(policy.model.parameters(), lr=config["lr"])
class ValueNetworkMixin: def setup_mixins(policy: Policy, obs_space: gym.spaces.Space,
def _value(self, obs): action_space: gym.spaces.Space,
_ = self.model({"obs": torch.Tensor([obs]).to(self.device)}, [], [1]) config: TrainerConfigDict) -> None:
return self.model.value_function()[0] """Call all mixin classes' constructors before PPOPolicy initialization.
Args:
policy (Policy): The Policy object.
obs_space (gym.spaces.Space): The Policy's observation space.
action_space (gym.spaces.Space): The Policy's action space.
config (TrainerConfigDict): The Policy's config.
"""
ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
A3CTorchPolicy = build_policy_class( A3CTorchPolicy = build_policy_class(
@ -90,9 +87,10 @@ A3CTorchPolicy = build_policy_class(
get_default_config=lambda: ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, get_default_config=lambda: ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG,
loss_fn=actor_critic_loss, loss_fn=actor_critic_loss,
stats_fn=loss_and_entropy_stats, stats_fn=loss_and_entropy_stats,
postprocess_fn=add_advantages, postprocess_fn=compute_gae_for_sample_batch,
extra_action_out_fn=model_value_predictions, extra_action_out_fn=model_value_predictions,
extra_grad_process_fn=apply_grad_clipping, extra_grad_process_fn=apply_grad_clipping,
optimizer_fn=torch_optimizer, optimizer_fn=torch_optimizer,
before_loss_init=setup_mixins,
mixins=[ValueNetworkMixin], mixins=[ValueNetworkMixin],
) )

View file

@ -8,7 +8,6 @@ from typing import Dict, List, Tuple, Type, Union
import ray import ray
import ray.experimental.tf_utils import ray.experimental.tf_utils
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
from ray.rllib.agents.sac.sac_tf_policy import postprocess_trajectory, \ from ray.rllib.agents.sac.sac_tf_policy import postprocess_trajectory, \
validate_spaces validate_spaces
from ray.rllib.agents.sac.sac_torch_policy import _get_dist_class, stats, \ from ray.rllib.agents.sac.sac_torch_policy import _get_dist_class, stats, \
@ -22,7 +21,8 @@ from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import LocalOptimizer, TensorType, \ from ray.rllib.utils.typing import LocalOptimizer, TensorType, \
TrainerConfigDict TrainerConfigDict
from ray.rllib.utils.torch_ops import convert_to_torch_tensor from ray.rllib.utils.torch_ops import apply_grad_clipping, \
convert_to_torch_tensor
torch, nn = try_import_torch() torch, nn = try_import_torch()
F = nn.functional F = nn.functional

View file

@ -1,7 +1,6 @@
import logging import logging
import ray import ray
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
from ray.rllib.agents.ddpg.ddpg_tf_policy import build_ddpg_models, \ from ray.rllib.agents.ddpg.ddpg_tf_policy import build_ddpg_models, \
get_distribution_inputs_and_class, validate_spaces get_distribution_inputs_and_class, validate_spaces
from ray.rllib.agents.dqn.dqn_tf_policy import postprocess_nstep_and_prio, \ from ray.rllib.agents.dqn.dqn_tf_policy import postprocess_nstep_and_prio, \
@ -10,7 +9,7 @@ from ray.rllib.models.torch.torch_action_dist import TorchDeterministic
from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_ops import huber_loss, l2_loss from ray.rllib.utils.torch_ops import apply_grad_clipping, huber_loss, l2_loss
torch, nn = try_import_torch() torch, nn = try_import_torch()

View file

@ -301,17 +301,11 @@ def adam_optimizer(policy: Policy, config: TrainerConfigDict
def clip_gradients(policy: Policy, optimizer: "tf.keras.optimizers.Optimizer", def clip_gradients(policy: Policy, optimizer: "tf.keras.optimizers.Optimizer",
loss: TensorType) -> ModelGradients: loss: TensorType) -> ModelGradients:
if policy.config["grad_clip"] is not None: return minimize_and_clip(
grads_and_vars = minimize_and_clip( optimizer,
optimizer, loss,
loss, var_list=policy.q_func_vars,
var_list=policy.q_func_vars, clip_val=policy.config["grad_clip"])
clip_val=policy.config["grad_clip"])
else:
grads_and_vars = optimizer.compute_gradients(
loss, var_list=policy.q_func_vars)
grads_and_vars = [(g, v) for (g, v) in grads_and_vars if g is not None]
return grads_and_vars
def build_q_stats(policy: Policy, batch) -> Dict[str, TensorType]: def build_q_stats(policy: Policy, batch) -> Dict[str, TensorType]:

View file

@ -4,7 +4,6 @@ from typing import Dict, List, Tuple
import gym import gym
import ray import ray
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
from ray.rllib.agents.dqn.dqn_tf_policy import ( from ray.rllib.agents.dqn.dqn_tf_policy import (
PRIO_WEIGHTS, Q_SCOPE, Q_TARGET_SCOPE, postprocess_nstep_and_prio) PRIO_WEIGHTS, Q_SCOPE, Q_TARGET_SCOPE, postprocess_nstep_and_prio)
from ray.rllib.agents.dqn.dqn_torch_model import DQNTorchModel from ray.rllib.agents.dqn.dqn_torch_model import DQNTorchModel
@ -20,9 +19,8 @@ from ray.rllib.policy.torch_policy import LearningRateSchedule
from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils.exploration.parameter_noise import ParameterNoise from ray.rllib.utils.exploration.parameter_noise import ParameterNoise
from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_ops import (FLOAT_MIN, huber_loss, from ray.rllib.utils.torch_ops import apply_grad_clipping, FLOAT_MIN, \
reduce_mean_ignore_inf, huber_loss, reduce_mean_ignore_inf, softmax_cross_entropy_with_logits
softmax_cross_entropy_with_logits)
from ray.rllib.utils.typing import TensorType, TrainerConfigDict from ray.rllib.utils.typing import TensorType, TrainerConfigDict
torch, nn = try_import_torch() torch, nn = try_import_torch()

View file

@ -1,11 +1,11 @@
import logging import logging
import ray import ray
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
from ray.rllib.agents.dreamer.utils import FreezeParameters from ray.rllib.agents.dreamer.utils import FreezeParameters
from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_ops import apply_grad_clipping
torch, nn = try_import_torch() torch, nn = try_import_torch()
if torch: if torch:

View file

@ -1,7 +1,6 @@
import logging import logging
import ray import ray
from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy
from ray.rllib.agents.impala.vtrace_tf_policy import VTraceTFPolicy from ray.rllib.agents.impala.vtrace_tf_policy import VTraceTFPolicy
from ray.rllib.agents.trainer import Trainer, with_common_config from ray.rllib.agents.trainer import Trainer, with_common_config
from ray.rllib.agents.trainer_template import build_trainer from ray.rllib.agents.trainer_template import build_trainer
@ -160,6 +159,7 @@ def get_policy_class(config):
if config["vtrace"]: if config["vtrace"]:
return VTraceTFPolicy return VTraceTFPolicy
else: else:
from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy
return A3CTFPolicy return A3CTFPolicy

View file

@ -3,7 +3,6 @@ import logging
import numpy as np import numpy as np
import ray import ray
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
import ray.rllib.agents.impala.vtrace_torch as vtrace import ray.rllib.agents.impala.vtrace_torch as vtrace
from ray.rllib.models.torch.torch_action_dist import TorchCategorical from ray.rllib.models.torch.torch_action_dist import TorchCategorical
from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.policy.policy_template import build_policy_class
@ -11,8 +10,8 @@ from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy import LearningRateSchedule, \ from ray.rllib.policy.torch_policy import LearningRateSchedule, \
EntropyCoeffSchedule EntropyCoeffSchedule
from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_ops import explained_variance, global_norm, \ from ray.rllib.utils.torch_ops import apply_grad_clipping, \
sequence_mask explained_variance, global_norm, sequence_mask
torch, nn = try_import_torch() torch, nn = try_import_torch()

View file

@ -1,10 +1,10 @@
import logging import logging
import ray import ray
from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae, \ from ray.rllib.agents.ppo.ppo_tf_policy import vf_preds_fetches, \
vf_preds_fetches, compute_and_clip_gradients, setup_config, \ compute_and_clip_gradients, setup_config, ValueNetworkMixin
ValueNetworkMixin from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
from ray.rllib.evaluation.postprocessing import Postprocessing Postprocessing
from ray.rllib.models.utils import get_activation_fn from ray.rllib.models.utils import get_activation_fn
from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.tf_policy_template import build_tf_policy from ray.rllib.policy.tf_policy_template import build_tf_policy
@ -422,7 +422,7 @@ MAMLTFPolicy = build_tf_policy(
stats_fn=maml_stats, stats_fn=maml_stats,
optimizer_fn=maml_optimizer_fn, optimizer_fn=maml_optimizer_fn,
extra_action_fetches_fn=vf_preds_fetches, extra_action_fetches_fn=vf_preds_fetches,
postprocess_fn=postprocess_ppo_gae, postprocess_fn=compute_gae_for_sample_batch,
gradients_fn=compute_and_clip_gradients, gradients_fn=compute_and_clip_gradients,
before_init=setup_config, before_init=setup_config,
before_loss_init=setup_mixins, before_loss_init=setup_mixins,

View file

@ -1,15 +1,15 @@
import logging import logging
import ray import ray
from ray.rllib.evaluation.postprocessing import Postprocessing from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
Postprocessing
from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae, \ from ray.rllib.agents.ppo.ppo_tf_policy import setup_config
setup_config
from ray.rllib.agents.ppo.ppo_torch_policy import vf_preds_fetches, \ from ray.rllib.agents.ppo.ppo_torch_policy import vf_preds_fetches, \
ValueNetworkMixin ValueNetworkMixin
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_ops import apply_grad_clipping
torch, nn = try_import_torch() torch, nn = try_import_torch()
@ -355,7 +355,7 @@ MAMLTorchPolicy = build_policy_class(
stats_fn=maml_stats, stats_fn=maml_stats,
optimizer_fn=maml_optimizer_fn, optimizer_fn=maml_optimizer_fn,
extra_action_out_fn=vf_preds_fetches, extra_action_out_fn=vf_preds_fetches,
postprocess_fn=postprocess_ppo_gae, postprocess_fn=compute_gae_for_sample_batch,
extra_grad_process_fn=apply_grad_clipping, extra_grad_process_fn=apply_grad_clipping,
before_init=setup_config, before_init=setup_config,
after_init=setup_mixins, after_init=setup_mixins,

View file

@ -3,18 +3,18 @@ import logging
from typing import Tuple, Type from typing import Tuple, Type
import ray import ray
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
from ray.rllib.agents.maml.maml_torch_policy import setup_mixins, \ from ray.rllib.agents.maml.maml_torch_policy import setup_mixins, \
maml_loss, maml_stats, maml_optimizer_fn, KLCoeffMixin maml_loss, maml_stats, maml_optimizer_fn, KLCoeffMixin
from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae, \ from ray.rllib.agents.ppo.ppo_tf_policy import setup_config
setup_config
from ray.rllib.agents.ppo.ppo_torch_policy import vf_preds_fetches from ray.rllib.agents.ppo.ppo_torch_policy import vf_preds_fetches
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch
from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
from ray.rllib.policy.policy import Policy from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_ops import apply_grad_clipping
from ray.rllib.utils.typing import TrainerConfigDict from ray.rllib.utils.typing import TrainerConfigDict
torch, nn = try_import_torch() torch, nn = try_import_torch()
@ -85,7 +85,7 @@ MBMPOTorchPolicy = build_policy_class(
stats_fn=maml_stats, stats_fn=maml_stats,
optimizer_fn=maml_optimizer_fn, optimizer_fn=maml_optimizer_fn,
extra_action_out_fn=vf_preds_fetches, extra_action_out_fn=vf_preds_fetches,
postprocess_fn=postprocess_ppo_gae, postprocess_fn=compute_gae_for_sample_batch,
extra_grad_process_fn=apply_grad_clipping, extra_grad_process_fn=apply_grad_clipping,
before_init=setup_config, before_init=setup_config,
after_init=setup_mixins, after_init=setup_mixins,

View file

@ -13,9 +13,9 @@ from typing import Dict, List, Optional, Type, Union
from ray.rllib.agents.impala import vtrace_tf as vtrace from ray.rllib.agents.impala import vtrace_tf as vtrace
from ray.rllib.agents.impala.vtrace_tf_policy import _make_time_major, \ from ray.rllib.agents.impala.vtrace_tf_policy import _make_time_major, \
clip_gradients, choose_optimizer clip_gradients, choose_optimizer
from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae
from ray.rllib.evaluation.episode import MultiAgentEpisode from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.rllib.evaluation.postprocessing import Postprocessing from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
Postprocessing
from ray.rllib.models.tf.tf_action_dist import Categorical from ray.rllib.models.tf.tf_action_dist import Categorical
from ray.rllib.policy.policy import Policy from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.sample_batch import SampleBatch
@ -338,8 +338,8 @@ def postprocess_trajectory(
SampleBatch: The postprocessed, modified SampleBatch (or a new one). SampleBatch: The postprocessed, modified SampleBatch (or a new one).
""" """
if not policy.config["vtrace"]: if not policy.config["vtrace"]:
sample_batch = postprocess_ppo_gae(policy, sample_batch, sample_batch = compute_gae_for_sample_batch(
other_agent_batches, episode) policy, sample_batch, other_agent_batches, episode)
# TODO: (sven) remove this del once we have trajectory view API fully in # TODO: (sven) remove this del once we have trajectory view API fully in
# place. # place.

View file

@ -10,7 +10,6 @@ import numpy as np
import logging import logging
from typing import Type from typing import Type
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
import ray.rllib.agents.impala.vtrace_torch as vtrace import ray.rllib.agents.impala.vtrace_torch as vtrace
from ray.rllib.agents.impala.vtrace_torch_policy import make_time_major, \ from ray.rllib.agents.impala.vtrace_torch_policy import make_time_major, \
choose_optimizer choose_optimizer
@ -27,8 +26,8 @@ from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy import LearningRateSchedule from ray.rllib.policy.torch_policy import LearningRateSchedule
from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_ops import explained_variance, global_norm, \ from ray.rllib.utils.torch_ops import apply_grad_clipping, explained_variance,\
sequence_mask global_norm, sequence_mask
from ray.rllib.utils.typing import TensorType, TrainerConfigDict from ray.rllib.utils.typing import TensorType, TrainerConfigDict
torch, nn = try_import_torch() torch, nn = try_import_torch()

View file

@ -8,7 +8,7 @@ from typing import Dict, List, Optional, Type, Union
import ray import ray
from ray.rllib.evaluation.episode import MultiAgentEpisode from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.rllib.evaluation.postprocessing import compute_advantages, \ from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
Postprocessing Postprocessing
from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.tf.tf_action_dist import TFActionDistribution from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
@ -160,71 +160,6 @@ def vf_preds_fetches(policy: Policy) -> Dict[str, TensorType]:
} }
def postprocess_ppo_gae(
policy: Policy,
sample_batch: SampleBatch,
other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
episode: Optional[MultiAgentEpisode] = None) -> SampleBatch:
"""Postprocesses a trajectory and returns the processed trajectory.
The trajectory contains only data from one episode and from one agent.
- If `config.batch_mode=truncate_episodes` (default), sample_batch may
contain a truncated (at-the-end) episode, in case the
`config.rollout_fragment_length` was reached by the sampler.
- If `config.batch_mode=complete_episodes`, sample_batch will contain
exactly one episode (no matter how long).
New columns can be added to sample_batch and existing ones may be altered.
Args:
policy (Policy): The Policy used to generate the trajectory
(`sample_batch`)
sample_batch (SampleBatch): The SampleBatch to postprocess.
other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional
dict of AgentIDs mapping to other agents' trajectory data (from the
same episode). NOTE: The other agents use the same policy.
episode (Optional[MultiAgentEpisode]): Optional multi-agent episode
object in which the agents operated.
Returns:
SampleBatch: The postprocessed, modified SampleBatch (or a new one).
"""
# Trajectory is actually complete -> last r=0.0.
if sample_batch[SampleBatch.DONES][-1]:
last_r = 0.0
# Trajectory has been truncated -> last r=VF estimate of last obs.
else:
# Input dict is provided to us automatically via the Model's
# requirements. It's a single-timestep (last one in trajectory)
# input_dict.
if policy.config["_use_trajectory_view_api"]:
# Create an input dict according to the Model's requirements.
input_dict = policy.model.get_input_dict(
sample_batch, index="last")
last_r = policy._value(**input_dict)
# TODO: (sven) Remove once trajectory view API is all-algo default.
else:
next_state = []
for i in range(policy.num_state_tensors()):
next_state.append(sample_batch["state_out_{}".format(i)][-1])
last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1],
sample_batch[SampleBatch.ACTIONS][-1],
sample_batch[SampleBatch.REWARDS][-1],
*next_state)
# Adds the policy logits, VF preds, and advantages to the batch,
# using GAE ("generalized advantage estimation") or not.
batch = compute_advantages(
sample_batch,
last_r,
policy.config["gamma"],
policy.config["lambda"],
use_gae=policy.config["use_gae"],
use_critic=policy.config.get("use_critic", True))
return batch
def compute_and_clip_gradients(policy: Policy, optimizer: LocalOptimizer, def compute_and_clip_gradients(policy: Policy, optimizer: LocalOptimizer,
loss: TensorType) -> ModelGradients: loss: TensorType) -> ModelGradients:
"""Gradients computing function (from loss tensor, using local optimizer). """Gradients computing function (from loss tensor, using local optimizer).
@ -392,13 +327,29 @@ def setup_mixins(policy: Policy, obs_space: gym.spaces.Space,
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
def postprocess_ppo_gae(
policy: Policy,
sample_batch: SampleBatch,
other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
episode: Optional[MultiAgentEpisode] = None) -> SampleBatch:
# Stub serving backward compatibility.
deprecation_warning(
old="rllib.agents.ppo.ppo_tf_policy.postprocess_ppo_gae",
new="rllib.evaluation.postprocessing.compute_gae_for_sample_batch",
error=False)
return compute_gae_for_sample_batch(policy, sample_batch,
other_agent_batches, episode)
# Build a child class of `DynamicTFPolicy`, given the custom functions defined # Build a child class of `DynamicTFPolicy`, given the custom functions defined
# above. # above.
PPOTFPolicy = build_tf_policy( PPOTFPolicy = build_tf_policy(
name="PPOTFPolicy", name="PPOTFPolicy",
loss_fn=ppo_surrogate_loss, loss_fn=ppo_surrogate_loss,
get_default_config=lambda: ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG, get_default_config=lambda: ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG,
postprocess_fn=postprocess_ppo_gae, postprocess_fn=compute_gae_for_sample_batch,
stats_fn=kl_and_loss_stats, stats_fn=kl_and_loss_stats,
gradients_fn=compute_and_clip_gradients, gradients_fn=compute_and_clip_gradients,
extra_action_fetches_fn=vf_preds_fetches, extra_action_fetches_fn=vf_preds_fetches,

View file

@ -7,10 +7,9 @@ import numpy as np
from typing import Dict, List, Type, Union from typing import Dict, List, Type, Union
import ray import ray
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping from ray.rllib.agents.ppo.ppo_tf_policy import setup_config
from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae, \ from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
setup_config Postprocessing
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
from ray.rllib.policy.policy import Policy from ray.rllib.policy.policy import Policy
@ -19,8 +18,8 @@ from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy import EntropyCoeffSchedule, \ from ray.rllib.policy.torch_policy import EntropyCoeffSchedule, \
LearningRateSchedule LearningRateSchedule
from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_ops import convert_to_torch_tensor, \ from ray.rllib.utils.torch_ops import apply_grad_clipping, \
explained_variance, sequence_mask convert_to_torch_tensor, explained_variance, sequence_mask
from ray.rllib.utils.typing import TensorType, TrainerConfigDict from ray.rllib.utils.typing import TensorType, TrainerConfigDict
torch, nn = try_import_torch() torch, nn = try_import_torch()
@ -279,7 +278,7 @@ PPOTorchPolicy = build_policy_class(
loss_fn=ppo_surrogate_loss, loss_fn=ppo_surrogate_loss,
stats_fn=kl_and_loss_stats, stats_fn=kl_and_loss_stats,
extra_action_out_fn=vf_preds_fetches, extra_action_out_fn=vf_preds_fetches,
postprocess_fn=postprocess_ppo_gae, postprocess_fn=compute_gae_for_sample_batch,
extra_grad_process_fn=apply_grad_clipping, extra_grad_process_fn=apply_grad_clipping,
before_init=setup_config, before_init=setup_config,
before_loss_init=setup_mixins, before_loss_init=setup_mixins,

View file

@ -5,11 +5,12 @@ import unittest
import ray import ray
from ray.rllib.agents.callbacks import DefaultCallbacks from ray.rllib.agents.callbacks import DefaultCallbacks
import ray.rllib.agents.ppo as ppo import ray.rllib.agents.ppo as ppo
from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae as \ from ray.rllib.agents.ppo.ppo_tf_policy import ppo_surrogate_loss as \
postprocess_ppo_gae_tf, ppo_surrogate_loss as ppo_surrogate_loss_tf ppo_surrogate_loss_tf
from ray.rllib.agents.ppo.ppo_torch_policy import postprocess_ppo_gae as \ from ray.rllib.agents.ppo.ppo_torch_policy import ppo_surrogate_loss as \
postprocess_ppo_gae_torch, ppo_surrogate_loss as ppo_surrogate_loss_torch ppo_surrogate_loss_torch
from ray.rllib.evaluation.postprocessing import Postprocessing from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
Postprocessing
from ray.rllib.models.tf.tf_action_dist import Categorical from ray.rllib.models.tf.tf_action_dist import Categorical
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.torch_action_dist import TorchCategorical from ray.rllib.models.torch.torch_action_dist import TorchCategorical
@ -212,11 +213,8 @@ class TestPPO(unittest.TestCase):
# Check the variable is initially zero. # Check the variable is initially zero.
init_std = get_value() init_std = get_value()
assert init_std == 0.0, init_std assert init_std == 0.0, init_std
batch = compute_gae_for_sample_batch(policy, FAKE_BATCH.copy())
if fw in ["tf2", "tf", "tfe"]: if fw == "torch":
batch = postprocess_ppo_gae_tf(policy, FAKE_BATCH.copy())
else:
batch = postprocess_ppo_gae_torch(policy, FAKE_BATCH.copy())
batch = policy._lazy_tensor_dict(batch) batch = policy._lazy_tensor_dict(batch)
policy.learn_on_batch(batch) policy.learn_on_batch(batch)
@ -255,11 +253,9 @@ class TestPPO(unittest.TestCase):
# to train_batch dict. # to train_batch dict.
# A = [0.99^2 * 0.5 + 0.99 * -1.0 + 1.0, 0.99 * 0.5 - 1.0, 0.5] = # A = [0.99^2 * 0.5 + 0.99 * -1.0 + 1.0, 0.99 * 0.5 - 1.0, 0.5] =
# [0.50005, -0.505, 0.5] # [0.50005, -0.505, 0.5]
if fw in ["tf2", "tf", "tfe"]: train_batch = compute_gae_for_sample_batch(policy,
train_batch = postprocess_ppo_gae_tf(policy, FAKE_BATCH.copy()) FAKE_BATCH.copy())
else: if fw == "torch":
train_batch = postprocess_ppo_gae_torch(
policy, FAKE_BATCH.copy())
train_batch = policy._lazy_tensor_dict(train_batch) train_batch = policy._lazy_tensor_dict(train_batch)
# Check Advantage values. # Check Advantage values.

View file

@ -73,8 +73,7 @@ DEFAULT_CONFIG = with_common_config({
"timesteps_per_iteration": 100, "timesteps_per_iteration": 100,
# === Replay buffer === # === Replay buffer ===
# Size of the replay buffer. Note that if async_updates is set, then # Size of the replay buffer (in time steps).
# each worker will have a replay buffer of this size.
"buffer_size": int(1e6), "buffer_size": int(1e6),
# If True prioritized replay buffer will be used. # If True prioritized replay buffer will be used.
"prioritized_replay": False, "prioritized_replay": False,
@ -104,9 +103,7 @@ DEFAULT_CONFIG = with_common_config({
# Update the replay buffer with this many samples at once. Note that this # Update the replay buffer with this many samples at once. Note that this
# setting applies per-worker if num_workers > 1. # setting applies per-worker if num_workers > 1.
"rollout_fragment_length": 1, "rollout_fragment_length": 1,
# Size of a batched sampled from replay buffer for training. Note that # Size of a batched sampled from replay buffer for training.
# if async_updates is set, then each worker returns gradients for a
# batch of this size.
"train_batch_size": 256, "train_batch_size": 256,
# Update the target network every `target_network_update_freq` steps. # Update the target network every `target_network_update_freq` steps.
"target_network_update_freq": 0, "target_network_update_freq": 0,

View file

@ -9,7 +9,6 @@ from typing import Dict, List, Optional, Tuple, Type, Union
import ray import ray
import ray.experimental.tf_utils import ray.experimental.tf_utils
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
from ray.rllib.agents.sac.sac_tf_policy import build_sac_model, \ from ray.rllib.agents.sac.sac_tf_policy import build_sac_model, \
postprocess_trajectory, validate_spaces postprocess_trajectory, validate_spaces
from ray.rllib.agents.dqn.dqn_tf_policy import PRIO_WEIGHTS from ray.rllib.agents.dqn.dqn_tf_policy import PRIO_WEIGHTS
@ -23,7 +22,7 @@ from ray.rllib.models.torch.torch_action_dist import (
TorchCategorical, TorchSquashedGaussian, TorchDiagGaussian, TorchBeta) TorchCategorical, TorchSquashedGaussian, TorchDiagGaussian, TorchBeta)
from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.spaces.simplex import Simplex from ray.rllib.utils.spaces.simplex import Simplex
from ray.rllib.utils.torch_ops import huber_loss from ray.rllib.utils.torch_ops import apply_grad_clipping, huber_loss
from ray.rllib.utils.typing import LocalOptimizer, TensorType, \ from ray.rllib.utils.typing import LocalOptimizer, TensorType, \
TrainerConfigDict TrainerConfigDict

View file

@ -265,17 +265,11 @@ class MADDPGTFPolicy(MADDPGPostprocessing, TFPolicy):
@override(TFPolicy) @override(TFPolicy)
def gradients(self, optimizer, loss): def gradients(self, optimizer, loss):
if self.config["grad_norm_clipping"] is not None: self.gvs = {
self.gvs = { k: minimize_and_clip(optimizer, self.losses[k], self.vars[k],
k: minimize_and_clip(optimizer, self.losses[k], self.vars[k], self.config["grad_norm_clipping"])
self.config["grad_norm_clipping"]) for k, optimizer in self.optimizers.items()
for k, optimizer in self.optimizers.items() }
}
else:
self.gvs = {
k: optimizer.compute_gradients(self.losses[k], self.vars[k])
for k, optimizer in self.optimizers.items()
}
return self.gvs["critic"] + self.gvs["actor"] return self.gvs["critic"] + self.gvs["actor"]
@override(TFPolicy) @override(TFPolicy)

View file

@ -1,22 +1,12 @@
import numpy as np import numpy as np
import scipy.signal import scipy.signal
from typing import Dict, Optional
from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.typing import AgentID
def discount_cumsum(x: np.ndarray, gamma: float) -> float:
"""Calculates the discounted cumulative sum over a reward sequence `x`.
y[t] - discount*y[t+1] = x[t]
reversed(y)[t] - discount*reversed(y)[t-1] = reversed(x)[t]
Args:
gamma (float): The discount factor gamma.
Returns:
float: The discounted cumulative sum over the reward sequence `x`.
"""
return scipy.signal.lfilter([1], [1, float(-gamma)], x[::-1], axis=0)[::-1]
class Postprocessing: class Postprocessing:
@ -89,3 +79,83 @@ def compute_advantages(rollout: SampleBatch,
Postprocessing.ADVANTAGES].astype(np.float32) Postprocessing.ADVANTAGES].astype(np.float32)
return rollout return rollout
def compute_gae_for_sample_batch(
policy: Policy,
sample_batch: SampleBatch,
other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
episode: Optional[MultiAgentEpisode] = None) -> SampleBatch:
"""Adds GAE (generalized advantage estimations) to a trajectory.
The trajectory contains only data from one episode and from one agent.
- If `config.batch_mode=truncate_episodes` (default), sample_batch may
contain a truncated (at-the-end) episode, in case the
`config.rollout_fragment_length` was reached by the sampler.
- If `config.batch_mode=complete_episodes`, sample_batch will contain
exactly one episode (no matter how long).
New columns can be added to sample_batch and existing ones may be altered.
Args:
policy (Policy): The Policy used to generate the trajectory
(`sample_batch`)
sample_batch (SampleBatch): The SampleBatch to postprocess.
other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional
dict of AgentIDs mapping to other agents' trajectory data (from the
same episode). NOTE: The other agents use the same policy.
episode (Optional[MultiAgentEpisode]): Optional multi-agent episode
object in which the agents operated.
Returns:
SampleBatch: The postprocessed, modified SampleBatch (or a new one).
"""
# Trajectory is actually complete -> last r=0.0.
if sample_batch[SampleBatch.DONES][-1]:
last_r = 0.0
# Trajectory has been truncated -> last r=VF estimate of last obs.
else:
# Input dict is provided to us automatically via the Model's
# requirements. It's a single-timestep (last one in trajectory)
# input_dict.
if policy.config.get("_use_trajectory_view_api"):
# Create an input dict according to the Model's requirements.
input_dict = policy.model.get_input_dict(
sample_batch, index="last")
last_r = policy._value(**input_dict)
# TODO: (sven) Remove once trajectory view API is all-algo default.
else:
next_state = []
for i in range(policy.num_state_tensors()):
next_state.append(sample_batch["state_out_{}".format(i)][-1])
last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1],
sample_batch[SampleBatch.ACTIONS][-1],
sample_batch[SampleBatch.REWARDS][-1],
*next_state)
# Adds the policy logits, VF preds, and advantages to the batch,
# using GAE ("generalized advantage estimation") or not.
batch = compute_advantages(
sample_batch,
last_r,
policy.config["gamma"],
policy.config["lambda"],
use_gae=policy.config["use_gae"],
use_critic=policy.config.get("use_critic", True))
return batch
def discount_cumsum(x: np.ndarray, gamma: float) -> float:
"""Calculates the discounted cumulative sum over a reward sequence `x`.
y[t] - discount*y[t+1] = x[t]
reversed(y)[t] - discount*reversed(y)[t-1] = reversed(x)[t]
Args:
gamma (float): The discount factor gamma.
Returns:
float: The discounted cumulative sum over the reward sequence `x`.
"""
return scipy.signal.lfilter([1], [1, float(-gamma)], x[::-1], axis=0)[::-1]

View file

@ -14,11 +14,11 @@ mspacman-sac-tf:
# state-preprocessor=Our default Atari Conv2D-net. # state-preprocessor=Our default Atari Conv2D-net.
use_state_preprocessor: true use_state_preprocessor: true
Q_model: Q_model:
hidden_activation: relu fcnet_hiddens: [512]
hidden_layer_sizes: [512] fcnet_activation: relu
policy_model: policy_model:
hidden_activation: relu fcnet_hiddens: [512]
hidden_layer_sizes: [512] fcnet_activation: relu
# Do hard syncs. # Do hard syncs.
# Soft-syncs seem to work less reliably for discrete action spaces. # Soft-syncs seem to work less reliably for discrete action spaces.
tau: 1.0 tau: 1.0

View file

@ -92,7 +92,7 @@ def minimize_and_clip(optimizer, objective, var_list, clip_val=10.0):
variable is clipped to `clip_val` variable is clipped to `clip_val`
""" """
# Accidentally passing values < 0.0 will break all gradients. # Accidentally passing values < 0.0 will break all gradients.
assert clip_val > 0.0, clip_val assert clip_val is None or clip_val > 0.0, clip_val
if tf.executing_eagerly(): if tf.executing_eagerly():
tape = optimizer.tape tape = optimizer.tape
@ -102,10 +102,8 @@ def minimize_and_clip(optimizer, objective, var_list, clip_val=10.0):
grads_and_vars = optimizer.compute_gradients( grads_and_vars = optimizer.compute_gradients(
objective, var_list=var_list) objective, var_list=var_list)
for i, (grad, var) in enumerate(grads_and_vars): return [(tf.clip_by_norm(g, clip_val) if clip_val is not None else g, v)
if grad is not None: for (g, v) in grads_and_vars if g is not None]
grads_and_vars[i] = (tf.clip_by_norm(grad, clip_val), var)
return grads_and_vars
def make_tf_callable(session_or_none, dynamic_shape=False): def make_tf_callable(session_or_none, dynamic_shape=False):

View file

@ -14,6 +14,30 @@ FLOAT_MIN = -3.4e38
FLOAT_MAX = 3.4e38 FLOAT_MAX = 3.4e38
def apply_grad_clipping(policy, optimizer, loss):
"""Applies gradient clipping to already computed grads inside `optimizer`.
Args:
policy (TorchPolicy): The TorchPolicy, which calculated `loss`.
optimizer (torch.optim.Optimizer): A local torch optimizer object.
loss (torch.Tensor): The torch loss tensor.
"""
info = {}
if policy.config["grad_clip"]:
for param_group in optimizer.param_groups:
# Make sure we only pass params with grad != None into torch
# clip_grad_norm_. Would fail otherwise.
params = list(
filter(lambda p: p.grad is not None, param_group["params"]))
if params:
grad_gnorm = nn.utils.clip_grad_norm_(
params, policy.config["grad_clip"])
if isinstance(grad_gnorm, torch.Tensor):
grad_gnorm = grad_gnorm.cpu().numpy()
info["grad_gnorm"] = grad_gnorm
return info
def atanh(x): def atanh(x):
return 0.5 * torch.log((1 + x) / (1 - x)) return 0.5 * torch.log((1 + x) / (1 - x))