mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Solve PyTorch/TF-eager A3C async race condition between calling model and its value function. (#13467)
This commit is contained in:
parent
516eb77080
commit
1f00f834ac
7 changed files with 63 additions and 9 deletions
|
@ -53,11 +53,9 @@ def get_policy_class(config):
|
|||
|
||||
def validate_config(config):
|
||||
if config["entropy_coeff"] < 0:
|
||||
raise DeprecationWarning("`entropy_coeff` must be >= 0")
|
||||
if config["sample_async"] and config["framework"] == "torch":
|
||||
config["sample_async"] = False
|
||||
logger.warning("`sample_async=True` is not supported for PyTorch! "
|
||||
"Multithreading can lead to crashes.")
|
||||
raise ValueError("`entropy_coeff` must be >= 0.0!")
|
||||
if config["num_workers"] <= 0 and config["sample_async"]:
|
||||
raise ValueError("`num_workers` for A3C must be >= 1!")
|
||||
|
||||
|
||||
def execution_plan(workers, config):
|
||||
|
|
|
@ -25,7 +25,6 @@ class TestA2C(unittest.TestCase):
|
|||
|
||||
# Test against all frameworks.
|
||||
for fw in framework_iterator(config):
|
||||
config["sample_async"] = fw in ["tf", "tfe", "tf2"]
|
||||
for env in ["PongDeterministic-v0"]:
|
||||
trainer = a3c.A2CTrainer(config=config, env=env)
|
||||
for i in range(num_iterations):
|
||||
|
|
|
@ -24,8 +24,7 @@ class TestA3C(unittest.TestCase):
|
|||
num_iterations = 1
|
||||
|
||||
# Test against all frameworks.
|
||||
for fw in framework_iterator(config):
|
||||
config["sample_async"] = fw == "tf"
|
||||
for _ in framework_iterator(config):
|
||||
for env in ["CartPole-v0", "Pendulum-v0", "PongDeterministic-v0"]:
|
||||
print("env={}".format(env))
|
||||
trainer = a3c.A3CTrainer(config=config, env=env)
|
||||
|
|
|
@ -4,6 +4,7 @@ It supports both traced and non-traced eager execution modes."""
|
|||
|
||||
import functools
|
||||
import logging
|
||||
import threading
|
||||
|
||||
from ray.util.debug import log_once
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
|
@ -15,6 +16,7 @@ from ray.rllib.utils import add_mixins, force_list
|
|||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.tf_ops import convert_to_non_tf_type
|
||||
from ray.rllib.utils.threading import with_lock
|
||||
from ray.rllib.utils.tracking_dict import UsageTrackingDict
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
@ -255,6 +257,13 @@ def build_eager_tf_policy(name,
|
|||
config["model"],
|
||||
framework=self.framework,
|
||||
)
|
||||
# Lock used for locking some methods on the object-level.
|
||||
# This prevents possible race conditions when calling the model
|
||||
# first, then its value function (e.g. in a loss function), in
|
||||
# between of which another model call is made (e.g. to compute an
|
||||
# action).
|
||||
self._lock = threading.RLock()
|
||||
|
||||
# Auto-update model's inference view requirements, if recurrent.
|
||||
self._update_model_view_requirements_from_init_state()
|
||||
|
||||
|
@ -305,6 +314,7 @@ def build_eager_tf_policy(name,
|
|||
episode)
|
||||
return sample_batch
|
||||
|
||||
@with_lock
|
||||
@override(Policy)
|
||||
def learn_on_batch(self, postprocessed_batch):
|
||||
# Callback handling.
|
||||
|
@ -351,6 +361,7 @@ def build_eager_tf_policy(name,
|
|||
grads = [g for g, v in grads_and_vars]
|
||||
return grads, stats
|
||||
|
||||
@with_lock
|
||||
@override(Policy)
|
||||
@convert_eager_inputs
|
||||
@convert_eager_outputs
|
||||
|
@ -448,6 +459,7 @@ def build_eager_tf_policy(name,
|
|||
|
||||
return actions, state_out, extra_fetches
|
||||
|
||||
@with_lock
|
||||
@override(Policy)
|
||||
def compute_log_likelihoods(self,
|
||||
actions,
|
||||
|
@ -593,6 +605,7 @@ def build_eager_tf_policy(name,
|
|||
self._optimizer.apply_gradients(
|
||||
[(g, v) for g, v in grads_and_vars if g is not None])
|
||||
|
||||
@with_lock
|
||||
def _compute_gradients(self, samples):
|
||||
"""Computes and returns grads as eager tensors."""
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ import gym
|
|||
import logging
|
||||
import numpy as np
|
||||
import time
|
||||
import threading
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
from ray.rllib.models.modelv2 import ModelV2
|
||||
|
@ -15,6 +16,7 @@ from ray.rllib.utils import force_list
|
|||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule
|
||||
from ray.rllib.utils.threading import with_lock
|
||||
from ray.rllib.utils.torch_ops import convert_to_non_torch_type, \
|
||||
convert_to_torch_tensor
|
||||
from ray.rllib.utils.tracking_dict import UsageTrackingDict
|
||||
|
@ -110,6 +112,14 @@ class TorchPolicy(Policy):
|
|||
logger.info("TorchPolicy running on CPU.")
|
||||
self.device = torch.device("cpu")
|
||||
self.model = model.to(self.device)
|
||||
|
||||
# Lock used for locking some methods on the object-level.
|
||||
# This prevents possible race conditions when calling the model
|
||||
# first, then its value function (e.g. in a loss function), in
|
||||
# between of which another model call is made (e.g. to compute an
|
||||
# action).
|
||||
self._lock = threading.RLock()
|
||||
|
||||
self._state_inputs = self.model.get_initial_state()
|
||||
self._is_recurrent = len(self._state_inputs) > 0
|
||||
# Auto-update model's inference view requirements, if recurrent.
|
||||
|
@ -197,6 +207,7 @@ class TorchPolicy(Policy):
|
|||
return self._compute_action_helper(input_dict, state_batches,
|
||||
seq_lens, explore, timestep)
|
||||
|
||||
@with_lock
|
||||
def _compute_action_helper(self, input_dict, state_batches, seq_lens,
|
||||
explore, timestep):
|
||||
"""Shared forward pass logic (w/ and w/o trajectory view API).
|
||||
|
@ -206,6 +217,7 @@ class TorchPolicy(Policy):
|
|||
- actions, state_out, extra_fetches, logp.
|
||||
"""
|
||||
self._is_recurrent = state_batches is not None and state_batches != []
|
||||
|
||||
# Switch to eval mode.
|
||||
if self.model:
|
||||
self.model.eval()
|
||||
|
@ -274,6 +286,7 @@ class TorchPolicy(Policy):
|
|||
|
||||
return convert_to_non_torch_type((actions, state_out, extra_fetches))
|
||||
|
||||
@with_lock
|
||||
@override(Policy)
|
||||
@DeveloperAPI
|
||||
def compute_log_likelihoods(
|
||||
|
@ -325,12 +338,15 @@ class TorchPolicy(Policy):
|
|||
|
||||
action_dist = dist_class(dist_inputs, self.model)
|
||||
log_likelihoods = action_dist.logp(input_dict[SampleBatch.ACTIONS])
|
||||
|
||||
return log_likelihoods
|
||||
|
||||
@with_lock
|
||||
@override(Policy)
|
||||
@DeveloperAPI
|
||||
def learn_on_batch(
|
||||
self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]:
|
||||
|
||||
# Set Model to train mode.
|
||||
if self.model:
|
||||
self.model.train()
|
||||
|
@ -348,8 +364,10 @@ class TorchPolicy(Policy):
|
|||
|
||||
if self.model:
|
||||
fetches["model"] = self.model.metrics()
|
||||
|
||||
return fetches
|
||||
|
||||
@with_lock
|
||||
@override(Policy)
|
||||
@DeveloperAPI
|
||||
def compute_gradients(self,
|
||||
|
|
|
@ -92,7 +92,7 @@ class TestExplorations(unittest.TestCase):
|
|||
do_test_explorations(
|
||||
a3c.A2CTrainer,
|
||||
"CartPole-v0",
|
||||
a3c.DEFAULT_CONFIG,
|
||||
a3c.a2c.A2C_DEFAULT_CONFIG,
|
||||
np.array([0.0, 0.1, 0.0, 0.0]),
|
||||
prev_a=np.array(1))
|
||||
|
||||
|
|
27
rllib/utils/threading.py
Normal file
27
rllib/utils/threading.py
Normal file
|
@ -0,0 +1,27 @@
|
|||
from typing import Callable
|
||||
|
||||
|
||||
def with_lock(func: Callable):
|
||||
"""Use as decorator (@withlock) around object methods that need locking.
|
||||
|
||||
Note: The object must have a self._lock = threading.Lock() property.
|
||||
Locking thus works on the object level (no two locked methods of the same
|
||||
object can be called asynchronously).
|
||||
|
||||
Args:
|
||||
func (Callable): The function to decorate/wrap.
|
||||
|
||||
Returns:
|
||||
Callable: The wrapped (object-level locked) function.
|
||||
"""
|
||||
|
||||
def wrapper(self, *a, **k):
|
||||
try:
|
||||
with self._lock:
|
||||
return func(self, *a, **k)
|
||||
except AttributeError:
|
||||
raise AttributeError(
|
||||
"Object {} must have a `self._lock` property (assigned to a "
|
||||
"threading.Lock() object in its constructor)!".format(self))
|
||||
|
||||
return wrapper
|
Loading…
Add table
Reference in a new issue