[RLlib] Issues: 17397, 17425, 16715, 17174. When on driver, Torch|TFPolicy should not use ray.get_gpu_ids() (b/c no GPUs assigned by ray). (#17444)

This commit is contained in:
Sven Mika 2021-08-02 17:29:59 -04:00 committed by GitHub
parent af880378da
commit 8a844ff840
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 265 additions and 57 deletions

View file

@ -86,9 +86,30 @@ In an example below, we train A2C by specifying 8 workers through the config fla
Specifying Resources
~~~~~~~~~~~~~~~~~~~~
You can control the degree of parallelism used by setting the ``num_workers`` hyperparameter for most algorithms. The number of GPUs the driver should use can be set via the ``num_gpus`` option. Similarly, the resource allocation to workers can be controlled via ``num_cpus_per_worker``, ``num_gpus_per_worker``, and ``custom_resources_per_worker``. The number of GPUs can be a fractional quantity to allocate only a fraction of a GPU. For example, with DQN you can pack five trainers onto one GPU by setting ``num_gpus: 0.2``.
You can control the degree of parallelism used by setting the ``num_workers``
hyperparameter for most algorithms. The Trainer will construct that many
"remote worker" instances (`see RolloutWorker class <https://github.com/ray-project/ray/blob/master/rllib/evaluation/rollout_worker.py>`__)
that are constructed as ray.remote actors, plus exactly one "local worker", a ``RolloutWorker`` object that is not a
ray actor, but lives directly inside the Trainer.
For most algorithms, learning updates are performed on the local worker and sample collection from
one or more environments is performed by the remote workers (in parallel).
For example, setting ``num_workers=0`` will only create the local worker, in which case both
sample collection and training will be done by the local worker.
On the other hand, setting ``num_workers=5`` will create the local worker (responsible for training updates)
and 5 remote workers (responsible for sample collection).
For synchronous algorithms like PPO and A2C, the driver and workers can make use of the same GPU. To do this for an amount of ``n`` GPUS:
Since learning is most of the time done on the local worker, it may help to provide one or more GPUs
to that worker via the ``num_gpus`` setting.
Similarly, the resource allocation to remote workers can be controlled via ``num_cpus_per_worker``, ``num_gpus_per_worker``, and ``custom_resources_per_worker``.
The number of GPUs can be fractional quantities (e.g. 0.5) to allocate only a fraction
of a GPU. For example, with DQN you can pack five trainers onto one GPU by setting
``num_gpus: 0.2``. Check out `this fractional GPU example here <https://github.com/ray-project/ray/blob/master/rllib/examples/fractional_gpus.py>`__
as well that also demonstrates how environments (running on the remote workers) that
require a GPU can benefit from the ``num_gpus_per_worker`` setting.
For synchronous algorithms like PPO and A2C, the driver and workers can make use of
the same GPU. To do this for an amount of ``n`` GPUS:
.. code-block:: python
@ -99,6 +120,11 @@ For synchronous algorithms like PPO and A2C, the driver and workers can make use
.. Original image: https://docs.google.com/drawings/d/14QINFvx3grVyJyjAnjggOCEVN-Iq6pYVJ3jA2S6j8z0/edit?usp=sharing
.. image:: rllib-config.svg
If you specify ``num_gpus`` and your machine does not have the required number of GPUs
available, a RuntimeError will be thrown by the respective worker. On the other hand,
if you set ``num_gpus=0``, your policies will be built solely on the CPU, even if
GPUs are available on the machine.
Scaling Guide
~~~~~~~~~~~~~

View file

@ -1554,6 +1554,13 @@ py_test(
srcs = ["tests/test_filters.py"]
)
py_test(
name = "tests/test_gpus",
tags = ["tests_dir", "tests_dir_G"],
size = "large",
srcs = ["tests/test_gpus.py"]
)
#py_test(
# name = "tests/test_ignore_worker_failure",
# tags = ["tests_dir", "tests_dir_I"],

View file

@ -17,6 +17,7 @@ class TestApexDQN(unittest.TestCase):
def test_apex_zero_workers(self):
config = apex.APEX_DEFAULT_CONFIG.copy()
config["num_workers"] = 0
config["num_gpus"] = 0
config["learning_starts"] = 1000
config["prioritized_replay"] = True
config["timesteps_per_iteration"] = 100
@ -31,6 +32,7 @@ class TestApexDQN(unittest.TestCase):
"""Test whether an APEX-DQNTrainer can be built on all frameworks."""
config = apex.APEX_DEFAULT_CONFIG.copy()
config["num_workers"] = 3
config["num_gpus"] = 0
config["learning_starts"] = 1000
config["prioritized_replay"] = True
config["timesteps_per_iteration"] = 100

View file

@ -23,6 +23,7 @@ class TestIMPALA(unittest.TestCase):
def test_impala_compilation(self):
"""Test whether an ImpalaTrainer can be built with both frameworks."""
config = impala.DEFAULT_CONFIG.copy()
config["num_gpus"] = 0
config["model"]["lstm_use_prev_action"] = True
config["model"]["lstm_use_prev_reward"] = True
num_iterations = 1
@ -49,6 +50,7 @@ class TestIMPALA(unittest.TestCase):
def test_impala_lr_schedule(self):
config = impala.DEFAULT_CONFIG.copy()
config["num_gpus"] = 0
# Test whether we correctly ignore the "lr" setting.
# The first lr should be 0.0005.
config["lr"] = 0.1

View file

@ -427,9 +427,9 @@ class TestSAC(unittest.TestCase):
check(
tf_var,
np.transpose(torch_var.detach().cpu()),
atol=0.002)
atol=0.003)
else:
check(tf_var, torch_var, atol=0.002)
check(tf_var, torch_var, atol=0.003)
# And alpha.
check(policy.model.log_alpha,
tf_weights["default_policy/log_alpha"])
@ -444,9 +444,9 @@ class TestSAC(unittest.TestCase):
check(
tf_var,
np.transpose(torch_var.detach().cpu()),
atol=0.002)
atol=0.003)
else:
check(tf_var, torch_var, atol=0.002)
check(tf_var, torch_var, atol=0.003)
trainer.stop()
def _get_batch_helper(self, obs_size, actions, batch_size):

