[RLlib] Solve PyTorch/TF-eager A3C async race condition between calling model and its value function. (#13467)

This commit is contained in:
Sven Mika 2021-01-18 19:29:03 +01:00 committed by GitHub
parent 516eb77080
commit 1f00f834ac
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 63 additions and 9 deletions

View file

@ -53,11 +53,9 @@ def get_policy_class(config):
def validate_config(config):
if config["entropy_coeff"] < 0:
raise DeprecationWarning("`entropy_coeff` must be >= 0")
if config["sample_async"] and config["framework"] == "torch":
config["sample_async"] = False
logger.warning("`sample_async=True` is not supported for PyTorch! "
"Multithreading can lead to crashes.")
raise ValueError("`entropy_coeff` must be >= 0.0!")
if config["num_workers"] <= 0 and config["sample_async"]:
raise ValueError("`num_workers` for A3C must be >= 1!")
def execution_plan(workers, config):

View file

@ -25,7 +25,6 @@ class TestA2C(unittest.TestCase):
# Test against all frameworks.
for fw in framework_iterator(config):
config["sample_async"] = fw in ["tf", "tfe", "tf2"]
for env in ["PongDeterministic-v0"]:
trainer = a3c.A2CTrainer(config=config, env=env)
for i in range(num_iterations):

View file

@ -24,8 +24,7 @@ class TestA3C(unittest.TestCase):
num_iterations = 1
# Test against all frameworks.
for fw in framework_iterator(config):
config["sample_async"] = fw == "tf"
for _ in framework_iterator(config):
for env in ["CartPole-v0", "Pendulum-v0", "PongDeterministic-v0"]:
print("env={}".format(env))
trainer = a3c.A3CTrainer(config=config, env=env)

View file

@ -4,6 +4,7 @@ It supports both traced and non-traced eager execution modes."""
import functools
import logging
import threading
from ray.util.debug import log_once
from ray.rllib.models.catalog import ModelCatalog
@ -15,6 +16,7 @@ from ray.rllib.utils import add_mixins, force_list
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.tf_ops import convert_to_non_tf_type
from ray.rllib.utils.threading import with_lock
from ray.rllib.utils.tracking_dict import UsageTrackingDict
tf1, tf, tfv = try_import_tf()
@ -255,6 +257,13 @@ def build_eager_tf_policy(name,
config["model"],
framework=self.framework,
)
# Lock used for locking some methods on the object-level.
# This prevents possible race conditions when calling the model
# first, then its value function (e.g. in a loss function), in
# between of which another model call is made (e.g. to compute an
# action).
self._lock = threading.RLock()
# Auto-update model's inference view requirements, if recurrent.
self._update_model_view_requirements_from_init_state()
@ -305,6 +314,7 @@ def build_eager_tf_policy(name,
episode)
return sample_batch
@with_lock
@override(Policy)
def learn_on_batch(self, postprocessed_batch):
# Callback handling.
@ -351,6 +361,7 @@ def build_eager_tf_policy(name,
grads = [g for g, v in grads_and_vars]
return grads, stats
@with_lock
@override(Policy)
@convert_eager_inputs
@convert_eager_outputs
@ -448,6 +459,7 @@ def build_eager_tf_policy(name,
return actions, state_out, extra_fetches
@with_lock
@override(Policy)
def compute_log_likelihoods(self,
actions,
@ -593,6 +605,7 @@ def build_eager_tf_policy(name,
self._optimizer.apply_gradients(
[(g, v) for g, v in grads_and_vars if g is not None])
@with_lock
def _compute_gradients(self, samples):
"""Computes and returns grads as eager tensors."""

View file

@ -3,6 +3,7 @@ import gym
import logging
import numpy as np
import time
import threading
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
from ray.rllib.models.modelv2 import ModelV2
@ -15,6 +16,7 @@ from ray.rllib.utils import force_list
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule
from ray.rllib.utils.threading import with_lock
from ray.rllib.utils.torch_ops import convert_to_non_torch_type, \
convert_to_torch_tensor
from ray.rllib.utils.tracking_dict import UsageTrackingDict
@ -110,6 +112,14 @@ class TorchPolicy(Policy):
logger.info("TorchPolicy running on CPU.")
self.device = torch.device("cpu")
self.model = model.to(self.device)
# Lock used for locking some methods on the object-level.
# This prevents possible race conditions when calling the model
# first, then its value function (e.g. in a loss function), in
# between of which another model call is made (e.g. to compute an
# action).
self._lock = threading.RLock()
self._state_inputs = self.model.get_initial_state()
self._is_recurrent = len(self._state_inputs) > 0
# Auto-update model's inference view requirements, if recurrent.
@ -197,6 +207,7 @@ class TorchPolicy(Policy):
return self._compute_action_helper(input_dict, state_batches,
seq_lens, explore, timestep)
@with_lock
def _compute_action_helper(self, input_dict, state_batches, seq_lens,
explore, timestep):
"""Shared forward pass logic (w/ and w/o trajectory view API).
@ -206,6 +217,7 @@ class TorchPolicy(Policy):
- actions, state_out, extra_fetches, logp.
"""
self._is_recurrent = state_batches is not None and state_batches != []
# Switch to eval mode.
if self.model:
self.model.eval()
@ -274,6 +286,7 @@ class TorchPolicy(Policy):
return convert_to_non_torch_type((actions, state_out, extra_fetches))
@with_lock
@override(Policy)
@DeveloperAPI
def compute_log_likelihoods(
@ -325,12 +338,15 @@ class TorchPolicy(Policy):
action_dist = dist_class(dist_inputs, self.model)
log_likelihoods = action_dist.logp(input_dict[SampleBatch.ACTIONS])
return log_likelihoods
@with_lock
@override(Policy)
@DeveloperAPI
def learn_on_batch(
self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]:
# Set Model to train mode.
if self.model:
self.model.train()
@ -348,8 +364,10 @@ class TorchPolicy(Policy):
if self.model:
fetches["model"] = self.model.metrics()
return fetches
@with_lock
@override(Policy)
@DeveloperAPI
def compute_gradients(self,

View file

@ -92,7 +92,7 @@ class TestExplorations(unittest.TestCase):
do_test_explorations(
a3c.A2CTrainer,
"CartPole-v0",
a3c.DEFAULT_CONFIG,
a3c.a2c.A2C_DEFAULT_CONFIG,
np.array([0.0, 0.1, 0.0, 0.0]),
prev_a=np.array(1))

27
rllib/utils/threading.py Normal file
View file

@ -0,0 +1,27 @@
from typing import Callable
def with_lock(func: Callable):
"""Use as decorator (@withlock) around object methods that need locking.
Note: The object must have a self._lock = threading.Lock() property.
Locking thus works on the object level (no two locked methods of the same
object can be called asynchronously).
Args:
func (Callable): The function to decorate/wrap.
Returns:
Callable: The wrapped (object-level locked) function.
"""
def wrapper(self, *a, **k):
try:
with self._lock:
return func(self, *a, **k)
except AttributeError:
raise AttributeError(
"Object {} must have a `self._lock` property (assigned to a "
"threading.Lock() object in its constructor)!".format(self))
return wrapper