[RLlib] Fix AlphaStar for tf2+tracing; smaller cleanups around avoiding to wrap a TFPolicy as_eager() or with_tracing more than once. (#24271)

This commit is contained in:
Sven Mika 2022-04-28 13:43:21 +02:00 committed by GitHub
parent 576addf9ca
commit 6551922c21
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 58 additions and 28 deletions

View file

@ -6,6 +6,7 @@ from ray.actor import ActorHandle
from ray.rllib.agents.trainer import Trainer
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.utils.actors import create_colocated_actors
from ray.rllib.utils.tf_utils import get_tf_eager_cls_if_necessary
from ray.rllib.utils.typing import PolicyID, TrainerConfigDict
@ -173,6 +174,10 @@ class _Shard:
assert len(self.policy_actors) < self.max_num_policies
actual_policy_class = get_tf_eager_cls_if_necessary(
policy_spec.policy_class, cfg
)
colocated = create_colocated_actors(
actor_specs=[
(
@ -181,7 +186,7 @@ class _Shard:
num_gpus=self.num_gpus_per_policy
if not cfg["_fake_gpus"]
else 0,
)(policy_spec.policy_class),
)(actual_policy_class),
# Policy c'tor args.
(policy_spec.observation_space, policy_spec.action_space, cfg),
# Policy c'tor kwargs={}.
@ -207,6 +212,10 @@ class _Shard:
assert self.replay_actor is None
assert len(self.policy_actors) == 0
actual_policy_class = get_tf_eager_cls_if_necessary(
policy_spec.policy_class, config
)
colocated = create_colocated_actors(
actor_specs=[
(self.replay_actor_class, self.replay_actor_args, {}, 1),
@ -218,7 +227,7 @@ class _Shard:
num_gpus=self.num_gpus_per_policy
if not config["_fake_gpus"]
else 0,
)(policy_spec.policy_class),
)(actual_policy_class),
# Policy c'tor args.
(policy_spec.observation_space, policy_spec.action_space, config),
# Policy c'tor kwargs={}.

View file

@ -129,6 +129,12 @@ def check_too_many_retraces(obj):
return _func
class EagerTFPolicy(Policy):
"""Dummy class to recognize any eagerized TFPolicy by its inheritance."""
pass
def traced_eager_policy(eager_policy_cls):
"""Wrapper class that enables tracing for all eager policy methods.
@ -237,6 +243,11 @@ def traced_eager_policy(eager_policy_cls):
# `apply_gradients()` (which will call the traced helper).
return super(TracedEagerPolicy, self).apply_gradients(grads)
@classmethod
def with_tracing(cls):
# Already traced -> Return same class.
return cls
TracedEagerPolicy.__name__ = eager_policy_cls.__name__ + "_traced"
TracedEagerPolicy.__qualname__ = eager_policy_cls.__qualname__ + "_traced"
return TracedEagerPolicy
@ -287,7 +298,7 @@ def build_eager_tf_policy(
This has the same signature as build_tf_policy()."""
base = add_mixins(Policy, mixins)
base = add_mixins(EagerTFPolicy, mixins)
if obs_include_prev_action_reward != DEPRECATED_VALUE:
deprecation_warning(old="obs_include_prev_action_reward", error=False)
@ -309,7 +320,7 @@ def build_eager_tf_policy(
if not tf1.executing_eagerly():
tf1.enable_eager_execution()
self.framework = config.get("framework", "tfe")
Policy.__init__(self, observation_space, action_space, config)
EagerTFPolicy.__init__(self, observation_space, action_space, config)
# Global timestep should be a tensor.
self.global_timestep = tf.Variable(0, trainable=False, dtype=tf.int64)
@ -594,7 +605,7 @@ def build_eager_tf_policy(
):
assert tf.executing_eagerly()
# Call super's postprocess_trajectory first.
sample_batch = Policy.postprocess_trajectory(self, sample_batch)
sample_batch = EagerTFPolicy.postprocess_trajectory(self, sample_batch)
if postprocess_fn:
return postprocess_fn(self, sample_batch, other_agent_batches, episode)
return sample_batch
@ -848,7 +859,11 @@ def build_eager_tf_policy(
return actions, state_out, extra_fetches
def _learn_on_batch_helper(self, samples):
# TODO: Figure out, why _ray_trace_ctx=None helps to prevent a crash in
# AlphaStar w/ framework=tf2; eager_tracing=True on the policy learner actors.
# It seems there may be a clash between the traced-by-tf function and the
# traced-by-ray functions (for making the policy class a ray actor).
def _learn_on_batch_helper(self, samples, _ray_trace_ctx=None):
# Increase the tracing counter to make sure we don't re-trace too
# often. If eager_tracing=True, this counter should only get
# incremented during the @tf.function trace operations, never when

View file

@ -737,8 +737,11 @@ class Policy(metaclass=ABCMeta):
"""
# Store the current global time step (sum over all policies' sample
# steps).
# Make sure, we keep global_timestep as a Tensor.
if self.framework in ["tf2", "tfe"]:
# Make sure, we keep global_timestep as a Tensor for tf-eager
# (leads to memory leaks if not doing so).
from ray.rllib.policy.eager_tf_policy import EagerTFPolicy
if self.framework in ["tf2", "tfe"] and isinstance(self, EagerTFPolicy):
self.global_timestep.assign(global_vars["timestep"])
else:
self.global_timestep = global_vars["timestep"]

View file

@ -233,28 +233,31 @@ def get_tf_eager_cls_if_necessary(
"""
cls = orig_cls
framework = config.get("framework", "tf")
if framework in ["tf2", "tf", "tfe"]:
if not tf1:
raise ImportError("Could not import tensorflow!")
if framework in ["tf2", "tfe"]:
assert tf1.executing_eagerly()
from ray.rllib.policy.tf_policy import TFPolicy
if framework in ["tf2", "tf", "tfe"] and not tf1:
raise ImportError("Could not import tensorflow!")
# Create eager-class.
if hasattr(orig_cls, "as_eager"):
cls = orig_cls.as_eager()
if config.get("eager_tracing"):
cls = cls.with_tracing()
# Could be some other type of policy or already
# eager-ized.
elif not issubclass(orig_cls, TFPolicy):
pass
else:
raise ValueError(
"This policy does not support eager "
"execution: {}".format(orig_cls)
)
if framework in ["tf2", "tfe"]:
assert tf1.executing_eagerly()
from ray.rllib.policy.tf_policy import TFPolicy
from ray.rllib.policy.eager_tf_policy import EagerTFPolicy
# Create eager-class (if not already one).
if hasattr(orig_cls, "as_eager") and not issubclass(orig_cls, EagerTFPolicy):
cls = orig_cls.as_eager()
# Could be some other type of policy or already
# eager-ized.
elif not issubclass(orig_cls, TFPolicy):
pass
else:
raise ValueError(
"This policy does not support eager " "execution: {}".format(orig_cls)
)
# Now that we know, policy is an eager one, add tracing, if necessary.
if config.get("eager_tracing") and issubclass(cls, EagerTFPolicy):
cls = cls.with_tracing()
return cls