[RLlib] More Trainer -> Algorithm renaming cleanups. (#25869)

This commit is contained in:
Sven Mika 2022-06-20 15:54:00 +02:00 committed by GitHub
parent e13cc4088a
commit 96693055bd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
39 changed files with 166 additions and 166 deletions

View file

@ -123,24 +123,24 @@
--test_env=RAY_USE_MULTIPROCESSING_CPU_COUNT=1 --test_env=RAY_USE_MULTIPROCESSING_CPU_COUNT=1
rllib/... rllib/...
- label: ":brain: RLlib: Trainer Tests (generic)" - label: ":brain: RLlib: Algorithm Tests (generic)"
conditions: ["RAY_CI_RLLIB_DIRECTLY_AFFECTED"] conditions: ["RAY_CI_RLLIB_DIRECTLY_AFFECTED"]
commands: commands:
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT - cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT
- RLLIB_TESTING=1 PYTHON=3.7 ./ci/env/install-dependencies.sh - RLLIB_TESTING=1 PYTHON=3.7 ./ci/env/install-dependencies.sh
# Test all tests in the `agents` (soon to be "trainers") dir: # Test all tests in the `algorithms` dir:
- bazel test --config=ci $(./ci/run/bazel_export_options) - bazel test --config=ci $(./ci/run/bazel_export_options)
--build_tests_only --build_tests_only
--test_tag_filters=algorithms_dir_generic,-multi_gpu --test_tag_filters=algorithms_dir_generic,-multi_gpu
--test_env=RAY_USE_MULTIPROCESSING_CPU_COUNT=1 --test_env=RAY_USE_MULTIPROCESSING_CPU_COUNT=1
rllib/... rllib/...
- label: ":brain: RLlib: Trainer Tests (specific algos)" - label: ":brain: RLlib: Algorithm Tests (specific algos)"
conditions: ["RAY_CI_RLLIB_DIRECTLY_AFFECTED"] conditions: ["RAY_CI_RLLIB_DIRECTLY_AFFECTED"]
commands: commands:
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT - cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT
- RLLIB_TESTING=1 PYTHON=3.7 ./ci/env/install-dependencies.sh - RLLIB_TESTING=1 PYTHON=3.7 ./ci/env/install-dependencies.sh
# Test all tests in the `agents` (soon to be "trainers") dir: # Test all tests in the `algorithms` dir:
- bazel test --config=ci $(./ci/run/bazel_export_options) - bazel test --config=ci $(./ci/run/bazel_export_options)
--build_tests_only --build_tests_only
--test_tag_filters=algorithms_dir,-algorithms_dir_generic,-multi_gpu --test_tag_filters=algorithms_dir,-algorithms_dir_generic,-multi_gpu

View file

@ -740,7 +740,7 @@ Here is an example of the basic usage (for a more complete example, see `custom_
# NOTE: In order for this to work, your (custom) model needs to implement # NOTE: In order for this to work, your (custom) model needs to implement
# the `import_from_h5` method. # the `import_from_h5` method.
# See https://github.com/ray-project/ray/blob/master/rllib/tests/test_model_imports.py # See https://github.com/ray-project/ray/blob/master/rllib/tests/test_model_imports.py
# for detailed examples for tf- and torch trainers/models. # for detailed examples for tf- and torch policies/models.
.. note:: .. note::
@ -1270,7 +1270,7 @@ Below are some examples of how the custom evaluation metrics are reported nested
Sample output for `python custom_eval.py --custom-eval` Sample output for `python custom_eval.py --custom-eval`
------------------------------------------------------------------------ ------------------------------------------------------------------------
INFO trainer.py:631 -- Running custom eval function <function ...> INFO algorithm.py:631 -- Running custom eval function <function ...>
Update corridor length to 4 Update corridor length to 4
Update corridor length to 7 Update corridor length to 7
Custom evaluation round 1 Custom evaluation round 1

View file

@ -15,7 +15,7 @@
# actions vs continuous actions. # actions vs continuous actions.
# -- "fake_gpus": Tests that run using 2 fake GPUs. # -- "fake_gpus": Tests that run using 2 fake GPUs.
# - Quick agent compilation/tune-train tests, tagged "quick_train". # - Quick algo compilation/tune-train tests, tagged "quick_train".
# NOTE: These should be obsoleted in favor of "algorithms_dir" tests as # NOTE: These should be obsoleted in favor of "algorithms_dir" tests as
# they cover the same functionaliy. # they cover the same functionaliy.
@ -28,7 +28,7 @@
# - `policy` directory tests. # - `policy` directory tests.
# - `utils` directory tests. # - `utils` directory tests.
# - Trainer ("agents") tests, tagged "algorithms_dir". # - Algorithm tests, tagged "algorithms_dir".
# - Tests directory (everything in rllib/tests/...), tagged: "tests_dir" and # - Tests directory (everything in rllib/tests/...), tagged: "tests_dir" and
# "tests_dir_[A-Z]" # "tests_dir_[A-Z]"
@ -65,7 +65,7 @@
load("//bazel:python.bzl", "py_test_module_list") load("//bazel:python.bzl", "py_test_module_list")
# -------------------------------------------------------------------- # --------------------------------------------------------------------
# Agents learning regression tests. # Algorithms learning regression tests.
# #
# Tag: learning_tests # Tag: learning_tests
# #
@ -685,40 +685,41 @@ py_test(
# -------------------------------------------------------------------- # --------------------------------------------------------------------
# Agents (Compilation, Losses, simple agent functionality tests) # Algorithms (Compilation, Losses, simple functionality tests)
# rllib/algorithms/ # rllib/algorithms/
# #
# Tag: algorithms_dir # Tag: algorithms_dir
# -------------------------------------------------------------------- # --------------------------------------------------------------------
# Generic (all Trainers) # Generic (all Algorithms)
py_test(
name = "test_algorithm",
tags = ["team:rllib", "algorithms_dir", "algorithms_dir_generic"],
size = "large",
srcs = ["algorithms/tests/test_algorithm.py"]
)
py_test( py_test(
name = "test_callbacks", name = "test_callbacks",
tags = ["team:rllib", "algorithms_dir", "algorithms_dir_generic"], tags = ["team:rllib", "algorithms_dir", "algorithms_dir_generic"],
size = "medium", size = "medium",
srcs = ["agents/tests/test_callbacks.py"] srcs = ["algorithms/tests/test_callbacks.py"]
) )
py_test( py_test(
name = "test_memory_leaks_generic", name = "test_memory_leaks_generic",
main = "agents/tests/test_memory_leaks.py", main = "algorithms/tests/test_memory_leaks.py",
tags = ["team:rllib", "algorithms_dir"], tags = ["team:rllib", "algorithms_dir"],
size = "large", size = "large",
srcs = ["agents/tests/test_memory_leaks.py"] srcs = ["algorithms/tests/test_memory_leaks.py"]
)
py_test(
name = "test_trainer",
tags = ["team:rllib", "algorithms_dir", "algorithms_dir_generic"],
size = "large",
srcs = ["agents/tests/test_trainer.py"]
) )
py_test( py_test(
name = "tests/test_worker_failures", name = "tests/test_worker_failures",
tags = ["team:rllib", "tests_dir", "algorithms_dir_generic"], tags = ["team:rllib", "tests_dir", "algorithms_dir_generic"],
size = "large", size = "large",
srcs = ["agents/tests/test_worker_failures.py"] srcs = ["algorithms/tests/test_worker_failures.py"]
) )
# Specific Algorithms # Specific Algorithms
@ -809,7 +810,7 @@ py_test(
py_test( py_test(
name = "test_cql", name = "test_cql",
tags = ["team:rllib", "algorithms_dir"], tags = ["team:rllib", "algorithms_dir"],
size = "medium", size = "large",
srcs = ["algorithms/cql/tests/test_cql.py"] srcs = ["algorithms/cql/tests/test_cql.py"]
) )
@ -982,7 +983,7 @@ py_test(
) )
# -------------------------------------------------------------------- # --------------------------------------------------------------------
# contrib Agents # contrib Algorithms
# -------------------------------------------------------------------- # --------------------------------------------------------------------
py_test( py_test(
@ -1071,7 +1072,7 @@ py_test(
) )
# -------------------------------------------------------------------- # --------------------------------------------------------------------
# Agents (quick training test iterations via `rllib train`) # Algorithms (quick training test iterations via `rllib train`)
# #
# Tag: quick_train # Tag: quick_train
# #

View file

@ -30,11 +30,12 @@ class TestAlphaZero(unittest.TestCase):
# Only working for torch right now. # Only working for torch right now.
for _ in framework_iterator(config, frameworks="torch"): for _ in framework_iterator(config, frameworks="torch"):
trainer = config.build() algo = config.build()
for i in range(num_iterations): for i in range(num_iterations):
results = trainer.train() results = algo.train()
check_train_results(results) check_train_results(results)
print(results) print(results)
algo.stop()
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -28,23 +28,23 @@ class TestAPPO(unittest.TestCase):
for _ in framework_iterator(config, with_eager_tracing=True): for _ in framework_iterator(config, with_eager_tracing=True):
print("w/o v-trace") print("w/o v-trace")
config.vtrace = False config.vtrace = False
trainer = config.build(env="CartPole-v0") algo = config.build(env="CartPole-v0")
for i in range(num_iterations): for i in range(num_iterations):
results = trainer.train() results = algo.train()
check_train_results(results) check_train_results(results)
print(results) print(results)
check_compute_single_action(trainer) check_compute_single_action(algo)
trainer.stop() algo.stop()
print("w/ v-trace") print("w/ v-trace")
config.vtrace = True config.vtrace = True
trainer = config.build(env="CartPole-v0") algo = config.build(env="CartPole-v0")
for i in range(num_iterations): for i in range(num_iterations):
results = trainer.train() results = algo.train()
check_train_results(results) check_train_results(results)
print(results) print(results)
check_compute_single_action(trainer) check_compute_single_action(algo)
trainer.stop() algo.stop()
def test_appo_compilation_use_kl_loss(self): def test_appo_compilation_use_kl_loss(self):
"""Test whether APPO can be built with kl_loss enabled.""" """Test whether APPO can be built with kl_loss enabled."""
@ -54,13 +54,13 @@ class TestAPPO(unittest.TestCase):
num_iterations = 2 num_iterations = 2
for _ in framework_iterator(config, with_eager_tracing=True): for _ in framework_iterator(config, with_eager_tracing=True):
trainer = config.build(env="CartPole-v0") algo = config.build(env="CartPole-v0")
for i in range(num_iterations): for i in range(num_iterations):
results = trainer.train() results = algo.train()
check_train_results(results) check_train_results(results)
print(results) print(results)
check_compute_single_action(trainer) check_compute_single_action(algo)
trainer.stop() algo.stop()
def test_appo_two_tf_optimizers(self): def test_appo_two_tf_optimizers(self):
# Not explicitly setting this should cause a warning, but not fail. # Not explicitly setting this should cause a warning, but not fail.
@ -78,13 +78,13 @@ class TestAPPO(unittest.TestCase):
# Only supported for tf so far. # Only supported for tf so far.
for _ in framework_iterator(config, frameworks=("tf2", "tf")): for _ in framework_iterator(config, frameworks=("tf2", "tf")):
trainer = config.build(env="CartPole-v0") algo = config.build(env="CartPole-v0")
for i in range(num_iterations): for i in range(num_iterations):
results = trainer.train() results = algo.train()
check_train_results(results) check_train_results(results)
print(results) print(results)
check_compute_single_action(trainer) check_compute_single_action(algo)
trainer.stop() algo.stop()
def test_appo_entropy_coeff_schedule(self): def test_appo_entropy_coeff_schedule(self):
# Initial lr, doesn't really matter because of the schedule below. # Initial lr, doesn't really matter because of the schedule below.
@ -113,33 +113,33 @@ class TestAPPO(unittest.TestCase):
# which entropy coeff depends on, is updated after each worker rollout. # which entropy coeff depends on, is updated after each worker rollout.
config.min_time_s_per_iteration = 0 config.min_time_s_per_iteration = 0
def _step_n_times(trainer, n: int): def _step_n_times(algo, n: int):
"""Step trainer n times. """Step Algorithm n times.
Returns: Returns:
learning rate at the end of the execution. learning rate at the end of the execution.
""" """
for _ in range(n): for _ in range(n):
results = trainer.train() results = algo.train()
print(trainer.workers.local_worker().global_vars) print(algo.workers.local_worker().global_vars)
print(results) print(results)
return results["info"][LEARNER_INFO][DEFAULT_POLICY_ID][LEARNER_STATS_KEY][ return results["info"][LEARNER_INFO][DEFAULT_POLICY_ID][LEARNER_STATS_KEY][
"entropy_coeff" "entropy_coeff"
] ]
for _ in framework_iterator(config): for _ in framework_iterator(config):
trainer = config.build(env="CartPole-v0") algo = config.build(env="CartPole-v0")
coeff = _step_n_times(trainer, 10) # 200 timesteps coeff = _step_n_times(algo, 10) # 200 timesteps
# Should be close to the starting coeff of 0.01. # Should be close to the starting coeff of 0.01.
self.assertLessEqual(coeff, 0.01) self.assertLessEqual(coeff, 0.01)
self.assertGreaterEqual(coeff, 0.001) self.assertGreaterEqual(coeff, 0.001)
coeff = _step_n_times(trainer, 20) # 400 timesteps coeff = _step_n_times(algo, 20) # 400 timesteps
# Should have annealed to the final coeff of 0.0001. # Should have annealed to the final coeff of 0.0001.
self.assertLessEqual(coeff, 0.001) self.assertLessEqual(coeff, 0.001)
trainer.stop() algo.stop()
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -33,13 +33,13 @@ class TestARS(unittest.TestCase):
num_iterations = 2 num_iterations = 2
for _ in framework_iterator(config): for _ in framework_iterator(config):
trainer = config.build(env="CartPole-v0") algo = config.build(env="CartPole-v0")
for i in range(num_iterations): for i in range(num_iterations):
results = trainer.train() results = algo.train()
print(results) print(results)
check_compute_single_action(trainer) check_compute_single_action(algo)
trainer.stop() algo.stop()
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -29,13 +29,13 @@ class TestES(unittest.TestCase):
for _ in framework_iterator(config): for _ in framework_iterator(config):
for env in ["CartPole-v0", "Pendulum-v1"]: for env in ["CartPole-v0", "Pendulum-v1"]:
trainer = config.build(env=env) algo = config.build(env=env)
for i in range(num_iterations): for i in range(num_iterations):
results = trainer.train() results = algo.train()
print(results) print(results)
check_compute_single_action(trainer) check_compute_single_action(algo)
trainer.stop() algo.stop()
ray.shutdown() ray.shutdown()

View file

@ -37,8 +37,8 @@ class MARWILConfig(AlgorithmConfig):
... .offline_data(input_=["./rllib/tests/data/cartpole/large.json"]) ... .offline_data(input_=["./rllib/tests/data/cartpole/large.json"])
>>> print(config.to_dict()) >>> print(config.to_dict())
>>> # Build a Algorithm object from the config and run 1 training iteration. >>> # Build a Algorithm object from the config and run 1 training iteration.
>>> trainer = config.build() >>> algo = config.build()
>>> trainer.train() >>> algo.train()
Example: Example:
>>> from ray.rllib.algorithms.marwil import MARWILConfig >>> from ray.rllib.algorithms.marwil import MARWILConfig

View file

@ -30,9 +30,9 @@ class R2D2Config(DQNConfig):
>>> .resources(num_gpus=1)\ >>> .resources(num_gpus=1)\
>>> .rollouts(num_rollout_workers=30)\ >>> .rollouts(num_rollout_workers=30)\
>>> .environment("CartPole-v1") >>> .environment("CartPole-v1")
>>> trainer = R2D2(config=config) >>> algo = R2D2(config=config)
>>> while True: >>> while True:
>>> trainer.train() >>> algo.train()
Example: Example:
>>> from ray.rllib.algorithms.r2d2.r2d2 import R2D2Config >>> from ray.rllib.algorithms.r2d2.r2d2 import R2D2Config
@ -170,8 +170,6 @@ class R2D2Config(DQNConfig):
return self return self
# Build an R2D2 trainer, which uses the framework specific Policy
# determined in `get_policy_class()` above.
class R2D2(DQN): class R2D2(DQN):
"""Recurrent Experience Replay in Distrib. Reinforcement Learning (R2D2). """Recurrent Experience Replay in Distrib. Reinforcement Learning (R2D2).

View file

@ -78,14 +78,14 @@ class TestR2D2(unittest.TestCase):
# Test building an R2D2 agent in all frameworks. # Test building an R2D2 agent in all frameworks.
for _ in framework_iterator(config, with_eager_tracing=True): for _ in framework_iterator(config, with_eager_tracing=True):
trainer = config.build(env="CartPole-v0") algo = config.build(env="CartPole-v0")
for i in range(num_iterations): for i in range(num_iterations):
results = trainer.train() results = algo.train()
check_train_results(results) check_train_results(results)
check_batch_sizes(results) check_batch_sizes(results)
print(results) print(results)
check_compute_single_action(trainer, include_state=True) check_compute_single_action(algo, include_state=True)
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -29,8 +29,8 @@ class SACConfig(AlgorithmConfig):
... .rollouts(num_rollout_workers=4) ... .rollouts(num_rollout_workers=4)
>>> print(config.to_dict()) >>> print(config.to_dict())
>>> # Build a Algorithm object from the config and run 1 training iteration. >>> # Build a Algorithm object from the config and run 1 training iteration.
>>> trainer = config.build(env="CartPole-v1") >>> algo = config.build(env="CartPole-v1")
>>> trainer.train() >>> algo.train()
""" """
def __init__(self, algo_class=None): def __init__(self, algo_class=None):

View file

@ -18,8 +18,8 @@ class TD3Config(DDPGConfig):
>>> config = TD3Config().training(lr=0.01).resources(num_gpus=1) >>> config = TD3Config().training(lr=0.01).resources(num_gpus=1)
>>> print(config.to_dict()) >>> print(config.to_dict())
>>> # Build a Algorithm object from the config and run one training iteration. >>> # Build a Algorithm object from the config and run one training iteration.
>>> trainer = config.build(env="Pendulum-v1") >>> algo = config.build(env="Pendulum-v1")
>>> trainer.train() >>> algo.train()
Example: Example:
>>> from ray.rllib.algorithms.ddpg.td3 import TD3Config >>> from ray.rllib.algorithms.ddpg.td3 import TD3Config

View file

@ -38,10 +38,10 @@ class TestAlgorithm(unittest.TestCase):
algo = pg.PG(env="CartPole-v0", config=standard_config) algo = pg.PG(env="CartPole-v0", config=standard_config)
# When (we validate config 2 times). # When (we validate config 2 times).
# Try deprecated `Trainer._validate_config()` method (static). # Try deprecated `Algorithm._validate_config()` method (static).
algo._validate_config(standard_config, algo) algo._validate_config(standard_config, algo)
config_v1 = copy.deepcopy(standard_config) config_v1 = copy.deepcopy(standard_config)
# Try new method: `Trainer.validate_config()` (non-static). # Try new method: `Algorithm.validate_config()` (non-static).
algo.validate_config(standard_config) algo.validate_config(standard_config)
config_v2 = copy.deepcopy(standard_config) config_v2 = copy.deepcopy(standard_config)
@ -239,7 +239,7 @@ class TestAlgorithm(unittest.TestCase):
algo_wo_env_on_driver.stop() algo_wo_env_on_driver.stop()
# Try again using `create_env_on_driver=True`. # Try again using `create_env_on_driver=True`.
# This force-adds the env on the local-worker, so this Trainer # This force-adds the env on the local-worker, so this Algorithm
# can `evaluate` even though it doesn't have an evaluation-worker # can `evaluate` even though it doesn't have an evaluation-worker
# set. # set.
config.create_env_on_local_worker = True config.create_env_on_local_worker = True

View file

@ -47,13 +47,13 @@ class TestCallbacks(unittest.TestCase):
config = dict(base_config, callbacks=callbacks) config = dict(base_config, callbacks=callbacks)
for _ in framework_iterator(config, frameworks=("tf", "torch")): for _ in framework_iterator(config, frameworks=("tf", "torch")):
trainer = dqn.DQN(config=config) algo = dqn.DQN(config=config)
# Fake the counter on the local worker (doesn't have an env) and # Fake the counter on the local worker (doesn't have an env) and
# set it to -1 so the below `foreach_worker()` won't fail. # set it to -1 so the below `foreach_worker()` won't fail.
trainer.workers.local_worker().sum_sub_env_vector_indices = -1 algo.workers.local_worker().sum_sub_env_vector_indices = -1
# Get sub-env vector index sums from the 2 remote workers: # Get sub-env vector index sums from the 2 remote workers:
sum_sub_env_vector_indices = trainer.workers.foreach_worker( sum_sub_env_vector_indices = algo.workers.foreach_worker(
lambda w: w.sum_sub_env_vector_indices lambda w: w.sum_sub_env_vector_indices
) )
# Local worker has no environments -> Expect the -1 special # Local worker has no environments -> Expect the -1 special
@ -63,7 +63,7 @@ class TestCallbacks(unittest.TestCase):
# of 6 (sum of vector indices: 0 + 1 + 2 + 3). # of 6 (sum of vector indices: 0 + 1 + 2 + 3).
self.assertTrue(sum_sub_env_vector_indices[1] == 6) self.assertTrue(sum_sub_env_vector_indices[1] == 6)
self.assertTrue(sum_sub_env_vector_indices[2] == 6) self.assertTrue(sum_sub_env_vector_indices[2] == 6)
trainer.stop() algo.stop()
def test_on_sub_environment_created_with_remote_envs(self): def test_on_sub_environment_created_with_remote_envs(self):
base_config = { base_config = {
@ -84,13 +84,13 @@ class TestCallbacks(unittest.TestCase):
config = dict(base_config, callbacks=callbacks) config = dict(base_config, callbacks=callbacks)
for _ in framework_iterator(config, frameworks=("tf", "torch")): for _ in framework_iterator(config, frameworks=("tf", "torch")):
trainer = dqn.DQN(config=config) algo = dqn.DQN(config=config)
# Fake the counter on the local worker (doesn't have an env) and # Fake the counter on the local worker (doesn't have an env) and
# set it to -1 so the below `foreach_worker()` won't fail. # set it to -1 so the below `foreach_worker()` won't fail.
trainer.workers.local_worker().sum_sub_env_vector_indices = -1 algo.workers.local_worker().sum_sub_env_vector_indices = -1
# Get sub-env vector index sums from the 2 remote workers: # Get sub-env vector index sums from the 2 remote workers:
sum_sub_env_vector_indices = trainer.workers.foreach_worker( sum_sub_env_vector_indices = algo.workers.foreach_worker(
lambda w: w.sum_sub_env_vector_indices lambda w: w.sum_sub_env_vector_indices
) )
# Local worker has no environments -> Expect the -1 special # Local worker has no environments -> Expect the -1 special
@ -100,7 +100,7 @@ class TestCallbacks(unittest.TestCase):
# of 6 (sum of vector indices: 0 + 1 + 2 + 3). # of 6 (sum of vector indices: 0 + 1 + 2 + 3).
self.assertTrue(sum_sub_env_vector_indices[1] == 6) self.assertTrue(sum_sub_env_vector_indices[1] == 6)
self.assertTrue(sum_sub_env_vector_indices[2] == 6) self.assertTrue(sum_sub_env_vector_indices[2] == 6)
trainer.stop() algo.stop()
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -30,10 +30,10 @@ class TestMemoryLeaks(unittest.TestCase):
config["env_config"] = { config["env_config"] = {
"static_samples": True, "static_samples": True,
} }
trainer = ppo.PPO(config=config) algo = ppo.PPO(config=config)
results = check_memory_leaks(trainer, to_check={"env"}, repeats=150) results = check_memory_leaks(algo, to_check={"env"}, repeats=150)
assert results["env"] assert results["env"]
trainer.stop() algo.stop()
def test_leaky_policy(self): def test_leaky_policy(self):
"""Tests, whether our diagnostics tools can detect leaks in a policy.""" """Tests, whether our diagnostics tools can detect leaks in a policy."""
@ -45,10 +45,10 @@ class TestMemoryLeaks(unittest.TestCase):
config["multiagent"]["policies"] = { config["multiagent"]["policies"] = {
"default_policy": PolicySpec(policy_class=MemoryLeakingPolicy), "default_policy": PolicySpec(policy_class=MemoryLeakingPolicy),
} }
trainer = dqn.DQN(config=config) algo = dqn.DQN(config=config)
results = check_memory_leaks(trainer, to_check={"policy"}, repeats=300) results = check_memory_leaks(algo, to_check={"policy"}, repeats=300)
assert results["policy"] assert results["policy"]
trainer.stop() algo.stop()
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -12,7 +12,7 @@ from ray.rllib.connectors.connector import (
from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.typing import ( from ray.rllib.utils.typing import (
ActionConnectorDataType, ActionConnectorDataType,
TrainerConfigDict, AlgorithmConfigDict,
) )
@ -50,8 +50,8 @@ register_connector(ActionConnectorPipeline.__name__, ActionConnectorPipeline)
@DeveloperAPI @DeveloperAPI
def get_action_connectors_from_trainer_config( def get_action_connectors_from_algorithm_config(
config: TrainerConfigDict, action_space: gym.Space config: AlgorithmConfigDict, action_space: gym.Space
) -> ActionConnectorPipeline: ) -> ActionConnectorPipeline:
connectors = [] connectors = []
return ActionConnectorPipeline(connectors) return ActionConnectorPipeline(connectors)

View file

@ -15,7 +15,7 @@ from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.typing import ( from ray.rllib.utils.typing import (
ActionConnectorDataType, ActionConnectorDataType,
AgentConnectorDataType, AgentConnectorDataType,
TrainerConfigDict, AlgorithmConfigDict,
) )
@ -67,7 +67,7 @@ register_connector(AgentConnectorPipeline.__name__, AgentConnectorPipeline)
# TODO(jungong) : finish this. # TODO(jungong) : finish this.
@DeveloperAPI @DeveloperAPI
def get_agent_connectors_from_config( def get_agent_connectors_from_config(
config: TrainerConfigDict, obs_space: gym.Space config: AlgorithmConfigDict, obs_space: gym.Space
) -> AgentConnectorPipeline: ) -> AgentConnectorPipeline:
connectors = [FlattenDataAgentConnector()] connectors = [FlattenDataAgentConnector()]

View file

@ -13,8 +13,8 @@ from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.typing import ( from ray.rllib.utils.typing import (
ActionConnectorDataType, ActionConnectorDataType,
AgentConnectorDataType, AgentConnectorDataType,
AlgorithmConfigDict,
TensorType, TensorType,
TrainerConfigDict,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -34,7 +34,7 @@ class ConnectorContext:
def __init__( def __init__(
self, self,
config: TrainerConfigDict = None, config: AlgorithmConfigDict = None,
model_initial_states: List[TensorType] = None, model_initial_states: List[TensorType] = None,
observation_space: gym.Space = None, observation_space: gym.Space = None,
action_space: gym.Space = None, action_space: gym.Space = None,

View file

@ -30,7 +30,7 @@ class MultiAgentEnv(gym.Env):
"""An environment that hosts multiple independent agents. """An environment that hosts multiple independent agents.
Agents are identified by (string) agent ids. Note that these "agents" here Agents are identified by (string) agent ids. Note that these "agents" here
are not to be confused with RLlib Trainers, which are also sometimes are not to be confused with RLlib Algorithms, which are also sometimes
referred to as "agents" or "RL agents". referred to as "agents" or "RL agents".
""" """

View file

@ -168,16 +168,16 @@ class TestTrajectoryViewAPI(unittest.TestCase):
config["env_config"] = {"config": {"start_at_t": 1}} # first obs is [1.0] config["env_config"] = {"config": {"start_at_t": 1}} # first obs is [1.0]
for _ in framework_iterator(config, frameworks="tf2"): for _ in framework_iterator(config, frameworks="tf2"):
trainer = ppo.PPO( algo = ppo.PPO(
config, config,
env="ray.rllib.examples.env.debug_counter_env.DebugCounterEnv", env="ray.rllib.examples.env.debug_counter_env.DebugCounterEnv",
) )
rw = trainer.workers.local_worker() rw = algo.workers.local_worker()
sample = rw.sample() sample = rw.sample()
assert sample.count == trainer.config["rollout_fragment_length"] assert sample.count == algo.config["rollout_fragment_length"]
results = trainer.train() results = algo.train()
assert results["timesteps_total"] == config["train_batch_size"] assert results["timesteps_total"] == config["train_batch_size"]
trainer.stop() algo.stop()
def test_traj_view_next_action(self): def test_traj_view_next_action(self):
action_space = Discrete(2) action_space = Discrete(2)
@ -341,10 +341,10 @@ class TestTrajectoryViewAPI(unittest.TestCase):
config["env_config"] = {"num_agents": num_agents} config["env_config"] = {"num_agents": num_agents}
num_iterations = 2 num_iterations = 2
trainer = ppo.PPO(config=config) algo = ppo.PPO(config=config)
results = None results = None
for i in range(num_iterations): for i in range(num_iterations):
results = trainer.train() results = algo.train()
self.assertEqual(results["agent_timesteps_total"], results["timesteps_total"]) self.assertEqual(results["agent_timesteps_total"], results["timesteps_total"])
self.assertEqual( self.assertEqual(
results["num_env_steps_trained"] * num_agents, results["num_env_steps_trained"] * num_agents,
@ -358,7 +358,7 @@ class TestTrajectoryViewAPI(unittest.TestCase):
results["agent_timesteps_total"], results["agent_timesteps_total"],
(num_iterations + 1) * config["train_batch_size"], (num_iterations + 1) * config["train_batch_size"],
) )
trainer.stop() algo.stop()
def test_get_single_step_input_dict_batch_repeat_value_larger_1(self): def test_get_single_step_input_dict_batch_repeat_value_larger_1(self):
"""Test whether a SampleBatch produces the correct 1-step input dict.""" """Test whether a SampleBatch produces the correct 1-step input dict."""

View file

@ -81,14 +81,14 @@ if __name__ == "__main__":
"episode_reward_mean": args.stop_reward, "episode_reward_mean": args.stop_reward,
} }
# To run the Trainer without tune.run, using our LSTM model and # To run the Algorithm without tune.run, using our LSTM model and
# manual state-in handling, do the following: # manual state-in handling, do the following:
# Example (use `config` from the above code): # Example (use `config` from the above code):
# >> import numpy as np # >> import numpy as np
# >> from ray.rllib.algorithms.ppo import PPO # >> from ray.rllib.algorithms.ppo import PPO
# >> # >>
# >> trainer = PPO(config) # >> algo = PPO(config)
# >> lstm_cell_size = config["model"]["lstm_cell_size"] # >> lstm_cell_size = config["model"]["lstm_cell_size"]
# >> env = StatelessCartPole() # >> env = StatelessCartPole()
# >> obs = env.reset() # >> obs = env.reset()
@ -101,7 +101,7 @@ if __name__ == "__main__":
# >> prev_r = 0.0 # >> prev_r = 0.0
# >> # >>
# >> while True: # >> while True:
# >> a, state_out, _ = trainer.compute_single_action( # >> a, state_out, _ = algo.compute_single_action(
# .. obs, state, prev_a, prev_r) # .. obs, state, prev_a, prev_r)
# >> obs, reward, done, _ = env.step(a) # >> obs, reward, done, _ = env.step(a)
# >> if done: # >> if done:

View file

@ -92,8 +92,8 @@ MyTFPolicy = build_tf_policy(
) )
# Create a new Trainer using the Policy defined above. # Create a new Algorithm using the Policy defined above.
class MyTrainer(Algorithm): class MyAlgo(Algorithm):
def get_default_policy_class(self, config): def get_default_policy_class(self, config):
return MyTFPolicy return MyTFPolicy
@ -117,7 +117,7 @@ if __name__ == "__main__":
"episode_reward_mean": args.stop_reward, "episode_reward_mean": args.stop_reward,
} }
results = tune.run(MyTrainer, stop=stop, config=config, verbose=1) results = tune.run(MyAlgo, stop=stop, config=config, verbose=1)
if args.as_test: if args.as_test:
check_learning_achieved(results, args.stop_reward) check_learning_achieved(results, args.stop_reward)

View file

@ -83,11 +83,11 @@ if __name__ == "__main__":
min_reward = -300 min_reward = -300
# Test for torch framework (tf not implemented yet). # Test for torch framework (tf not implemented yet).
trainer = cql.CQL(config=config) algo = cql.CQL(config=config)
learnt = False learnt = False
for i in range(num_iterations): for i in range(num_iterations):
print(f"Iter {i}") print(f"Iter {i}")
eval_results = trainer.train().get("evaluation") eval_results = algo.train().get("evaluation")
if eval_results: if eval_results:
print("... R={}".format(eval_results["episode_reward_mean"])) print("... R={}".format(eval_results["episode_reward_mean"]))
# Learn until some reward is reached on an actual live env. # Learn until some reward is reached on an actual live env.
@ -101,7 +101,7 @@ if __name__ == "__main__":
) )
# Get policy, model, and replay-buffer. # Get policy, model, and replay-buffer.
pol = trainer.get_policy() pol = algo.get_policy()
cql_model = pol.model cql_model = pol.model
from ray.rllib.algorithms.cql.cql import replay_buffer from ray.rllib.algorithms.cql.cql import replay_buffer
@ -116,7 +116,7 @@ if __name__ == "__main__":
final_q_values = torch.min(q_values, twin_q_values) final_q_values = torch.min(q_values, twin_q_values)
print(final_q_values) print(final_q_values)
# Example on how to do evaluation on the trained Trainer # Example on how to do evaluation on the trained Algorithm.
# using the data from our buffer. # using the data from our buffer.
# Get a sample (MultiAgentBatch). # Get a sample (MultiAgentBatch).
multi_agent_batch = replay_buffer.sample(num_items=config["train_batch_size"]) multi_agent_batch = replay_buffer.sample(num_items=config["train_batch_size"])
@ -128,11 +128,10 @@ if __name__ == "__main__":
model_out, _ = cql_model({"obs": obs}) model_out, _ = cql_model({"obs": obs})
# The estimated Q-values from the (historic) actions in the batch. # The estimated Q-values from the (historic) actions in the batch.
q_values_old = cql_model.get_q_values(model_out, torch.from_numpy(batch["actions"])) q_values_old = cql_model.get_q_values(model_out, torch.from_numpy(batch["actions"]))
# The estimated Q-values for the new actions computed # The estimated Q-values for the new actions computed by our policy.
# by our trainer policy.
actions_new = pol.compute_actions_from_input_dict({"obs": obs})[0] actions_new = pol.compute_actions_from_input_dict({"obs": obs})[0]
q_values_new = cql_model.get_q_values(model_out, torch.from_numpy(actions_new)) q_values_new = cql_model.get_q_values(model_out, torch.from_numpy(actions_new))
print(f"Q-val batch={q_values_old}") print(f"Q-val batch={q_values_old}")
print(f"Q-val policy={q_values_new}") print(f"Q-val policy={q_values_new}")
trainer.stop() algo.stop()

View file

@ -58,10 +58,10 @@ class RandomParametricPolicy(Policy, ABC):
pass pass
class RandomParametricTrainer(Algorithm): class RandomParametricAlgorithm(Algorithm):
"""Algo with Policy and config defined above and overriding `training_iteration`. """Algo with Policy and config defined above and overriding `training_step`.
Overrides the `training_iteration` method, which only runs a (dummy) Overrides the `training_step` method, which only runs a (dummy)
rollout and performs no learning. rollout and performs no learning.
""" """
@ -79,7 +79,7 @@ class RandomParametricTrainer(Algorithm):
def main(): def main():
register_env("pa_cartpole", lambda _: ParametricActionsCartPole(10)) register_env("pa_cartpole", lambda _: ParametricActionsCartPole(10))
algo = RandomParametricTrainer(env="pa_cartpole") algo = RandomParametricAlgorithm(env="pa_cartpole")
result = algo.train() result = algo.train()
assert result["episode_reward_mean"] > 10, result assert result["episode_reward_mean"] > 10, result
print("Test: OK") print("Test: OK")

View file

@ -75,10 +75,10 @@ def get_cli_args():
return args return args
# The modified Trainer class we will use. This is the exact same # The modified Algorithm class we will use:
# as a PPO, but with the additional default_resource_request # Subclassing from PPO, our algo will only modity `default_resource_request`,
# override, telling tune that it's ok (not mandatory) to place our # telling Ray Tune that it's ok (not mandatory) to place our n remote envs on a
# n remote envs on a different node (each env using 1 CPU). # different node (each env using 1 CPU).
class PPORemoteInference(PPO): class PPORemoteInference(PPO):
@classmethod @classmethod
@override(Algorithm) @override(Algorithm)
@ -145,7 +145,7 @@ if __name__ == "__main__":
): ):
break break
# Run with Tune for auto env and trainer creation and TensorBoard. # Run with Tune for auto env and algorithm creation and TensorBoard.
else: else:
stop = { stop = {
"training_iteration": args.stop_iters, "training_iteration": args.stop_iters,

View file

@ -64,12 +64,12 @@ parser.add_argument(
) )
# Define new Trainer with custom execution_plan/workflow. # Define new Algorithm with custom execution_plan/workflow.
class MyTrainer(Algorithm): class MyAlgo(Algorithm):
@classmethod @classmethod
@override(Algorithm) @override(Algorithm)
def get_default_config(cls) -> AlgorithmConfigDict: def get_default_config(cls) -> AlgorithmConfigDict:
# Run this Trainer with new `training_iteration` API and set some PPO-specific # Run this Algorithm with new `training_step` API and set some PPO-specific
# parameters. # parameters.
return with_common_config( return with_common_config(
{ {
@ -218,7 +218,7 @@ if __name__ == "__main__":
"episode_reward_mean": args.stop_reward, "episode_reward_mean": args.stop_reward,
} }
results = tune.run(MyTrainer, config=config, stop=stop) results = tune.run(MyAlgo, config=config, stop=stop)
if args.as_test: if args.as_test:
check_learning_achieved(results, args.stop_reward) check_learning_achieved(results, args.stop_reward)

View file

@ -17,7 +17,7 @@ parser.add_argument(
type=str, type=str,
default=None, default=None,
help="Full path to a checkpoint file for restoring a previously saved " help="Full path to a checkpoint file for restoring a previously saved "
"Trainer state.", "Algorithm state.",
) )
parser.add_argument("--num-workers", type=int, default=0) parser.add_argument("--num-workers", type=int, default=0)
parser.add_argument( parser.add_argument(

View file

@ -27,7 +27,7 @@ def StandardMetricsReporting(
train_op: Operator for executing training steps. train_op: Operator for executing training steps.
We ignore the output values. We ignore the output values.
workers: Rollout workers to collect metrics from. workers: Rollout workers to collect metrics from.
config: Trainer configuration, used to determine the frequency config: Algorithm configuration, used to determine the frequency
of stats reporting. of stats reporting.
selected_workers: Override the list of remote workers selected_workers: Override the list of remote workers
to collect metrics from. to collect metrics from.

View file

@ -51,7 +51,7 @@ class TestOPE(unittest.TestCase):
.framework("torch") .framework("torch")
.rollouts(batch_mode="complete_episodes") .rollouts(batch_mode="complete_episodes")
) )
cls.trainer = config.build() cls.algo = config.build()
# Train DQN for evaluation policy # Train DQN for evaluation policy
tune.run( tune.run(
@ -80,7 +80,7 @@ class TestOPE(unittest.TestCase):
done = False done = False
rewards = [] rewards = []
while not done: while not done:
act = cls.trainer.compute_single_action(obs) act = cls.algo.compute_single_action(obs)
obs, reward, done, _ = env.step(act) obs, reward, done, _ = env.step(act)
rewards.append(reward) rewards.append(reward)
ret = 0 ret = 0
@ -105,7 +105,7 @@ class TestOPE(unittest.TestCase):
name = "is" name = "is"
estimator = ImportanceSampling( estimator = ImportanceSampling(
name=name, name=name,
policy=self.trainer.get_policy(), policy=self.algo.get_policy(),
gamma=self.gamma, gamma=self.gamma,
) )
estimator.process(self.batch) estimator.process(self.batch)
@ -118,7 +118,7 @@ class TestOPE(unittest.TestCase):
name = "wis" name = "wis"
estimator = WeightedImportanceSampling( estimator = WeightedImportanceSampling(
name=name, name=name,
policy=self.trainer.get_policy(), policy=self.algo.get_policy(),
gamma=self.gamma, gamma=self.gamma,
) )
estimator.process(self.batch) estimator.process(self.batch)
@ -131,7 +131,7 @@ class TestOPE(unittest.TestCase):
name = "dm_qreg" name = "dm_qreg"
estimator = DirectMethod( estimator = DirectMethod(
name=name, name=name,
policy=self.trainer.get_policy(), policy=self.algo.get_policy(),
gamma=self.gamma, gamma=self.gamma,
q_model_type="qreg", q_model_type="qreg",
**self.model_config, **self.model_config,
@ -146,7 +146,7 @@ class TestOPE(unittest.TestCase):
name = "dm_fqe" name = "dm_fqe"
estimator = DirectMethod( estimator = DirectMethod(
name=name, name=name,
policy=self.trainer.get_policy(), policy=self.algo.get_policy(),
gamma=self.gamma, gamma=self.gamma,
q_model_type="fqe", q_model_type="fqe",
**self.model_config, **self.model_config,
@ -161,7 +161,7 @@ class TestOPE(unittest.TestCase):
name = "dr_qreg" name = "dr_qreg"
estimator = DoublyRobust( estimator = DoublyRobust(
name=name, name=name,
policy=self.trainer.get_policy(), policy=self.algo.get_policy(),
gamma=self.gamma, gamma=self.gamma,
q_model_type="qreg", q_model_type="qreg",
**self.model_config, **self.model_config,
@ -176,7 +176,7 @@ class TestOPE(unittest.TestCase):
name = "dr_fqe" name = "dr_fqe"
estimator = DoublyRobust( estimator = DoublyRobust(
name=name, name=name,
policy=self.trainer.get_policy(), policy=self.algo.get_policy(),
gamma=self.gamma, gamma=self.gamma,
q_model_type="fqe", q_model_type="fqe",
**self.model_config, **self.model_config,
@ -187,7 +187,7 @@ class TestOPE(unittest.TestCase):
self.mean_ret[name] = np.mean([e.metrics["v_new"] for e in estimates]) self.mean_ret[name] = np.mean([e.metrics["v_new"] for e in estimates])
self.std_ret[name] = np.std([e.metrics["v_new"] for e in estimates]) self.std_ret[name] = np.std([e.metrics["v_new"] for e in estimates])
def test_ope_in_trainer(self): def test_ope_in_algo(self):
# TODO (rohan): Add performance tests for off_policy_estimation_methods, # TODO (rohan): Add performance tests for off_policy_estimation_methods,
# with fixed seeds and hyperparameters # with fixed seeds and hyperparameters
pass pass

View file

@ -294,7 +294,7 @@ def _build_eager_tf_policy(
much simpler, but has lower performance. much simpler, but has lower performance.
You shouldn't need to call this directly. Rather, prefer to build a TF You shouldn't need to call this directly. Rather, prefer to build a TF
graph policy and use set {"framework": "tfe"} in the trainer config to have graph policy and use set {"framework": "tfe"} in the Algorithm's config to have
it automatically be converted to an eager policy. it automatically be converted to an eager policy.
This has the same signature as build_tf_policy().""" This has the same signature as build_tf_policy()."""

View file

@ -78,7 +78,7 @@ class EntropyCoeffSchedule:
class KLCoeffMixin: class KLCoeffMixin:
"""Assigns the `update_kl()` method to a TorchPolicy. """Assigns the `update_kl()` method to a TorchPolicy.
This is used by Trainers to update the KL coefficient This is used by Algorithms to update the KL coefficient
after each learning step based on `config.kl_target` and after each learning step based on `config.kl_target` and
the measured KL value (from the train_batch). the measured KL value (from the train_batch).
""" """

View file

@ -7,7 +7,7 @@ if __name__ == "__main__":
# Do not import torch for testing purposes. # Do not import torch for testing purposes.
os.environ["RLLIB_TEST_NO_TORCH_IMPORT"] = "1" os.environ["RLLIB_TEST_NO_TORCH_IMPORT"] = "1"
# Test registering (includes importing) all Trainers. # Test registering (includes importing) all Algorithms.
from ray.rllib import _register_all from ray.rllib import _register_all
# This should surface any dependency on torch, e.g. inside function # This should surface any dependency on torch, e.g. inside function
@ -19,7 +19,7 @@ if __name__ == "__main__":
assert "torch" not in sys.modules, "`torch` initially present, when it shouldn't!" assert "torch" not in sys.modules, "`torch` initially present, when it shouldn't!"
# Note: No ray.init(), to test it works without Ray # Note: No ray.init(), to test it works without Ray
trainer = A2C( algo = A2C(
env="CartPole-v0", env="CartPole-v0",
config={ config={
"framework": "tf", "framework": "tf",
@ -31,7 +31,7 @@ if __name__ == "__main__":
}, },
}, },
) )
trainer.train() algo.train()
assert ( assert (
"torch" not in sys.modules "torch" not in sys.modules

View file

@ -57,10 +57,10 @@ class TestPlacementGroups(unittest.TestCase):
config["env"] = "CartPole-v0" config["env"] = "CartPole-v0"
config["framework"] = "tf" config["framework"] = "tf"
# Create a trainer with an overridden default_resource_request # Create an Algorithm with an overridden default_resource_request
# method that returns a PlacementGroupFactory. # method that returns a PlacementGroupFactory.
class MyTrainer(PG): class MyAlgo(PG):
@classmethod @classmethod
def default_resource_request(cls, config): def default_resource_request(cls, config):
head_bundle = {"CPU": 1, "GPU": 0} head_bundle = {"CPU": 1, "GPU": 0}
@ -70,7 +70,7 @@ class TestPlacementGroups(unittest.TestCase):
strategy=config["placement_strategy"], strategy=config["placement_strategy"],
) )
tune.register_trainable("my_trainable", MyTrainer) tune.register_trainable("my_trainable", MyAlgo)
global trial_executor global trial_executor
trial_executor = RayTrialExecutor(reuse_actors=False) trial_executor = RayTrialExecutor(reuse_actors=False)

View file

@ -27,11 +27,11 @@ class TestTimeSteps(unittest.TestCase):
obs_batch = np.array([1]) obs_batch = np.array([1])
for _ in framework_iterator(config): for _ in framework_iterator(config):
trainer = pg.PG(config=config, env=RandomEnv) algo = pg.PG(config=config, env=RandomEnv)
policy = trainer.get_policy() policy = algo.get_policy()
for i in range(1, 21): for i in range(1, 21):
trainer.compute_single_action(obs) algo.compute_single_action(obs)
check(policy.global_timestep, i) check(policy.global_timestep, i)
for i in range(1, 21): for i in range(1, 21):
policy.compute_actions(obs_batch) policy.compute_actions(obs_batch)
@ -45,7 +45,8 @@ class TestTimeSteps(unittest.TestCase):
for i in range(1, 11): for i in range(1, 11):
policy.compute_actions(obs_batch) policy.compute_actions(obs_batch)
check(policy.global_timestep, i + crazy_timesteps) check(policy.global_timestep, i + crazy_timesteps)
trainer.train() algo.train()
algo.stop()
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -36,18 +36,18 @@ def PublicAPI(obj):
can expect these APIs to remain stable across RLlib releases. can expect these APIs to remain stable across RLlib releases.
Subclasses that inherit from a ``@PublicAPI`` base class can be Subclasses that inherit from a ``@PublicAPI`` base class can be
assumed part of the RLlib public API as well (e.g., all trainer classes assumed part of the RLlib public API as well (e.g., all Algorithm classes
are in public API because Trainer is ``@PublicAPI``). are in public API because Algorithm is ``@PublicAPI``).
In addition, you can assume all trainer configurations are part of their In addition, you can assume all algo configurations are part of their
public API as well. public API as well.
Examples: Examples:
>>> # Indicates that the `Trainer` class is exposed to end users >>> # Indicates that the `Algorithm` class is exposed to end users
>>> # of RLlib and will remain stable across RLlib releases. >>> # of RLlib and will remain stable across RLlib releases.
>>> from ray import tune >>> from ray import tune
>>> @PublicAPI # doctest: +SKIP >>> @PublicAPI # doctest: +SKIP
>>> class Trainer(tune.Trainable): # doctest: +SKIP >>> class Algorithm(tune.Trainable): # doctest: +SKIP
... ... # doctest: +SKIP ... ... # doctest: +SKIP
""" """
@ -110,7 +110,7 @@ def ExperimentalAPI(obj):
def OverrideToImplementCustomLogic(obj): def OverrideToImplementCustomLogic(obj):
"""Users should override this in their sub-classes to implement custom logic. """Users should override this in their sub-classes to implement custom logic.
Used in Trainer and Policy to tag methods that need overriding, e.g. Used in Algorithm and Policy to tag methods that need overriding, e.g.
`Policy.loss()`. `Policy.loss()`.
Examples: Examples:
@ -132,9 +132,9 @@ def OverrideToImplementCustomLogic_CallToSuperRecommended(obj):
Thereby, it is recommended (but not required) to call the super-class' Thereby, it is recommended (but not required) to call the super-class'
corresponding method. corresponding method.
Used in Trainer and Policy to tag methods that need overriding, but the Used in Algorithm and Policy to tag methods that need overriding, but the
super class' method should still be called, e.g. super class' method should still be called, e.g.
`Trainer.setup()`. `Algorithm.setup()`.
Examples: Examples:
>>> from ray import tune >>> from ray import tune

View file

@ -36,7 +36,7 @@ Suspect = DeveloperAPI(
@DeveloperAPI @DeveloperAPI
def check_memory_leaks( def check_memory_leaks(
trainer, algorithm,
to_check: Optional[Set[str]] = None, to_check: Optional[Set[str]] = None,
repeats: Optional[int] = None, repeats: Optional[int] = None,
max_num_trials: int = 3, max_num_trials: int = 3,
@ -49,7 +49,7 @@ def check_memory_leaks(
un-GC'd items to memory. un-GC'd items to memory.
Args: Args:
trainer: The Algorithm instance to test. algorithm: The Algorithm instance to test.
to_check: Set of strings to indentify components to test. Allowed strings to_check: Set of strings to indentify components to test. Allowed strings
are: "env", "policy", "model", "rollout_worker". By default, check all are: "env", "policy", "model", "rollout_worker". By default, check all
of these. of these.
@ -62,7 +62,7 @@ def check_memory_leaks(
A defaultdict(list) with keys being the `to_check` strings and values being A defaultdict(list) with keys being the `to_check` strings and values being
lists of Suspect instances that were found. lists of Suspect instances that were found.
""" """
local_worker = trainer.workers.local_worker() local_worker = algorithm.workers.local_worker()
# Which components should we test? # Which components should we test?
to_check = to_check or {"env", "model", "policy", "rollout_worker"} to_check = to_check or {"env", "model", "policy", "rollout_worker"}

View file

@ -12,7 +12,7 @@ NUM_AGENT_STEPS_TRAINED_THIS_ITER = "num_agent_steps_trained_this_iter"
LAST_TARGET_UPDATE_TS = "last_target_update_ts" LAST_TARGET_UPDATE_TS = "last_target_update_ts"
NUM_TARGET_UPDATES = "num_target_updates" NUM_TARGET_UPDATES = "num_target_updates"
# Performance timers (keys for Trainer._timers or metrics.timers). # Performance timers (keys for Algorithm._timers or metrics.timers).
TRAINING_ITERATION_TIMER = "training_iteration" TRAINING_ITERATION_TIMER = "training_iteration"
APPLY_GRADS_TIMER = "apply_grad" APPLY_GRADS_TIMER = "apply_grad"
COMPUTE_GRADS_TIMER = "compute_grads" COMPUTE_GRADS_TIMER = "compute_grads"