mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] Support >1 loss terms and optimizers for framework=tf2 (already supported for framework=[tf|torch]) (#19269)
This commit is contained in:
parent
635010d460
commit
bd2d2079d2
2 changed files with 61 additions and 22 deletions
|
@ -59,7 +59,7 @@ class TestAPPO(unittest.TestCase):
|
|||
num_iterations = 2
|
||||
|
||||
# Only supported for tf so far.
|
||||
for _ in framework_iterator(config, frameworks="tf"):
|
||||
for _ in framework_iterator(config, frameworks=("tf2", "tf")):
|
||||
trainer = ppo.APPOTrainer(config=config, env="CartPole-v0")
|
||||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
|
|
|
@ -334,8 +334,11 @@ def build_eager_tf_policy(
|
|||
if getattr(self, "exploration", None):
|
||||
optimizers = self.exploration.get_exploration_optimizer(
|
||||
optimizers)
|
||||
# TODO: (sven) Allow tf policy to have more than 1 optimizer.
|
||||
# Just like torch Policy does.
|
||||
|
||||
# The list of local (tf) optimizers (one per loss term).
|
||||
self._optimizers: List[LocalOptimizer] = optimizers
|
||||
# Backward compatibility: A user's policy may only support a single
|
||||
# loss term and optimizer (no lists).
|
||||
self._optimizer: LocalOptimizer = \
|
||||
optimizers[0] if optimizers else None
|
||||
|
||||
|
@ -737,42 +740,78 @@ def build_eager_tf_policy(
|
|||
def _get_is_training_placeholder(self):
|
||||
return tf.convert_to_tensor(self._is_training)
|
||||
|
||||
def _apply_gradients(self, grads_and_vars):
|
||||
if apply_gradients_fn:
|
||||
apply_gradients_fn(self, self._optimizer, grads_and_vars)
|
||||
else:
|
||||
self._optimizer.apply_gradients(
|
||||
[(g, v) for g, v in grads_and_vars if g is not None])
|
||||
|
||||
@with_lock
|
||||
def _compute_gradients(self, samples):
|
||||
"""Computes and returns grads as eager tensors."""
|
||||
|
||||
with tf.GradientTape(persistent=compute_gradients_fn is not None) \
|
||||
as tape:
|
||||
loss = loss_fn(self, self.model, self.dist_class, samples)
|
||||
|
||||
# Gather all variables for which to calculate losses.
|
||||
if isinstance(self.model, tf.keras.Model):
|
||||
variables = self.model.trainable_variables
|
||||
else:
|
||||
variables = self.model.trainable_variables()
|
||||
|
||||
# Calculate the loss(es) inside a tf GradientTape.
|
||||
with tf.GradientTape(persistent=compute_gradients_fn is not None) \
|
||||
as tape:
|
||||
losses = loss_fn(self, self.model, self.dist_class, samples)
|
||||
losses = force_list(losses)
|
||||
|
||||
# User provided a compute_gradients_fn.
|
||||
if compute_gradients_fn:
|
||||
grads_and_vars = compute_gradients_fn(self,
|
||||
OptimizerWrapper(tape),
|
||||
loss)
|
||||
# Wrap our tape inside a wrapper, such that the resulting
|
||||
# object looks like a "classic" tf.optimizer. This way, custom
|
||||
# compute_gradients_fn will work on both tf static graph
|
||||
# and tf-eager.
|
||||
optimizer = OptimizerWrapper(tape)
|
||||
# More than one loss terms/optimizers.
|
||||
if self.config["_tf_policy_handles_more_than_one_loss"]:
|
||||
grads_and_vars = compute_gradients_fn(
|
||||
self, [optimizer] * len(losses), losses)
|
||||
# Only one loss and one optimizer.
|
||||
else:
|
||||
grads_and_vars = [
|
||||
compute_gradients_fn(self, optimizer, losses[0])
|
||||
]
|
||||
# Default: Compute gradients using the above tape.
|
||||
else:
|
||||
grads_and_vars = list(
|
||||
zip(tape.gradient(loss, variables), variables))
|
||||
grads_and_vars = [
|
||||
list(zip(tape.gradient(loss, variables), variables))
|
||||
for loss in losses
|
||||
]
|
||||
|
||||
if log_once("grad_vars"):
|
||||
for _, v in grads_and_vars:
|
||||
logger.info("Optimizing variable {}".format(v.name))
|
||||
for g_and_v in grads_and_vars:
|
||||
for g, v in g_and_v:
|
||||
if g is not None:
|
||||
logger.info(f"Optimizing variable {v.name}")
|
||||
|
||||
# `grads_and_vars` is returned a list (len=num optimizers/losses)
|
||||
# of lists of (grad, var) tuples.
|
||||
if self.config["_tf_policy_handles_more_than_one_loss"]:
|
||||
grads = [[g for g, _ in g_and_v] for g_and_v in grads_and_vars]
|
||||
# `grads_and_vars` is returned as a list of (grad, var) tuples.
|
||||
else:
|
||||
grads_and_vars = grads_and_vars[0]
|
||||
grads = [g for g, _ in grads_and_vars]
|
||||
|
||||
grads = [g for g, v in grads_and_vars]
|
||||
stats = self._stats(self, samples, grads)
|
||||
return grads_and_vars, stats
|
||||
|
||||
def _apply_gradients(self, grads_and_vars):
|
||||
if apply_gradients_fn:
|
||||
if self.config["_tf_policy_handles_more_than_one_loss"]:
|
||||
apply_gradients_fn(self, self._optimizers, grads_and_vars)
|
||||
else:
|
||||
apply_gradients_fn(self, self._optimizer, grads_and_vars)
|
||||
else:
|
||||
if self.config["_tf_policy_handles_more_than_one_loss"]:
|
||||
for i, o in enumerate(self._optimizers):
|
||||
o.apply_gradients([(g, v) for g, v in grads_and_vars[i]
|
||||
if g is not None])
|
||||
else:
|
||||
self._optimizer.apply_gradients(
|
||||
[(g, v) for g, v in grads_and_vars if g is not None])
|
||||
|
||||
def _stats(self, outputs, samples, grads):
|
||||
|
||||
fetches = {}
|
||||
|
|
Loading…
Add table
Reference in a new issue