mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] Issues: 17397, 17425, 16715, 17174. When on driver, Torch|TFPolicy should not use ray.get_gpu_ids()
(b/c no GPUs assigned by ray). (#17444)
This commit is contained in:
parent
af880378da
commit
8a844ff840
11 changed files with 265 additions and 57 deletions
|
@ -86,9 +86,30 @@ In an example below, we train A2C by specifying 8 workers through the config fla
|
|||
Specifying Resources
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
You can control the degree of parallelism used by setting the ``num_workers`` hyperparameter for most algorithms. The number of GPUs the driver should use can be set via the ``num_gpus`` option. Similarly, the resource allocation to workers can be controlled via ``num_cpus_per_worker``, ``num_gpus_per_worker``, and ``custom_resources_per_worker``. The number of GPUs can be a fractional quantity to allocate only a fraction of a GPU. For example, with DQN you can pack five trainers onto one GPU by setting ``num_gpus: 0.2``.
|
||||
You can control the degree of parallelism used by setting the ``num_workers``
|
||||
hyperparameter for most algorithms. The Trainer will construct that many
|
||||
"remote worker" instances (`see RolloutWorker class <https://github.com/ray-project/ray/blob/master/rllib/evaluation/rollout_worker.py>`__)
|
||||
that are constructed as ray.remote actors, plus exactly one "local worker", a ``RolloutWorker`` object that is not a
|
||||
ray actor, but lives directly inside the Trainer.
|
||||
For most algorithms, learning updates are performed on the local worker and sample collection from
|
||||
one or more environments is performed by the remote workers (in parallel).
|
||||
For example, setting ``num_workers=0`` will only create the local worker, in which case both
|
||||
sample collection and training will be done by the local worker.
|
||||
On the other hand, setting ``num_workers=5`` will create the local worker (responsible for training updates)
|
||||
and 5 remote workers (responsible for sample collection).
|
||||
|
||||
For synchronous algorithms like PPO and A2C, the driver and workers can make use of the same GPU. To do this for an amount of ``n`` GPUS:
|
||||
Since learning is most of the time done on the local worker, it may help to provide one or more GPUs
|
||||
to that worker via the ``num_gpus`` setting.
|
||||
Similarly, the resource allocation to remote workers can be controlled via ``num_cpus_per_worker``, ``num_gpus_per_worker``, and ``custom_resources_per_worker``.
|
||||
|
||||
The number of GPUs can be fractional quantities (e.g. 0.5) to allocate only a fraction
|
||||
of a GPU. For example, with DQN you can pack five trainers onto one GPU by setting
|
||||
``num_gpus: 0.2``. Check out `this fractional GPU example here <https://github.com/ray-project/ray/blob/master/rllib/examples/fractional_gpus.py>`__
|
||||
as well that also demonstrates how environments (running on the remote workers) that
|
||||
require a GPU can benefit from the ``num_gpus_per_worker`` setting.
|
||||
|
||||
For synchronous algorithms like PPO and A2C, the driver and workers can make use of
|
||||
the same GPU. To do this for an amount of ``n`` GPUS:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -99,6 +120,11 @@ For synchronous algorithms like PPO and A2C, the driver and workers can make use
|
|||
.. Original image: https://docs.google.com/drawings/d/14QINFvx3grVyJyjAnjggOCEVN-Iq6pYVJ3jA2S6j8z0/edit?usp=sharing
|
||||
.. image:: rllib-config.svg
|
||||
|
||||
If you specify ``num_gpus`` and your machine does not have the required number of GPUs
|
||||
available, a RuntimeError will be thrown by the respective worker. On the other hand,
|
||||
if you set ``num_gpus=0``, your policies will be built solely on the CPU, even if
|
||||
GPUs are available on the machine.
|
||||
|
||||
Scaling Guide
|
||||
~~~~~~~~~~~~~
|
||||
|
||||
|
|
|
@ -1554,6 +1554,13 @@ py_test(
|
|||
srcs = ["tests/test_filters.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tests/test_gpus",
|
||||
tags = ["tests_dir", "tests_dir_G"],
|
||||
size = "large",
|
||||
srcs = ["tests/test_gpus.py"]
|
||||
)
|
||||
|
||||
#py_test(
|
||||
# name = "tests/test_ignore_worker_failure",
|
||||
# tags = ["tests_dir", "tests_dir_I"],
|
||||
|
|
|
@ -17,6 +17,7 @@ class TestApexDQN(unittest.TestCase):
|
|||
def test_apex_zero_workers(self):
|
||||
config = apex.APEX_DEFAULT_CONFIG.copy()
|
||||
config["num_workers"] = 0
|
||||
config["num_gpus"] = 0
|
||||
config["learning_starts"] = 1000
|
||||
config["prioritized_replay"] = True
|
||||
config["timesteps_per_iteration"] = 100
|
||||
|
@ -31,6 +32,7 @@ class TestApexDQN(unittest.TestCase):
|
|||
"""Test whether an APEX-DQNTrainer can be built on all frameworks."""
|
||||
config = apex.APEX_DEFAULT_CONFIG.copy()
|
||||
config["num_workers"] = 3
|
||||
config["num_gpus"] = 0
|
||||
config["learning_starts"] = 1000
|
||||
config["prioritized_replay"] = True
|
||||
config["timesteps_per_iteration"] = 100
|
||||
|
|
|
@ -23,6 +23,7 @@ class TestIMPALA(unittest.TestCase):
|
|||
def test_impala_compilation(self):
|
||||
"""Test whether an ImpalaTrainer can be built with both frameworks."""
|
||||
config = impala.DEFAULT_CONFIG.copy()
|
||||
config["num_gpus"] = 0
|
||||
config["model"]["lstm_use_prev_action"] = True
|
||||
config["model"]["lstm_use_prev_reward"] = True
|
||||
num_iterations = 1
|
||||
|
@ -49,6 +50,7 @@ class TestIMPALA(unittest.TestCase):
|
|||
|
||||
def test_impala_lr_schedule(self):
|
||||
config = impala.DEFAULT_CONFIG.copy()
|
||||
config["num_gpus"] = 0
|
||||
# Test whether we correctly ignore the "lr" setting.
|
||||
# The first lr should be 0.0005.
|
||||
config["lr"] = 0.1
|
||||
|
|
|
@ -427,9 +427,9 @@ class TestSAC(unittest.TestCase):
|
|||
check(
|
||||
tf_var,
|
||||
np.transpose(torch_var.detach().cpu()),
|
||||
atol=0.002)
|
||||
atol=0.003)
|
||||
else:
|
||||
check(tf_var, torch_var, atol=0.002)
|
||||
check(tf_var, torch_var, atol=0.003)
|
||||
# And alpha.
|
||||
check(policy.model.log_alpha,
|
||||
tf_weights["default_policy/log_alpha"])
|
||||
|
@ -444,9 +444,9 @@ class TestSAC(unittest.TestCase):
|
|||
check(
|
||||
tf_var,
|
||||
np.transpose(torch_var.detach().cpu()),
|
||||
atol=0.002)
|
||||
atol=0.003)
|
||||
else:
|
||||
check(tf_var, torch_var, atol=0.002)
|
||||
check(tf_var, torch_var, atol=0.003)
|
||||
trainer.stop()
|
||||
|
||||
def _get_batch_helper(self, obs_size, actions, batch_size):
|
||||
|
|
|
@ -397,7 +397,7 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
|
||||
# Create an env for this worker.
|
||||
if not (worker_index == 0 and num_workers > 0
|
||||
and policy_config["create_env_on_driver"] is False):
|
||||
and not policy_config.get("create_env_on_driver")):
|
||||
# Run the `env_creator` function passing the EnvContext.
|
||||
self.env = env_creator(env_context)
|
||||
if self.env is not None:
|
||||
|
@ -498,6 +498,7 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
self.env,
|
||||
spaces=self.spaces,
|
||||
policy_config=policy_config)
|
||||
|
||||
# List of IDs of those policies, which should be trained.
|
||||
# By default, these are all policies found in the policy_dict.
|
||||
self.policies_to_train: List[PolicyID] = policies_to_train or list(
|
||||
|
@ -547,6 +548,36 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
elif tf1 and policy_config.get("framework") == "tfe":
|
||||
tf1.set_random_seed(seed)
|
||||
|
||||
# Check available number of GPUs.
|
||||
num_gpus = policy_config.get("num_gpus", 0) if \
|
||||
self.worker_index == 0 else \
|
||||
policy_config.get("num_gpus_per_worker", 0)
|
||||
# Error if we don't find enough GPUs.
|
||||
if ray.is_initialized() and \
|
||||
ray.worker._mode() != ray.worker.LOCAL_MODE and \
|
||||
not policy_config.get("_fake_gpus"):
|
||||
|
||||
if policy_config.get("framework") in ["tf2", "tf", "tfe"]:
|
||||
if len(get_tf_gpu_devices()) < num_gpus:
|
||||
raise RuntimeError(
|
||||
f"Not enough GPUs found for num_gpus={num_gpus}! "
|
||||
f"Found only these IDs: {get_tf_gpu_devices()}.")
|
||||
elif policy_config.get("framework") == "torch":
|
||||
if torch.cuda.device_count() < num_gpus:
|
||||
raise RuntimeError(
|
||||
f"Not enough GPUs found ({torch.cuda.device_count()}) "
|
||||
f"for num_gpus={num_gpus}!")
|
||||
# Warn, if running in local-mode and actual GPUs (not faked) are
|
||||
# requested.
|
||||
elif ray.is_initialized() and \
|
||||
ray.worker._mode() == ray.worker.LOCAL_MODE and \
|
||||
num_gpus > 0 and not policy_config.get("_fake_gpus"):
|
||||
logger.warning(
|
||||
"You are running ray with `local_mode=True`, but have "
|
||||
f"configured {num_gpus} GPUs to be used! In local mode, "
|
||||
f"Policies are placed on the CPU and the `num_gpus` setting "
|
||||
f"is ignored.")
|
||||
|
||||
self._build_policy_map(
|
||||
policy_dict,
|
||||
policy_config,
|
||||
|
@ -561,24 +592,6 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
if not pol._model_init_state_automatically_added:
|
||||
pol._update_model_view_requirements_from_init_state()
|
||||
|
||||
if (ray.is_initialized()
|
||||
and ray.worker._mode() != ray.worker.LOCAL_MODE):
|
||||
# Check available number of GPUs
|
||||
if not ray.get_gpu_ids():
|
||||
logger.debug("Creating policy evaluation worker {}".format(
|
||||
worker_index) +
|
||||
" on CPU (please ignore any CUDA init errors)")
|
||||
elif (policy_config["framework"] in ["tf2", "tf", "tfe"] and
|
||||
not get_tf_gpu_devices()) or \
|
||||
(policy_config["framework"] == "torch" and
|
||||
not torch.cuda.is_available()):
|
||||
raise RuntimeError(
|
||||
"GPUs were assigned to this worker by Ray, but "
|
||||
"your DL framework ({}) reports GPU acceleration is "
|
||||
"disabled. This could be due to a bad CUDA- or {} "
|
||||
"installation.".format(policy_config["framework"],
|
||||
policy_config["framework"]))
|
||||
|
||||
self.multiagent: bool = set(
|
||||
self.policy_map.keys()) != {DEFAULT_POLICY_ID}
|
||||
if self.multiagent and self.env is not None:
|
||||
|
@ -1425,7 +1438,7 @@ def _determine_spaces_for_multi_agent_dict(
|
|||
obs_space = spaces[pid][0]
|
||||
elif env_obs_space is not None:
|
||||
obs_space = env_obs_space
|
||||
elif policy_config and "observation_space" in policy_config:
|
||||
elif policy_config and policy_config.get("observation_space"):
|
||||
obs_space = policy_config["observation_space"]
|
||||
else:
|
||||
raise ValueError(
|
||||
|
@ -1441,7 +1454,7 @@ def _determine_spaces_for_multi_agent_dict(
|
|||
act_space = spaces[pid][1]
|
||||
elif env_act_space is not None:
|
||||
act_space = env_act_space
|
||||
elif policy_config and "action_space" in policy_config:
|
||||
elif policy_config and policy_config.get("action_space"):
|
||||
act_space = policy_config["action_space"]
|
||||
else:
|
||||
raise ValueError(
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import errno
|
||||
import gym
|
||||
import logging
|
||||
import math
|
||||
import numpy as np
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple, Union, TYPE_CHECKING
|
||||
|
@ -148,21 +149,46 @@ class TFPolicy(Policy):
|
|||
|
||||
# Get devices to build the graph on.
|
||||
worker_idx = self.config.get("worker_index", 0)
|
||||
num_gpus = config["num_gpus"] if worker_idx == 0 \
|
||||
else config["num_gpus_per_worker"]
|
||||
if not config["_fake_gpus"] and \
|
||||
ray.worker._mode() == ray.worker.LOCAL_MODE:
|
||||
num_gpus = 0
|
||||
elif worker_idx == 0:
|
||||
num_gpus = config["num_gpus"]
|
||||
else:
|
||||
num_gpus = config["num_gpus_per_worker"]
|
||||
gpu_ids = get_gpu_devices()
|
||||
|
||||
# No GPU configured, fake GPUs, or none available.
|
||||
if config["_fake_gpus"] or num_gpus == 0 or not get_gpu_devices():
|
||||
# Place on one or more CPU(s) when either:
|
||||
# - Fake GPU mode.
|
||||
# - num_gpus=0 (either set by user or we are in local_mode=True).
|
||||
# - no GPUs available.
|
||||
if config["_fake_gpus"] or num_gpus == 0 or not gpu_ids:
|
||||
logger.info("TFPolicy (worker={}) running on {}.".format(
|
||||
worker_idx
|
||||
if worker_idx > 0 else "local", f"{num_gpus} fake-GPUs"
|
||||
if config["_fake_gpus"] else "CPU"))
|
||||
self.devices = ["/cpu:0" for _ in range(num_gpus or 1)]
|
||||
# One or more actual GPUs (no fake GPUs).
|
||||
self.devices = [
|
||||
"/cpu:0" for _ in range(int(math.ceil(num_gpus)) or 1)
|
||||
]
|
||||
# Place on one or more actual GPU(s), when:
|
||||
# - num_gpus > 0 (set by user) AND
|
||||
# - local_mode=False AND
|
||||
# - actual GPUs available AND
|
||||
# - non-fake GPU mode.
|
||||
else:
|
||||
logger.info("TFPolicy (worker={}) running on {} GPU(s).".format(
|
||||
worker_idx if worker_idx > 0 else "local", num_gpus))
|
||||
gpu_ids = ray.get_gpu_ids()
|
||||
|
||||
# We are a remote worker (WORKER_MODE=1):
|
||||
# GPUs should be assigned to us by ray.
|
||||
if ray.worker._mode() == ray.worker.WORKER_MODE:
|
||||
gpu_ids = ray.get_gpu_ids()
|
||||
|
||||
if len(gpu_ids) < num_gpus:
|
||||
raise ValueError(
|
||||
"TFPolicy was not able to find enough GPU IDs! Found "
|
||||
f"{gpu_ids}, but num_gpus={num_gpus}.")
|
||||
|
||||
self.devices = [
|
||||
f"/gpu:{i}" for i, _ in enumerate(gpu_ids) if i < num_gpus
|
||||
]
|
||||
|
|
|
@ -2,6 +2,7 @@ import copy
|
|||
import functools
|
||||
import gym
|
||||
import logging
|
||||
import math
|
||||
import numpy as np
|
||||
import os
|
||||
import time
|
||||
|
@ -115,11 +116,6 @@ class TorchPolicy(Policy):
|
|||
self.framework = "torch"
|
||||
super().__init__(observation_space, action_space, config)
|
||||
|
||||
# Log device and worker index.
|
||||
from ray.rllib.evaluation.rollout_worker import get_global_worker
|
||||
worker = get_global_worker()
|
||||
worker_idx = worker.worker_index if worker else 0
|
||||
|
||||
# Create multi-GPU model towers, if necessary.
|
||||
# - The central main model will be stored under self.model, residing
|
||||
# on self.device.
|
||||
|
@ -133,33 +129,60 @@ class TorchPolicy(Policy):
|
|||
# parallelization will be done.
|
||||
# TODO: (sven) implement data pre-loading and n loader buffers for
|
||||
# torch.
|
||||
if config["_fake_gpus"] or config["num_gpus"] == 0 or \
|
||||
not torch.cuda.is_available():
|
||||
|
||||
# Get devices to build the graph on.
|
||||
worker_idx = self.config.get("worker_index", 0)
|
||||
if not config["_fake_gpus"] and \
|
||||
ray.worker._mode() == ray.worker.LOCAL_MODE:
|
||||
num_gpus = 0
|
||||
elif worker_idx == 0:
|
||||
num_gpus = config["num_gpus"]
|
||||
else:
|
||||
num_gpus = config["num_gpus_per_worker"]
|
||||
gpu_ids = list(range(torch.cuda.device_count()))
|
||||
|
||||
# Place on one or more CPU(s) when either:
|
||||
# - Fake GPU mode.
|
||||
# - num_gpus=0 (either set by user or we are in local_mode=True).
|
||||
# - no GPUs available.
|
||||
if config["_fake_gpus"] or num_gpus == 0 or not gpu_ids:
|
||||
logger.info("TorchPolicy (worker={}) running on {}.".format(
|
||||
worker_idx if worker_idx > 0 else "local",
|
||||
"{} fake-GPUs".format(config["num_gpus"])
|
||||
worker_idx
|
||||
if worker_idx > 0 else "local", "{} fake-GPUs".format(num_gpus)
|
||||
if config["_fake_gpus"] else "CPU"))
|
||||
self.device = torch.device("cpu")
|
||||
self.devices = [
|
||||
self.device for _ in range(config["num_gpus"] or 1)
|
||||
self.device for _ in range(int(math.ceil(num_gpus)) or 1)
|
||||
]
|
||||
self.model_gpu_towers = [
|
||||
model if i == 0 else copy.deepcopy(model)
|
||||
for i in range(config["num_gpus"] or 1)
|
||||
for i in range(int(math.ceil(num_gpus)) or 1)
|
||||
]
|
||||
self.model = model
|
||||
# Place on one or more actual GPU(s), when:
|
||||
# - num_gpus > 0 (set by user) AND
|
||||
# - local_mode=False AND
|
||||
# - actual GPUs available AND
|
||||
# - non-fake GPU mode.
|
||||
else:
|
||||
logger.info("TorchPolicy (worker={}) running on {} GPU(s).".format(
|
||||
worker_idx if worker_idx > 0 else "local", config["num_gpus"]))
|
||||
gpu_ids = ray.get_gpu_ids()
|
||||
worker_idx if worker_idx > 0 else "local", num_gpus))
|
||||
# We are a remote worker (WORKER_MODE=1):
|
||||
# GPUs should be assigned to us by ray.
|
||||
if ray.worker._mode() == ray.worker.WORKER_MODE:
|
||||
gpu_ids = ray.get_gpu_ids()
|
||||
|
||||
if len(gpu_ids) < num_gpus:
|
||||
raise ValueError(
|
||||
"TorchPolicy was not able to find enough GPU IDs! Found "
|
||||
f"{gpu_ids}, but num_gpus={num_gpus}.")
|
||||
|
||||
self.devices = [
|
||||
torch.device("cuda:{}".format(i))
|
||||
for i, id_ in enumerate(gpu_ids) if i < config["num_gpus"]
|
||||
for i, id_ in enumerate(gpu_ids) if i < num_gpus
|
||||
]
|
||||
self.device = self.devices[0]
|
||||
ids = [
|
||||
id_ for i, id_ in enumerate(gpu_ids) if i < config["num_gpus"]
|
||||
]
|
||||
ids = [id_ for i, id_ in enumerate(gpu_ids) if i < num_gpus]
|
||||
self.model_gpu_towers = []
|
||||
for i, _ in enumerate(ids):
|
||||
model_copy = copy.deepcopy(model)
|
||||
|
|
108
rllib/tests/test_gpus.py
Normal file
108
rllib/tests/test_gpus.py
Normal file
|
@ -0,0 +1,108 @@
|
|||
import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.pg import PGTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.test_utils import framework_iterator
|
||||
from ray import tune
|
||||
|
||||
torch, _ = try_import_torch()
|
||||
|
||||
|
||||
class TestGPUs(unittest.TestCase):
|
||||
def test_gpus_in_non_local_mode(self):
|
||||
# Non-local mode.
|
||||
ray.init(num_cpus=8)
|
||||
|
||||
actual_gpus = torch.cuda.device_count()
|
||||
print(f"Actual GPUs found (by torch): {actual_gpus}")
|
||||
|
||||
config = DEFAULT_CONFIG.copy()
|
||||
config["num_workers"] = 2
|
||||
config["env"] = "CartPole-v0"
|
||||
|
||||
# Expect errors when we run a config w/ num_gpus>0 w/o a GPU
|
||||
# and _fake_gpus=False.
|
||||
for num_gpus in [0, 0.1, 1, actual_gpus + 4]:
|
||||
# Only allow possible num_gpus_per_worker (so test would not
|
||||
# block infinitely due to a down worker).
|
||||
per_worker = [0] if actual_gpus == 0 or actual_gpus < num_gpus \
|
||||
else [0, 0.5, 1]
|
||||
for num_gpus_per_worker in per_worker:
|
||||
for fake_gpus in [False] + ([] if num_gpus == 0 else [True]):
|
||||
config["num_gpus"] = num_gpus
|
||||
config["num_gpus_per_worker"] = num_gpus_per_worker
|
||||
config["_fake_gpus"] = fake_gpus
|
||||
|
||||
print(f"\n------------\nnum_gpus={num_gpus} "
|
||||
f"num_gpus_per_worker={num_gpus_per_worker} "
|
||||
f"_fake_gpus={fake_gpus}")
|
||||
|
||||
frameworks = ("tf", "torch") if num_gpus > 1 else \
|
||||
("tf2", "tf", "torch")
|
||||
for _ in framework_iterator(config, frameworks=frameworks):
|
||||
# Expect that trainer creation causes a num_gpu error.
|
||||
if actual_gpus < num_gpus + 2 * num_gpus_per_worker \
|
||||
and not fake_gpus:
|
||||
# "Direct" RLlib (create Trainer on the driver).
|
||||
# Cannot run through ray.tune.run() as it would
|
||||
# simply wait infinitely for the resources to
|
||||
# become available.
|
||||
print("direct RLlib")
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Not enough GPUs found.+for "
|
||||
f"num_gpus={num_gpus}",
|
||||
lambda: PGTrainer(config, env="CartPole-v0"),
|
||||
)
|
||||
# If actual_gpus >= num_gpus or faked,
|
||||
# expect no error.
|
||||
else:
|
||||
print("direct RLlib")
|
||||
trainer = PGTrainer(config, env="CartPole-v0")
|
||||
trainer.stop()
|
||||
# Cannot run through ray.tune.run() w/ fake GPUs
|
||||
# as it would simply wait infinitely for the
|
||||
# resources to become available (even though, we
|
||||
# wouldn't really need them).
|
||||
if num_gpus == 0:
|
||||
print("via ray.tune.run()")
|
||||
tune.run(
|
||||
"PG",
|
||||
config=config,
|
||||
stop={"training_iteration": 0})
|
||||
ray.shutdown()
|
||||
|
||||
def test_gpus_in_local_mode(self):
|
||||
# Local mode.
|
||||
ray.init(num_gpus=8, local_mode=True)
|
||||
|
||||
actual_gpus_available = torch.cuda.device_count()
|
||||
|
||||
config = DEFAULT_CONFIG.copy()
|
||||
config["num_workers"] = 2
|
||||
config["env"] = "CartPole-v0"
|
||||
|
||||
# Expect no errors in local mode.
|
||||
for num_gpus in [0, 0.1, 1, actual_gpus_available + 4]:
|
||||
print(f"num_gpus={num_gpus}")
|
||||
for fake_gpus in [False, True]:
|
||||
print(f"_fake_gpus={fake_gpus}")
|
||||
config["num_gpus"] = num_gpus
|
||||
config["_fake_gpus"] = fake_gpus
|
||||
frameworks = ("tf", "torch") if num_gpus > 1 else \
|
||||
("tf2", "tf", "torch")
|
||||
for _ in framework_iterator(config, frameworks=frameworks):
|
||||
print("direct RLlib")
|
||||
trainer = PGTrainer(config, env="CartPole-v0")
|
||||
trainer.stop()
|
||||
print("via ray.tune.run()")
|
||||
tune.run(
|
||||
"PG", config=config, stop={"training_iteration": 0})
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
|
@ -130,7 +130,7 @@ class TestExplorations(unittest.TestCase):
|
|||
do_test_explorations(
|
||||
impala.ImpalaTrainer,
|
||||
"CartPole-v0",
|
||||
impala.DEFAULT_CONFIG,
|
||||
dict(impala.DEFAULT_CONFIG.copy(), num_gpus=0),
|
||||
np.array([0.0, 0.1, 0.0, 0.0]),
|
||||
prev_a=np.array(0))
|
||||
|
||||
|
|
|
@ -44,14 +44,15 @@ def get_gpu_devices():
|
|||
"""
|
||||
if tfv == 1:
|
||||
from tensorflow.python.client import device_lib
|
||||
local_device_protos = device_lib.list_local_devices()
|
||||
return [x.name for x in local_device_protos if x.device_type == "GPU"]
|
||||
devices = device_lib.list_local_devices()
|
||||
else:
|
||||
try:
|
||||
gpus = tf.config.list_physical_devices("GPU")
|
||||
devices = tf.config.list_physical_devices()
|
||||
except Exception:
|
||||
gpus = tf.config.experimental.list_physical_devices("GPU")
|
||||
return gpus
|
||||
devices = tf.config.experimental.list_physical_devices()
|
||||
|
||||
# Expect "GPU", but also stuff like: "XLA_GPU".
|
||||
return [d.name for d in devices if "GPU" in d.device_type]
|
||||
|
||||
|
||||
def get_placeholder(*, space=None, value=None, name=None, time_axis=False):
|
||||
|
|
Loading…
Add table
Reference in a new issue