mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[RLlib] Solve PyTorch/TF-eager A3C async race condition between calling model and its value function. (#13467)
This commit is contained in:
parent
516eb77080
commit
1f00f834ac
7 changed files with 63 additions and 9 deletions
|
@ -53,11 +53,9 @@ def get_policy_class(config):
|
||||||
|
|
||||||
def validate_config(config):
|
def validate_config(config):
|
||||||
if config["entropy_coeff"] < 0:
|
if config["entropy_coeff"] < 0:
|
||||||
raise DeprecationWarning("`entropy_coeff` must be >= 0")
|
raise ValueError("`entropy_coeff` must be >= 0.0!")
|
||||||
if config["sample_async"] and config["framework"] == "torch":
|
if config["num_workers"] <= 0 and config["sample_async"]:
|
||||||
config["sample_async"] = False
|
raise ValueError("`num_workers` for A3C must be >= 1!")
|
||||||
logger.warning("`sample_async=True` is not supported for PyTorch! "
|
|
||||||
"Multithreading can lead to crashes.")
|
|
||||||
|
|
||||||
|
|
||||||
def execution_plan(workers, config):
|
def execution_plan(workers, config):
|
||||||
|
|
|
@ -25,7 +25,6 @@ class TestA2C(unittest.TestCase):
|
||||||
|
|
||||||
# Test against all frameworks.
|
# Test against all frameworks.
|
||||||
for fw in framework_iterator(config):
|
for fw in framework_iterator(config):
|
||||||
config["sample_async"] = fw in ["tf", "tfe", "tf2"]
|
|
||||||
for env in ["PongDeterministic-v0"]:
|
for env in ["PongDeterministic-v0"]:
|
||||||
trainer = a3c.A2CTrainer(config=config, env=env)
|
trainer = a3c.A2CTrainer(config=config, env=env)
|
||||||
for i in range(num_iterations):
|
for i in range(num_iterations):
|
||||||
|
|
|
@ -24,8 +24,7 @@ class TestA3C(unittest.TestCase):
|
||||||
num_iterations = 1
|
num_iterations = 1
|
||||||
|
|
||||||
# Test against all frameworks.
|
# Test against all frameworks.
|
||||||
for fw in framework_iterator(config):
|
for _ in framework_iterator(config):
|
||||||
config["sample_async"] = fw == "tf"
|
|
||||||
for env in ["CartPole-v0", "Pendulum-v0", "PongDeterministic-v0"]:
|
for env in ["CartPole-v0", "Pendulum-v0", "PongDeterministic-v0"]:
|
||||||
print("env={}".format(env))
|
print("env={}".format(env))
|
||||||
trainer = a3c.A3CTrainer(config=config, env=env)
|
trainer = a3c.A3CTrainer(config=config, env=env)
|
||||||
|
|
|
@ -4,6 +4,7 @@ It supports both traced and non-traced eager execution modes."""
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
|
import threading
|
||||||
|
|
||||||
from ray.util.debug import log_once
|
from ray.util.debug import log_once
|
||||||
from ray.rllib.models.catalog import ModelCatalog
|
from ray.rllib.models.catalog import ModelCatalog
|
||||||
|
@ -15,6 +16,7 @@ from ray.rllib.utils import add_mixins, force_list
|
||||||
from ray.rllib.utils.annotations import override
|
from ray.rllib.utils.annotations import override
|
||||||
from ray.rllib.utils.framework import try_import_tf
|
from ray.rllib.utils.framework import try_import_tf
|
||||||
from ray.rllib.utils.tf_ops import convert_to_non_tf_type
|
from ray.rllib.utils.tf_ops import convert_to_non_tf_type
|
||||||
|
from ray.rllib.utils.threading import with_lock
|
||||||
from ray.rllib.utils.tracking_dict import UsageTrackingDict
|
from ray.rllib.utils.tracking_dict import UsageTrackingDict
|
||||||
|
|
||||||
tf1, tf, tfv = try_import_tf()
|
tf1, tf, tfv = try_import_tf()
|
||||||
|
@ -255,6 +257,13 @@ def build_eager_tf_policy(name,
|
||||||
config["model"],
|
config["model"],
|
||||||
framework=self.framework,
|
framework=self.framework,
|
||||||
)
|
)
|
||||||
|
# Lock used for locking some methods on the object-level.
|
||||||
|
# This prevents possible race conditions when calling the model
|
||||||
|
# first, then its value function (e.g. in a loss function), in
|
||||||
|
# between of which another model call is made (e.g. to compute an
|
||||||
|
# action).
|
||||||
|
self._lock = threading.RLock()
|
||||||
|
|
||||||
# Auto-update model's inference view requirements, if recurrent.
|
# Auto-update model's inference view requirements, if recurrent.
|
||||||
self._update_model_view_requirements_from_init_state()
|
self._update_model_view_requirements_from_init_state()
|
||||||
|
|
||||||
|
@ -305,6 +314,7 @@ def build_eager_tf_policy(name,
|
||||||
episode)
|
episode)
|
||||||
return sample_batch
|
return sample_batch
|
||||||
|
|
||||||
|
@with_lock
|
||||||
@override(Policy)
|
@override(Policy)
|
||||||
def learn_on_batch(self, postprocessed_batch):
|
def learn_on_batch(self, postprocessed_batch):
|
||||||
# Callback handling.
|
# Callback handling.
|
||||||
|
@ -351,6 +361,7 @@ def build_eager_tf_policy(name,
|
||||||
grads = [g for g, v in grads_and_vars]
|
grads = [g for g, v in grads_and_vars]
|
||||||
return grads, stats
|
return grads, stats
|
||||||
|
|
||||||
|
@with_lock
|
||||||
@override(Policy)
|
@override(Policy)
|
||||||
@convert_eager_inputs
|
@convert_eager_inputs
|
||||||
@convert_eager_outputs
|
@convert_eager_outputs
|
||||||
|
@ -448,6 +459,7 @@ def build_eager_tf_policy(name,
|
||||||
|
|
||||||
return actions, state_out, extra_fetches
|
return actions, state_out, extra_fetches
|
||||||
|
|
||||||
|
@with_lock
|
||||||
@override(Policy)
|
@override(Policy)
|
||||||
def compute_log_likelihoods(self,
|
def compute_log_likelihoods(self,
|
||||||
actions,
|
actions,
|
||||||
|
@ -593,6 +605,7 @@ def build_eager_tf_policy(name,
|
||||||
self._optimizer.apply_gradients(
|
self._optimizer.apply_gradients(
|
||||||
[(g, v) for g, v in grads_and_vars if g is not None])
|
[(g, v) for g, v in grads_and_vars if g is not None])
|
||||||
|
|
||||||
|
@with_lock
|
||||||
def _compute_gradients(self, samples):
|
def _compute_gradients(self, samples):
|
||||||
"""Computes and returns grads as eager tensors."""
|
"""Computes and returns grads as eager tensors."""
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@ import gym
|
||||||
import logging
|
import logging
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import time
|
import time
|
||||||
|
import threading
|
||||||
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
|
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
from ray.rllib.models.modelv2 import ModelV2
|
from ray.rllib.models.modelv2 import ModelV2
|
||||||
|
@ -15,6 +16,7 @@ from ray.rllib.utils import force_list
|
||||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||||
from ray.rllib.utils.framework import try_import_torch
|
from ray.rllib.utils.framework import try_import_torch
|
||||||
from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule
|
from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule
|
||||||
|
from ray.rllib.utils.threading import with_lock
|
||||||
from ray.rllib.utils.torch_ops import convert_to_non_torch_type, \
|
from ray.rllib.utils.torch_ops import convert_to_non_torch_type, \
|
||||||
convert_to_torch_tensor
|
convert_to_torch_tensor
|
||||||
from ray.rllib.utils.tracking_dict import UsageTrackingDict
|
from ray.rllib.utils.tracking_dict import UsageTrackingDict
|
||||||
|
@ -110,6 +112,14 @@ class TorchPolicy(Policy):
|
||||||
logger.info("TorchPolicy running on CPU.")
|
logger.info("TorchPolicy running on CPU.")
|
||||||
self.device = torch.device("cpu")
|
self.device = torch.device("cpu")
|
||||||
self.model = model.to(self.device)
|
self.model = model.to(self.device)
|
||||||
|
|
||||||
|
# Lock used for locking some methods on the object-level.
|
||||||
|
# This prevents possible race conditions when calling the model
|
||||||
|
# first, then its value function (e.g. in a loss function), in
|
||||||
|
# between of which another model call is made (e.g. to compute an
|
||||||
|
# action).
|
||||||
|
self._lock = threading.RLock()
|
||||||
|
|
||||||
self._state_inputs = self.model.get_initial_state()
|
self._state_inputs = self.model.get_initial_state()
|
||||||
self._is_recurrent = len(self._state_inputs) > 0
|
self._is_recurrent = len(self._state_inputs) > 0
|
||||||
# Auto-update model's inference view requirements, if recurrent.
|
# Auto-update model's inference view requirements, if recurrent.
|
||||||
|
@ -197,6 +207,7 @@ class TorchPolicy(Policy):
|
||||||
return self._compute_action_helper(input_dict, state_batches,
|
return self._compute_action_helper(input_dict, state_batches,
|
||||||
seq_lens, explore, timestep)
|
seq_lens, explore, timestep)
|
||||||
|
|
||||||
|
@with_lock
|
||||||
def _compute_action_helper(self, input_dict, state_batches, seq_lens,
|
def _compute_action_helper(self, input_dict, state_batches, seq_lens,
|
||||||
explore, timestep):
|
explore, timestep):
|
||||||
"""Shared forward pass logic (w/ and w/o trajectory view API).
|
"""Shared forward pass logic (w/ and w/o trajectory view API).
|
||||||
|
@ -206,6 +217,7 @@ class TorchPolicy(Policy):
|
||||||
- actions, state_out, extra_fetches, logp.
|
- actions, state_out, extra_fetches, logp.
|
||||||
"""
|
"""
|
||||||
self._is_recurrent = state_batches is not None and state_batches != []
|
self._is_recurrent = state_batches is not None and state_batches != []
|
||||||
|
|
||||||
# Switch to eval mode.
|
# Switch to eval mode.
|
||||||
if self.model:
|
if self.model:
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
@ -274,6 +286,7 @@ class TorchPolicy(Policy):
|
||||||
|
|
||||||
return convert_to_non_torch_type((actions, state_out, extra_fetches))
|
return convert_to_non_torch_type((actions, state_out, extra_fetches))
|
||||||
|
|
||||||
|
@with_lock
|
||||||
@override(Policy)
|
@override(Policy)
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def compute_log_likelihoods(
|
def compute_log_likelihoods(
|
||||||
|
@ -325,12 +338,15 @@ class TorchPolicy(Policy):
|
||||||
|
|
||||||
action_dist = dist_class(dist_inputs, self.model)
|
action_dist = dist_class(dist_inputs, self.model)
|
||||||
log_likelihoods = action_dist.logp(input_dict[SampleBatch.ACTIONS])
|
log_likelihoods = action_dist.logp(input_dict[SampleBatch.ACTIONS])
|
||||||
|
|
||||||
return log_likelihoods
|
return log_likelihoods
|
||||||
|
|
||||||
|
@with_lock
|
||||||
@override(Policy)
|
@override(Policy)
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def learn_on_batch(
|
def learn_on_batch(
|
||||||
self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]:
|
self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]:
|
||||||
|
|
||||||
# Set Model to train mode.
|
# Set Model to train mode.
|
||||||
if self.model:
|
if self.model:
|
||||||
self.model.train()
|
self.model.train()
|
||||||
|
@ -348,8 +364,10 @@ class TorchPolicy(Policy):
|
||||||
|
|
||||||
if self.model:
|
if self.model:
|
||||||
fetches["model"] = self.model.metrics()
|
fetches["model"] = self.model.metrics()
|
||||||
|
|
||||||
return fetches
|
return fetches
|
||||||
|
|
||||||
|
@with_lock
|
||||||
@override(Policy)
|
@override(Policy)
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def compute_gradients(self,
|
def compute_gradients(self,
|
||||||
|
|
|
@ -92,7 +92,7 @@ class TestExplorations(unittest.TestCase):
|
||||||
do_test_explorations(
|
do_test_explorations(
|
||||||
a3c.A2CTrainer,
|
a3c.A2CTrainer,
|
||||||
"CartPole-v0",
|
"CartPole-v0",
|
||||||
a3c.DEFAULT_CONFIG,
|
a3c.a2c.A2C_DEFAULT_CONFIG,
|
||||||
np.array([0.0, 0.1, 0.0, 0.0]),
|
np.array([0.0, 0.1, 0.0, 0.0]),
|
||||||
prev_a=np.array(1))
|
prev_a=np.array(1))
|
||||||
|
|
||||||
|
|
27
rllib/utils/threading.py
Normal file
27
rllib/utils/threading.py
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
|
||||||
|
def with_lock(func: Callable):
|
||||||
|
"""Use as decorator (@withlock) around object methods that need locking.
|
||||||
|
|
||||||
|
Note: The object must have a self._lock = threading.Lock() property.
|
||||||
|
Locking thus works on the object level (no two locked methods of the same
|
||||||
|
object can be called asynchronously).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func (Callable): The function to decorate/wrap.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable: The wrapped (object-level locked) function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def wrapper(self, *a, **k):
|
||||||
|
try:
|
||||||
|
with self._lock:
|
||||||
|
return func(self, *a, **k)
|
||||||
|
except AttributeError:
|
||||||
|
raise AttributeError(
|
||||||
|
"Object {} must have a `self._lock` property (assigned to a "
|
||||||
|
"threading.Lock() object in its constructor)!".format(self))
|
||||||
|
|
||||||
|
return wrapper
|
Loading…
Add table
Reference in a new issue