mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[RLlib] More Trainer -> Algorithm renaming cleanups. (#25869)
This commit is contained in:
parent
e13cc4088a
commit
96693055bd
39 changed files with 166 additions and 166 deletions
|
@ -123,24 +123,24 @@
|
||||||
--test_env=RAY_USE_MULTIPROCESSING_CPU_COUNT=1
|
--test_env=RAY_USE_MULTIPROCESSING_CPU_COUNT=1
|
||||||
rllib/...
|
rllib/...
|
||||||
|
|
||||||
- label: ":brain: RLlib: Trainer Tests (generic)"
|
- label: ":brain: RLlib: Algorithm Tests (generic)"
|
||||||
conditions: ["RAY_CI_RLLIB_DIRECTLY_AFFECTED"]
|
conditions: ["RAY_CI_RLLIB_DIRECTLY_AFFECTED"]
|
||||||
commands:
|
commands:
|
||||||
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT
|
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT
|
||||||
- RLLIB_TESTING=1 PYTHON=3.7 ./ci/env/install-dependencies.sh
|
- RLLIB_TESTING=1 PYTHON=3.7 ./ci/env/install-dependencies.sh
|
||||||
# Test all tests in the `agents` (soon to be "trainers") dir:
|
# Test all tests in the `algorithms` dir:
|
||||||
- bazel test --config=ci $(./ci/run/bazel_export_options)
|
- bazel test --config=ci $(./ci/run/bazel_export_options)
|
||||||
--build_tests_only
|
--build_tests_only
|
||||||
--test_tag_filters=algorithms_dir_generic,-multi_gpu
|
--test_tag_filters=algorithms_dir_generic,-multi_gpu
|
||||||
--test_env=RAY_USE_MULTIPROCESSING_CPU_COUNT=1
|
--test_env=RAY_USE_MULTIPROCESSING_CPU_COUNT=1
|
||||||
rllib/...
|
rllib/...
|
||||||
|
|
||||||
- label: ":brain: RLlib: Trainer Tests (specific algos)"
|
- label: ":brain: RLlib: Algorithm Tests (specific algos)"
|
||||||
conditions: ["RAY_CI_RLLIB_DIRECTLY_AFFECTED"]
|
conditions: ["RAY_CI_RLLIB_DIRECTLY_AFFECTED"]
|
||||||
commands:
|
commands:
|
||||||
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT
|
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT
|
||||||
- RLLIB_TESTING=1 PYTHON=3.7 ./ci/env/install-dependencies.sh
|
- RLLIB_TESTING=1 PYTHON=3.7 ./ci/env/install-dependencies.sh
|
||||||
# Test all tests in the `agents` (soon to be "trainers") dir:
|
# Test all tests in the `algorithms` dir:
|
||||||
- bazel test --config=ci $(./ci/run/bazel_export_options)
|
- bazel test --config=ci $(./ci/run/bazel_export_options)
|
||||||
--build_tests_only
|
--build_tests_only
|
||||||
--test_tag_filters=algorithms_dir,-algorithms_dir_generic,-multi_gpu
|
--test_tag_filters=algorithms_dir,-algorithms_dir_generic,-multi_gpu
|
||||||
|
|
|
@ -740,7 +740,7 @@ Here is an example of the basic usage (for a more complete example, see `custom_
|
||||||
# NOTE: In order for this to work, your (custom) model needs to implement
|
# NOTE: In order for this to work, your (custom) model needs to implement
|
||||||
# the `import_from_h5` method.
|
# the `import_from_h5` method.
|
||||||
# See https://github.com/ray-project/ray/blob/master/rllib/tests/test_model_imports.py
|
# See https://github.com/ray-project/ray/blob/master/rllib/tests/test_model_imports.py
|
||||||
# for detailed examples for tf- and torch trainers/models.
|
# for detailed examples for tf- and torch policies/models.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
|
@ -1270,7 +1270,7 @@ Below are some examples of how the custom evaluation metrics are reported nested
|
||||||
Sample output for `python custom_eval.py --custom-eval`
|
Sample output for `python custom_eval.py --custom-eval`
|
||||||
------------------------------------------------------------------------
|
------------------------------------------------------------------------
|
||||||
|
|
||||||
INFO trainer.py:631 -- Running custom eval function <function ...>
|
INFO algorithm.py:631 -- Running custom eval function <function ...>
|
||||||
Update corridor length to 4
|
Update corridor length to 4
|
||||||
Update corridor length to 7
|
Update corridor length to 7
|
||||||
Custom evaluation round 1
|
Custom evaluation round 1
|
||||||
|
|
39
rllib/BUILD
39
rllib/BUILD
|
@ -15,7 +15,7 @@
|
||||||
# actions vs continuous actions.
|
# actions vs continuous actions.
|
||||||
# -- "fake_gpus": Tests that run using 2 fake GPUs.
|
# -- "fake_gpus": Tests that run using 2 fake GPUs.
|
||||||
|
|
||||||
# - Quick agent compilation/tune-train tests, tagged "quick_train".
|
# - Quick algo compilation/tune-train tests, tagged "quick_train".
|
||||||
# NOTE: These should be obsoleted in favor of "algorithms_dir" tests as
|
# NOTE: These should be obsoleted in favor of "algorithms_dir" tests as
|
||||||
# they cover the same functionaliy.
|
# they cover the same functionaliy.
|
||||||
|
|
||||||
|
@ -28,7 +28,7 @@
|
||||||
# - `policy` directory tests.
|
# - `policy` directory tests.
|
||||||
# - `utils` directory tests.
|
# - `utils` directory tests.
|
||||||
|
|
||||||
# - Trainer ("agents") tests, tagged "algorithms_dir".
|
# - Algorithm tests, tagged "algorithms_dir".
|
||||||
|
|
||||||
# - Tests directory (everything in rllib/tests/...), tagged: "tests_dir" and
|
# - Tests directory (everything in rllib/tests/...), tagged: "tests_dir" and
|
||||||
# "tests_dir_[A-Z]"
|
# "tests_dir_[A-Z]"
|
||||||
|
@ -65,7 +65,7 @@
|
||||||
load("//bazel:python.bzl", "py_test_module_list")
|
load("//bazel:python.bzl", "py_test_module_list")
|
||||||
|
|
||||||
# --------------------------------------------------------------------
|
# --------------------------------------------------------------------
|
||||||
# Agents learning regression tests.
|
# Algorithms learning regression tests.
|
||||||
#
|
#
|
||||||
# Tag: learning_tests
|
# Tag: learning_tests
|
||||||
#
|
#
|
||||||
|
@ -685,40 +685,41 @@ py_test(
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------------------------------
|
# --------------------------------------------------------------------
|
||||||
# Agents (Compilation, Losses, simple agent functionality tests)
|
# Algorithms (Compilation, Losses, simple functionality tests)
|
||||||
# rllib/algorithms/
|
# rllib/algorithms/
|
||||||
#
|
#
|
||||||
# Tag: algorithms_dir
|
# Tag: algorithms_dir
|
||||||
# --------------------------------------------------------------------
|
# --------------------------------------------------------------------
|
||||||
|
|
||||||
# Generic (all Trainers)
|
# Generic (all Algorithms)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "test_algorithm",
|
||||||
|
tags = ["team:rllib", "algorithms_dir", "algorithms_dir_generic"],
|
||||||
|
size = "large",
|
||||||
|
srcs = ["algorithms/tests/test_algorithm.py"]
|
||||||
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "test_callbacks",
|
name = "test_callbacks",
|
||||||
tags = ["team:rllib", "algorithms_dir", "algorithms_dir_generic"],
|
tags = ["team:rllib", "algorithms_dir", "algorithms_dir_generic"],
|
||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = ["agents/tests/test_callbacks.py"]
|
srcs = ["algorithms/tests/test_callbacks.py"]
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "test_memory_leaks_generic",
|
name = "test_memory_leaks_generic",
|
||||||
main = "agents/tests/test_memory_leaks.py",
|
main = "algorithms/tests/test_memory_leaks.py",
|
||||||
tags = ["team:rllib", "algorithms_dir"],
|
tags = ["team:rllib", "algorithms_dir"],
|
||||||
size = "large",
|
size = "large",
|
||||||
srcs = ["agents/tests/test_memory_leaks.py"]
|
srcs = ["algorithms/tests/test_memory_leaks.py"]
|
||||||
)
|
|
||||||
|
|
||||||
py_test(
|
|
||||||
name = "test_trainer",
|
|
||||||
tags = ["team:rllib", "algorithms_dir", "algorithms_dir_generic"],
|
|
||||||
size = "large",
|
|
||||||
srcs = ["agents/tests/test_trainer.py"]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "tests/test_worker_failures",
|
name = "tests/test_worker_failures",
|
||||||
tags = ["team:rllib", "tests_dir", "algorithms_dir_generic"],
|
tags = ["team:rllib", "tests_dir", "algorithms_dir_generic"],
|
||||||
size = "large",
|
size = "large",
|
||||||
srcs = ["agents/tests/test_worker_failures.py"]
|
srcs = ["algorithms/tests/test_worker_failures.py"]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Specific Algorithms
|
# Specific Algorithms
|
||||||
|
@ -809,7 +810,7 @@ py_test(
|
||||||
py_test(
|
py_test(
|
||||||
name = "test_cql",
|
name = "test_cql",
|
||||||
tags = ["team:rllib", "algorithms_dir"],
|
tags = ["team:rllib", "algorithms_dir"],
|
||||||
size = "medium",
|
size = "large",
|
||||||
srcs = ["algorithms/cql/tests/test_cql.py"]
|
srcs = ["algorithms/cql/tests/test_cql.py"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -982,7 +983,7 @@ py_test(
|
||||||
)
|
)
|
||||||
|
|
||||||
# --------------------------------------------------------------------
|
# --------------------------------------------------------------------
|
||||||
# contrib Agents
|
# contrib Algorithms
|
||||||
# --------------------------------------------------------------------
|
# --------------------------------------------------------------------
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
|
@ -1071,7 +1072,7 @@ py_test(
|
||||||
)
|
)
|
||||||
|
|
||||||
# --------------------------------------------------------------------
|
# --------------------------------------------------------------------
|
||||||
# Agents (quick training test iterations via `rllib train`)
|
# Algorithms (quick training test iterations via `rllib train`)
|
||||||
#
|
#
|
||||||
# Tag: quick_train
|
# Tag: quick_train
|
||||||
#
|
#
|
||||||
|
|
|
@ -30,11 +30,12 @@ class TestAlphaZero(unittest.TestCase):
|
||||||
|
|
||||||
# Only working for torch right now.
|
# Only working for torch right now.
|
||||||
for _ in framework_iterator(config, frameworks="torch"):
|
for _ in framework_iterator(config, frameworks="torch"):
|
||||||
trainer = config.build()
|
algo = config.build()
|
||||||
for i in range(num_iterations):
|
for i in range(num_iterations):
|
||||||
results = trainer.train()
|
results = algo.train()
|
||||||
check_train_results(results)
|
check_train_results(results)
|
||||||
print(results)
|
print(results)
|
||||||
|
algo.stop()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -28,23 +28,23 @@ class TestAPPO(unittest.TestCase):
|
||||||
for _ in framework_iterator(config, with_eager_tracing=True):
|
for _ in framework_iterator(config, with_eager_tracing=True):
|
||||||
print("w/o v-trace")
|
print("w/o v-trace")
|
||||||
config.vtrace = False
|
config.vtrace = False
|
||||||
trainer = config.build(env="CartPole-v0")
|
algo = config.build(env="CartPole-v0")
|
||||||
for i in range(num_iterations):
|
for i in range(num_iterations):
|
||||||
results = trainer.train()
|
results = algo.train()
|
||||||
check_train_results(results)
|
check_train_results(results)
|
||||||
print(results)
|
print(results)
|
||||||
check_compute_single_action(trainer)
|
check_compute_single_action(algo)
|
||||||
trainer.stop()
|
algo.stop()
|
||||||
|
|
||||||
print("w/ v-trace")
|
print("w/ v-trace")
|
||||||
config.vtrace = True
|
config.vtrace = True
|
||||||
trainer = config.build(env="CartPole-v0")
|
algo = config.build(env="CartPole-v0")
|
||||||
for i in range(num_iterations):
|
for i in range(num_iterations):
|
||||||
results = trainer.train()
|
results = algo.train()
|
||||||
check_train_results(results)
|
check_train_results(results)
|
||||||
print(results)
|
print(results)
|
||||||
check_compute_single_action(trainer)
|
check_compute_single_action(algo)
|
||||||
trainer.stop()
|
algo.stop()
|
||||||
|
|
||||||
def test_appo_compilation_use_kl_loss(self):
|
def test_appo_compilation_use_kl_loss(self):
|
||||||
"""Test whether APPO can be built with kl_loss enabled."""
|
"""Test whether APPO can be built with kl_loss enabled."""
|
||||||
|
@ -54,13 +54,13 @@ class TestAPPO(unittest.TestCase):
|
||||||
num_iterations = 2
|
num_iterations = 2
|
||||||
|
|
||||||
for _ in framework_iterator(config, with_eager_tracing=True):
|
for _ in framework_iterator(config, with_eager_tracing=True):
|
||||||
trainer = config.build(env="CartPole-v0")
|
algo = config.build(env="CartPole-v0")
|
||||||
for i in range(num_iterations):
|
for i in range(num_iterations):
|
||||||
results = trainer.train()
|
results = algo.train()
|
||||||
check_train_results(results)
|
check_train_results(results)
|
||||||
print(results)
|
print(results)
|
||||||
check_compute_single_action(trainer)
|
check_compute_single_action(algo)
|
||||||
trainer.stop()
|
algo.stop()
|
||||||
|
|
||||||
def test_appo_two_tf_optimizers(self):
|
def test_appo_two_tf_optimizers(self):
|
||||||
# Not explicitly setting this should cause a warning, but not fail.
|
# Not explicitly setting this should cause a warning, but not fail.
|
||||||
|
@ -78,13 +78,13 @@ class TestAPPO(unittest.TestCase):
|
||||||
|
|
||||||
# Only supported for tf so far.
|
# Only supported for tf so far.
|
||||||
for _ in framework_iterator(config, frameworks=("tf2", "tf")):
|
for _ in framework_iterator(config, frameworks=("tf2", "tf")):
|
||||||
trainer = config.build(env="CartPole-v0")
|
algo = config.build(env="CartPole-v0")
|
||||||
for i in range(num_iterations):
|
for i in range(num_iterations):
|
||||||
results = trainer.train()
|
results = algo.train()
|
||||||
check_train_results(results)
|
check_train_results(results)
|
||||||
print(results)
|
print(results)
|
||||||
check_compute_single_action(trainer)
|
check_compute_single_action(algo)
|
||||||
trainer.stop()
|
algo.stop()
|
||||||
|
|
||||||
def test_appo_entropy_coeff_schedule(self):
|
def test_appo_entropy_coeff_schedule(self):
|
||||||
# Initial lr, doesn't really matter because of the schedule below.
|
# Initial lr, doesn't really matter because of the schedule below.
|
||||||
|
@ -113,33 +113,33 @@ class TestAPPO(unittest.TestCase):
|
||||||
# which entropy coeff depends on, is updated after each worker rollout.
|
# which entropy coeff depends on, is updated after each worker rollout.
|
||||||
config.min_time_s_per_iteration = 0
|
config.min_time_s_per_iteration = 0
|
||||||
|
|
||||||
def _step_n_times(trainer, n: int):
|
def _step_n_times(algo, n: int):
|
||||||
"""Step trainer n times.
|
"""Step Algorithm n times.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
learning rate at the end of the execution.
|
learning rate at the end of the execution.
|
||||||
"""
|
"""
|
||||||
for _ in range(n):
|
for _ in range(n):
|
||||||
results = trainer.train()
|
results = algo.train()
|
||||||
print(trainer.workers.local_worker().global_vars)
|
print(algo.workers.local_worker().global_vars)
|
||||||
print(results)
|
print(results)
|
||||||
return results["info"][LEARNER_INFO][DEFAULT_POLICY_ID][LEARNER_STATS_KEY][
|
return results["info"][LEARNER_INFO][DEFAULT_POLICY_ID][LEARNER_STATS_KEY][
|
||||||
"entropy_coeff"
|
"entropy_coeff"
|
||||||
]
|
]
|
||||||
|
|
||||||
for _ in framework_iterator(config):
|
for _ in framework_iterator(config):
|
||||||
trainer = config.build(env="CartPole-v0")
|
algo = config.build(env="CartPole-v0")
|
||||||
|
|
||||||
coeff = _step_n_times(trainer, 10) # 200 timesteps
|
coeff = _step_n_times(algo, 10) # 200 timesteps
|
||||||
# Should be close to the starting coeff of 0.01.
|
# Should be close to the starting coeff of 0.01.
|
||||||
self.assertLessEqual(coeff, 0.01)
|
self.assertLessEqual(coeff, 0.01)
|
||||||
self.assertGreaterEqual(coeff, 0.001)
|
self.assertGreaterEqual(coeff, 0.001)
|
||||||
|
|
||||||
coeff = _step_n_times(trainer, 20) # 400 timesteps
|
coeff = _step_n_times(algo, 20) # 400 timesteps
|
||||||
# Should have annealed to the final coeff of 0.0001.
|
# Should have annealed to the final coeff of 0.0001.
|
||||||
self.assertLessEqual(coeff, 0.001)
|
self.assertLessEqual(coeff, 0.001)
|
||||||
|
|
||||||
trainer.stop()
|
algo.stop()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -33,13 +33,13 @@ class TestARS(unittest.TestCase):
|
||||||
num_iterations = 2
|
num_iterations = 2
|
||||||
|
|
||||||
for _ in framework_iterator(config):
|
for _ in framework_iterator(config):
|
||||||
trainer = config.build(env="CartPole-v0")
|
algo = config.build(env="CartPole-v0")
|
||||||
for i in range(num_iterations):
|
for i in range(num_iterations):
|
||||||
results = trainer.train()
|
results = algo.train()
|
||||||
print(results)
|
print(results)
|
||||||
|
|
||||||
check_compute_single_action(trainer)
|
check_compute_single_action(algo)
|
||||||
trainer.stop()
|
algo.stop()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -29,13 +29,13 @@ class TestES(unittest.TestCase):
|
||||||
|
|
||||||
for _ in framework_iterator(config):
|
for _ in framework_iterator(config):
|
||||||
for env in ["CartPole-v0", "Pendulum-v1"]:
|
for env in ["CartPole-v0", "Pendulum-v1"]:
|
||||||
trainer = config.build(env=env)
|
algo = config.build(env=env)
|
||||||
for i in range(num_iterations):
|
for i in range(num_iterations):
|
||||||
results = trainer.train()
|
results = algo.train()
|
||||||
print(results)
|
print(results)
|
||||||
|
|
||||||
check_compute_single_action(trainer)
|
check_compute_single_action(algo)
|
||||||
trainer.stop()
|
algo.stop()
|
||||||
ray.shutdown()
|
ray.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -37,8 +37,8 @@ class MARWILConfig(AlgorithmConfig):
|
||||||
... .offline_data(input_=["./rllib/tests/data/cartpole/large.json"])
|
... .offline_data(input_=["./rllib/tests/data/cartpole/large.json"])
|
||||||
>>> print(config.to_dict())
|
>>> print(config.to_dict())
|
||||||
>>> # Build a Algorithm object from the config and run 1 training iteration.
|
>>> # Build a Algorithm object from the config and run 1 training iteration.
|
||||||
>>> trainer = config.build()
|
>>> algo = config.build()
|
||||||
>>> trainer.train()
|
>>> algo.train()
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> from ray.rllib.algorithms.marwil import MARWILConfig
|
>>> from ray.rllib.algorithms.marwil import MARWILConfig
|
||||||
|
|
|
@ -30,9 +30,9 @@ class R2D2Config(DQNConfig):
|
||||||
>>> .resources(num_gpus=1)\
|
>>> .resources(num_gpus=1)\
|
||||||
>>> .rollouts(num_rollout_workers=30)\
|
>>> .rollouts(num_rollout_workers=30)\
|
||||||
>>> .environment("CartPole-v1")
|
>>> .environment("CartPole-v1")
|
||||||
>>> trainer = R2D2(config=config)
|
>>> algo = R2D2(config=config)
|
||||||
>>> while True:
|
>>> while True:
|
||||||
>>> trainer.train()
|
>>> algo.train()
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> from ray.rllib.algorithms.r2d2.r2d2 import R2D2Config
|
>>> from ray.rllib.algorithms.r2d2.r2d2 import R2D2Config
|
||||||
|
@ -170,8 +170,6 @@ class R2D2Config(DQNConfig):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
# Build an R2D2 trainer, which uses the framework specific Policy
|
|
||||||
# determined in `get_policy_class()` above.
|
|
||||||
class R2D2(DQN):
|
class R2D2(DQN):
|
||||||
"""Recurrent Experience Replay in Distrib. Reinforcement Learning (R2D2).
|
"""Recurrent Experience Replay in Distrib. Reinforcement Learning (R2D2).
|
||||||
|
|
||||||
|
|
|
@ -78,14 +78,14 @@ class TestR2D2(unittest.TestCase):
|
||||||
|
|
||||||
# Test building an R2D2 agent in all frameworks.
|
# Test building an R2D2 agent in all frameworks.
|
||||||
for _ in framework_iterator(config, with_eager_tracing=True):
|
for _ in framework_iterator(config, with_eager_tracing=True):
|
||||||
trainer = config.build(env="CartPole-v0")
|
algo = config.build(env="CartPole-v0")
|
||||||
for i in range(num_iterations):
|
for i in range(num_iterations):
|
||||||
results = trainer.train()
|
results = algo.train()
|
||||||
check_train_results(results)
|
check_train_results(results)
|
||||||
check_batch_sizes(results)
|
check_batch_sizes(results)
|
||||||
print(results)
|
print(results)
|
||||||
|
|
||||||
check_compute_single_action(trainer, include_state=True)
|
check_compute_single_action(algo, include_state=True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -29,8 +29,8 @@ class SACConfig(AlgorithmConfig):
|
||||||
... .rollouts(num_rollout_workers=4)
|
... .rollouts(num_rollout_workers=4)
|
||||||
>>> print(config.to_dict())
|
>>> print(config.to_dict())
|
||||||
>>> # Build a Algorithm object from the config and run 1 training iteration.
|
>>> # Build a Algorithm object from the config and run 1 training iteration.
|
||||||
>>> trainer = config.build(env="CartPole-v1")
|
>>> algo = config.build(env="CartPole-v1")
|
||||||
>>> trainer.train()
|
>>> algo.train()
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, algo_class=None):
|
def __init__(self, algo_class=None):
|
||||||
|
|
|
@ -18,8 +18,8 @@ class TD3Config(DDPGConfig):
|
||||||
>>> config = TD3Config().training(lr=0.01).resources(num_gpus=1)
|
>>> config = TD3Config().training(lr=0.01).resources(num_gpus=1)
|
||||||
>>> print(config.to_dict())
|
>>> print(config.to_dict())
|
||||||
>>> # Build a Algorithm object from the config and run one training iteration.
|
>>> # Build a Algorithm object from the config and run one training iteration.
|
||||||
>>> trainer = config.build(env="Pendulum-v1")
|
>>> algo = config.build(env="Pendulum-v1")
|
||||||
>>> trainer.train()
|
>>> algo.train()
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> from ray.rllib.algorithms.ddpg.td3 import TD3Config
|
>>> from ray.rllib.algorithms.ddpg.td3 import TD3Config
|
||||||
|
|
|
@ -38,10 +38,10 @@ class TestAlgorithm(unittest.TestCase):
|
||||||
algo = pg.PG(env="CartPole-v0", config=standard_config)
|
algo = pg.PG(env="CartPole-v0", config=standard_config)
|
||||||
|
|
||||||
# When (we validate config 2 times).
|
# When (we validate config 2 times).
|
||||||
# Try deprecated `Trainer._validate_config()` method (static).
|
# Try deprecated `Algorithm._validate_config()` method (static).
|
||||||
algo._validate_config(standard_config, algo)
|
algo._validate_config(standard_config, algo)
|
||||||
config_v1 = copy.deepcopy(standard_config)
|
config_v1 = copy.deepcopy(standard_config)
|
||||||
# Try new method: `Trainer.validate_config()` (non-static).
|
# Try new method: `Algorithm.validate_config()` (non-static).
|
||||||
algo.validate_config(standard_config)
|
algo.validate_config(standard_config)
|
||||||
config_v2 = copy.deepcopy(standard_config)
|
config_v2 = copy.deepcopy(standard_config)
|
||||||
|
|
||||||
|
@ -239,7 +239,7 @@ class TestAlgorithm(unittest.TestCase):
|
||||||
algo_wo_env_on_driver.stop()
|
algo_wo_env_on_driver.stop()
|
||||||
|
|
||||||
# Try again using `create_env_on_driver=True`.
|
# Try again using `create_env_on_driver=True`.
|
||||||
# This force-adds the env on the local-worker, so this Trainer
|
# This force-adds the env on the local-worker, so this Algorithm
|
||||||
# can `evaluate` even though it doesn't have an evaluation-worker
|
# can `evaluate` even though it doesn't have an evaluation-worker
|
||||||
# set.
|
# set.
|
||||||
config.create_env_on_local_worker = True
|
config.create_env_on_local_worker = True
|
|
@ -47,13 +47,13 @@ class TestCallbacks(unittest.TestCase):
|
||||||
config = dict(base_config, callbacks=callbacks)
|
config = dict(base_config, callbacks=callbacks)
|
||||||
|
|
||||||
for _ in framework_iterator(config, frameworks=("tf", "torch")):
|
for _ in framework_iterator(config, frameworks=("tf", "torch")):
|
||||||
trainer = dqn.DQN(config=config)
|
algo = dqn.DQN(config=config)
|
||||||
# Fake the counter on the local worker (doesn't have an env) and
|
# Fake the counter on the local worker (doesn't have an env) and
|
||||||
# set it to -1 so the below `foreach_worker()` won't fail.
|
# set it to -1 so the below `foreach_worker()` won't fail.
|
||||||
trainer.workers.local_worker().sum_sub_env_vector_indices = -1
|
algo.workers.local_worker().sum_sub_env_vector_indices = -1
|
||||||
|
|
||||||
# Get sub-env vector index sums from the 2 remote workers:
|
# Get sub-env vector index sums from the 2 remote workers:
|
||||||
sum_sub_env_vector_indices = trainer.workers.foreach_worker(
|
sum_sub_env_vector_indices = algo.workers.foreach_worker(
|
||||||
lambda w: w.sum_sub_env_vector_indices
|
lambda w: w.sum_sub_env_vector_indices
|
||||||
)
|
)
|
||||||
# Local worker has no environments -> Expect the -1 special
|
# Local worker has no environments -> Expect the -1 special
|
||||||
|
@ -63,7 +63,7 @@ class TestCallbacks(unittest.TestCase):
|
||||||
# of 6 (sum of vector indices: 0 + 1 + 2 + 3).
|
# of 6 (sum of vector indices: 0 + 1 + 2 + 3).
|
||||||
self.assertTrue(sum_sub_env_vector_indices[1] == 6)
|
self.assertTrue(sum_sub_env_vector_indices[1] == 6)
|
||||||
self.assertTrue(sum_sub_env_vector_indices[2] == 6)
|
self.assertTrue(sum_sub_env_vector_indices[2] == 6)
|
||||||
trainer.stop()
|
algo.stop()
|
||||||
|
|
||||||
def test_on_sub_environment_created_with_remote_envs(self):
|
def test_on_sub_environment_created_with_remote_envs(self):
|
||||||
base_config = {
|
base_config = {
|
||||||
|
@ -84,13 +84,13 @@ class TestCallbacks(unittest.TestCase):
|
||||||
config = dict(base_config, callbacks=callbacks)
|
config = dict(base_config, callbacks=callbacks)
|
||||||
|
|
||||||
for _ in framework_iterator(config, frameworks=("tf", "torch")):
|
for _ in framework_iterator(config, frameworks=("tf", "torch")):
|
||||||
trainer = dqn.DQN(config=config)
|
algo = dqn.DQN(config=config)
|
||||||
# Fake the counter on the local worker (doesn't have an env) and
|
# Fake the counter on the local worker (doesn't have an env) and
|
||||||
# set it to -1 so the below `foreach_worker()` won't fail.
|
# set it to -1 so the below `foreach_worker()` won't fail.
|
||||||
trainer.workers.local_worker().sum_sub_env_vector_indices = -1
|
algo.workers.local_worker().sum_sub_env_vector_indices = -1
|
||||||
|
|
||||||
# Get sub-env vector index sums from the 2 remote workers:
|
# Get sub-env vector index sums from the 2 remote workers:
|
||||||
sum_sub_env_vector_indices = trainer.workers.foreach_worker(
|
sum_sub_env_vector_indices = algo.workers.foreach_worker(
|
||||||
lambda w: w.sum_sub_env_vector_indices
|
lambda w: w.sum_sub_env_vector_indices
|
||||||
)
|
)
|
||||||
# Local worker has no environments -> Expect the -1 special
|
# Local worker has no environments -> Expect the -1 special
|
||||||
|
@ -100,7 +100,7 @@ class TestCallbacks(unittest.TestCase):
|
||||||
# of 6 (sum of vector indices: 0 + 1 + 2 + 3).
|
# of 6 (sum of vector indices: 0 + 1 + 2 + 3).
|
||||||
self.assertTrue(sum_sub_env_vector_indices[1] == 6)
|
self.assertTrue(sum_sub_env_vector_indices[1] == 6)
|
||||||
self.assertTrue(sum_sub_env_vector_indices[2] == 6)
|
self.assertTrue(sum_sub_env_vector_indices[2] == 6)
|
||||||
trainer.stop()
|
algo.stop()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
|
@ -30,10 +30,10 @@ class TestMemoryLeaks(unittest.TestCase):
|
||||||
config["env_config"] = {
|
config["env_config"] = {
|
||||||
"static_samples": True,
|
"static_samples": True,
|
||||||
}
|
}
|
||||||
trainer = ppo.PPO(config=config)
|
algo = ppo.PPO(config=config)
|
||||||
results = check_memory_leaks(trainer, to_check={"env"}, repeats=150)
|
results = check_memory_leaks(algo, to_check={"env"}, repeats=150)
|
||||||
assert results["env"]
|
assert results["env"]
|
||||||
trainer.stop()
|
algo.stop()
|
||||||
|
|
||||||
def test_leaky_policy(self):
|
def test_leaky_policy(self):
|
||||||
"""Tests, whether our diagnostics tools can detect leaks in a policy."""
|
"""Tests, whether our diagnostics tools can detect leaks in a policy."""
|
||||||
|
@ -45,10 +45,10 @@ class TestMemoryLeaks(unittest.TestCase):
|
||||||
config["multiagent"]["policies"] = {
|
config["multiagent"]["policies"] = {
|
||||||
"default_policy": PolicySpec(policy_class=MemoryLeakingPolicy),
|
"default_policy": PolicySpec(policy_class=MemoryLeakingPolicy),
|
||||||
}
|
}
|
||||||
trainer = dqn.DQN(config=config)
|
algo = dqn.DQN(config=config)
|
||||||
results = check_memory_leaks(trainer, to_check={"policy"}, repeats=300)
|
results = check_memory_leaks(algo, to_check={"policy"}, repeats=300)
|
||||||
assert results["policy"]
|
assert results["policy"]
|
||||||
trainer.stop()
|
algo.stop()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
|
@ -12,7 +12,7 @@ from ray.rllib.connectors.connector import (
|
||||||
from ray.rllib.utils.annotations import DeveloperAPI
|
from ray.rllib.utils.annotations import DeveloperAPI
|
||||||
from ray.rllib.utils.typing import (
|
from ray.rllib.utils.typing import (
|
||||||
ActionConnectorDataType,
|
ActionConnectorDataType,
|
||||||
TrainerConfigDict,
|
AlgorithmConfigDict,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -50,8 +50,8 @@ register_connector(ActionConnectorPipeline.__name__, ActionConnectorPipeline)
|
||||||
|
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def get_action_connectors_from_trainer_config(
|
def get_action_connectors_from_algorithm_config(
|
||||||
config: TrainerConfigDict, action_space: gym.Space
|
config: AlgorithmConfigDict, action_space: gym.Space
|
||||||
) -> ActionConnectorPipeline:
|
) -> ActionConnectorPipeline:
|
||||||
connectors = []
|
connectors = []
|
||||||
return ActionConnectorPipeline(connectors)
|
return ActionConnectorPipeline(connectors)
|
||||||
|
|
|
@ -15,7 +15,7 @@ from ray.rllib.utils.annotations import DeveloperAPI
|
||||||
from ray.rllib.utils.typing import (
|
from ray.rllib.utils.typing import (
|
||||||
ActionConnectorDataType,
|
ActionConnectorDataType,
|
||||||
AgentConnectorDataType,
|
AgentConnectorDataType,
|
||||||
TrainerConfigDict,
|
AlgorithmConfigDict,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -67,7 +67,7 @@ register_connector(AgentConnectorPipeline.__name__, AgentConnectorPipeline)
|
||||||
# TODO(jungong) : finish this.
|
# TODO(jungong) : finish this.
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def get_agent_connectors_from_config(
|
def get_agent_connectors_from_config(
|
||||||
config: TrainerConfigDict, obs_space: gym.Space
|
config: AlgorithmConfigDict, obs_space: gym.Space
|
||||||
) -> AgentConnectorPipeline:
|
) -> AgentConnectorPipeline:
|
||||||
connectors = [FlattenDataAgentConnector()]
|
connectors = [FlattenDataAgentConnector()]
|
||||||
|
|
||||||
|
|
|
@ -13,8 +13,8 @@ from ray.rllib.utils.annotations import DeveloperAPI
|
||||||
from ray.rllib.utils.typing import (
|
from ray.rllib.utils.typing import (
|
||||||
ActionConnectorDataType,
|
ActionConnectorDataType,
|
||||||
AgentConnectorDataType,
|
AgentConnectorDataType,
|
||||||
|
AlgorithmConfigDict,
|
||||||
TensorType,
|
TensorType,
|
||||||
TrainerConfigDict,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -34,7 +34,7 @@ class ConnectorContext:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: TrainerConfigDict = None,
|
config: AlgorithmConfigDict = None,
|
||||||
model_initial_states: List[TensorType] = None,
|
model_initial_states: List[TensorType] = None,
|
||||||
observation_space: gym.Space = None,
|
observation_space: gym.Space = None,
|
||||||
action_space: gym.Space = None,
|
action_space: gym.Space = None,
|
||||||
|
|
2
rllib/env/multi_agent_env.py
vendored
2
rllib/env/multi_agent_env.py
vendored
|
@ -30,7 +30,7 @@ class MultiAgentEnv(gym.Env):
|
||||||
"""An environment that hosts multiple independent agents.
|
"""An environment that hosts multiple independent agents.
|
||||||
|
|
||||||
Agents are identified by (string) agent ids. Note that these "agents" here
|
Agents are identified by (string) agent ids. Note that these "agents" here
|
||||||
are not to be confused with RLlib Trainers, which are also sometimes
|
are not to be confused with RLlib Algorithms, which are also sometimes
|
||||||
referred to as "agents" or "RL agents".
|
referred to as "agents" or "RL agents".
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
|
@ -168,16 +168,16 @@ class TestTrajectoryViewAPI(unittest.TestCase):
|
||||||
config["env_config"] = {"config": {"start_at_t": 1}} # first obs is [1.0]
|
config["env_config"] = {"config": {"start_at_t": 1}} # first obs is [1.0]
|
||||||
|
|
||||||
for _ in framework_iterator(config, frameworks="tf2"):
|
for _ in framework_iterator(config, frameworks="tf2"):
|
||||||
trainer = ppo.PPO(
|
algo = ppo.PPO(
|
||||||
config,
|
config,
|
||||||
env="ray.rllib.examples.env.debug_counter_env.DebugCounterEnv",
|
env="ray.rllib.examples.env.debug_counter_env.DebugCounterEnv",
|
||||||
)
|
)
|
||||||
rw = trainer.workers.local_worker()
|
rw = algo.workers.local_worker()
|
||||||
sample = rw.sample()
|
sample = rw.sample()
|
||||||
assert sample.count == trainer.config["rollout_fragment_length"]
|
assert sample.count == algo.config["rollout_fragment_length"]
|
||||||
results = trainer.train()
|
results = algo.train()
|
||||||
assert results["timesteps_total"] == config["train_batch_size"]
|
assert results["timesteps_total"] == config["train_batch_size"]
|
||||||
trainer.stop()
|
algo.stop()
|
||||||
|
|
||||||
def test_traj_view_next_action(self):
|
def test_traj_view_next_action(self):
|
||||||
action_space = Discrete(2)
|
action_space = Discrete(2)
|
||||||
|
@ -341,10 +341,10 @@ class TestTrajectoryViewAPI(unittest.TestCase):
|
||||||
config["env_config"] = {"num_agents": num_agents}
|
config["env_config"] = {"num_agents": num_agents}
|
||||||
|
|
||||||
num_iterations = 2
|
num_iterations = 2
|
||||||
trainer = ppo.PPO(config=config)
|
algo = ppo.PPO(config=config)
|
||||||
results = None
|
results = None
|
||||||
for i in range(num_iterations):
|
for i in range(num_iterations):
|
||||||
results = trainer.train()
|
results = algo.train()
|
||||||
self.assertEqual(results["agent_timesteps_total"], results["timesteps_total"])
|
self.assertEqual(results["agent_timesteps_total"], results["timesteps_total"])
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
results["num_env_steps_trained"] * num_agents,
|
results["num_env_steps_trained"] * num_agents,
|
||||||
|
@ -358,7 +358,7 @@ class TestTrajectoryViewAPI(unittest.TestCase):
|
||||||
results["agent_timesteps_total"],
|
results["agent_timesteps_total"],
|
||||||
(num_iterations + 1) * config["train_batch_size"],
|
(num_iterations + 1) * config["train_batch_size"],
|
||||||
)
|
)
|
||||||
trainer.stop()
|
algo.stop()
|
||||||
|
|
||||||
def test_get_single_step_input_dict_batch_repeat_value_larger_1(self):
|
def test_get_single_step_input_dict_batch_repeat_value_larger_1(self):
|
||||||
"""Test whether a SampleBatch produces the correct 1-step input dict."""
|
"""Test whether a SampleBatch produces the correct 1-step input dict."""
|
||||||
|
|
|
@ -81,14 +81,14 @@ if __name__ == "__main__":
|
||||||
"episode_reward_mean": args.stop_reward,
|
"episode_reward_mean": args.stop_reward,
|
||||||
}
|
}
|
||||||
|
|
||||||
# To run the Trainer without tune.run, using our LSTM model and
|
# To run the Algorithm without tune.run, using our LSTM model and
|
||||||
# manual state-in handling, do the following:
|
# manual state-in handling, do the following:
|
||||||
|
|
||||||
# Example (use `config` from the above code):
|
# Example (use `config` from the above code):
|
||||||
# >> import numpy as np
|
# >> import numpy as np
|
||||||
# >> from ray.rllib.algorithms.ppo import PPO
|
# >> from ray.rllib.algorithms.ppo import PPO
|
||||||
# >>
|
# >>
|
||||||
# >> trainer = PPO(config)
|
# >> algo = PPO(config)
|
||||||
# >> lstm_cell_size = config["model"]["lstm_cell_size"]
|
# >> lstm_cell_size = config["model"]["lstm_cell_size"]
|
||||||
# >> env = StatelessCartPole()
|
# >> env = StatelessCartPole()
|
||||||
# >> obs = env.reset()
|
# >> obs = env.reset()
|
||||||
|
@ -101,7 +101,7 @@ if __name__ == "__main__":
|
||||||
# >> prev_r = 0.0
|
# >> prev_r = 0.0
|
||||||
# >>
|
# >>
|
||||||
# >> while True:
|
# >> while True:
|
||||||
# >> a, state_out, _ = trainer.compute_single_action(
|
# >> a, state_out, _ = algo.compute_single_action(
|
||||||
# .. obs, state, prev_a, prev_r)
|
# .. obs, state, prev_a, prev_r)
|
||||||
# >> obs, reward, done, _ = env.step(a)
|
# >> obs, reward, done, _ = env.step(a)
|
||||||
# >> if done:
|
# >> if done:
|
||||||
|
|
|
@ -92,8 +92,8 @@ MyTFPolicy = build_tf_policy(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Create a new Trainer using the Policy defined above.
|
# Create a new Algorithm using the Policy defined above.
|
||||||
class MyTrainer(Algorithm):
|
class MyAlgo(Algorithm):
|
||||||
def get_default_policy_class(self, config):
|
def get_default_policy_class(self, config):
|
||||||
return MyTFPolicy
|
return MyTFPolicy
|
||||||
|
|
||||||
|
@ -117,7 +117,7 @@ if __name__ == "__main__":
|
||||||
"episode_reward_mean": args.stop_reward,
|
"episode_reward_mean": args.stop_reward,
|
||||||
}
|
}
|
||||||
|
|
||||||
results = tune.run(MyTrainer, stop=stop, config=config, verbose=1)
|
results = tune.run(MyAlgo, stop=stop, config=config, verbose=1)
|
||||||
|
|
||||||
if args.as_test:
|
if args.as_test:
|
||||||
check_learning_achieved(results, args.stop_reward)
|
check_learning_achieved(results, args.stop_reward)
|
||||||
|
|
|
@ -83,11 +83,11 @@ if __name__ == "__main__":
|
||||||
min_reward = -300
|
min_reward = -300
|
||||||
|
|
||||||
# Test for torch framework (tf not implemented yet).
|
# Test for torch framework (tf not implemented yet).
|
||||||
trainer = cql.CQL(config=config)
|
algo = cql.CQL(config=config)
|
||||||
learnt = False
|
learnt = False
|
||||||
for i in range(num_iterations):
|
for i in range(num_iterations):
|
||||||
print(f"Iter {i}")
|
print(f"Iter {i}")
|
||||||
eval_results = trainer.train().get("evaluation")
|
eval_results = algo.train().get("evaluation")
|
||||||
if eval_results:
|
if eval_results:
|
||||||
print("... R={}".format(eval_results["episode_reward_mean"]))
|
print("... R={}".format(eval_results["episode_reward_mean"]))
|
||||||
# Learn until some reward is reached on an actual live env.
|
# Learn until some reward is reached on an actual live env.
|
||||||
|
@ -101,7 +101,7 @@ if __name__ == "__main__":
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get policy, model, and replay-buffer.
|
# Get policy, model, and replay-buffer.
|
||||||
pol = trainer.get_policy()
|
pol = algo.get_policy()
|
||||||
cql_model = pol.model
|
cql_model = pol.model
|
||||||
from ray.rllib.algorithms.cql.cql import replay_buffer
|
from ray.rllib.algorithms.cql.cql import replay_buffer
|
||||||
|
|
||||||
|
@ -116,7 +116,7 @@ if __name__ == "__main__":
|
||||||
final_q_values = torch.min(q_values, twin_q_values)
|
final_q_values = torch.min(q_values, twin_q_values)
|
||||||
print(final_q_values)
|
print(final_q_values)
|
||||||
|
|
||||||
# Example on how to do evaluation on the trained Trainer
|
# Example on how to do evaluation on the trained Algorithm.
|
||||||
# using the data from our buffer.
|
# using the data from our buffer.
|
||||||
# Get a sample (MultiAgentBatch).
|
# Get a sample (MultiAgentBatch).
|
||||||
multi_agent_batch = replay_buffer.sample(num_items=config["train_batch_size"])
|
multi_agent_batch = replay_buffer.sample(num_items=config["train_batch_size"])
|
||||||
|
@ -128,11 +128,10 @@ if __name__ == "__main__":
|
||||||
model_out, _ = cql_model({"obs": obs})
|
model_out, _ = cql_model({"obs": obs})
|
||||||
# The estimated Q-values from the (historic) actions in the batch.
|
# The estimated Q-values from the (historic) actions in the batch.
|
||||||
q_values_old = cql_model.get_q_values(model_out, torch.from_numpy(batch["actions"]))
|
q_values_old = cql_model.get_q_values(model_out, torch.from_numpy(batch["actions"]))
|
||||||
# The estimated Q-values for the new actions computed
|
# The estimated Q-values for the new actions computed by our policy.
|
||||||
# by our trainer policy.
|
|
||||||
actions_new = pol.compute_actions_from_input_dict({"obs": obs})[0]
|
actions_new = pol.compute_actions_from_input_dict({"obs": obs})[0]
|
||||||
q_values_new = cql_model.get_q_values(model_out, torch.from_numpy(actions_new))
|
q_values_new = cql_model.get_q_values(model_out, torch.from_numpy(actions_new))
|
||||||
print(f"Q-val batch={q_values_old}")
|
print(f"Q-val batch={q_values_old}")
|
||||||
print(f"Q-val policy={q_values_new}")
|
print(f"Q-val policy={q_values_new}")
|
||||||
|
|
||||||
trainer.stop()
|
algo.stop()
|
||||||
|
|
|
@ -58,10 +58,10 @@ class RandomParametricPolicy(Policy, ABC):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class RandomParametricTrainer(Algorithm):
|
class RandomParametricAlgorithm(Algorithm):
|
||||||
"""Algo with Policy and config defined above and overriding `training_iteration`.
|
"""Algo with Policy and config defined above and overriding `training_step`.
|
||||||
|
|
||||||
Overrides the `training_iteration` method, which only runs a (dummy)
|
Overrides the `training_step` method, which only runs a (dummy)
|
||||||
rollout and performs no learning.
|
rollout and performs no learning.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -79,7 +79,7 @@ class RandomParametricTrainer(Algorithm):
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
register_env("pa_cartpole", lambda _: ParametricActionsCartPole(10))
|
register_env("pa_cartpole", lambda _: ParametricActionsCartPole(10))
|
||||||
algo = RandomParametricTrainer(env="pa_cartpole")
|
algo = RandomParametricAlgorithm(env="pa_cartpole")
|
||||||
result = algo.train()
|
result = algo.train()
|
||||||
assert result["episode_reward_mean"] > 10, result
|
assert result["episode_reward_mean"] > 10, result
|
||||||
print("Test: OK")
|
print("Test: OK")
|
||||||
|
|
|
@ -75,10 +75,10 @@ def get_cli_args():
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
# The modified Trainer class we will use. This is the exact same
|
# The modified Algorithm class we will use:
|
||||||
# as a PPO, but with the additional default_resource_request
|
# Subclassing from PPO, our algo will only modity `default_resource_request`,
|
||||||
# override, telling tune that it's ok (not mandatory) to place our
|
# telling Ray Tune that it's ok (not mandatory) to place our n remote envs on a
|
||||||
# n remote envs on a different node (each env using 1 CPU).
|
# different node (each env using 1 CPU).
|
||||||
class PPORemoteInference(PPO):
|
class PPORemoteInference(PPO):
|
||||||
@classmethod
|
@classmethod
|
||||||
@override(Algorithm)
|
@override(Algorithm)
|
||||||
|
@ -145,7 +145,7 @@ if __name__ == "__main__":
|
||||||
):
|
):
|
||||||
break
|
break
|
||||||
|
|
||||||
# Run with Tune for auto env and trainer creation and TensorBoard.
|
# Run with Tune for auto env and algorithm creation and TensorBoard.
|
||||||
else:
|
else:
|
||||||
stop = {
|
stop = {
|
||||||
"training_iteration": args.stop_iters,
|
"training_iteration": args.stop_iters,
|
||||||
|
|
|
@ -64,12 +64,12 @@ parser.add_argument(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Define new Trainer with custom execution_plan/workflow.
|
# Define new Algorithm with custom execution_plan/workflow.
|
||||||
class MyTrainer(Algorithm):
|
class MyAlgo(Algorithm):
|
||||||
@classmethod
|
@classmethod
|
||||||
@override(Algorithm)
|
@override(Algorithm)
|
||||||
def get_default_config(cls) -> AlgorithmConfigDict:
|
def get_default_config(cls) -> AlgorithmConfigDict:
|
||||||
# Run this Trainer with new `training_iteration` API and set some PPO-specific
|
# Run this Algorithm with new `training_step` API and set some PPO-specific
|
||||||
# parameters.
|
# parameters.
|
||||||
return with_common_config(
|
return with_common_config(
|
||||||
{
|
{
|
||||||
|
@ -218,7 +218,7 @@ if __name__ == "__main__":
|
||||||
"episode_reward_mean": args.stop_reward,
|
"episode_reward_mean": args.stop_reward,
|
||||||
}
|
}
|
||||||
|
|
||||||
results = tune.run(MyTrainer, config=config, stop=stop)
|
results = tune.run(MyAlgo, config=config, stop=stop)
|
||||||
|
|
||||||
if args.as_test:
|
if args.as_test:
|
||||||
check_learning_achieved(results, args.stop_reward)
|
check_learning_achieved(results, args.stop_reward)
|
||||||
|
|
|
@ -17,7 +17,7 @@ parser.add_argument(
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="Full path to a checkpoint file for restoring a previously saved "
|
help="Full path to a checkpoint file for restoring a previously saved "
|
||||||
"Trainer state.",
|
"Algorithm state.",
|
||||||
)
|
)
|
||||||
parser.add_argument("--num-workers", type=int, default=0)
|
parser.add_argument("--num-workers", type=int, default=0)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|
|
@ -27,7 +27,7 @@ def StandardMetricsReporting(
|
||||||
train_op: Operator for executing training steps.
|
train_op: Operator for executing training steps.
|
||||||
We ignore the output values.
|
We ignore the output values.
|
||||||
workers: Rollout workers to collect metrics from.
|
workers: Rollout workers to collect metrics from.
|
||||||
config: Trainer configuration, used to determine the frequency
|
config: Algorithm configuration, used to determine the frequency
|
||||||
of stats reporting.
|
of stats reporting.
|
||||||
selected_workers: Override the list of remote workers
|
selected_workers: Override the list of remote workers
|
||||||
to collect metrics from.
|
to collect metrics from.
|
||||||
|
|
|
@ -51,7 +51,7 @@ class TestOPE(unittest.TestCase):
|
||||||
.framework("torch")
|
.framework("torch")
|
||||||
.rollouts(batch_mode="complete_episodes")
|
.rollouts(batch_mode="complete_episodes")
|
||||||
)
|
)
|
||||||
cls.trainer = config.build()
|
cls.algo = config.build()
|
||||||
|
|
||||||
# Train DQN for evaluation policy
|
# Train DQN for evaluation policy
|
||||||
tune.run(
|
tune.run(
|
||||||
|
@ -80,7 +80,7 @@ class TestOPE(unittest.TestCase):
|
||||||
done = False
|
done = False
|
||||||
rewards = []
|
rewards = []
|
||||||
while not done:
|
while not done:
|
||||||
act = cls.trainer.compute_single_action(obs)
|
act = cls.algo.compute_single_action(obs)
|
||||||
obs, reward, done, _ = env.step(act)
|
obs, reward, done, _ = env.step(act)
|
||||||
rewards.append(reward)
|
rewards.append(reward)
|
||||||
ret = 0
|
ret = 0
|
||||||
|
@ -105,7 +105,7 @@ class TestOPE(unittest.TestCase):
|
||||||
name = "is"
|
name = "is"
|
||||||
estimator = ImportanceSampling(
|
estimator = ImportanceSampling(
|
||||||
name=name,
|
name=name,
|
||||||
policy=self.trainer.get_policy(),
|
policy=self.algo.get_policy(),
|
||||||
gamma=self.gamma,
|
gamma=self.gamma,
|
||||||
)
|
)
|
||||||
estimator.process(self.batch)
|
estimator.process(self.batch)
|
||||||
|
@ -118,7 +118,7 @@ class TestOPE(unittest.TestCase):
|
||||||
name = "wis"
|
name = "wis"
|
||||||
estimator = WeightedImportanceSampling(
|
estimator = WeightedImportanceSampling(
|
||||||
name=name,
|
name=name,
|
||||||
policy=self.trainer.get_policy(),
|
policy=self.algo.get_policy(),
|
||||||
gamma=self.gamma,
|
gamma=self.gamma,
|
||||||
)
|
)
|
||||||
estimator.process(self.batch)
|
estimator.process(self.batch)
|
||||||
|
@ -131,7 +131,7 @@ class TestOPE(unittest.TestCase):
|
||||||
name = "dm_qreg"
|
name = "dm_qreg"
|
||||||
estimator = DirectMethod(
|
estimator = DirectMethod(
|
||||||
name=name,
|
name=name,
|
||||||
policy=self.trainer.get_policy(),
|
policy=self.algo.get_policy(),
|
||||||
gamma=self.gamma,
|
gamma=self.gamma,
|
||||||
q_model_type="qreg",
|
q_model_type="qreg",
|
||||||
**self.model_config,
|
**self.model_config,
|
||||||
|
@ -146,7 +146,7 @@ class TestOPE(unittest.TestCase):
|
||||||
name = "dm_fqe"
|
name = "dm_fqe"
|
||||||
estimator = DirectMethod(
|
estimator = DirectMethod(
|
||||||
name=name,
|
name=name,
|
||||||
policy=self.trainer.get_policy(),
|
policy=self.algo.get_policy(),
|
||||||
gamma=self.gamma,
|
gamma=self.gamma,
|
||||||
q_model_type="fqe",
|
q_model_type="fqe",
|
||||||
**self.model_config,
|
**self.model_config,
|
||||||
|
@ -161,7 +161,7 @@ class TestOPE(unittest.TestCase):
|
||||||
name = "dr_qreg"
|
name = "dr_qreg"
|
||||||
estimator = DoublyRobust(
|
estimator = DoublyRobust(
|
||||||
name=name,
|
name=name,
|
||||||
policy=self.trainer.get_policy(),
|
policy=self.algo.get_policy(),
|
||||||
gamma=self.gamma,
|
gamma=self.gamma,
|
||||||
q_model_type="qreg",
|
q_model_type="qreg",
|
||||||
**self.model_config,
|
**self.model_config,
|
||||||
|
@ -176,7 +176,7 @@ class TestOPE(unittest.TestCase):
|
||||||
name = "dr_fqe"
|
name = "dr_fqe"
|
||||||
estimator = DoublyRobust(
|
estimator = DoublyRobust(
|
||||||
name=name,
|
name=name,
|
||||||
policy=self.trainer.get_policy(),
|
policy=self.algo.get_policy(),
|
||||||
gamma=self.gamma,
|
gamma=self.gamma,
|
||||||
q_model_type="fqe",
|
q_model_type="fqe",
|
||||||
**self.model_config,
|
**self.model_config,
|
||||||
|
@ -187,7 +187,7 @@ class TestOPE(unittest.TestCase):
|
||||||
self.mean_ret[name] = np.mean([e.metrics["v_new"] for e in estimates])
|
self.mean_ret[name] = np.mean([e.metrics["v_new"] for e in estimates])
|
||||||
self.std_ret[name] = np.std([e.metrics["v_new"] for e in estimates])
|
self.std_ret[name] = np.std([e.metrics["v_new"] for e in estimates])
|
||||||
|
|
||||||
def test_ope_in_trainer(self):
|
def test_ope_in_algo(self):
|
||||||
# TODO (rohan): Add performance tests for off_policy_estimation_methods,
|
# TODO (rohan): Add performance tests for off_policy_estimation_methods,
|
||||||
# with fixed seeds and hyperparameters
|
# with fixed seeds and hyperparameters
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -294,7 +294,7 @@ def _build_eager_tf_policy(
|
||||||
much simpler, but has lower performance.
|
much simpler, but has lower performance.
|
||||||
|
|
||||||
You shouldn't need to call this directly. Rather, prefer to build a TF
|
You shouldn't need to call this directly. Rather, prefer to build a TF
|
||||||
graph policy and use set {"framework": "tfe"} in the trainer config to have
|
graph policy and use set {"framework": "tfe"} in the Algorithm's config to have
|
||||||
it automatically be converted to an eager policy.
|
it automatically be converted to an eager policy.
|
||||||
|
|
||||||
This has the same signature as build_tf_policy()."""
|
This has the same signature as build_tf_policy()."""
|
||||||
|
|
|
@ -78,7 +78,7 @@ class EntropyCoeffSchedule:
|
||||||
class KLCoeffMixin:
|
class KLCoeffMixin:
|
||||||
"""Assigns the `update_kl()` method to a TorchPolicy.
|
"""Assigns the `update_kl()` method to a TorchPolicy.
|
||||||
|
|
||||||
This is used by Trainers to update the KL coefficient
|
This is used by Algorithms to update the KL coefficient
|
||||||
after each learning step based on `config.kl_target` and
|
after each learning step based on `config.kl_target` and
|
||||||
the measured KL value (from the train_batch).
|
the measured KL value (from the train_batch).
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -7,7 +7,7 @@ if __name__ == "__main__":
|
||||||
# Do not import torch for testing purposes.
|
# Do not import torch for testing purposes.
|
||||||
os.environ["RLLIB_TEST_NO_TORCH_IMPORT"] = "1"
|
os.environ["RLLIB_TEST_NO_TORCH_IMPORT"] = "1"
|
||||||
|
|
||||||
# Test registering (includes importing) all Trainers.
|
# Test registering (includes importing) all Algorithms.
|
||||||
from ray.rllib import _register_all
|
from ray.rllib import _register_all
|
||||||
|
|
||||||
# This should surface any dependency on torch, e.g. inside function
|
# This should surface any dependency on torch, e.g. inside function
|
||||||
|
@ -19,7 +19,7 @@ if __name__ == "__main__":
|
||||||
assert "torch" not in sys.modules, "`torch` initially present, when it shouldn't!"
|
assert "torch" not in sys.modules, "`torch` initially present, when it shouldn't!"
|
||||||
|
|
||||||
# Note: No ray.init(), to test it works without Ray
|
# Note: No ray.init(), to test it works without Ray
|
||||||
trainer = A2C(
|
algo = A2C(
|
||||||
env="CartPole-v0",
|
env="CartPole-v0",
|
||||||
config={
|
config={
|
||||||
"framework": "tf",
|
"framework": "tf",
|
||||||
|
@ -31,7 +31,7 @@ if __name__ == "__main__":
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
trainer.train()
|
algo.train()
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
"torch" not in sys.modules
|
"torch" not in sys.modules
|
||||||
|
|
|
@ -57,10 +57,10 @@ class TestPlacementGroups(unittest.TestCase):
|
||||||
config["env"] = "CartPole-v0"
|
config["env"] = "CartPole-v0"
|
||||||
config["framework"] = "tf"
|
config["framework"] = "tf"
|
||||||
|
|
||||||
# Create a trainer with an overridden default_resource_request
|
# Create an Algorithm with an overridden default_resource_request
|
||||||
# method that returns a PlacementGroupFactory.
|
# method that returns a PlacementGroupFactory.
|
||||||
|
|
||||||
class MyTrainer(PG):
|
class MyAlgo(PG):
|
||||||
@classmethod
|
@classmethod
|
||||||
def default_resource_request(cls, config):
|
def default_resource_request(cls, config):
|
||||||
head_bundle = {"CPU": 1, "GPU": 0}
|
head_bundle = {"CPU": 1, "GPU": 0}
|
||||||
|
@ -70,7 +70,7 @@ class TestPlacementGroups(unittest.TestCase):
|
||||||
strategy=config["placement_strategy"],
|
strategy=config["placement_strategy"],
|
||||||
)
|
)
|
||||||
|
|
||||||
tune.register_trainable("my_trainable", MyTrainer)
|
tune.register_trainable("my_trainable", MyAlgo)
|
||||||
|
|
||||||
global trial_executor
|
global trial_executor
|
||||||
trial_executor = RayTrialExecutor(reuse_actors=False)
|
trial_executor = RayTrialExecutor(reuse_actors=False)
|
||||||
|
|
|
@ -27,11 +27,11 @@ class TestTimeSteps(unittest.TestCase):
|
||||||
obs_batch = np.array([1])
|
obs_batch = np.array([1])
|
||||||
|
|
||||||
for _ in framework_iterator(config):
|
for _ in framework_iterator(config):
|
||||||
trainer = pg.PG(config=config, env=RandomEnv)
|
algo = pg.PG(config=config, env=RandomEnv)
|
||||||
policy = trainer.get_policy()
|
policy = algo.get_policy()
|
||||||
|
|
||||||
for i in range(1, 21):
|
for i in range(1, 21):
|
||||||
trainer.compute_single_action(obs)
|
algo.compute_single_action(obs)
|
||||||
check(policy.global_timestep, i)
|
check(policy.global_timestep, i)
|
||||||
for i in range(1, 21):
|
for i in range(1, 21):
|
||||||
policy.compute_actions(obs_batch)
|
policy.compute_actions(obs_batch)
|
||||||
|
@ -45,7 +45,8 @@ class TestTimeSteps(unittest.TestCase):
|
||||||
for i in range(1, 11):
|
for i in range(1, 11):
|
||||||
policy.compute_actions(obs_batch)
|
policy.compute_actions(obs_batch)
|
||||||
check(policy.global_timestep, i + crazy_timesteps)
|
check(policy.global_timestep, i + crazy_timesteps)
|
||||||
trainer.train()
|
algo.train()
|
||||||
|
algo.stop()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -36,18 +36,18 @@ def PublicAPI(obj):
|
||||||
can expect these APIs to remain stable across RLlib releases.
|
can expect these APIs to remain stable across RLlib releases.
|
||||||
|
|
||||||
Subclasses that inherit from a ``@PublicAPI`` base class can be
|
Subclasses that inherit from a ``@PublicAPI`` base class can be
|
||||||
assumed part of the RLlib public API as well (e.g., all trainer classes
|
assumed part of the RLlib public API as well (e.g., all Algorithm classes
|
||||||
are in public API because Trainer is ``@PublicAPI``).
|
are in public API because Algorithm is ``@PublicAPI``).
|
||||||
|
|
||||||
In addition, you can assume all trainer configurations are part of their
|
In addition, you can assume all algo configurations are part of their
|
||||||
public API as well.
|
public API as well.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> # Indicates that the `Trainer` class is exposed to end users
|
>>> # Indicates that the `Algorithm` class is exposed to end users
|
||||||
>>> # of RLlib and will remain stable across RLlib releases.
|
>>> # of RLlib and will remain stable across RLlib releases.
|
||||||
>>> from ray import tune
|
>>> from ray import tune
|
||||||
>>> @PublicAPI # doctest: +SKIP
|
>>> @PublicAPI # doctest: +SKIP
|
||||||
>>> class Trainer(tune.Trainable): # doctest: +SKIP
|
>>> class Algorithm(tune.Trainable): # doctest: +SKIP
|
||||||
... ... # doctest: +SKIP
|
... ... # doctest: +SKIP
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -110,7 +110,7 @@ def ExperimentalAPI(obj):
|
||||||
def OverrideToImplementCustomLogic(obj):
|
def OverrideToImplementCustomLogic(obj):
|
||||||
"""Users should override this in their sub-classes to implement custom logic.
|
"""Users should override this in their sub-classes to implement custom logic.
|
||||||
|
|
||||||
Used in Trainer and Policy to tag methods that need overriding, e.g.
|
Used in Algorithm and Policy to tag methods that need overriding, e.g.
|
||||||
`Policy.loss()`.
|
`Policy.loss()`.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
@ -132,9 +132,9 @@ def OverrideToImplementCustomLogic_CallToSuperRecommended(obj):
|
||||||
Thereby, it is recommended (but not required) to call the super-class'
|
Thereby, it is recommended (but not required) to call the super-class'
|
||||||
corresponding method.
|
corresponding method.
|
||||||
|
|
||||||
Used in Trainer and Policy to tag methods that need overriding, but the
|
Used in Algorithm and Policy to tag methods that need overriding, but the
|
||||||
super class' method should still be called, e.g.
|
super class' method should still be called, e.g.
|
||||||
`Trainer.setup()`.
|
`Algorithm.setup()`.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> from ray import tune
|
>>> from ray import tune
|
||||||
|
|
|
@ -36,7 +36,7 @@ Suspect = DeveloperAPI(
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def check_memory_leaks(
|
def check_memory_leaks(
|
||||||
trainer,
|
algorithm,
|
||||||
to_check: Optional[Set[str]] = None,
|
to_check: Optional[Set[str]] = None,
|
||||||
repeats: Optional[int] = None,
|
repeats: Optional[int] = None,
|
||||||
max_num_trials: int = 3,
|
max_num_trials: int = 3,
|
||||||
|
@ -49,7 +49,7 @@ def check_memory_leaks(
|
||||||
un-GC'd items to memory.
|
un-GC'd items to memory.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
trainer: The Algorithm instance to test.
|
algorithm: The Algorithm instance to test.
|
||||||
to_check: Set of strings to indentify components to test. Allowed strings
|
to_check: Set of strings to indentify components to test. Allowed strings
|
||||||
are: "env", "policy", "model", "rollout_worker". By default, check all
|
are: "env", "policy", "model", "rollout_worker". By default, check all
|
||||||
of these.
|
of these.
|
||||||
|
@ -62,7 +62,7 @@ def check_memory_leaks(
|
||||||
A defaultdict(list) with keys being the `to_check` strings and values being
|
A defaultdict(list) with keys being the `to_check` strings and values being
|
||||||
lists of Suspect instances that were found.
|
lists of Suspect instances that were found.
|
||||||
"""
|
"""
|
||||||
local_worker = trainer.workers.local_worker()
|
local_worker = algorithm.workers.local_worker()
|
||||||
|
|
||||||
# Which components should we test?
|
# Which components should we test?
|
||||||
to_check = to_check or {"env", "model", "policy", "rollout_worker"}
|
to_check = to_check or {"env", "model", "policy", "rollout_worker"}
|
||||||
|
|
|
@ -12,7 +12,7 @@ NUM_AGENT_STEPS_TRAINED_THIS_ITER = "num_agent_steps_trained_this_iter"
|
||||||
LAST_TARGET_UPDATE_TS = "last_target_update_ts"
|
LAST_TARGET_UPDATE_TS = "last_target_update_ts"
|
||||||
NUM_TARGET_UPDATES = "num_target_updates"
|
NUM_TARGET_UPDATES = "num_target_updates"
|
||||||
|
|
||||||
# Performance timers (keys for Trainer._timers or metrics.timers).
|
# Performance timers (keys for Algorithm._timers or metrics.timers).
|
||||||
TRAINING_ITERATION_TIMER = "training_iteration"
|
TRAINING_ITERATION_TIMER = "training_iteration"
|
||||||
APPLY_GRADS_TIMER = "apply_grad"
|
APPLY_GRADS_TIMER = "apply_grad"
|
||||||
COMPUTE_GRADS_TIMER = "compute_grads"
|
COMPUTE_GRADS_TIMER = "compute_grads"
|
||||||
|
|
Loading…
Add table
Reference in a new issue