[RLlib] Support >1 loss terms and optimizers for framework=tf2 (already supported for framework=[tf|torch]) (#19269)

This commit is contained in:
Sven Mika 2021-10-10 12:19:47 +02:00 committed by GitHub
parent 635010d460
commit bd2d2079d2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 61 additions and 22 deletions

View file

@ -59,7 +59,7 @@ class TestAPPO(unittest.TestCase):
num_iterations = 2
# Only supported for tf so far.
for _ in framework_iterator(config, frameworks="tf"):
for _ in framework_iterator(config, frameworks=("tf2", "tf")):
trainer = ppo.APPOTrainer(config=config, env="CartPole-v0")
for i in range(num_iterations):
results = trainer.train()

View file

@ -334,8 +334,11 @@ def build_eager_tf_policy(
if getattr(self, "exploration", None):
optimizers = self.exploration.get_exploration_optimizer(
optimizers)
# TODO: (sven) Allow tf policy to have more than 1 optimizer.
# Just like torch Policy does.
# The list of local (tf) optimizers (one per loss term).
self._optimizers: List[LocalOptimizer] = optimizers
# Backward compatibility: A user's policy may only support a single
# loss term and optimizer (no lists).
self._optimizer: LocalOptimizer = \
optimizers[0] if optimizers else None
@ -737,42 +740,78 @@ def build_eager_tf_policy(
def _get_is_training_placeholder(self):
return tf.convert_to_tensor(self._is_training)
def _apply_gradients(self, grads_and_vars):
if apply_gradients_fn:
apply_gradients_fn(self, self._optimizer, grads_and_vars)
else:
self._optimizer.apply_gradients(
[(g, v) for g, v in grads_and_vars if g is not None])
@with_lock
def _compute_gradients(self, samples):
"""Computes and returns grads as eager tensors."""
with tf.GradientTape(persistent=compute_gradients_fn is not None) \
as tape:
loss = loss_fn(self, self.model, self.dist_class, samples)
# Gather all variables for which to calculate losses.
if isinstance(self.model, tf.keras.Model):
variables = self.model.trainable_variables
else:
variables = self.model.trainable_variables()
# Calculate the loss(es) inside a tf GradientTape.
with tf.GradientTape(persistent=compute_gradients_fn is not None) \
as tape:
losses = loss_fn(self, self.model, self.dist_class, samples)
losses = force_list(losses)
# User provided a compute_gradients_fn.
if compute_gradients_fn:
grads_and_vars = compute_gradients_fn(self,
OptimizerWrapper(tape),
loss)
# Wrap our tape inside a wrapper, such that the resulting
# object looks like a "classic" tf.optimizer. This way, custom
# compute_gradients_fn will work on both tf static graph
# and tf-eager.
optimizer = OptimizerWrapper(tape)
# More than one loss terms/optimizers.
if self.config["_tf_policy_handles_more_than_one_loss"]:
grads_and_vars = compute_gradients_fn(
self, [optimizer] * len(losses), losses)
# Only one loss and one optimizer.
else:
grads_and_vars = [
compute_gradients_fn(self, optimizer, losses[0])
]
# Default: Compute gradients using the above tape.
else:
grads_and_vars = list(
zip(tape.gradient(loss, variables), variables))
grads_and_vars = [
list(zip(tape.gradient(loss, variables), variables))
for loss in losses
]
if log_once("grad_vars"):
for _, v in grads_and_vars:
logger.info("Optimizing variable {}".format(v.name))
for g_and_v in grads_and_vars:
for g, v in g_and_v:
if g is not None:
logger.info(f"Optimizing variable {v.name}")
# `grads_and_vars` is returned a list (len=num optimizers/losses)
# of lists of (grad, var) tuples.
if self.config["_tf_policy_handles_more_than_one_loss"]:
grads = [[g for g, _ in g_and_v] for g_and_v in grads_and_vars]
# `grads_and_vars` is returned as a list of (grad, var) tuples.
else:
grads_and_vars = grads_and_vars[0]
grads = [g for g, _ in grads_and_vars]
grads = [g for g, v in grads_and_vars]
stats = self._stats(self, samples, grads)
return grads_and_vars, stats
def _apply_gradients(self, grads_and_vars):
if apply_gradients_fn:
if self.config["_tf_policy_handles_more_than_one_loss"]:
apply_gradients_fn(self, self._optimizers, grads_and_vars)
else:
apply_gradients_fn(self, self._optimizer, grads_and_vars)
else:
if self.config["_tf_policy_handles_more_than_one_loss"]:
for i, o in enumerate(self._optimizers):
o.apply_gradients([(g, v) for g, v in grads_and_vars[i]
if g is not None])
else:
self._optimizer.apply_gradients(
[(g, v) for g, v in grads_and_vars if g is not None])
def _stats(self, outputs, samples, grads):
fetches = {}