diff --git a/rllib/agents/alpha_star/distributed_learners.py b/rllib/agents/alpha_star/distributed_learners.py index 217586d96..d30559a2a 100644 --- a/rllib/agents/alpha_star/distributed_learners.py +++ b/rllib/agents/alpha_star/distributed_learners.py @@ -6,6 +6,7 @@ from ray.actor import ActorHandle from ray.rllib.agents.trainer import Trainer from ray.rllib.policy.policy import PolicySpec from ray.rllib.utils.actors import create_colocated_actors +from ray.rllib.utils.tf_utils import get_tf_eager_cls_if_necessary from ray.rllib.utils.typing import PolicyID, TrainerConfigDict @@ -173,6 +174,10 @@ class _Shard: assert len(self.policy_actors) < self.max_num_policies + actual_policy_class = get_tf_eager_cls_if_necessary( + policy_spec.policy_class, cfg + ) + colocated = create_colocated_actors( actor_specs=[ ( @@ -181,7 +186,7 @@ class _Shard: num_gpus=self.num_gpus_per_policy if not cfg["_fake_gpus"] else 0, - )(policy_spec.policy_class), + )(actual_policy_class), # Policy c'tor args. (policy_spec.observation_space, policy_spec.action_space, cfg), # Policy c'tor kwargs={}. @@ -207,6 +212,10 @@ class _Shard: assert self.replay_actor is None assert len(self.policy_actors) == 0 + actual_policy_class = get_tf_eager_cls_if_necessary( + policy_spec.policy_class, config + ) + colocated = create_colocated_actors( actor_specs=[ (self.replay_actor_class, self.replay_actor_args, {}, 1), @@ -218,7 +227,7 @@ class _Shard: num_gpus=self.num_gpus_per_policy if not config["_fake_gpus"] else 0, - )(policy_spec.policy_class), + )(actual_policy_class), # Policy c'tor args. (policy_spec.observation_space, policy_spec.action_space, config), # Policy c'tor kwargs={}. diff --git a/rllib/policy/eager_tf_policy.py b/rllib/policy/eager_tf_policy.py index 1beba5714..c64f937af 100644 --- a/rllib/policy/eager_tf_policy.py +++ b/rllib/policy/eager_tf_policy.py @@ -129,6 +129,12 @@ def check_too_many_retraces(obj): return _func +class EagerTFPolicy(Policy): + """Dummy class to recognize any eagerized TFPolicy by its inheritance.""" + + pass + + def traced_eager_policy(eager_policy_cls): """Wrapper class that enables tracing for all eager policy methods. @@ -237,6 +243,11 @@ def traced_eager_policy(eager_policy_cls): # `apply_gradients()` (which will call the traced helper). return super(TracedEagerPolicy, self).apply_gradients(grads) + @classmethod + def with_tracing(cls): + # Already traced -> Return same class. + return cls + TracedEagerPolicy.__name__ = eager_policy_cls.__name__ + "_traced" TracedEagerPolicy.__qualname__ = eager_policy_cls.__qualname__ + "_traced" return TracedEagerPolicy @@ -287,7 +298,7 @@ def build_eager_tf_policy( This has the same signature as build_tf_policy().""" - base = add_mixins(Policy, mixins) + base = add_mixins(EagerTFPolicy, mixins) if obs_include_prev_action_reward != DEPRECATED_VALUE: deprecation_warning(old="obs_include_prev_action_reward", error=False) @@ -309,7 +320,7 @@ def build_eager_tf_policy( if not tf1.executing_eagerly(): tf1.enable_eager_execution() self.framework = config.get("framework", "tfe") - Policy.__init__(self, observation_space, action_space, config) + EagerTFPolicy.__init__(self, observation_space, action_space, config) # Global timestep should be a tensor. self.global_timestep = tf.Variable(0, trainable=False, dtype=tf.int64) @@ -594,7 +605,7 @@ def build_eager_tf_policy( ): assert tf.executing_eagerly() # Call super's postprocess_trajectory first. - sample_batch = Policy.postprocess_trajectory(self, sample_batch) + sample_batch = EagerTFPolicy.postprocess_trajectory(self, sample_batch) if postprocess_fn: return postprocess_fn(self, sample_batch, other_agent_batches, episode) return sample_batch @@ -848,7 +859,11 @@ def build_eager_tf_policy( return actions, state_out, extra_fetches - def _learn_on_batch_helper(self, samples): + # TODO: Figure out, why _ray_trace_ctx=None helps to prevent a crash in + # AlphaStar w/ framework=tf2; eager_tracing=True on the policy learner actors. + # It seems there may be a clash between the traced-by-tf function and the + # traced-by-ray functions (for making the policy class a ray actor). + def _learn_on_batch_helper(self, samples, _ray_trace_ctx=None): # Increase the tracing counter to make sure we don't re-trace too # often. If eager_tracing=True, this counter should only get # incremented during the @tf.function trace operations, never when diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index b0c6385a9..d5b6d3317 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -737,8 +737,11 @@ class Policy(metaclass=ABCMeta): """ # Store the current global time step (sum over all policies' sample # steps). - # Make sure, we keep global_timestep as a Tensor. - if self.framework in ["tf2", "tfe"]: + # Make sure, we keep global_timestep as a Tensor for tf-eager + # (leads to memory leaks if not doing so). + from ray.rllib.policy.eager_tf_policy import EagerTFPolicy + + if self.framework in ["tf2", "tfe"] and isinstance(self, EagerTFPolicy): self.global_timestep.assign(global_vars["timestep"]) else: self.global_timestep = global_vars["timestep"] diff --git a/rllib/utils/tf_utils.py b/rllib/utils/tf_utils.py index c644f3dc2..0950cd2a0 100644 --- a/rllib/utils/tf_utils.py +++ b/rllib/utils/tf_utils.py @@ -233,28 +233,31 @@ def get_tf_eager_cls_if_necessary( """ cls = orig_cls framework = config.get("framework", "tf") - if framework in ["tf2", "tf", "tfe"]: - if not tf1: - raise ImportError("Could not import tensorflow!") - if framework in ["tf2", "tfe"]: - assert tf1.executing_eagerly() - from ray.rllib.policy.tf_policy import TFPolicy + if framework in ["tf2", "tf", "tfe"] and not tf1: + raise ImportError("Could not import tensorflow!") - # Create eager-class. - if hasattr(orig_cls, "as_eager"): - cls = orig_cls.as_eager() - if config.get("eager_tracing"): - cls = cls.with_tracing() - # Could be some other type of policy or already - # eager-ized. - elif not issubclass(orig_cls, TFPolicy): - pass - else: - raise ValueError( - "This policy does not support eager " - "execution: {}".format(orig_cls) - ) + if framework in ["tf2", "tfe"]: + assert tf1.executing_eagerly() + + from ray.rllib.policy.tf_policy import TFPolicy + from ray.rllib.policy.eager_tf_policy import EagerTFPolicy + + # Create eager-class (if not already one). + if hasattr(orig_cls, "as_eager") and not issubclass(orig_cls, EagerTFPolicy): + cls = orig_cls.as_eager() + # Could be some other type of policy or already + # eager-ized. + elif not issubclass(orig_cls, TFPolicy): + pass + else: + raise ValueError( + "This policy does not support eager " "execution: {}".format(orig_cls) + ) + + # Now that we know, policy is an eager one, add tracing, if necessary. + if config.get("eager_tracing") and issubclass(cls, EagerTFPolicy): + cls = cls.with_tracing() return cls