mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Fix AlphaStar for tf2+tracing; smaller cleanups around avoiding to wrap a TFPolicy as_eager()
or with_tracing
more than once. (#24271)
This commit is contained in:
parent
576addf9ca
commit
6551922c21
4 changed files with 58 additions and 28 deletions
|
@ -6,6 +6,7 @@ from ray.actor import ActorHandle
|
|||
from ray.rllib.agents.trainer import Trainer
|
||||
from ray.rllib.policy.policy import PolicySpec
|
||||
from ray.rllib.utils.actors import create_colocated_actors
|
||||
from ray.rllib.utils.tf_utils import get_tf_eager_cls_if_necessary
|
||||
from ray.rllib.utils.typing import PolicyID, TrainerConfigDict
|
||||
|
||||
|
||||
|
@ -173,6 +174,10 @@ class _Shard:
|
|||
|
||||
assert len(self.policy_actors) < self.max_num_policies
|
||||
|
||||
actual_policy_class = get_tf_eager_cls_if_necessary(
|
||||
policy_spec.policy_class, cfg
|
||||
)
|
||||
|
||||
colocated = create_colocated_actors(
|
||||
actor_specs=[
|
||||
(
|
||||
|
@ -181,7 +186,7 @@ class _Shard:
|
|||
num_gpus=self.num_gpus_per_policy
|
||||
if not cfg["_fake_gpus"]
|
||||
else 0,
|
||||
)(policy_spec.policy_class),
|
||||
)(actual_policy_class),
|
||||
# Policy c'tor args.
|
||||
(policy_spec.observation_space, policy_spec.action_space, cfg),
|
||||
# Policy c'tor kwargs={}.
|
||||
|
@ -207,6 +212,10 @@ class _Shard:
|
|||
assert self.replay_actor is None
|
||||
assert len(self.policy_actors) == 0
|
||||
|
||||
actual_policy_class = get_tf_eager_cls_if_necessary(
|
||||
policy_spec.policy_class, config
|
||||
)
|
||||
|
||||
colocated = create_colocated_actors(
|
||||
actor_specs=[
|
||||
(self.replay_actor_class, self.replay_actor_args, {}, 1),
|
||||
|
@ -218,7 +227,7 @@ class _Shard:
|
|||
num_gpus=self.num_gpus_per_policy
|
||||
if not config["_fake_gpus"]
|
||||
else 0,
|
||||
)(policy_spec.policy_class),
|
||||
)(actual_policy_class),
|
||||
# Policy c'tor args.
|
||||
(policy_spec.observation_space, policy_spec.action_space, config),
|
||||
# Policy c'tor kwargs={}.
|
||||
|
|
|
@ -129,6 +129,12 @@ def check_too_many_retraces(obj):
|
|||
return _func
|
||||
|
||||
|
||||
class EagerTFPolicy(Policy):
|
||||
"""Dummy class to recognize any eagerized TFPolicy by its inheritance."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def traced_eager_policy(eager_policy_cls):
|
||||
"""Wrapper class that enables tracing for all eager policy methods.
|
||||
|
||||
|
@ -237,6 +243,11 @@ def traced_eager_policy(eager_policy_cls):
|
|||
# `apply_gradients()` (which will call the traced helper).
|
||||
return super(TracedEagerPolicy, self).apply_gradients(grads)
|
||||
|
||||
@classmethod
|
||||
def with_tracing(cls):
|
||||
# Already traced -> Return same class.
|
||||
return cls
|
||||
|
||||
TracedEagerPolicy.__name__ = eager_policy_cls.__name__ + "_traced"
|
||||
TracedEagerPolicy.__qualname__ = eager_policy_cls.__qualname__ + "_traced"
|
||||
return TracedEagerPolicy
|
||||
|
@ -287,7 +298,7 @@ def build_eager_tf_policy(
|
|||
|
||||
This has the same signature as build_tf_policy()."""
|
||||
|
||||
base = add_mixins(Policy, mixins)
|
||||
base = add_mixins(EagerTFPolicy, mixins)
|
||||
|
||||
if obs_include_prev_action_reward != DEPRECATED_VALUE:
|
||||
deprecation_warning(old="obs_include_prev_action_reward", error=False)
|
||||
|
@ -309,7 +320,7 @@ def build_eager_tf_policy(
|
|||
if not tf1.executing_eagerly():
|
||||
tf1.enable_eager_execution()
|
||||
self.framework = config.get("framework", "tfe")
|
||||
Policy.__init__(self, observation_space, action_space, config)
|
||||
EagerTFPolicy.__init__(self, observation_space, action_space, config)
|
||||
|
||||
# Global timestep should be a tensor.
|
||||
self.global_timestep = tf.Variable(0, trainable=False, dtype=tf.int64)
|
||||
|
@ -594,7 +605,7 @@ def build_eager_tf_policy(
|
|||
):
|
||||
assert tf.executing_eagerly()
|
||||
# Call super's postprocess_trajectory first.
|
||||
sample_batch = Policy.postprocess_trajectory(self, sample_batch)
|
||||
sample_batch = EagerTFPolicy.postprocess_trajectory(self, sample_batch)
|
||||
if postprocess_fn:
|
||||
return postprocess_fn(self, sample_batch, other_agent_batches, episode)
|
||||
return sample_batch
|
||||
|
@ -848,7 +859,11 @@ def build_eager_tf_policy(
|
|||
|
||||
return actions, state_out, extra_fetches
|
||||
|
||||
def _learn_on_batch_helper(self, samples):
|
||||
# TODO: Figure out, why _ray_trace_ctx=None helps to prevent a crash in
|
||||
# AlphaStar w/ framework=tf2; eager_tracing=True on the policy learner actors.
|
||||
# It seems there may be a clash between the traced-by-tf function and the
|
||||
# traced-by-ray functions (for making the policy class a ray actor).
|
||||
def _learn_on_batch_helper(self, samples, _ray_trace_ctx=None):
|
||||
# Increase the tracing counter to make sure we don't re-trace too
|
||||
# often. If eager_tracing=True, this counter should only get
|
||||
# incremented during the @tf.function trace operations, never when
|
||||
|
|
|
@ -737,8 +737,11 @@ class Policy(metaclass=ABCMeta):
|
|||
"""
|
||||
# Store the current global time step (sum over all policies' sample
|
||||
# steps).
|
||||
# Make sure, we keep global_timestep as a Tensor.
|
||||
if self.framework in ["tf2", "tfe"]:
|
||||
# Make sure, we keep global_timestep as a Tensor for tf-eager
|
||||
# (leads to memory leaks if not doing so).
|
||||
from ray.rllib.policy.eager_tf_policy import EagerTFPolicy
|
||||
|
||||
if self.framework in ["tf2", "tfe"] and isinstance(self, EagerTFPolicy):
|
||||
self.global_timestep.assign(global_vars["timestep"])
|
||||
else:
|
||||
self.global_timestep = global_vars["timestep"]
|
||||
|
|
|
@ -233,28 +233,31 @@ def get_tf_eager_cls_if_necessary(
|
|||
"""
|
||||
cls = orig_cls
|
||||
framework = config.get("framework", "tf")
|
||||
if framework in ["tf2", "tf", "tfe"]:
|
||||
if not tf1:
|
||||
|
||||
if framework in ["tf2", "tf", "tfe"] and not tf1:
|
||||
raise ImportError("Could not import tensorflow!")
|
||||
|
||||
if framework in ["tf2", "tfe"]:
|
||||
assert tf1.executing_eagerly()
|
||||
|
||||
from ray.rllib.policy.tf_policy import TFPolicy
|
||||
from ray.rllib.policy.eager_tf_policy import EagerTFPolicy
|
||||
|
||||
# Create eager-class.
|
||||
if hasattr(orig_cls, "as_eager"):
|
||||
# Create eager-class (if not already one).
|
||||
if hasattr(orig_cls, "as_eager") and not issubclass(orig_cls, EagerTFPolicy):
|
||||
cls = orig_cls.as_eager()
|
||||
if config.get("eager_tracing"):
|
||||
cls = cls.with_tracing()
|
||||
# Could be some other type of policy or already
|
||||
# eager-ized.
|
||||
elif not issubclass(orig_cls, TFPolicy):
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
"This policy does not support eager "
|
||||
"execution: {}".format(orig_cls)
|
||||
"This policy does not support eager " "execution: {}".format(orig_cls)
|
||||
)
|
||||
|
||||
# Now that we know, policy is an eager one, add tracing, if necessary.
|
||||
if config.get("eager_tracing") and issubclass(cls, EagerTFPolicy):
|
||||
cls = cls.with_tracing()
|
||||
return cls
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue