mirror of
https://github.com/vale981/ray
synced 2025-03-04 17:41:43 -05:00
[RLlib] More Trainer -> Algorithm renaming cleanups. (#25869)
This commit is contained in:
parent
e13cc4088a
commit
96693055bd
39 changed files with 166 additions and 166 deletions
|
@ -123,24 +123,24 @@
|
|||
--test_env=RAY_USE_MULTIPROCESSING_CPU_COUNT=1
|
||||
rllib/...
|
||||
|
||||
- label: ":brain: RLlib: Trainer Tests (generic)"
|
||||
- label: ":brain: RLlib: Algorithm Tests (generic)"
|
||||
conditions: ["RAY_CI_RLLIB_DIRECTLY_AFFECTED"]
|
||||
commands:
|
||||
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT
|
||||
- RLLIB_TESTING=1 PYTHON=3.7 ./ci/env/install-dependencies.sh
|
||||
# Test all tests in the `agents` (soon to be "trainers") dir:
|
||||
# Test all tests in the `algorithms` dir:
|
||||
- bazel test --config=ci $(./ci/run/bazel_export_options)
|
||||
--build_tests_only
|
||||
--test_tag_filters=algorithms_dir_generic,-multi_gpu
|
||||
--test_env=RAY_USE_MULTIPROCESSING_CPU_COUNT=1
|
||||
rllib/...
|
||||
|
||||
- label: ":brain: RLlib: Trainer Tests (specific algos)"
|
||||
- label: ":brain: RLlib: Algorithm Tests (specific algos)"
|
||||
conditions: ["RAY_CI_RLLIB_DIRECTLY_AFFECTED"]
|
||||
commands:
|
||||
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT
|
||||
- RLLIB_TESTING=1 PYTHON=3.7 ./ci/env/install-dependencies.sh
|
||||
# Test all tests in the `agents` (soon to be "trainers") dir:
|
||||
# Test all tests in the `algorithms` dir:
|
||||
- bazel test --config=ci $(./ci/run/bazel_export_options)
|
||||
--build_tests_only
|
||||
--test_tag_filters=algorithms_dir,-algorithms_dir_generic,-multi_gpu
|
||||
|
|
|
@ -740,7 +740,7 @@ Here is an example of the basic usage (for a more complete example, see `custom_
|
|||
# NOTE: In order for this to work, your (custom) model needs to implement
|
||||
# the `import_from_h5` method.
|
||||
# See https://github.com/ray-project/ray/blob/master/rllib/tests/test_model_imports.py
|
||||
# for detailed examples for tf- and torch trainers/models.
|
||||
# for detailed examples for tf- and torch policies/models.
|
||||
|
||||
.. note::
|
||||
|
||||
|
@ -1270,7 +1270,7 @@ Below are some examples of how the custom evaluation metrics are reported nested
|
|||
Sample output for `python custom_eval.py --custom-eval`
|
||||
------------------------------------------------------------------------
|
||||
|
||||
INFO trainer.py:631 -- Running custom eval function <function ...>
|
||||
INFO algorithm.py:631 -- Running custom eval function <function ...>
|
||||
Update corridor length to 4
|
||||
Update corridor length to 7
|
||||
Custom evaluation round 1
|
||||
|
|
39
rllib/BUILD
39
rllib/BUILD
|
@ -15,7 +15,7 @@
|
|||
# actions vs continuous actions.
|
||||
# -- "fake_gpus": Tests that run using 2 fake GPUs.
|
||||
|
||||
# - Quick agent compilation/tune-train tests, tagged "quick_train".
|
||||
# - Quick algo compilation/tune-train tests, tagged "quick_train".
|
||||
# NOTE: These should be obsoleted in favor of "algorithms_dir" tests as
|
||||
# they cover the same functionaliy.
|
||||
|
||||
|
@ -28,7 +28,7 @@
|
|||
# - `policy` directory tests.
|
||||
# - `utils` directory tests.
|
||||
|
||||
# - Trainer ("agents") tests, tagged "algorithms_dir".
|
||||
# - Algorithm tests, tagged "algorithms_dir".
|
||||
|
||||
# - Tests directory (everything in rllib/tests/...), tagged: "tests_dir" and
|
||||
# "tests_dir_[A-Z]"
|
||||
|
@ -65,7 +65,7 @@
|
|||
load("//bazel:python.bzl", "py_test_module_list")
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
# Agents learning regression tests.
|
||||
# Algorithms learning regression tests.
|
||||
#
|
||||
# Tag: learning_tests
|
||||
#
|
||||
|
@ -685,40 +685,41 @@ py_test(
|
|||
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
# Agents (Compilation, Losses, simple agent functionality tests)
|
||||
# Algorithms (Compilation, Losses, simple functionality tests)
|
||||
# rllib/algorithms/
|
||||
#
|
||||
# Tag: algorithms_dir
|
||||
# --------------------------------------------------------------------
|
||||
|
||||
# Generic (all Trainers)
|
||||
# Generic (all Algorithms)
|
||||
|
||||
py_test(
|
||||
name = "test_algorithm",
|
||||
tags = ["team:rllib", "algorithms_dir", "algorithms_dir_generic"],
|
||||
size = "large",
|
||||
srcs = ["algorithms/tests/test_algorithm.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_callbacks",
|
||||
tags = ["team:rllib", "algorithms_dir", "algorithms_dir_generic"],
|
||||
size = "medium",
|
||||
srcs = ["agents/tests/test_callbacks.py"]
|
||||
srcs = ["algorithms/tests/test_callbacks.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_memory_leaks_generic",
|
||||
main = "agents/tests/test_memory_leaks.py",
|
||||
main = "algorithms/tests/test_memory_leaks.py",
|
||||
tags = ["team:rllib", "algorithms_dir"],
|
||||
size = "large",
|
||||
srcs = ["agents/tests/test_memory_leaks.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_trainer",
|
||||
tags = ["team:rllib", "algorithms_dir", "algorithms_dir_generic"],
|
||||
size = "large",
|
||||
srcs = ["agents/tests/test_trainer.py"]
|
||||
srcs = ["algorithms/tests/test_memory_leaks.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tests/test_worker_failures",
|
||||
tags = ["team:rllib", "tests_dir", "algorithms_dir_generic"],
|
||||
size = "large",
|
||||
srcs = ["agents/tests/test_worker_failures.py"]
|
||||
srcs = ["algorithms/tests/test_worker_failures.py"]
|
||||
)
|
||||
|
||||
# Specific Algorithms
|
||||
|
@ -809,7 +810,7 @@ py_test(
|
|||
py_test(
|
||||
name = "test_cql",
|
||||
tags = ["team:rllib", "algorithms_dir"],
|
||||
size = "medium",
|
||||
size = "large",
|
||||
srcs = ["algorithms/cql/tests/test_cql.py"]
|
||||
)
|
||||
|
||||
|
@ -982,7 +983,7 @@ py_test(
|
|||
)
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
# contrib Agents
|
||||
# contrib Algorithms
|
||||
# --------------------------------------------------------------------
|
||||
|
||||
py_test(
|
||||
|
@ -1071,7 +1072,7 @@ py_test(
|
|||
)
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
# Agents (quick training test iterations via `rllib train`)
|
||||
# Algorithms (quick training test iterations via `rllib train`)
|
||||
#
|
||||
# Tag: quick_train
|
||||
#
|
||||
|
|
|
@ -30,11 +30,12 @@ class TestAlphaZero(unittest.TestCase):
|
|||
|
||||
# Only working for torch right now.
|
||||
for _ in framework_iterator(config, frameworks="torch"):
|
||||
trainer = config.build()
|
||||
algo = config.build()
|
||||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
results = algo.train()
|
||||
check_train_results(results)
|
||||
print(results)
|
||||
algo.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -28,23 +28,23 @@ class TestAPPO(unittest.TestCase):
|
|||
for _ in framework_iterator(config, with_eager_tracing=True):
|
||||
print("w/o v-trace")
|
||||
config.vtrace = False
|
||||
trainer = config.build(env="CartPole-v0")
|
||||
algo = config.build(env="CartPole-v0")
|
||||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
results = algo.train()
|
||||
check_train_results(results)
|
||||
print(results)
|
||||
check_compute_single_action(trainer)
|
||||
trainer.stop()
|
||||
check_compute_single_action(algo)
|
||||
algo.stop()
|
||||
|
||||
print("w/ v-trace")
|
||||
config.vtrace = True
|
||||
trainer = config.build(env="CartPole-v0")
|
||||
algo = config.build(env="CartPole-v0")
|
||||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
results = algo.train()
|
||||
check_train_results(results)
|
||||
print(results)
|
||||
check_compute_single_action(trainer)
|
||||
trainer.stop()
|
||||
check_compute_single_action(algo)
|
||||
algo.stop()
|
||||
|
||||
def test_appo_compilation_use_kl_loss(self):
|
||||
"""Test whether APPO can be built with kl_loss enabled."""
|
||||
|
@ -54,13 +54,13 @@ class TestAPPO(unittest.TestCase):
|
|||
num_iterations = 2
|
||||
|
||||
for _ in framework_iterator(config, with_eager_tracing=True):
|
||||
trainer = config.build(env="CartPole-v0")
|
||||
algo = config.build(env="CartPole-v0")
|
||||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
results = algo.train()
|
||||
check_train_results(results)
|
||||
print(results)
|
||||
check_compute_single_action(trainer)
|
||||
trainer.stop()
|
||||
check_compute_single_action(algo)
|
||||
algo.stop()
|
||||
|
||||
def test_appo_two_tf_optimizers(self):
|
||||
# Not explicitly setting this should cause a warning, but not fail.
|
||||
|
@ -78,13 +78,13 @@ class TestAPPO(unittest.TestCase):
|
|||
|
||||
# Only supported for tf so far.
|
||||
for _ in framework_iterator(config, frameworks=("tf2", "tf")):
|
||||
trainer = config.build(env="CartPole-v0")
|
||||
algo = config.build(env="CartPole-v0")
|
||||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
results = algo.train()
|
||||
check_train_results(results)
|
||||
print(results)
|
||||
check_compute_single_action(trainer)
|
||||
trainer.stop()
|
||||
check_compute_single_action(algo)
|
||||
algo.stop()
|
||||
|
||||
def test_appo_entropy_coeff_schedule(self):
|
||||
# Initial lr, doesn't really matter because of the schedule below.
|
||||
|
@ -113,33 +113,33 @@ class TestAPPO(unittest.TestCase):
|
|||
# which entropy coeff depends on, is updated after each worker rollout.
|
||||
config.min_time_s_per_iteration = 0
|
||||
|
||||
def _step_n_times(trainer, n: int):
|
||||
"""Step trainer n times.
|
||||
def _step_n_times(algo, n: int):
|
||||
"""Step Algorithm n times.
|
||||
|
||||
Returns:
|
||||
learning rate at the end of the execution.
|
||||
"""
|
||||
for _ in range(n):
|
||||
results = trainer.train()
|
||||
print(trainer.workers.local_worker().global_vars)
|
||||
results = algo.train()
|
||||
print(algo.workers.local_worker().global_vars)
|
||||
print(results)
|
||||
return results["info"][LEARNER_INFO][DEFAULT_POLICY_ID][LEARNER_STATS_KEY][
|
||||
"entropy_coeff"
|
||||
]
|
||||
|
||||
for _ in framework_iterator(config):
|
||||
trainer = config.build(env="CartPole-v0")
|
||||
algo = config.build(env="CartPole-v0")
|
||||
|
||||
coeff = _step_n_times(trainer, 10) # 200 timesteps
|
||||
coeff = _step_n_times(algo, 10) # 200 timesteps
|
||||
# Should be close to the starting coeff of 0.01.
|
||||
self.assertLessEqual(coeff, 0.01)
|
||||
self.assertGreaterEqual(coeff, 0.001)
|
||||
|
||||
coeff = _step_n_times(trainer, 20) # 400 timesteps
|
||||
coeff = _step_n_times(algo, 20) # 400 timesteps
|
||||
# Should have annealed to the final coeff of 0.0001.
|
||||
self.assertLessEqual(coeff, 0.001)
|
||||
|
||||
trainer.stop()
|
||||
algo.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -33,13 +33,13 @@ class TestARS(unittest.TestCase):
|
|||
num_iterations = 2
|
||||
|
||||
for _ in framework_iterator(config):
|
||||
trainer = config.build(env="CartPole-v0")
|
||||
algo = config.build(env="CartPole-v0")
|
||||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
results = algo.train()
|
||||
print(results)
|
||||
|
||||
check_compute_single_action(trainer)
|
||||
trainer.stop()
|
||||
check_compute_single_action(algo)
|
||||
algo.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -29,13 +29,13 @@ class TestES(unittest.TestCase):
|
|||
|
||||
for _ in framework_iterator(config):
|
||||
for env in ["CartPole-v0", "Pendulum-v1"]:
|
||||
trainer = config.build(env=env)
|
||||
algo = config.build(env=env)
|
||||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
results = algo.train()
|
||||
print(results)
|
||||
|
||||
check_compute_single_action(trainer)
|
||||
trainer.stop()
|
||||
check_compute_single_action(algo)
|
||||
algo.stop()
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
|
|
|
@ -37,8 +37,8 @@ class MARWILConfig(AlgorithmConfig):
|
|||
... .offline_data(input_=["./rllib/tests/data/cartpole/large.json"])
|
||||
>>> print(config.to_dict())
|
||||
>>> # Build a Algorithm object from the config and run 1 training iteration.
|
||||
>>> trainer = config.build()
|
||||
>>> trainer.train()
|
||||
>>> algo = config.build()
|
||||
>>> algo.train()
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.algorithms.marwil import MARWILConfig
|
||||
|
|
|
@ -30,9 +30,9 @@ class R2D2Config(DQNConfig):
|
|||
>>> .resources(num_gpus=1)\
|
||||
>>> .rollouts(num_rollout_workers=30)\
|
||||
>>> .environment("CartPole-v1")
|
||||
>>> trainer = R2D2(config=config)
|
||||
>>> algo = R2D2(config=config)
|
||||
>>> while True:
|
||||
>>> trainer.train()
|
||||
>>> algo.train()
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.algorithms.r2d2.r2d2 import R2D2Config
|
||||
|
@ -170,8 +170,6 @@ class R2D2Config(DQNConfig):
|
|||
return self
|
||||
|
||||
|
||||
# Build an R2D2 trainer, which uses the framework specific Policy
|
||||
# determined in `get_policy_class()` above.
|
||||
class R2D2(DQN):
|
||||
"""Recurrent Experience Replay in Distrib. Reinforcement Learning (R2D2).
|
||||
|
||||
|
|
|
@ -78,14 +78,14 @@ class TestR2D2(unittest.TestCase):
|
|||
|
||||
# Test building an R2D2 agent in all frameworks.
|
||||
for _ in framework_iterator(config, with_eager_tracing=True):
|
||||
trainer = config.build(env="CartPole-v0")
|
||||
algo = config.build(env="CartPole-v0")
|
||||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
results = algo.train()
|
||||
check_train_results(results)
|
||||
check_batch_sizes(results)
|
||||
print(results)
|
||||
|
||||
check_compute_single_action(trainer, include_state=True)
|
||||
check_compute_single_action(algo, include_state=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -29,8 +29,8 @@ class SACConfig(AlgorithmConfig):
|
|||
... .rollouts(num_rollout_workers=4)
|
||||
>>> print(config.to_dict())
|
||||
>>> # Build a Algorithm object from the config and run 1 training iteration.
|
||||
>>> trainer = config.build(env="CartPole-v1")
|
||||
>>> trainer.train()
|
||||
>>> algo = config.build(env="CartPole-v1")
|
||||
>>> algo.train()
|
||||
"""
|
||||
|
||||
def __init__(self, algo_class=None):
|
||||
|
|
|
@ -18,8 +18,8 @@ class TD3Config(DDPGConfig):
|
|||
>>> config = TD3Config().training(lr=0.01).resources(num_gpus=1)
|
||||
>>> print(config.to_dict())
|
||||
>>> # Build a Algorithm object from the config and run one training iteration.
|
||||
>>> trainer = config.build(env="Pendulum-v1")
|
||||
>>> trainer.train()
|
||||
>>> algo = config.build(env="Pendulum-v1")
|
||||
>>> algo.train()
|
||||
|
||||
Example:
|
||||
>>> from ray.rllib.algorithms.ddpg.td3 import TD3Config
|
||||
|
|
|
@ -38,10 +38,10 @@ class TestAlgorithm(unittest.TestCase):
|
|||
algo = pg.PG(env="CartPole-v0", config=standard_config)
|
||||
|
||||
# When (we validate config 2 times).
|
||||
# Try deprecated `Trainer._validate_config()` method (static).
|
||||
# Try deprecated `Algorithm._validate_config()` method (static).
|
||||
algo._validate_config(standard_config, algo)
|
||||
config_v1 = copy.deepcopy(standard_config)
|
||||
# Try new method: `Trainer.validate_config()` (non-static).
|
||||
# Try new method: `Algorithm.validate_config()` (non-static).
|
||||
algo.validate_config(standard_config)
|
||||
config_v2 = copy.deepcopy(standard_config)
|
||||
|
||||
|
@ -239,7 +239,7 @@ class TestAlgorithm(unittest.TestCase):
|
|||
algo_wo_env_on_driver.stop()
|
||||
|
||||
# Try again using `create_env_on_driver=True`.
|
||||
# This force-adds the env on the local-worker, so this Trainer
|
||||
# This force-adds the env on the local-worker, so this Algorithm
|
||||
# can `evaluate` even though it doesn't have an evaluation-worker
|
||||
# set.
|
||||
config.create_env_on_local_worker = True
|
|
@ -47,13 +47,13 @@ class TestCallbacks(unittest.TestCase):
|
|||
config = dict(base_config, callbacks=callbacks)
|
||||
|
||||
for _ in framework_iterator(config, frameworks=("tf", "torch")):
|
||||
trainer = dqn.DQN(config=config)
|
||||
algo = dqn.DQN(config=config)
|
||||
# Fake the counter on the local worker (doesn't have an env) and
|
||||
# set it to -1 so the below `foreach_worker()` won't fail.
|
||||
trainer.workers.local_worker().sum_sub_env_vector_indices = -1
|
||||
algo.workers.local_worker().sum_sub_env_vector_indices = -1
|
||||
|
||||
# Get sub-env vector index sums from the 2 remote workers:
|
||||
sum_sub_env_vector_indices = trainer.workers.foreach_worker(
|
||||
sum_sub_env_vector_indices = algo.workers.foreach_worker(
|
||||
lambda w: w.sum_sub_env_vector_indices
|
||||
)
|
||||
# Local worker has no environments -> Expect the -1 special
|
||||
|
@ -63,7 +63,7 @@ class TestCallbacks(unittest.TestCase):
|
|||
# of 6 (sum of vector indices: 0 + 1 + 2 + 3).
|
||||
self.assertTrue(sum_sub_env_vector_indices[1] == 6)
|
||||
self.assertTrue(sum_sub_env_vector_indices[2] == 6)
|
||||
trainer.stop()
|
||||
algo.stop()
|
||||
|
||||
def test_on_sub_environment_created_with_remote_envs(self):
|
||||
base_config = {
|
||||
|
@ -84,13 +84,13 @@ class TestCallbacks(unittest.TestCase):
|
|||
config = dict(base_config, callbacks=callbacks)
|
||||
|
||||
for _ in framework_iterator(config, frameworks=("tf", "torch")):
|
||||
trainer = dqn.DQN(config=config)
|
||||
algo = dqn.DQN(config=config)
|
||||
# Fake the counter on the local worker (doesn't have an env) and
|
||||
# set it to -1 so the below `foreach_worker()` won't fail.
|
||||
trainer.workers.local_worker().sum_sub_env_vector_indices = -1
|
||||
algo.workers.local_worker().sum_sub_env_vector_indices = -1
|
||||
|
||||
# Get sub-env vector index sums from the 2 remote workers:
|
||||
sum_sub_env_vector_indices = trainer.workers.foreach_worker(
|
||||
sum_sub_env_vector_indices = algo.workers.foreach_worker(
|
||||
lambda w: w.sum_sub_env_vector_indices
|
||||
)
|
||||
# Local worker has no environments -> Expect the -1 special
|
||||
|
@ -100,7 +100,7 @@ class TestCallbacks(unittest.TestCase):
|
|||
# of 6 (sum of vector indices: 0 + 1 + 2 + 3).
|
||||
self.assertTrue(sum_sub_env_vector_indices[1] == 6)
|
||||
self.assertTrue(sum_sub_env_vector_indices[2] == 6)
|
||||
trainer.stop()
|
||||
algo.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
|
@ -30,10 +30,10 @@ class TestMemoryLeaks(unittest.TestCase):
|
|||
config["env_config"] = {
|
||||
"static_samples": True,
|
||||
}
|
||||
trainer = ppo.PPO(config=config)
|
||||
results = check_memory_leaks(trainer, to_check={"env"}, repeats=150)
|
||||
algo = ppo.PPO(config=config)
|
||||
results = check_memory_leaks(algo, to_check={"env"}, repeats=150)
|
||||
assert results["env"]
|
||||
trainer.stop()
|
||||
algo.stop()
|
||||
|
||||
def test_leaky_policy(self):
|
||||
"""Tests, whether our diagnostics tools can detect leaks in a policy."""
|
||||
|
@ -45,10 +45,10 @@ class TestMemoryLeaks(unittest.TestCase):
|
|||
config["multiagent"]["policies"] = {
|
||||
"default_policy": PolicySpec(policy_class=MemoryLeakingPolicy),
|
||||
}
|
||||
trainer = dqn.DQN(config=config)
|
||||
results = check_memory_leaks(trainer, to_check={"policy"}, repeats=300)
|
||||
algo = dqn.DQN(config=config)
|
||||
results = check_memory_leaks(algo, to_check={"policy"}, repeats=300)
|
||||
assert results["policy"]
|
||||
trainer.stop()
|
||||
algo.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
|
@ -12,7 +12,7 @@ from ray.rllib.connectors.connector import (
|
|||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils.typing import (
|
||||
ActionConnectorDataType,
|
||||
TrainerConfigDict,
|
||||
AlgorithmConfigDict,
|
||||
)
|
||||
|
||||
|
||||
|
@ -50,8 +50,8 @@ register_connector(ActionConnectorPipeline.__name__, ActionConnectorPipeline)
|
|||
|
||||
|
||||
@DeveloperAPI
|
||||
def get_action_connectors_from_trainer_config(
|
||||
config: TrainerConfigDict, action_space: gym.Space
|
||||
def get_action_connectors_from_algorithm_config(
|
||||
config: AlgorithmConfigDict, action_space: gym.Space
|
||||
) -> ActionConnectorPipeline:
|
||||
connectors = []
|
||||
return ActionConnectorPipeline(connectors)
|
||||
|
|
|
@ -15,7 +15,7 @@ from ray.rllib.utils.annotations import DeveloperAPI
|
|||
from ray.rllib.utils.typing import (
|
||||
ActionConnectorDataType,
|
||||
AgentConnectorDataType,
|
||||
TrainerConfigDict,
|
||||
AlgorithmConfigDict,
|
||||
)
|
||||
|
||||
|
||||
|
@ -67,7 +67,7 @@ register_connector(AgentConnectorPipeline.__name__, AgentConnectorPipeline)
|
|||
# TODO(jungong) : finish this.
|
||||
@DeveloperAPI
|
||||
def get_agent_connectors_from_config(
|
||||
config: TrainerConfigDict, obs_space: gym.Space
|
||||
config: AlgorithmConfigDict, obs_space: gym.Space
|
||||
) -> AgentConnectorPipeline:
|
||||
connectors = [FlattenDataAgentConnector()]
|
||||
|
||||
|
|
|
@ -13,8 +13,8 @@ from ray.rllib.utils.annotations import DeveloperAPI
|
|||
from ray.rllib.utils.typing import (
|
||||
ActionConnectorDataType,
|
||||
AgentConnectorDataType,
|
||||
AlgorithmConfigDict,
|
||||
TensorType,
|
||||
TrainerConfigDict,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -34,7 +34,7 @@ class ConnectorContext:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
config: TrainerConfigDict = None,
|
||||
config: AlgorithmConfigDict = None,
|
||||
model_initial_states: List[TensorType] = None,
|
||||
observation_space: gym.Space = None,
|
||||
action_space: gym.Space = None,
|
||||
|
|
2
rllib/env/multi_agent_env.py
vendored
2
rllib/env/multi_agent_env.py
vendored
|
@ -30,7 +30,7 @@ class MultiAgentEnv(gym.Env):
|
|||
"""An environment that hosts multiple independent agents.
|
||||
|
||||
Agents are identified by (string) agent ids. Note that these "agents" here
|
||||
are not to be confused with RLlib Trainers, which are also sometimes
|
||||
are not to be confused with RLlib Algorithms, which are also sometimes
|
||||
referred to as "agents" or "RL agents".
|
||||
"""
|
||||
|
||||
|
|
|
@ -168,16 +168,16 @@ class TestTrajectoryViewAPI(unittest.TestCase):
|
|||
config["env_config"] = {"config": {"start_at_t": 1}} # first obs is [1.0]
|
||||
|
||||
for _ in framework_iterator(config, frameworks="tf2"):
|
||||
trainer = ppo.PPO(
|
||||
algo = ppo.PPO(
|
||||
config,
|
||||
env="ray.rllib.examples.env.debug_counter_env.DebugCounterEnv",
|
||||
)
|
||||
rw = trainer.workers.local_worker()
|
||||
rw = algo.workers.local_worker()
|
||||
sample = rw.sample()
|
||||
assert sample.count == trainer.config["rollout_fragment_length"]
|
||||
results = trainer.train()
|
||||
assert sample.count == algo.config["rollout_fragment_length"]
|
||||
results = algo.train()
|
||||
assert results["timesteps_total"] == config["train_batch_size"]
|
||||
trainer.stop()
|
||||
algo.stop()
|
||||
|
||||
def test_traj_view_next_action(self):
|
||||
action_space = Discrete(2)
|
||||
|
@ -341,10 +341,10 @@ class TestTrajectoryViewAPI(unittest.TestCase):
|
|||
config["env_config"] = {"num_agents": num_agents}
|
||||
|
||||
num_iterations = 2
|
||||
trainer = ppo.PPO(config=config)
|
||||
algo = ppo.PPO(config=config)
|
||||
results = None
|
||||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
results = algo.train()
|
||||
self.assertEqual(results["agent_timesteps_total"], results["timesteps_total"])
|
||||
self.assertEqual(
|
||||
results["num_env_steps_trained"] * num_agents,
|
||||
|
@ -358,7 +358,7 @@ class TestTrajectoryViewAPI(unittest.TestCase):
|
|||
results["agent_timesteps_total"],
|
||||
(num_iterations + 1) * config["train_batch_size"],
|
||||
)
|
||||
trainer.stop()
|
||||
algo.stop()
|
||||
|
||||
def test_get_single_step_input_dict_batch_repeat_value_larger_1(self):
|
||||
"""Test whether a SampleBatch produces the correct 1-step input dict."""
|
||||
|
|
|
@ -81,14 +81,14 @@ if __name__ == "__main__":
|
|||
"episode_reward_mean": args.stop_reward,
|
||||
}
|
||||
|
||||
# To run the Trainer without tune.run, using our LSTM model and
|
||||
# To run the Algorithm without tune.run, using our LSTM model and
|
||||
# manual state-in handling, do the following:
|
||||
|
||||
# Example (use `config` from the above code):
|
||||
# >> import numpy as np
|
||||
# >> from ray.rllib.algorithms.ppo import PPO
|
||||
# >>
|
||||
# >> trainer = PPO(config)
|
||||
# >> algo = PPO(config)
|
||||
# >> lstm_cell_size = config["model"]["lstm_cell_size"]
|
||||
# >> env = StatelessCartPole()
|
||||
# >> obs = env.reset()
|
||||
|
@ -101,7 +101,7 @@ if __name__ == "__main__":
|
|||
# >> prev_r = 0.0
|
||||
# >>
|
||||
# >> while True:
|
||||
# >> a, state_out, _ = trainer.compute_single_action(
|
||||
# >> a, state_out, _ = algo.compute_single_action(
|
||||
# .. obs, state, prev_a, prev_r)
|
||||
# >> obs, reward, done, _ = env.step(a)
|
||||
# >> if done:
|
||||
|
|
|
@ -92,8 +92,8 @@ MyTFPolicy = build_tf_policy(
|
|||
)
|
||||
|
||||
|
||||
# Create a new Trainer using the Policy defined above.
|
||||
class MyTrainer(Algorithm):
|
||||
# Create a new Algorithm using the Policy defined above.
|
||||
class MyAlgo(Algorithm):
|
||||
def get_default_policy_class(self, config):
|
||||
return MyTFPolicy
|
||||
|
||||
|
@ -117,7 +117,7 @@ if __name__ == "__main__":
|
|||
"episode_reward_mean": args.stop_reward,
|
||||
}
|
||||
|
||||
results = tune.run(MyTrainer, stop=stop, config=config, verbose=1)
|
||||
results = tune.run(MyAlgo, stop=stop, config=config, verbose=1)
|
||||
|
||||
if args.as_test:
|
||||
check_learning_achieved(results, args.stop_reward)
|
||||
|
|
|
@ -83,11 +83,11 @@ if __name__ == "__main__":
|
|||
min_reward = -300
|
||||
|
||||
# Test for torch framework (tf not implemented yet).
|
||||
trainer = cql.CQL(config=config)
|
||||
algo = cql.CQL(config=config)
|
||||
learnt = False
|
||||
for i in range(num_iterations):
|
||||
print(f"Iter {i}")
|
||||
eval_results = trainer.train().get("evaluation")
|
||||
eval_results = algo.train().get("evaluation")
|
||||
if eval_results:
|
||||
print("... R={}".format(eval_results["episode_reward_mean"]))
|
||||
# Learn until some reward is reached on an actual live env.
|
||||
|
@ -101,7 +101,7 @@ if __name__ == "__main__":
|
|||
)
|
||||
|
||||
# Get policy, model, and replay-buffer.
|
||||
pol = trainer.get_policy()
|
||||
pol = algo.get_policy()
|
||||
cql_model = pol.model
|
||||
from ray.rllib.algorithms.cql.cql import replay_buffer
|
||||
|
||||
|
@ -116,7 +116,7 @@ if __name__ == "__main__":
|
|||
final_q_values = torch.min(q_values, twin_q_values)
|
||||
print(final_q_values)
|
||||
|
||||
# Example on how to do evaluation on the trained Trainer
|
||||
# Example on how to do evaluation on the trained Algorithm.
|
||||
# using the data from our buffer.
|
||||
# Get a sample (MultiAgentBatch).
|
||||
multi_agent_batch = replay_buffer.sample(num_items=config["train_batch_size"])
|
||||
|
@ -128,11 +128,10 @@ if __name__ == "__main__":
|
|||
model_out, _ = cql_model({"obs": obs})
|
||||
# The estimated Q-values from the (historic) actions in the batch.
|
||||
q_values_old = cql_model.get_q_values(model_out, torch.from_numpy(batch["actions"]))
|
||||
# The estimated Q-values for the new actions computed
|
||||
# by our trainer policy.
|
||||
# The estimated Q-values for the new actions computed by our policy.
|
||||
actions_new = pol.compute_actions_from_input_dict({"obs": obs})[0]
|
||||
q_values_new = cql_model.get_q_values(model_out, torch.from_numpy(actions_new))
|
||||
print(f"Q-val batch={q_values_old}")
|
||||
print(f"Q-val policy={q_values_new}")
|
||||
|
||||
trainer.stop()
|
||||
algo.stop()
|
||||
|
|
|
@ -58,10 +58,10 @@ class RandomParametricPolicy(Policy, ABC):
|
|||
pass
|
||||
|
||||
|
||||
class RandomParametricTrainer(Algorithm):
|
||||
"""Algo with Policy and config defined above and overriding `training_iteration`.
|
||||
class RandomParametricAlgorithm(Algorithm):
|
||||
"""Algo with Policy and config defined above and overriding `training_step`.
|
||||
|
||||
Overrides the `training_iteration` method, which only runs a (dummy)
|
||||
Overrides the `training_step` method, which only runs a (dummy)
|
||||
rollout and performs no learning.
|
||||
"""
|
||||
|
||||
|
@ -79,7 +79,7 @@ class RandomParametricTrainer(Algorithm):
|
|||
|
||||
def main():
|
||||
register_env("pa_cartpole", lambda _: ParametricActionsCartPole(10))
|
||||
algo = RandomParametricTrainer(env="pa_cartpole")
|
||||
algo = RandomParametricAlgorithm(env="pa_cartpole")
|
||||
result = algo.train()
|
||||
assert result["episode_reward_mean"] > 10, result
|
||||
print("Test: OK")
|
||||
|
|
|
@ -75,10 +75,10 @@ def get_cli_args():
|
|||
return args
|
||||
|
||||
|
||||
# The modified Trainer class we will use. This is the exact same
|
||||
# as a PPO, but with the additional default_resource_request
|
||||
# override, telling tune that it's ok (not mandatory) to place our
|
||||
# n remote envs on a different node (each env using 1 CPU).
|
||||
# The modified Algorithm class we will use:
|
||||
# Subclassing from PPO, our algo will only modity `default_resource_request`,
|
||||
# telling Ray Tune that it's ok (not mandatory) to place our n remote envs on a
|
||||
# different node (each env using 1 CPU).
|
||||
class PPORemoteInference(PPO):
|
||||
@classmethod
|
||||
@override(Algorithm)
|
||||
|
@ -145,7 +145,7 @@ if __name__ == "__main__":
|
|||
):
|
||||
break
|
||||
|
||||
# Run with Tune for auto env and trainer creation and TensorBoard.
|
||||
# Run with Tune for auto env and algorithm creation and TensorBoard.
|
||||
else:
|
||||
stop = {
|
||||
"training_iteration": args.stop_iters,
|
||||
|
|
|
@ -64,12 +64,12 @@ parser.add_argument(
|
|||
)
|
||||
|
||||
|
||||
# Define new Trainer with custom execution_plan/workflow.
|
||||
class MyTrainer(Algorithm):
|
||||
# Define new Algorithm with custom execution_plan/workflow.
|
||||
class MyAlgo(Algorithm):
|
||||
@classmethod
|
||||
@override(Algorithm)
|
||||
def get_default_config(cls) -> AlgorithmConfigDict:
|
||||
# Run this Trainer with new `training_iteration` API and set some PPO-specific
|
||||
# Run this Algorithm with new `training_step` API and set some PPO-specific
|
||||
# parameters.
|
||||
return with_common_config(
|
||||
{
|
||||
|
@ -218,7 +218,7 @@ if __name__ == "__main__":
|
|||
"episode_reward_mean": args.stop_reward,
|
||||
}
|
||||
|
||||
results = tune.run(MyTrainer, config=config, stop=stop)
|
||||
results = tune.run(MyAlgo, config=config, stop=stop)
|
||||
|
||||
if args.as_test:
|
||||
check_learning_achieved(results, args.stop_reward)
|
||||
|
|
|
@ -17,7 +17,7 @@ parser.add_argument(
|
|||
type=str,
|
||||
default=None,
|
||||
help="Full path to a checkpoint file for restoring a previously saved "
|
||||
"Trainer state.",
|
||||
"Algorithm state.",
|
||||
)
|
||||
parser.add_argument("--num-workers", type=int, default=0)
|
||||
parser.add_argument(
|
||||
|
|
|
@ -27,7 +27,7 @@ def StandardMetricsReporting(
|
|||
train_op: Operator for executing training steps.
|
||||
We ignore the output values.
|
||||
workers: Rollout workers to collect metrics from.
|
||||
config: Trainer configuration, used to determine the frequency
|
||||
config: Algorithm configuration, used to determine the frequency
|
||||
of stats reporting.
|
||||
selected_workers: Override the list of remote workers
|
||||
to collect metrics from.
|
||||
|
|
|
@ -51,7 +51,7 @@ class TestOPE(unittest.TestCase):
|
|||
.framework("torch")
|
||||
.rollouts(batch_mode="complete_episodes")
|
||||
)
|
||||
cls.trainer = config.build()
|
||||
cls.algo = config.build()
|
||||
|
||||
# Train DQN for evaluation policy
|
||||
tune.run(
|
||||
|
@ -80,7 +80,7 @@ class TestOPE(unittest.TestCase):
|
|||
done = False
|
||||
rewards = []
|
||||
while not done:
|
||||
act = cls.trainer.compute_single_action(obs)
|
||||
act = cls.algo.compute_single_action(obs)
|
||||
obs, reward, done, _ = env.step(act)
|
||||
rewards.append(reward)
|
||||
ret = 0
|
||||
|
@ -105,7 +105,7 @@ class TestOPE(unittest.TestCase):
|
|||
name = "is"
|
||||
estimator = ImportanceSampling(
|
||||
name=name,
|
||||
policy=self.trainer.get_policy(),
|
||||
policy=self.algo.get_policy(),
|
||||
gamma=self.gamma,
|
||||
)
|
||||
estimator.process(self.batch)
|
||||
|
@ -118,7 +118,7 @@ class TestOPE(unittest.TestCase):
|
|||
name = "wis"
|
||||
estimator = WeightedImportanceSampling(
|
||||
name=name,
|
||||
policy=self.trainer.get_policy(),
|
||||
policy=self.algo.get_policy(),
|
||||
gamma=self.gamma,
|
||||
)
|
||||
estimator.process(self.batch)
|
||||
|
@ -131,7 +131,7 @@ class TestOPE(unittest.TestCase):
|
|||
name = "dm_qreg"
|
||||
estimator = DirectMethod(
|
||||
name=name,
|
||||
policy=self.trainer.get_policy(),
|
||||
policy=self.algo.get_policy(),
|
||||
gamma=self.gamma,
|
||||
q_model_type="qreg",
|
||||
**self.model_config,
|
||||
|
@ -146,7 +146,7 @@ class TestOPE(unittest.TestCase):
|
|||
name = "dm_fqe"
|
||||
estimator = DirectMethod(
|
||||
name=name,
|
||||
policy=self.trainer.get_policy(),
|
||||
policy=self.algo.get_policy(),
|
||||
gamma=self.gamma,
|
||||
q_model_type="fqe",
|
||||
**self.model_config,
|
||||
|
@ -161,7 +161,7 @@ class TestOPE(unittest.TestCase):
|
|||
name = "dr_qreg"
|
||||
estimator = DoublyRobust(
|
||||
name=name,
|
||||
policy=self.trainer.get_policy(),
|
||||
policy=self.algo.get_policy(),
|
||||
gamma=self.gamma,
|
||||
q_model_type="qreg",
|
||||
**self.model_config,
|
||||
|
@ -176,7 +176,7 @@ class TestOPE(unittest.TestCase):
|
|||
name = "dr_fqe"
|
||||
estimator = DoublyRobust(
|
||||
name=name,
|
||||
policy=self.trainer.get_policy(),
|
||||
policy=self.algo.get_policy(),
|
||||
gamma=self.gamma,
|
||||
q_model_type="fqe",
|
||||
**self.model_config,
|
||||
|
@ -187,7 +187,7 @@ class TestOPE(unittest.TestCase):
|
|||
self.mean_ret[name] = np.mean([e.metrics["v_new"] for e in estimates])
|
||||
self.std_ret[name] = np.std([e.metrics["v_new"] for e in estimates])
|
||||
|
||||
def test_ope_in_trainer(self):
|
||||
def test_ope_in_algo(self):
|
||||
# TODO (rohan): Add performance tests for off_policy_estimation_methods,
|
||||
# with fixed seeds and hyperparameters
|
||||
pass
|
||||
|
|
|
@ -294,7 +294,7 @@ def _build_eager_tf_policy(
|
|||
much simpler, but has lower performance.
|
||||
|
||||
You shouldn't need to call this directly. Rather, prefer to build a TF
|
||||
graph policy and use set {"framework": "tfe"} in the trainer config to have
|
||||
graph policy and use set {"framework": "tfe"} in the Algorithm's config to have
|
||||
it automatically be converted to an eager policy.
|
||||
|
||||
This has the same signature as build_tf_policy()."""
|
||||
|
|
|
@ -78,7 +78,7 @@ class EntropyCoeffSchedule:
|
|||
class KLCoeffMixin:
|
||||
"""Assigns the `update_kl()` method to a TorchPolicy.
|
||||
|
||||
This is used by Trainers to update the KL coefficient
|
||||
This is used by Algorithms to update the KL coefficient
|
||||
after each learning step based on `config.kl_target` and
|
||||
the measured KL value (from the train_batch).
|
||||
"""
|
||||
|
|
|
@ -7,7 +7,7 @@ if __name__ == "__main__":
|
|||
# Do not import torch for testing purposes.
|
||||
os.environ["RLLIB_TEST_NO_TORCH_IMPORT"] = "1"
|
||||
|
||||
# Test registering (includes importing) all Trainers.
|
||||
# Test registering (includes importing) all Algorithms.
|
||||
from ray.rllib import _register_all
|
||||
|
||||
# This should surface any dependency on torch, e.g. inside function
|
||||
|
@ -19,7 +19,7 @@ if __name__ == "__main__":
|
|||
assert "torch" not in sys.modules, "`torch` initially present, when it shouldn't!"
|
||||
|
||||
# Note: No ray.init(), to test it works without Ray
|
||||
trainer = A2C(
|
||||
algo = A2C(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"framework": "tf",
|
||||
|
@ -31,7 +31,7 @@ if __name__ == "__main__":
|
|||
},
|
||||
},
|
||||
)
|
||||
trainer.train()
|
||||
algo.train()
|
||||
|
||||
assert (
|
||||
"torch" not in sys.modules
|
||||
|
|
|
@ -57,10 +57,10 @@ class TestPlacementGroups(unittest.TestCase):
|
|||
config["env"] = "CartPole-v0"
|
||||
config["framework"] = "tf"
|
||||
|
||||
# Create a trainer with an overridden default_resource_request
|
||||
# Create an Algorithm with an overridden default_resource_request
|
||||
# method that returns a PlacementGroupFactory.
|
||||
|
||||
class MyTrainer(PG):
|
||||
class MyAlgo(PG):
|
||||
@classmethod
|
||||
def default_resource_request(cls, config):
|
||||
head_bundle = {"CPU": 1, "GPU": 0}
|
||||
|
@ -70,7 +70,7 @@ class TestPlacementGroups(unittest.TestCase):
|
|||
strategy=config["placement_strategy"],
|
||||
)
|
||||
|
||||
tune.register_trainable("my_trainable", MyTrainer)
|
||||
tune.register_trainable("my_trainable", MyAlgo)
|
||||
|
||||
global trial_executor
|
||||
trial_executor = RayTrialExecutor(reuse_actors=False)
|
||||
|
|
|
@ -27,11 +27,11 @@ class TestTimeSteps(unittest.TestCase):
|
|||
obs_batch = np.array([1])
|
||||
|
||||
for _ in framework_iterator(config):
|
||||
trainer = pg.PG(config=config, env=RandomEnv)
|
||||
policy = trainer.get_policy()
|
||||
algo = pg.PG(config=config, env=RandomEnv)
|
||||
policy = algo.get_policy()
|
||||
|
||||
for i in range(1, 21):
|
||||
trainer.compute_single_action(obs)
|
||||
algo.compute_single_action(obs)
|
||||
check(policy.global_timestep, i)
|
||||
for i in range(1, 21):
|
||||
policy.compute_actions(obs_batch)
|
||||
|
@ -45,7 +45,8 @@ class TestTimeSteps(unittest.TestCase):
|
|||
for i in range(1, 11):
|
||||
policy.compute_actions(obs_batch)
|
||||
check(policy.global_timestep, i + crazy_timesteps)
|
||||
trainer.train()
|
||||
algo.train()
|
||||
algo.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -36,18 +36,18 @@ def PublicAPI(obj):
|
|||
can expect these APIs to remain stable across RLlib releases.
|
||||
|
||||
Subclasses that inherit from a ``@PublicAPI`` base class can be
|
||||
assumed part of the RLlib public API as well (e.g., all trainer classes
|
||||
are in public API because Trainer is ``@PublicAPI``).
|
||||
assumed part of the RLlib public API as well (e.g., all Algorithm classes
|
||||
are in public API because Algorithm is ``@PublicAPI``).
|
||||
|
||||
In addition, you can assume all trainer configurations are part of their
|
||||
In addition, you can assume all algo configurations are part of their
|
||||
public API as well.
|
||||
|
||||
Examples:
|
||||
>>> # Indicates that the `Trainer` class is exposed to end users
|
||||
>>> # Indicates that the `Algorithm` class is exposed to end users
|
||||
>>> # of RLlib and will remain stable across RLlib releases.
|
||||
>>> from ray import tune
|
||||
>>> @PublicAPI # doctest: +SKIP
|
||||
>>> class Trainer(tune.Trainable): # doctest: +SKIP
|
||||
>>> class Algorithm(tune.Trainable): # doctest: +SKIP
|
||||
... ... # doctest: +SKIP
|
||||
"""
|
||||
|
||||
|
@ -110,7 +110,7 @@ def ExperimentalAPI(obj):
|
|||
def OverrideToImplementCustomLogic(obj):
|
||||
"""Users should override this in their sub-classes to implement custom logic.
|
||||
|
||||
Used in Trainer and Policy to tag methods that need overriding, e.g.
|
||||
Used in Algorithm and Policy to tag methods that need overriding, e.g.
|
||||
`Policy.loss()`.
|
||||
|
||||
Examples:
|
||||
|
@ -132,9 +132,9 @@ def OverrideToImplementCustomLogic_CallToSuperRecommended(obj):
|
|||
Thereby, it is recommended (but not required) to call the super-class'
|
||||
corresponding method.
|
||||
|
||||
Used in Trainer and Policy to tag methods that need overriding, but the
|
||||
Used in Algorithm and Policy to tag methods that need overriding, but the
|
||||
super class' method should still be called, e.g.
|
||||
`Trainer.setup()`.
|
||||
`Algorithm.setup()`.
|
||||
|
||||
Examples:
|
||||
>>> from ray import tune
|
||||
|
|
|
@ -36,7 +36,7 @@ Suspect = DeveloperAPI(
|
|||
|
||||
@DeveloperAPI
|
||||
def check_memory_leaks(
|
||||
trainer,
|
||||
algorithm,
|
||||
to_check: Optional[Set[str]] = None,
|
||||
repeats: Optional[int] = None,
|
||||
max_num_trials: int = 3,
|
||||
|
@ -49,7 +49,7 @@ def check_memory_leaks(
|
|||
un-GC'd items to memory.
|
||||
|
||||
Args:
|
||||
trainer: The Algorithm instance to test.
|
||||
algorithm: The Algorithm instance to test.
|
||||
to_check: Set of strings to indentify components to test. Allowed strings
|
||||
are: "env", "policy", "model", "rollout_worker". By default, check all
|
||||
of these.
|
||||
|
@ -62,7 +62,7 @@ def check_memory_leaks(
|
|||
A defaultdict(list) with keys being the `to_check` strings and values being
|
||||
lists of Suspect instances that were found.
|
||||
"""
|
||||
local_worker = trainer.workers.local_worker()
|
||||
local_worker = algorithm.workers.local_worker()
|
||||
|
||||
# Which components should we test?
|
||||
to_check = to_check or {"env", "model", "policy", "rollout_worker"}
|
||||
|
|
|
@ -12,7 +12,7 @@ NUM_AGENT_STEPS_TRAINED_THIS_ITER = "num_agent_steps_trained_this_iter"
|
|||
LAST_TARGET_UPDATE_TS = "last_target_update_ts"
|
||||
NUM_TARGET_UPDATES = "num_target_updates"
|
||||
|
||||
# Performance timers (keys for Trainer._timers or metrics.timers).
|
||||
# Performance timers (keys for Algorithm._timers or metrics.timers).
|
||||
TRAINING_ITERATION_TIMER = "training_iteration"
|
||||
APPLY_GRADS_TIMER = "apply_grad"
|
||||
COMPUTE_GRADS_TIMER = "compute_grads"
|
||||
|
|
Loading…
Add table
Reference in a new issue