View file

@ -397,7 +397,7 @@ class RolloutWorker(ParallelIteratorWorker):
# Create an env for this worker.
if not (worker_index == 0 and num_workers > 0
and policy_config["create_env_on_driver"] is False):
and not policy_config.get("create_env_on_driver")):
# Run the `env_creator` function passing the EnvContext.
self.env = env_creator(env_context)
if self.env is not None:
@ -498,6 +498,7 @@ class RolloutWorker(ParallelIteratorWorker):
self.env,
spaces=self.spaces,
policy_config=policy_config)
# List of IDs of those policies, which should be trained.
# By default, these are all policies found in the policy_dict.
self.policies_to_train: List[PolicyID] = policies_to_train or list(
@ -547,6 +548,36 @@ class RolloutWorker(ParallelIteratorWorker):
elif tf1 and policy_config.get("framework") == "tfe":
tf1.set_random_seed(seed)
# Check available number of GPUs.
num_gpus = policy_config.get("num_gpus", 0) if \
self.worker_index == 0 else \
policy_config.get("num_gpus_per_worker", 0)
# Error if we don't find enough GPUs.
if ray.is_initialized() and \
ray.worker._mode() != ray.worker.LOCAL_MODE and \
not policy_config.get("_fake_gpus"):
if policy_config.get("framework") in ["tf2", "tf", "tfe"]:
if len(get_tf_gpu_devices()) < num_gpus:
raise RuntimeError(
f"Not enough GPUs found for num_gpus={num_gpus}! "
f"Found only these IDs: {get_tf_gpu_devices()}.")
elif policy_config.get("framework") == "torch":
if torch.cuda.device_count() < num_gpus:
raise RuntimeError(
f"Not enough GPUs found ({torch.cuda.device_count()}) "
f"for num_gpus={num_gpus}!")
# Warn, if running in local-mode and actual GPUs (not faked) are
# requested.
elif ray.is_initialized() and \
ray.worker._mode() == ray.worker.LOCAL_MODE and \
num_gpus > 0 and not policy_config.get("_fake_gpus"):
logger.warning(
"You are running ray with `local_mode=True`, but have "
f"configured {num_gpus} GPUs to be used! In local mode, "
f"Policies are placed on the CPU and the `num_gpus` setting "
f"is ignored.")
self._build_policy_map(
policy_dict,
policy_config,
@ -561,24 +592,6 @@ class RolloutWorker(ParallelIteratorWorker):
if not pol._model_init_state_automatically_added:
pol._update_model_view_requirements_from_init_state()
if (ray.is_initialized()
and ray.worker._mode() != ray.worker.LOCAL_MODE):
# Check available number of GPUs
if not ray.get_gpu_ids():
logger.debug("Creating policy evaluation worker {}".format(
worker_index) +
" on CPU (please ignore any CUDA init errors)")
elif (policy_config["framework"] in ["tf2", "tf", "tfe"] and
not get_tf_gpu_devices()) or \
(policy_config["framework"] == "torch" and
not torch.cuda.is_available()):
raise RuntimeError(
"GPUs were assigned to this worker by Ray, but "
"your DL framework ({}) reports GPU acceleration is "
"disabled. This could be due to a bad CUDA- or {} "
"installation.".format(policy_config["framework"],
policy_config["framework"]))
self.multiagent: bool = set(
self.policy_map.keys()) != {DEFAULT_POLICY_ID}
if self.multiagent and self.env is not None:
@ -1425,7 +1438,7 @@ def _determine_spaces_for_multi_agent_dict(
obs_space = spaces[pid][0]
elif env_obs_space is not None:
obs_space = env_obs_space
elif policy_config and "observation_space" in policy_config:
elif policy_config and policy_config.get("observation_space"):
obs_space = policy_config["observation_space"]
else:
raise ValueError(
@ -1441,7 +1454,7 @@ def _determine_spaces_for_multi_agent_dict(
act_space = spaces[pid][1]
elif env_act_space is not None:
act_space = env_act_space
elif policy_config and "action_space" in policy_config:
elif policy_config and policy_config.get("action_space"):
act_space = policy_config["action_space"]
else:
raise ValueError(

View file

@ -1,6 +1,7 @@
import errno
import gym
import logging
import math
import numpy as np
import os
from typing import Dict, List, Optional, Tuple, Union, TYPE_CHECKING
@ -148,21 +149,46 @@ class TFPolicy(Policy):
# Get devices to build the graph on.
worker_idx = self.config.get("worker_index", 0)
num_gpus = config["num_gpus"] if worker_idx == 0 \
else config["num_gpus_per_worker"]
if not config["_fake_gpus"] and \
ray.worker._mode() == ray.worker.LOCAL_MODE:
num_gpus = 0
elif worker_idx == 0:
num_gpus = config["num_gpus"]
else:
num_gpus = config["num_gpus_per_worker"]
gpu_ids = get_gpu_devices()
# No GPU configured, fake GPUs, or none available.
if config["_fake_gpus"] or num_gpus == 0 or not get_gpu_devices():
# Place on one or more CPU(s) when either:
# - Fake GPU mode.
# - num_gpus=0 (either set by user or we are in local_mode=True).
# - no GPUs available.
if config["_fake_gpus"] or num_gpus == 0 or not gpu_ids:
logger.info("TFPolicy (worker={}) running on {}.".format(
worker_idx
if worker_idx > 0 else "local", f"{num_gpus} fake-GPUs"
if config["_fake_gpus"] else "CPU"))
self.devices = ["/cpu:0" for _ in range(num_gpus or 1)]
# One or more actual GPUs (no fake GPUs).
self.devices = [
"/cpu:0" for _ in range(int(math.ceil(num_gpus)) or 1)
]
# Place on one or more actual GPU(s), when:
# - num_gpus > 0 (set by user) AND
# - local_mode=False AND
# - actual GPUs available AND
# - non-fake GPU mode.
else:
logger.info("TFPolicy (worker={}) running on {} GPU(s).".format(
worker_idx if worker_idx > 0 else "local", num_gpus))
gpu_ids = ray.get_gpu_ids()
# We are a remote worker (WORKER_MODE=1):
# GPUs should be assigned to us by ray.
if ray.worker._mode() == ray.worker.WORKER_MODE:
gpu_ids = ray.get_gpu_ids()
if len(gpu_ids) < num_gpus:
raise ValueError(
"TFPolicy was not able to find enough GPU IDs! Found "
f"{gpu_ids}, but num_gpus={num_gpus}.")
self.devices = [
f"/gpu:{i}" for i, _ in enumerate(gpu_ids) if i < num_gpus
]

View file

@ -2,6 +2,7 @@ import copy
import functools
import gym
import logging
import math
import numpy as np
import os
import time
@ -115,11 +116,6 @@ class TorchPolicy(Policy):
self.framework = "torch"
super().__init__(observation_space, action_space, config)
# Log device and worker index.
from ray.rllib.evaluation.rollout_worker import get_global_worker
worker = get_global_worker()
worker_idx = worker.worker_index if worker else 0
# Create multi-GPU model towers, if necessary.
# - The central main model will be stored under self.model, residing
# on self.device.
@ -133,33 +129,60 @@ class TorchPolicy(Policy):
# parallelization will be done.
# TODO: (sven) implement data pre-loading and n loader buffers for
# torch.
if config["_fake_gpus"] or config["num_gpus"] == 0 or \
not torch.cuda.is_available():
# Get devices to build the graph on.
worker_idx = self.config.get("worker_index", 0)
if not config["_fake_gpus"] and \
ray.worker._mode() == ray.worker.LOCAL_MODE:
num_gpus = 0
elif worker_idx == 0:
num_gpus = config["num_gpus"]
else:
num_gpus = config["num_gpus_per_worker"]
gpu_ids = list(range(torch.cuda.device_count()))
# Place on one or more CPU(s) when either:
# - Fake GPU mode.
# - num_gpus=0 (either set by user or we are in local_mode=True).
# - no GPUs available.
if config["_fake_gpus"] or num_gpus == 0 or not gpu_ids:
logger.info("TorchPolicy (worker={}) running on {}.".format(
worker_idx if worker_idx > 0 else "local",
"{} fake-GPUs".format(config["num_gpus"])
worker_idx
if worker_idx > 0 else "local", "{} fake-GPUs".format(num_gpus)
if config["_fake_gpus"] else "CPU"))
self.device = torch.device("cpu")
self.devices = [
self.device for _ in range(config["num_gpus"] or 1)
self.device for _ in range(int(math.ceil(num_gpus)) or 1)
]
self.model_gpu_towers = [
model if i == 0 else copy.deepcopy(model)
for i in range(config["num_gpus"] or 1)
for i in range(int(math.ceil(num_gpus)) or 1)
]
self.model = model
# Place on one or more actual GPU(s), when:
# - num_gpus > 0 (set by user) AND
# - local_mode=False AND
# - actual GPUs available AND
# - non-fake GPU mode.
else:
logger.info("TorchPolicy (worker={}) running on {} GPU(s).".format(
worker_idx if worker_idx > 0 else "local", config["num_gpus"]))
gpu_ids = ray.get_gpu_ids()
worker_idx if worker_idx > 0 else "local", num_gpus))
# We are a remote worker (WORKER_MODE=1):
# GPUs should be assigned to us by ray.
if ray.worker._mode() == ray.worker.WORKER_MODE:
gpu_ids = ray.get_gpu_ids()
if len(gpu_ids) < num_gpus:
raise ValueError(
"TorchPolicy was not able to find enough GPU IDs! Found "
f"{gpu_ids}, but num_gpus={num_gpus}.")
self.devices = [
torch.device("cuda:{}".format(i))
for i, id_ in enumerate(gpu_ids) if i < config["num_gpus"]
for i, id_ in enumerate(gpu_ids) if i < num_gpus
]
self.device = self.devices[0]
ids = [
id_ for i, id_ in enumerate(gpu_ids) if i < config["num_gpus"]
]
ids = [id_ for i, id_ in enumerate(gpu_ids) if i < num_gpus]
self.model_gpu_towers = []
for i, _ in enumerate(ids):
model_copy = copy.deepcopy(model)

108
rllib/tests/test_gpus.py Normal file
View file

@ -0,0 +1,108 @@
import unittest
import ray
from ray.rllib.agents.pg import PGTrainer, DEFAULT_CONFIG
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.test_utils import framework_iterator
from ray import tune
torch, _ = try_import_torch()
class TestGPUs(unittest.TestCase):
def test_gpus_in_non_local_mode(self):
# Non-local mode.
ray.init(num_cpus=8)
actual_gpus = torch.cuda.device_count()
print(f"Actual GPUs found (by torch): {actual_gpus}")
config = DEFAULT_CONFIG.copy()
config["num_workers"] = 2
config["env"] = "CartPole-v0"
# Expect errors when we run a config w/ num_gpus>0 w/o a GPU
# and _fake_gpus=False.
for num_gpus in [0, 0.1, 1, actual_gpus + 4]:
# Only allow possible num_gpus_per_worker (so test would not
# block infinitely due to a down worker).
per_worker = [0] if actual_gpus == 0 or actual_gpus < num_gpus \
else [0, 0.5, 1]
for num_gpus_per_worker in per_worker:
for fake_gpus in [False] + ([] if num_gpus == 0 else [True]):
config["num_gpus"] = num_gpus
config["num_gpus_per_worker"] = num_gpus_per_worker
config["_fake_gpus"] = fake_gpus
print(f"\n------------\nnum_gpus={num_gpus} "
f"num_gpus_per_worker={num_gpus_per_worker} "
f"_fake_gpus={fake_gpus}")
frameworks = ("tf", "torch") if num_gpus > 1 else \
("tf2", "tf", "torch")
for _ in framework_iterator(config, frameworks=frameworks):
# Expect that trainer creation causes a num_gpu error.
if actual_gpus < num_gpus + 2 * num_gpus_per_worker \
and not fake_gpus:
# "Direct" RLlib (create Trainer on the driver).
# Cannot run through ray.tune.run() as it would
# simply wait infinitely for the resources to
# become available.
print("direct RLlib")
self.assertRaisesRegex(
RuntimeError,
"Not enough GPUs found.+for "
f"num_gpus={num_gpus}",
lambda: PGTrainer(config, env="CartPole-v0"),
)
# If actual_gpus >= num_gpus or faked,
# expect no error.
else:
print("direct RLlib")
trainer = PGTrainer(config, env="CartPole-v0")
trainer.stop()
# Cannot run through ray.tune.run() w/ fake GPUs
# as it would simply wait infinitely for the
# resources to become available (even though, we
# wouldn't really need them).
if num_gpus == 0:
print("via ray.tune.run()")
tune.run(
"PG",
config=config,
stop={"training_iteration": 0})
ray.shutdown()
def test_gpus_in_local_mode(self):
# Local mode.
ray.init(num_gpus=8, local_mode=True)
actual_gpus_available = torch.cuda.device_count()
config = DEFAULT_CONFIG.copy()
config["num_workers"] = 2
config["env"] = "CartPole-v0"
# Expect no errors in local mode.
for num_gpus in [0, 0.1, 1, actual_gpus_available + 4]:
print(f"num_gpus={num_gpus}")
for fake_gpus in [False, True]:
print(f"_fake_gpus={fake_gpus}")
config["num_gpus"] = num_gpus
config["_fake_gpus"] = fake_gpus
frameworks = ("tf", "torch") if num_gpus > 1 else \
("tf2", "tf", "torch")
for _ in framework_iterator(config, frameworks=frameworks):
print("direct RLlib")
trainer = PGTrainer(config, env="CartPole-v0")
trainer.stop()
print("via ray.tune.run()")
tune.run(
"PG", config=config, stop={"training_iteration": 0})
ray.shutdown()
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -130,7 +130,7 @@ class TestExplorations(unittest.TestCase):
do_test_explorations(
impala.ImpalaTrainer,
"CartPole-v0",
impala.DEFAULT_CONFIG,
dict(impala.DEFAULT_CONFIG.copy(), num_gpus=0),
np.array([0.0, 0.1, 0.0, 0.0]),
prev_a=np.array(0))

View file

@ -44,14 +44,15 @@ def get_gpu_devices():
"""
if tfv == 1:
from tensorflow.python.client import device_lib
local_device_protos = device_lib.list_local_devices()
return [x.name for x in local_device_protos if x.device_type == "GPU"]
devices = device_lib.list_local_devices()
else:
try:
gpus = tf.config.list_physical_devices("GPU")
devices = tf.config.list_physical_devices()
except Exception:
gpus = tf.config.experimental.list_physical_devices("GPU")
return gpus
devices = tf.config.experimental.list_physical_devices()
# Expect "GPU", but also stuff like: "XLA_GPU".
return [d.name for d in devices if "GPU" in d.device_type]
def get_placeholder(*, space=None, value=None, name=None, time_axis=False):