[RLlib] Solve PyTorch/TF-eager A3C async race condition between calling model and its value function. (#13467)

This commit is contained in:
Sven Mika 2021-01-18 19:29:03 +01:00 committed by GitHub
parent 516eb77080
commit 1f00f834ac
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 63 additions and 9 deletions

View file

@ -53,11 +53,9 @@ def get_policy_class(config):
def validate_config(config): def validate_config(config):
if config["entropy_coeff"] < 0: if config["entropy_coeff"] < 0:
raise DeprecationWarning("`entropy_coeff` must be >= 0") raise ValueError("`entropy_coeff` must be >= 0.0!")
if config["sample_async"] and config["framework"] == "torch": if config["num_workers"] <= 0 and config["sample_async"]:
config["sample_async"] = False raise ValueError("`num_workers` for A3C must be >= 1!")
logger.warning("`sample_async=True` is not supported for PyTorch! "
"Multithreading can lead to crashes.")
def execution_plan(workers, config): def execution_plan(workers, config):

View file

@ -25,7 +25,6 @@ class TestA2C(unittest.TestCase):
# Test against all frameworks. # Test against all frameworks.
for fw in framework_iterator(config): for fw in framework_iterator(config):
config["sample_async"] = fw in ["tf", "tfe", "tf2"]
for env in ["PongDeterministic-v0"]: for env in ["PongDeterministic-v0"]:
trainer = a3c.A2CTrainer(config=config, env=env) trainer = a3c.A2CTrainer(config=config, env=env)
for i in range(num_iterations): for i in range(num_iterations):

View file

@ -24,8 +24,7 @@ class TestA3C(unittest.TestCase):
num_iterations = 1 num_iterations = 1
# Test against all frameworks. # Test against all frameworks.
for fw in framework_iterator(config): for _ in framework_iterator(config):
config["sample_async"] = fw == "tf"
for env in ["CartPole-v0", "Pendulum-v0", "PongDeterministic-v0"]: for env in ["CartPole-v0", "Pendulum-v0", "PongDeterministic-v0"]:
print("env={}".format(env)) print("env={}".format(env))
trainer = a3c.A3CTrainer(config=config, env=env) trainer = a3c.A3CTrainer(config=config, env=env)

View file

@ -4,6 +4,7 @@ It supports both traced and non-traced eager execution modes."""
import functools import functools
import logging import logging
import threading
from ray.util.debug import log_once from ray.util.debug import log_once
from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.catalog import ModelCatalog
@ -15,6 +16,7 @@ from ray.rllib.utils import add_mixins, force_list
from ray.rllib.utils.annotations import override from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.tf_ops import convert_to_non_tf_type from ray.rllib.utils.tf_ops import convert_to_non_tf_type
from ray.rllib.utils.threading import with_lock
from ray.rllib.utils.tracking_dict import UsageTrackingDict from ray.rllib.utils.tracking_dict import UsageTrackingDict
tf1, tf, tfv = try_import_tf() tf1, tf, tfv = try_import_tf()
@ -255,6 +257,13 @@ def build_eager_tf_policy(name,
config["model"], config["model"],
framework=self.framework, framework=self.framework,
) )
# Lock used for locking some methods on the object-level.
# This prevents possible race conditions when calling the model
# first, then its value function (e.g. in a loss function), in
# between of which another model call is made (e.g. to compute an
# action).
self._lock = threading.RLock()
# Auto-update model's inference view requirements, if recurrent. # Auto-update model's inference view requirements, if recurrent.
self._update_model_view_requirements_from_init_state() self._update_model_view_requirements_from_init_state()
@ -305,6 +314,7 @@ def build_eager_tf_policy(name,
episode) episode)
return sample_batch return sample_batch
@with_lock
@override(Policy) @override(Policy)
def learn_on_batch(self, postprocessed_batch): def learn_on_batch(self, postprocessed_batch):
# Callback handling. # Callback handling.
@ -351,6 +361,7 @@ def build_eager_tf_policy(name,
grads = [g for g, v in grads_and_vars] grads = [g for g, v in grads_and_vars]
return grads, stats return grads, stats
@with_lock
@override(Policy) @override(Policy)
@convert_eager_inputs @convert_eager_inputs
@convert_eager_outputs @convert_eager_outputs
@ -448,6 +459,7 @@ def build_eager_tf_policy(name,
return actions, state_out, extra_fetches return actions, state_out, extra_fetches
@with_lock
@override(Policy) @override(Policy)
def compute_log_likelihoods(self, def compute_log_likelihoods(self,
actions, actions,
@ -593,6 +605,7 @@ def build_eager_tf_policy(name,
self._optimizer.apply_gradients( self._optimizer.apply_gradients(
[(g, v) for g, v in grads_and_vars if g is not None]) [(g, v) for g, v in grads_and_vars if g is not None])
@with_lock
def _compute_gradients(self, samples): def _compute_gradients(self, samples):
"""Computes and returns grads as eager tensors.""" """Computes and returns grads as eager tensors."""

View file

@ -3,6 +3,7 @@ import gym
import logging import logging
import numpy as np import numpy as np
import time import time
import threading
from typing import Callable, Dict, List, Optional, Tuple, Type, Union from typing import Callable, Dict, List, Optional, Tuple, Type, Union
from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.modelv2 import ModelV2
@ -15,6 +16,7 @@ from ray.rllib.utils import force_list
from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule
from ray.rllib.utils.threading import with_lock
from ray.rllib.utils.torch_ops import convert_to_non_torch_type, \ from ray.rllib.utils.torch_ops import convert_to_non_torch_type, \
convert_to_torch_tensor convert_to_torch_tensor
from ray.rllib.utils.tracking_dict import UsageTrackingDict from ray.rllib.utils.tracking_dict import UsageTrackingDict
@ -110,6 +112,14 @@ class TorchPolicy(Policy):
logger.info("TorchPolicy running on CPU.") logger.info("TorchPolicy running on CPU.")
self.device = torch.device("cpu") self.device = torch.device("cpu")
self.model = model.to(self.device) self.model = model.to(self.device)
# Lock used for locking some methods on the object-level.
# This prevents possible race conditions when calling the model
# first, then its value function (e.g. in a loss function), in
# between of which another model call is made (e.g. to compute an
# action).
self._lock = threading.RLock()
self._state_inputs = self.model.get_initial_state() self._state_inputs = self.model.get_initial_state()
self._is_recurrent = len(self._state_inputs) > 0 self._is_recurrent = len(self._state_inputs) > 0
# Auto-update model's inference view requirements, if recurrent. # Auto-update model's inference view requirements, if recurrent.
@ -197,6 +207,7 @@ class TorchPolicy(Policy):
return self._compute_action_helper(input_dict, state_batches, return self._compute_action_helper(input_dict, state_batches,
seq_lens, explore, timestep) seq_lens, explore, timestep)
@with_lock
def _compute_action_helper(self, input_dict, state_batches, seq_lens, def _compute_action_helper(self, input_dict, state_batches, seq_lens,
explore, timestep): explore, timestep):
"""Shared forward pass logic (w/ and w/o trajectory view API). """Shared forward pass logic (w/ and w/o trajectory view API).
@ -206,6 +217,7 @@ class TorchPolicy(Policy):
- actions, state_out, extra_fetches, logp. - actions, state_out, extra_fetches, logp.
""" """
self._is_recurrent = state_batches is not None and state_batches != [] self._is_recurrent = state_batches is not None and state_batches != []
# Switch to eval mode. # Switch to eval mode.
if self.model: if self.model:
self.model.eval() self.model.eval()
@ -274,6 +286,7 @@ class TorchPolicy(Policy):
return convert_to_non_torch_type((actions, state_out, extra_fetches)) return convert_to_non_torch_type((actions, state_out, extra_fetches))
@with_lock
@override(Policy) @override(Policy)
@DeveloperAPI @DeveloperAPI
def compute_log_likelihoods( def compute_log_likelihoods(
@ -325,12 +338,15 @@ class TorchPolicy(Policy):
action_dist = dist_class(dist_inputs, self.model) action_dist = dist_class(dist_inputs, self.model)
log_likelihoods = action_dist.logp(input_dict[SampleBatch.ACTIONS]) log_likelihoods = action_dist.logp(input_dict[SampleBatch.ACTIONS])
return log_likelihoods return log_likelihoods
@with_lock
@override(Policy) @override(Policy)
@DeveloperAPI @DeveloperAPI
def learn_on_batch( def learn_on_batch(
self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]: self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]:
# Set Model to train mode. # Set Model to train mode.
if self.model: if self.model:
self.model.train() self.model.train()
@ -348,8 +364,10 @@ class TorchPolicy(Policy):
if self.model: if self.model:
fetches["model"] = self.model.metrics() fetches["model"] = self.model.metrics()
return fetches return fetches
@with_lock
@override(Policy) @override(Policy)
@DeveloperAPI @DeveloperAPI
def compute_gradients(self, def compute_gradients(self,

View file

@ -92,7 +92,7 @@ class TestExplorations(unittest.TestCase):
do_test_explorations( do_test_explorations(
a3c.A2CTrainer, a3c.A2CTrainer,
"CartPole-v0", "CartPole-v0",
a3c.DEFAULT_CONFIG, a3c.a2c.A2C_DEFAULT_CONFIG,
np.array([0.0, 0.1, 0.0, 0.0]), np.array([0.0, 0.1, 0.0, 0.0]),
prev_a=np.array(1)) prev_a=np.array(1))

27
rllib/utils/threading.py Normal file
View file

@ -0,0 +1,27 @@
from typing import Callable
def with_lock(func: Callable):
"""Use as decorator (@withlock) around object methods that need locking.
Note: The object must have a self._lock = threading.Lock() property.
Locking thus works on the object level (no two locked methods of the same
object can be called asynchronously).
Args:
func (Callable): The function to decorate/wrap.
Returns:
Callable: The wrapped (object-level locked) function.
"""
def wrapper(self, *a, **k):
try:
with self._lock:
return func(self, *a, **k)
except AttributeError:
raise AttributeError(
"Object {} must have a `self._lock` property (assigned to a "
"threading.Lock() object in its constructor)!".format(self))
return wrapper