mirror of
https://github.com/vale981/ray
synced 2025-03-04 17:41:43 -05:00
Why are these changes needed? Also: Add validation to make sure multi-gpu and micro-batch is not used together. Update A2C learning test to hit the microbatching branch. Minor comment updates.
This commit is contained in:
parent
467430dd55
commit
e6e10ce4cf
6 changed files with 53 additions and 18 deletions
|
@ -17,3 +17,5 @@ a2c-breakoutnoframeskip-v4:
|
|||
[0, 0.0007],
|
||||
[20000000, 0.000000000001],
|
||||
]
|
||||
train_batch_size: 256
|
||||
microbatch_size: 64
|
||||
|
|
10
rllib/BUILD
10
rllib/BUILD
|
@ -94,6 +94,16 @@ py_test(
|
|||
args = ["--yaml-dir=tuned_examples/a2c"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "learning_tests_cartpole_a2c_fake_gpus",
|
||||
main = "tests/run_regression_tests.py",
|
||||
tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete", "fake_gpus"],
|
||||
size = "large",
|
||||
srcs = ["tests/run_regression_tests.py"],
|
||||
data = ["tuned_examples/a2c/cartpole-a2c-fake-gpus.yaml"],
|
||||
args = ["--yaml-dir=tuned_examples/a2c"]
|
||||
)
|
||||
|
||||
# A3C
|
||||
|
||||
# py_test(
|
||||
|
|
|
@ -128,24 +128,35 @@ class A2C(A3C):
|
|||
"Otherwise, microbatches of desired size won't be achievable."
|
||||
)
|
||||
|
||||
if config["num_gpus"] > 1:
|
||||
raise AttributeError(
|
||||
"A2C does not support micro-batching and multiple GPUs "
|
||||
"at the same time."
|
||||
)
|
||||
|
||||
@override(Algorithm)
|
||||
def setup(self, config: PartialAlgorithmConfigDict):
|
||||
super().setup(config)
|
||||
|
||||
# if this variable isn't set (microbatch_size == None) then by default
|
||||
# we should make the number of microbatches that gradients are
|
||||
# computed on 1.
|
||||
if not self.config.get("microbatch_size", None):
|
||||
self.config["microbatch_size"] = self.config["train_batch_size"]
|
||||
|
||||
# Create a microbatch variable for collecting gradients on microbatches'.
|
||||
# These gradients will be accumulated on-the-fly and applied at once (once train
|
||||
# batch size has been collected) to the model.
|
||||
self._microbatches_grads = None
|
||||
self._microbatches_counts = self._num_microbatches = 0
|
||||
if (
|
||||
self.config["_disable_execution_plan_api"] is True
|
||||
and self.config["microbatch_size"]
|
||||
):
|
||||
self._microbatches_grads = None
|
||||
self._microbatches_counts = self._num_microbatches = 0
|
||||
|
||||
@override(A3C)
|
||||
def training_step(self) -> ResultDict:
|
||||
# Fallback to Algorithm.training_step() and A3C policies (loss_fn etc).
|
||||
# W/o microbatching: Identical to Algorithm's default implementation.
|
||||
# Only difference to a default Algorithm being the value function loss term
|
||||
# and its value computations alongside each action.
|
||||
if self.config["microbatch_size"] is None:
|
||||
return Algorithm.training_step(self)
|
||||
|
||||
# In microbatch mode, we want to compute gradients on experience
|
||||
# microbatches, average a number of these microbatches, and then
|
||||
# apply the averaged gradient in one SGD step. This conserves GPU
|
||||
|
|
|
@ -172,8 +172,6 @@ class A3C(Algorithm):
|
|||
raise ValueError("`entropy_coeff` must be >= 0.0!")
|
||||
if config["num_workers"] <= 0 and config["sample_async"]:
|
||||
raise ValueError("`num_workers` for A3C must be >= 1!")
|
||||
if "_fake_gpus" in config:
|
||||
assert not config["_fake_gpus"], "A3C/A2C do not support fake_gpus"
|
||||
|
||||
@override(Algorithm)
|
||||
def get_default_policy_class(self, config: AlgorithmConfigDict) -> Type[Policy]:
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib.algorithms.pg.pg import PG, DEFAULT_CONFIG
|
||||
from ray.rllib.algorithms.a2c.a2c import A2C, A2C_DEFAULT_CONFIG
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.test_utils import framework_iterator
|
||||
from ray import tune
|
||||
|
@ -17,7 +17,7 @@ class TestGPUs(unittest.TestCase):
|
|||
actual_gpus = torch.cuda.device_count()
|
||||
print(f"Actual GPUs found (by torch): {actual_gpus}")
|
||||
|
||||
config = DEFAULT_CONFIG.copy()
|
||||
config = A2C_DEFAULT_CONFIG.copy()
|
||||
config["num_workers"] = 2
|
||||
config["env"] = "CartPole-v0"
|
||||
|
||||
|
@ -58,13 +58,13 @@ class TestGPUs(unittest.TestCase):
|
|||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Found 0 GPUs on your machine",
|
||||
lambda: PG(config, env="CartPole-v0"),
|
||||
lambda: A2C(config, env="CartPole-v0"),
|
||||
)
|
||||
# If actual_gpus >= num_gpus or faked,
|
||||
# expect no error.
|
||||
else:
|
||||
print("direct RLlib")
|
||||
trainer = PG(config, env="CartPole-v0")
|
||||
trainer = A2C(config, env="CartPole-v0")
|
||||
trainer.stop()
|
||||
# Cannot run through ray.tune.run() w/ fake GPUs
|
||||
# as it would simply wait infinitely for the
|
||||
|
@ -73,7 +73,7 @@ class TestGPUs(unittest.TestCase):
|
|||
if num_gpus == 0:
|
||||
print("via ray.tune.run()")
|
||||
tune.run(
|
||||
"PG", config=config, stop={"training_iteration": 0}
|
||||
"A2C", config=config, stop={"training_iteration": 0}
|
||||
)
|
||||
ray.shutdown()
|
||||
|
||||
|
@ -83,7 +83,7 @@ class TestGPUs(unittest.TestCase):
|
|||
|
||||
actual_gpus_available = torch.cuda.device_count()
|
||||
|
||||
config = DEFAULT_CONFIG.copy()
|
||||
config = A2C_DEFAULT_CONFIG.copy()
|
||||
config["num_workers"] = 2
|
||||
config["env"] = "CartPole-v0"
|
||||
|
||||
|
@ -97,10 +97,10 @@ class TestGPUs(unittest.TestCase):
|
|||
frameworks = ("tf", "torch") if num_gpus > 1 else ("tf2", "tf", "torch")
|
||||
for _ in framework_iterator(config, frameworks=frameworks):
|
||||
print("direct RLlib")
|
||||
trainer = PG(config, env="CartPole-v0")
|
||||
trainer = A2C(config, env="CartPole-v0")
|
||||
trainer.stop()
|
||||
print("via ray.tune.run()")
|
||||
tune.run("PG", config=config, stop={"training_iteration": 0})
|
||||
tune.run("A2C", config=config, stop={"training_iteration": 0})
|
||||
|
||||
ray.shutdown()
|
||||
|
||||
|
|
14
rllib/tuned_examples/a2c/cartpole-a2c-fake-gpus.yaml
Normal file
14
rllib/tuned_examples/a2c/cartpole-a2c-fake-gpus.yaml
Normal file
|
@ -0,0 +1,14 @@
|
|||
cartpole-a2c-fake-gpus:
|
||||
env: CartPole-v0
|
||||
run: A2C
|
||||
stop:
|
||||
episode_reward_mean: 150
|
||||
training_iteration: 200
|
||||
config:
|
||||
# Works for both torch and tf.
|
||||
framework: tf
|
||||
num_workers: 0
|
||||
lr: 0.001
|
||||
# Fake 2 GPUs.
|
||||
num_gpus: 2
|
||||
_fake_gpus: true
|
Loading…
Add table
Reference in a new issue