mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[RLlib] Trainer.training_iteration -> Trainer.training_step; Iterations vs reportings: Clarification of terms. (#25076)
This commit is contained in:
parent
94d6c212df
commit
7c39aa5fac
83 changed files with 383 additions and 311 deletions
|
@ -489,7 +489,7 @@ The following is a list of the common algorithm hyper-parameters:
|
|||
# If - after one `step_attempt()`, the time limit has not been reached,
|
||||
# will perform n more `step_attempt()` calls until this minimum time has been
|
||||
# consumed. Set to 0 for no minimum time.
|
||||
"min_time_s_per_reporting": 0,
|
||||
"min_time_s_per_iteration": 0,
|
||||
# Minimum train/sample timesteps to accumulate within a single `train()` call.
|
||||
# This value does not affect learning, only the number of times
|
||||
# `self.step_attempt()` is called by `self.train()`.
|
||||
|
@ -497,8 +497,8 @@ The following is a list of the common algorithm hyper-parameters:
|
|||
# training) have not been reached, will perform n more `step_attempt()`
|
||||
# calls until the minimum timesteps have been executed.
|
||||
# Set to 0 for no minimum timesteps.
|
||||
"min_train_timesteps_per_reporting": 0,
|
||||
"min_sample_timesteps_per_reporting": 0,
|
||||
"min_train_timesteps_per_iteration": 0,
|
||||
"min_sample_timesteps_per_iteration": 0,
|
||||
|
||||
# This argument, in conjunction with worker_index, sets the random seed of
|
||||
# each worker, so that identically configured trials will have identical
|
||||
|
@ -655,7 +655,7 @@ The following is a list of the common algorithm hyper-parameters:
|
|||
|
||||
# === API deprecations/simplifications/changes ===
|
||||
# If True, the execution plan API will not be used. Instead,
|
||||
# a Trainer's `training_iteration()` method will be called on each
|
||||
# a Trainer's `training_step()` method will be called on each
|
||||
# training iteration.
|
||||
"_disable_execution_plan_api": True,
|
||||
|
||||
|
|
|
@ -50,7 +50,7 @@ run_experiments(
|
|||
"rollout_fragment_length": 1,
|
||||
"train_batch_size": 1,
|
||||
"min_iter_time_s": 10,
|
||||
"min_sample_timesteps_per_reporting": 10,
|
||||
"min_sample_timesteps_per_iteration": 10,
|
||||
},
|
||||
}
|
||||
},
|
||||
|
|
|
@ -28,4 +28,4 @@ apex-breakoutnoframeskip-v4:
|
|||
rollout_fragment_length: 20
|
||||
train_batch_size: 512
|
||||
target_network_update_freq: 50000
|
||||
min_sample_timesteps_per_reporting: 25000
|
||||
min_sample_timesteps_per_iteration: 25000
|
||||
|
|
|
@ -29,7 +29,7 @@ cql-halfcheetahbulletenv-v0:
|
|||
learning_starts: 256
|
||||
train_batch_size: 256
|
||||
target_network_update_freq: 0
|
||||
min_train_timesteps_per_reporting: 1000
|
||||
min_train_timesteps_per_iteration: 1000
|
||||
optimization:
|
||||
actor_learning_rate: 0.0001
|
||||
critic_learning_rate: 0.0003
|
||||
|
|
|
@ -21,7 +21,7 @@ ddpg-hopperbulletenv-v0:
|
|||
ou_base_scale: 0.1
|
||||
ou_theta: 0.15
|
||||
ou_sigma: 0.2
|
||||
min_sample_timesteps_per_reporting: 1000
|
||||
min_sample_timesteps_per_iteration: 1000
|
||||
target_network_update_freq: 0
|
||||
tau: 0.001
|
||||
replay_buffer_config:
|
||||
|
|
|
@ -28,4 +28,4 @@ dqn-breakoutnoframeskip-v4:
|
|||
epsilon_timesteps: 200000
|
||||
final_epsilon: 0.01
|
||||
num_gpus: 0.5
|
||||
min_sample_timesteps_per_reporting: 10000
|
||||
min_sample_timesteps_per_iteration: 10000
|
||||
|
|
|
@ -23,7 +23,7 @@ sac-halfcheetahbulletenv-v0:
|
|||
rollout_fragment_length: 1
|
||||
train_batch_size: 256
|
||||
target_network_update_freq: 1
|
||||
min_sample_timesteps_per_reporting: 1000
|
||||
min_sample_timesteps_per_iteration: 1000
|
||||
replay_buffer_config:
|
||||
learning_starts: 10000
|
||||
type: MultiAgentPrioritizedReplayBuffer
|
||||
|
|
|
@ -1122,7 +1122,7 @@ py_test(
|
|||
"--env", "Pendulum-v1",
|
||||
"--run", "APEX_DDPG",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"framework\": \"tf\", \"num_workers\": 2, \"optimizer\": {\"num_replay_buffer_shards\": 1}, \"replay_buffer_config\": {\"learning_starts\": 100}, \"min_time_s_per_reporting\": 1, \"batch_mode\": \"complete_episodes\"}'",
|
||||
"--config", "'{\"framework\": \"tf\", \"num_workers\": 2, \"optimizer\": {\"num_replay_buffer_shards\": 1}, \"replay_buffer_config\": {\"learning_starts\": 100}, \"min_time_s_per_iteration\": 1, \"batch_mode\": \"complete_episodes\"}'",
|
||||
"--ray-num-cpus", "4",
|
||||
]
|
||||
)
|
||||
|
@ -1205,7 +1205,7 @@ py_test(
|
|||
"--env", "CartPole-v0",
|
||||
"--run", "IMPALA",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"framework\": \"tf\", \"num_gpus\": 0, \"num_workers\": 2, \"min_time_s_per_reporting\": 1, \"num_multi_gpu_tower_stacks\": 2, \"replay_buffer_num_slots\": 100, \"replay_proportion\": 1.0}'",
|
||||
"--config", "'{\"framework\": \"tf\", \"num_gpus\": 0, \"num_workers\": 2, \"min_time_s_per_iteration\": 1, \"num_multi_gpu_tower_stacks\": 2, \"replay_buffer_num_slots\": 100, \"replay_proportion\": 1.0}'",
|
||||
"--ray-num-cpus", "4",
|
||||
]
|
||||
)
|
||||
|
@ -1219,7 +1219,7 @@ py_test(
|
|||
"--env", "CartPole-v0",
|
||||
"--run", "IMPALA",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"framework\": \"tf\", \"num_gpus\": 0, \"num_workers\": 2, \"min_time_s_per_reporting\": 1, \"num_multi_gpu_tower_stacks\": 2, \"replay_buffer_num_slots\": 100, \"replay_proportion\": 1.0, \"model\": {\"use_lstm\": true}}'",
|
||||
"--config", "'{\"framework\": \"tf\", \"num_gpus\": 0, \"num_workers\": 2, \"min_time_s_per_iteration\": 1, \"num_multi_gpu_tower_stacks\": 2, \"replay_buffer_num_slots\": 100, \"replay_proportion\": 1.0, \"model\": {\"use_lstm\": true}}'",
|
||||
"--ray-num-cpus", "4",
|
||||
]
|
||||
)
|
||||
|
|
|
@ -152,9 +152,9 @@ class IgnoresWorkerFailure(unittest.TestCase):
|
|||
self.do_test(
|
||||
"APEX",
|
||||
{
|
||||
"min_sample_timesteps_per_reporting": 1000,
|
||||
"min_sample_timesteps_per_iteration": 1000,
|
||||
"num_gpus": 0,
|
||||
"min_time_s_per_reporting": 1,
|
||||
"min_time_s_per_iteration": 1,
|
||||
"explore": False,
|
||||
"learning_starts": 1000,
|
||||
"target_network_update_freq": 100,
|
||||
|
@ -168,7 +168,7 @@ class IgnoresWorkerFailure(unittest.TestCase):
|
|||
self.do_test("IMPALA", {"num_gpus": 0})
|
||||
|
||||
def test_sync_replay(self):
|
||||
self.do_test("DQN", {"min_sample_timesteps_per_reporting": 1})
|
||||
self.do_test("DQN", {"min_sample_timesteps_per_iteration": 1})
|
||||
|
||||
def test_multi_g_p_u(self):
|
||||
self.do_test(
|
||||
|
|
|
@ -72,11 +72,13 @@ from ray.rllib.utils.error import EnvError, ERR_MSG_INVALID_ENV_DESCRIPTOR
|
|||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.from_config import from_config
|
||||
from ray.rllib.utils.metrics import (
|
||||
TRAINING_ITERATION_TIMER,
|
||||
NUM_ENV_STEPS_SAMPLED,
|
||||
NUM_AGENT_STEPS_SAMPLED,
|
||||
NUM_ENV_STEPS_TRAINED,
|
||||
NUM_AGENT_STEPS_SAMPLED_THIS_ITER,
|
||||
NUM_AGENT_STEPS_TRAINED,
|
||||
NUM_ENV_STEPS_SAMPLED,
|
||||
NUM_ENV_STEPS_SAMPLED_THIS_ITER,
|
||||
NUM_ENV_STEPS_TRAINED,
|
||||
TRAINING_ITERATION_TIMER,
|
||||
)
|
||||
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
|
||||
from ray.rllib.utils.pre_checks.multi_agent import check_multi_agent
|
||||
|
@ -212,7 +214,6 @@ class Trainer(Trainable):
|
|||
logger_creator: Callable that creates a ray.tune.Logger
|
||||
object. If unspecified, a default logger is created.
|
||||
"""
|
||||
|
||||
# User provided (partial) config (this may be w/o the default
|
||||
# Trainer's Config object). Will get merged with TrainerConfig()
|
||||
# in self.setup().
|
||||
|
@ -292,6 +293,18 @@ class Trainer(Trainable):
|
|||
config, logger_creator, remote_checkpoint_dir, sync_function_tpl
|
||||
)
|
||||
|
||||
# Check, whether `training_iteration` is still a tune.Trainable property
|
||||
# and has not been overridden by the user in the attempt to implement the
|
||||
# algos logic (this should be done now inside `training_step`).
|
||||
try:
|
||||
assert isinstance(self.training_iteration, int)
|
||||
except AssertionError:
|
||||
raise AssertionError(
|
||||
"Your Trainer's `training_iteration` seems to be overridden by your "
|
||||
"custom training logic! To solve this problem, simply rename your "
|
||||
"`self.training_iteration()` method into `self.training_step`."
|
||||
)
|
||||
|
||||
@OverrideToImplementCustomLogic
|
||||
@classmethod
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
|
@ -535,7 +548,7 @@ class Trainer(Trainable):
|
|||
|
||||
@override(Trainable)
|
||||
def step(self) -> ResultDict:
|
||||
"""Implements the main `Trainer.train()` logic.
|
||||
"""Implements the main `Trainer.train()` logic, defining one "iteration".
|
||||
|
||||
Takes n attempts to perform a single training step. Thereby
|
||||
catches RayErrors resulting from worker failures. After n attempts,
|
||||
|
@ -549,18 +562,76 @@ class Trainer(Trainable):
|
|||
The results dict with stats/infos on sampling, training,
|
||||
and - if required - evaluation.
|
||||
"""
|
||||
step_attempt_results = None
|
||||
# Do we have to run `self.evaluate()` this iteration?
|
||||
# `self.iteration` gets incremented after this function returns,
|
||||
# meaning that e. g. the first time this function is called,
|
||||
# self.iteration will be 0.
|
||||
evaluate_this_iter = (
|
||||
self.config["evaluation_interval"]
|
||||
and (self.iteration + 1) % self.config["evaluation_interval"] == 0
|
||||
)
|
||||
|
||||
# Results dict for training (and if appolicable: evaluation).
|
||||
result: ResultDict = {}
|
||||
|
||||
first_step_attempt = True
|
||||
|
||||
self._rollout_worker_metrics = []
|
||||
local_worker = (
|
||||
self.workers.local_worker()
|
||||
if hasattr(self.workers, "local_worker")
|
||||
else None
|
||||
)
|
||||
|
||||
# Create a step context ...
|
||||
with self._step_context() as step_ctx:
|
||||
while not step_ctx.should_stop(step_attempt_results):
|
||||
# so we can query it whether we should stop the iteration loop (e.g. when
|
||||
# we have reached `min_time_s_per_iteration`).
|
||||
while not step_ctx.should_stop(result):
|
||||
# Try to train one step.
|
||||
try:
|
||||
step_attempt_results = self.step_attempt()
|
||||
# No evaluation necessary, just run the next training iteration.
|
||||
if not evaluate_this_iter:
|
||||
result = self._exec_plan_or_training_step_fn()
|
||||
# We have to evaluate in this training iteration.
|
||||
else:
|
||||
# No parallelism.
|
||||
if not self.config["evaluation_parallel_to_training"]:
|
||||
result = self._exec_plan_or_training_step_fn()
|
||||
# Kick off evaluation-loop (and parallel train() call,
|
||||
# if requested).
|
||||
# Parallel eval + training.
|
||||
else:
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
train_future = executor.submit(
|
||||
lambda: self._exec_plan_or_training_step_fn()
|
||||
)
|
||||
# Automatically determine duration of the evaluation
|
||||
# (as long as training takes).
|
||||
if self.config["evaluation_duration"] == "auto":
|
||||
unit = self.config["evaluation_duration_unit"]
|
||||
result.update(
|
||||
self.evaluate(
|
||||
duration_fn=functools.partial(
|
||||
self._auto_duration_fn,
|
||||
unit,
|
||||
self.config["evaluation_num_workers"],
|
||||
self.config["evaluation_config"],
|
||||
train_future,
|
||||
)
|
||||
)
|
||||
)
|
||||
# Run `self.evaluate()` only once per iteration.
|
||||
elif first_step_attempt:
|
||||
first_step_attempt = False
|
||||
result.update(self.evaluate())
|
||||
# Collect the training results from the future.
|
||||
result.update(train_future.result())
|
||||
|
||||
# Sequential: train (already done above), then eval.
|
||||
if not self.config["evaluation_parallel_to_training"]:
|
||||
result.update(self.evaluate())
|
||||
|
||||
# Collect rollout worker metrics.
|
||||
episodes, self._episodes_to_be_collected = collect_episodes(
|
||||
local_worker,
|
||||
|
@ -571,6 +642,7 @@ class Trainer(Trainable):
|
|||
],
|
||||
)
|
||||
self._rollout_worker_metrics.extend(episodes)
|
||||
|
||||
# @ray.remote RolloutWorker failure.
|
||||
except RayError as e:
|
||||
# Try to recover w/o the failed worker.
|
||||
|
@ -596,7 +668,13 @@ class Trainer(Trainable):
|
|||
time.sleep(0.5)
|
||||
raise e
|
||||
|
||||
result = step_attempt_results
|
||||
# Attach latest available evaluation results to train results,
|
||||
# if necessary.
|
||||
if not evaluate_this_iter and self.config["always_attach_evaluation_results"]:
|
||||
assert isinstance(
|
||||
self.evaluation_metrics, dict
|
||||
), "Trainer.evaluate() needs to return a dict."
|
||||
result.update(self.evaluation_metrics)
|
||||
|
||||
if hasattr(self, "workers") and isinstance(self.workers, WorkerSet):
|
||||
# Sync filters on workers.
|
||||
|
@ -604,9 +682,9 @@ class Trainer(Trainable):
|
|||
|
||||
# Collect worker metrics.
|
||||
if self.config["_disable_execution_plan_api"]:
|
||||
result = self._compile_step_results(
|
||||
result = self._compile_iteration_results(
|
||||
step_ctx=step_ctx,
|
||||
step_attempt_results=step_attempt_results,
|
||||
iteration_results=result,
|
||||
)
|
||||
|
||||
# Check `env_task_fn` for possible update of the env's task.
|
||||
|
@ -628,96 +706,6 @@ class Trainer(Trainable):
|
|||
|
||||
return result
|
||||
|
||||
def step_attempt(self) -> ResultDict:
|
||||
"""Attempts a single training step, including evaluation, if required.
|
||||
|
||||
Override this method in your Trainer sub-classes if you would like to
|
||||
keep the n step-attempts logic (catch worker failures) in place or
|
||||
override `step()` directly if you would like to handle worker
|
||||
failures yourself.
|
||||
|
||||
Returns:
|
||||
The results dict with stats/infos on sampling, training,
|
||||
and - if required - evaluation.
|
||||
"""
|
||||
|
||||
def auto_duration_fn(unit, num_eval_workers, eval_cfg, num_units_done):
|
||||
# Training is done and we already ran at least one
|
||||
# evaluation -> Nothing left to run.
|
||||
if num_units_done > 0 and train_future.done():
|
||||
return 0
|
||||
# Count by episodes. -> Run n more
|
||||
# (n=num eval workers).
|
||||
elif unit == "episodes":
|
||||
return num_eval_workers
|
||||
# Count by timesteps. -> Run n*m*p more
|
||||
# (n=num eval workers; m=rollout fragment length;
|
||||
# p=num-envs-per-worker).
|
||||
else:
|
||||
return (
|
||||
num_eval_workers
|
||||
* eval_cfg["rollout_fragment_length"]
|
||||
* eval_cfg["num_envs_per_worker"]
|
||||
)
|
||||
|
||||
# self._iteration gets incremented after this function returns,
|
||||
# meaning that e. g. the first time this function is called,
|
||||
# self._iteration will be 0.
|
||||
evaluate_this_iter = (
|
||||
self.config["evaluation_interval"]
|
||||
and (self._iteration + 1) % self.config["evaluation_interval"] == 0
|
||||
)
|
||||
|
||||
step_results = {}
|
||||
|
||||
# No evaluation necessary, just run the next training iteration.
|
||||
if not evaluate_this_iter:
|
||||
step_results = self._exec_plan_or_training_iteration_fn()
|
||||
# We have to evaluate in this training iteration.
|
||||
else:
|
||||
# No parallelism.
|
||||
if not self.config["evaluation_parallel_to_training"]:
|
||||
step_results = self._exec_plan_or_training_iteration_fn()
|
||||
|
||||
# Kick off evaluation-loop (and parallel train() call,
|
||||
# if requested).
|
||||
# Parallel eval + training.
|
||||
if self.config["evaluation_parallel_to_training"]:
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
train_future = executor.submit(
|
||||
lambda: self._exec_plan_or_training_iteration_fn()
|
||||
)
|
||||
# Automatically determine duration of the evaluation.
|
||||
if self.config["evaluation_duration"] == "auto":
|
||||
unit = self.config["evaluation_duration_unit"]
|
||||
step_results.update(
|
||||
self.evaluate(
|
||||
duration_fn=functools.partial(
|
||||
auto_duration_fn,
|
||||
unit,
|
||||
self.config["evaluation_num_workers"],
|
||||
self.config["evaluation_config"],
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
step_results.update(self.evaluate())
|
||||
# Collect the training results from the future.
|
||||
step_results.update(train_future.result())
|
||||
# Sequential: train (already done above), then eval.
|
||||
else:
|
||||
step_results.update(self.evaluate())
|
||||
|
||||
# Attach latest available evaluation results to train results,
|
||||
# if necessary.
|
||||
if not evaluate_this_iter and self.config["always_attach_evaluation_results"]:
|
||||
assert isinstance(
|
||||
self.evaluation_metrics, dict
|
||||
), "Trainer.evaluate() needs to return a dict."
|
||||
step_results.update(self.evaluation_metrics)
|
||||
|
||||
return step_results
|
||||
|
||||
@PublicAPI
|
||||
def evaluate(
|
||||
self,
|
||||
|
@ -803,7 +791,8 @@ class Trainer(Trainable):
|
|||
else (self.config["evaluation_num_workers"] or 1)
|
||||
* (1 if unit == "episodes" else rollout)
|
||||
)
|
||||
num_ts_run = 0
|
||||
agent_steps_this_iter = 0
|
||||
env_steps_this_iter = 0
|
||||
|
||||
# Default done-function returns True, whenever num episodes
|
||||
# have been completed.
|
||||
|
@ -825,7 +814,9 @@ class Trainer(Trainable):
|
|||
# `rollout_fragment_length` is exactly the desired ts.
|
||||
iters = duration if unit == "episodes" else 1
|
||||
for _ in range(iters):
|
||||
num_ts_run += len(self.workers.local_worker().sample())
|
||||
batch = self.workers.local_worker().sample()
|
||||
agent_steps_this_iter += batch.agent_steps()
|
||||
env_steps_this_iter += batch.env_steps()
|
||||
metrics = collect_metrics(
|
||||
self.workers.local_worker(),
|
||||
keep_custom_metrics=self.config["keep_per_episode_custom_metrics"],
|
||||
|
@ -839,7 +830,9 @@ class Trainer(Trainable):
|
|||
# `rollout_fragment_length` is exactly the desired ts.
|
||||
iters = duration if unit == "episodes" else 1
|
||||
for _ in range(iters):
|
||||
num_ts_run += len(self.evaluation_workers.local_worker().sample())
|
||||
batch = self.evaluation_workers.local_worker().sample()
|
||||
agent_steps_this_iter += batch.agent_steps()
|
||||
env_steps_this_iter += batch.env_steps()
|
||||
|
||||
# Evaluation worker set has n remote workers.
|
||||
else:
|
||||
|
@ -862,14 +855,18 @@ class Trainer(Trainable):
|
|||
< units_left_to_do
|
||||
]
|
||||
)
|
||||
agent_steps_this_iter = sum(b.agent_steps() for b in batches)
|
||||
env_steps_this_iter = sum(b.env_steps() for b in batches)
|
||||
# 1 episode per returned batch.
|
||||
if unit == "episodes":
|
||||
num_units_done += len(batches)
|
||||
# n timesteps per returned batch.
|
||||
else:
|
||||
ts = sum(len(b) for b in batches)
|
||||
num_ts_run += ts
|
||||
num_units_done += ts
|
||||
num_units_done += (
|
||||
agent_steps_this_iter
|
||||
if self._by_agent_steps
|
||||
else env_steps_this_iter
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Ran round {round_} of parallel evaluation "
|
||||
|
@ -882,7 +879,10 @@ class Trainer(Trainable):
|
|||
self.evaluation_workers.remote_workers(),
|
||||
keep_custom_metrics=self.config["keep_per_episode_custom_metrics"],
|
||||
)
|
||||
metrics["timesteps_this_iter"] = num_ts_run
|
||||
metrics[NUM_AGENT_STEPS_SAMPLED_THIS_ITER] = agent_steps_this_iter
|
||||
metrics[NUM_ENV_STEPS_SAMPLED_THIS_ITER] = env_steps_this_iter
|
||||
# TODO: Revmoe this key atv some point. Here for backward compatibility.
|
||||
metrics["timesteps_this_iter"] = env_steps_this_iter
|
||||
|
||||
# Evaluation does not run for every step.
|
||||
# Save evaluation metrics on trainer, so it can be attached to
|
||||
|
@ -894,7 +894,7 @@ class Trainer(Trainable):
|
|||
|
||||
@OverrideToImplementCustomLogic
|
||||
@DeveloperAPI
|
||||
def training_iteration(self) -> ResultDict:
|
||||
def training_step(self) -> ResultDict:
|
||||
"""Default single iteration logic of an algorithm.
|
||||
|
||||
- Collect on-policy samples (SampleBatches) in parallel using the
|
||||
|
@ -946,7 +946,7 @@ class Trainer(Trainable):
|
|||
raise NotImplementedError(
|
||||
"It is not longer recommended to use Trainer's `execution_plan` method/API."
|
||||
" Set `_disable_execution_plan_api=True` in your config and override the "
|
||||
"`Trainer.training_iteration()` method with your algo's custom "
|
||||
"`Trainer.training_step()` method with your algo's custom "
|
||||
"execution logic."
|
||||
)
|
||||
|
||||
|
@ -1692,10 +1692,10 @@ class Trainer(Trainable):
|
|||
weights = ray.put(self.workers.local_worker().save())
|
||||
worker_set.foreach_worker(lambda w: w.restore(ray.get(weights)))
|
||||
|
||||
def _exec_plan_or_training_iteration_fn(self):
|
||||
def _exec_plan_or_training_step_fn(self):
|
||||
with self._timers[TRAINING_ITERATION_TIMER]:
|
||||
if self.config["_disable_execution_plan_api"]:
|
||||
results = self.training_iteration()
|
||||
results = self.training_step()
|
||||
else:
|
||||
results = next(self.train_exec_impl)
|
||||
return results
|
||||
|
@ -1994,10 +1994,44 @@ class Trainer(Trainable):
|
|||
if config.get("min_iter_time_s", DEPRECATED_VALUE) != DEPRECATED_VALUE:
|
||||
deprecation_warning(
|
||||
old="min_iter_time_s",
|
||||
new="min_time_s_per_reporting",
|
||||
new="min_time_s_per_iteration",
|
||||
error=False,
|
||||
)
|
||||
config["min_time_s_per_reporting"] = config["min_iter_time_s"] or 0
|
||||
config["min_time_s_per_iteration"] = config["min_iter_time_s"] or 0
|
||||
|
||||
if config.get("min_time_s_per_reporting", DEPRECATED_VALUE) != DEPRECATED_VALUE:
|
||||
deprecation_warning(
|
||||
old="min_time_s_per_reporting",
|
||||
new="min_time_s_per_iteration",
|
||||
error=False,
|
||||
)
|
||||
config["min_time_s_per_iteration"] = config["min_time_s_per_reporting"] or 0
|
||||
|
||||
if (
|
||||
config.get("min_sample_timesteps_per_reporting", DEPRECATED_VALUE)
|
||||
!= DEPRECATED_VALUE
|
||||
):
|
||||
deprecation_warning(
|
||||
old="min_sample_timesteps_per_reporting",
|
||||
new="min_sample_timesteps_per_iteration",
|
||||
error=False,
|
||||
)
|
||||
config["min_sample_timesteps_per_iteration"] = (
|
||||
config["min_sample_timesteps_per_reporting"] or 0
|
||||
)
|
||||
|
||||
if (
|
||||
config.get("min_train_timesteps_per_reporting", DEPRECATED_VALUE)
|
||||
!= DEPRECATED_VALUE
|
||||
):
|
||||
deprecation_warning(
|
||||
old="min_train_timesteps_per_reporting",
|
||||
new="min_train_timesteps_per_iteration",
|
||||
error=False,
|
||||
)
|
||||
config["min_train_timesteps_per_iteration"] = (
|
||||
config["min_train_timesteps_per_reporting"] or 0
|
||||
)
|
||||
|
||||
if config.get("collect_metrics_timeout", DEPRECATED_VALUE) != DEPRECATED_VALUE:
|
||||
# TODO: Warn once all algos use the `training_iteration` method.
|
||||
|
@ -2013,11 +2047,11 @@ class Trainer(Trainable):
|
|||
if config.get("timesteps_per_iteration", DEPRECATED_VALUE) != DEPRECATED_VALUE:
|
||||
deprecation_warning(
|
||||
old="timesteps_per_iteration",
|
||||
new="`min_sample_timesteps_per_reporting` OR "
|
||||
"`min_train_timesteps_per_reporting`",
|
||||
new="`min_sample_timesteps_per_iteration` OR "
|
||||
"`min_train_timesteps_per_iteration`",
|
||||
error=False,
|
||||
)
|
||||
config["min_sample_timesteps_per_reporting"] = (
|
||||
config["min_sample_timesteps_per_iteration"] = (
|
||||
config["timesteps_per_iteration"] or 0
|
||||
)
|
||||
config["timesteps_per_iteration"] = DEPRECATED_VALUE
|
||||
|
@ -2098,15 +2132,22 @@ class Trainer(Trainable):
|
|||
pass
|
||||
|
||||
def try_recover_from_step_attempt(self) -> None:
|
||||
"""Try to identify and remove any unhealthy workers.
|
||||
"""Try to identify and remove any unhealthy workers (incl. eval workers).
|
||||
|
||||
This method is called after an unexpected remote error is encountered
|
||||
from a worker during the call to `self.step_attempt()` (within
|
||||
`self.step()`). It issues check requests to all current workers and
|
||||
removes any that respond with error. If no healthy workers remain,
|
||||
an error is raised. Otherwise, tries to re-build the execution plan
|
||||
with the remaining (healthy) workers.
|
||||
from a worker during the call to `self.step()`. It issues check requests to
|
||||
all current workers and removes any that respond with error. If no healthy
|
||||
workers remain, an error is raised.
|
||||
"""
|
||||
# Try to get our "eval" WorkerSet (used for evaluating policies).
|
||||
eval_workers = getattr(self, "evaluation_workers", None)
|
||||
if isinstance(eval_workers, WorkerSet):
|
||||
# Search for failed workers and try to recover (restart) them.
|
||||
if self.config["evaluation_config"].get("recreate_failed_workers") is True:
|
||||
eval_workers.recreate_failed_workers()
|
||||
elif self.config["evaluation_config"].get("ignore_worker_failures") is True:
|
||||
eval_workers.remove_failed_workers()
|
||||
|
||||
# Try to get our "main" WorkerSet (used for training sample collection).
|
||||
workers = getattr(self, "workers", None)
|
||||
if not isinstance(workers, WorkerSet):
|
||||
|
@ -2264,9 +2305,32 @@ class Trainer(Trainable):
|
|||
kwargs["local_replay_buffer"] = self.local_replay_buffer
|
||||
return kwargs
|
||||
|
||||
@staticmethod
|
||||
def _auto_duration_fn(
|
||||
unit, num_eval_workers, eval_cfg, train_future, num_units_done
|
||||
):
|
||||
# Training is done and we already ran at least one
|
||||
# evaluation -> Nothing left to run.
|
||||
if num_units_done > 0 and train_future.done():
|
||||
return 0
|
||||
# Count by episodes. -> Run n more
|
||||
# (n=num eval workers).
|
||||
elif unit == "episodes":
|
||||
return num_eval_workers
|
||||
# Count by timesteps. -> Run n*m*p more
|
||||
# (n=num eval workers; m=rollout fragment length;
|
||||
# p=num-envs-per-worker).
|
||||
else:
|
||||
return (
|
||||
num_eval_workers
|
||||
* eval_cfg["rollout_fragment_length"]
|
||||
* eval_cfg["num_envs_per_worker"]
|
||||
)
|
||||
|
||||
def _step_context(trainer):
|
||||
class StepCtx:
|
||||
def __enter__(self):
|
||||
self.started = False
|
||||
# Before first call to `step()`, `result` is expected to be None ->
|
||||
# Start with self.failures=-1 -> set to 0 before the very first call
|
||||
# to `self.step()`.
|
||||
|
@ -2293,9 +2357,9 @@ class Trainer(Trainable):
|
|||
|
||||
def should_stop(self, result):
|
||||
|
||||
# Before first call to `step()`, `result` is expected to be None ->
|
||||
# self.failures=0.
|
||||
if result is None:
|
||||
# Before first call to `step()`.
|
||||
if self.started is False:
|
||||
self.started = True
|
||||
# Fail after n retries.
|
||||
self.failures += 1
|
||||
if self.failures > self.failure_tolerance:
|
||||
|
@ -2332,9 +2396,9 @@ class Trainer(Trainable):
|
|||
- self.init_env_steps_trained
|
||||
)
|
||||
|
||||
min_t = trainer.config["min_time_s_per_reporting"]
|
||||
min_sample_ts = trainer.config["min_sample_timesteps_per_reporting"]
|
||||
min_train_ts = trainer.config["min_train_timesteps_per_reporting"]
|
||||
min_t = trainer.config["min_time_s_per_iteration"]
|
||||
min_sample_ts = trainer.config["min_sample_timesteps_per_iteration"]
|
||||
min_train_ts = trainer.config["min_train_timesteps_per_iteration"]
|
||||
# Repeat if not enough time has passed or if not enough
|
||||
# env|train timesteps have been processed (or these min
|
||||
# values are not provided by the user).
|
||||
|
@ -2354,21 +2418,21 @@ class Trainer(Trainable):
|
|||
|
||||
return StepCtx()
|
||||
|
||||
def _compile_step_results(self, *, step_ctx, step_attempt_results=None):
|
||||
def _compile_iteration_results(self, *, step_ctx, iteration_results=None):
|
||||
# Return dict.
|
||||
results: ResultDict = {}
|
||||
step_attempt_results = step_attempt_results or {}
|
||||
iteration_results = iteration_results or {}
|
||||
|
||||
# Evaluation results.
|
||||
if "evaluation" in step_attempt_results:
|
||||
results["evaluation"] = step_attempt_results.pop("evaluation")
|
||||
if "evaluation" in iteration_results:
|
||||
results["evaluation"] = iteration_results.pop("evaluation")
|
||||
|
||||
# Custom metrics and episode media.
|
||||
results["custom_metrics"] = step_attempt_results.pop("custom_metrics", {})
|
||||
results["episode_media"] = step_attempt_results.pop("episode_media", {})
|
||||
results["custom_metrics"] = iteration_results.pop("custom_metrics", {})
|
||||
results["episode_media"] = iteration_results.pop("episode_media", {})
|
||||
|
||||
# Learner info.
|
||||
results["info"] = {LEARNER_INFO: step_attempt_results}
|
||||
results["info"] = {LEARNER_INFO: iteration_results}
|
||||
|
||||
episodes = self._rollout_worker_metrics
|
||||
orig_episodes = list(episodes)
|
||||
|
@ -2434,6 +2498,10 @@ class Trainer(Trainable):
|
|||
def compute_action(self, *args, **kwargs):
|
||||
return self.compute_single_action(*args, **kwargs)
|
||||
|
||||
@Deprecated(new="logic moved into `self.step()`", error=True)
|
||||
def step_attempt(self):
|
||||
pass
|
||||
|
||||
@Deprecated(new="construct WorkerSet(...) instance directly", error=False)
|
||||
def _make_workers(
|
||||
self,
|
||||
|
|
|
@ -194,9 +194,9 @@ class TrainerConfig:
|
|||
self.keep_per_episode_custom_metrics = False
|
||||
self.metrics_episode_collection_timeout_s = 180
|
||||
self.metrics_num_episodes_for_smoothing = 100
|
||||
self.min_time_s_per_reporting = None
|
||||
self.min_train_timesteps_per_reporting = 0
|
||||
self.min_sample_timesteps_per_reporting = 0
|
||||
self.min_time_s_per_iteration = None
|
||||
self.min_train_timesteps_per_iteration = 0
|
||||
self.min_sample_timesteps_per_iteration = 0
|
||||
|
||||
# `self.debugging()`
|
||||
self.logger_creator = None
|
||||
|
@ -232,6 +232,9 @@ class TrainerConfig:
|
|||
self.prioritized_replay_alpha = DEPRECATED_VALUE
|
||||
self.prioritized_replay_beta = DEPRECATED_VALUE
|
||||
self.prioritized_replay_eps = DEPRECATED_VALUE
|
||||
self.min_time_s_per_reporting = DEPRECATED_VALUE
|
||||
self.min_train_timesteps_per_reporting = DEPRECATED_VALUE
|
||||
self.min_sample_timesteps_per_reporting = DEPRECATED_VALUE
|
||||
self.input_evaluation = DEPRECATED_VALUE
|
||||
|
||||
def to_dict(self) -> TrainerConfigDict:
|
||||
|
@ -1095,9 +1098,9 @@ class TrainerConfig:
|
|||
keep_per_episode_custom_metrics: Optional[bool] = None,
|
||||
metrics_episode_collection_timeout_s: Optional[int] = None,
|
||||
metrics_num_episodes_for_smoothing: Optional[int] = None,
|
||||
min_time_s_per_reporting: Optional[int] = None,
|
||||
min_train_timesteps_per_reporting: Optional[int] = None,
|
||||
min_sample_timesteps_per_reporting: Optional[int] = None,
|
||||
min_time_s_per_iteration: Optional[int] = None,
|
||||
min_train_timesteps_per_iteration: Optional[int] = None,
|
||||
min_sample_timesteps_per_iteration: Optional[int] = None,
|
||||
) -> "TrainerConfig":
|
||||
"""Sets the config's reporting settings.
|
||||
|
||||
|
@ -1108,24 +1111,27 @@ class TrainerConfig:
|
|||
this many seconds. Those that have not returned in time will be
|
||||
collected in the next train iteration.
|
||||
metrics_num_episodes_for_smoothing: Smooth metrics over this many episodes.
|
||||
min_time_s_per_reporting: Minimum time interval to run one `train()` call
|
||||
for: If - after one `step_attempt()`, this time limit has not been
|
||||
reached, will perform n more `step_attempt()` calls until this minimum
|
||||
time has been consumed. Set to None or 0 for no minimum time.
|
||||
min_train_timesteps_per_reporting: Minimum training timesteps to accumulate
|
||||
min_time_s_per_iteration: Minimum time to accumulate within a single
|
||||
`train()` call. This value does not affect learning,
|
||||
only the number of times `Trainer.training_step()` is called by
|
||||
`Trainer.train()`. If - after one such step attempt, the time taken
|
||||
has not reached `min_time_s_per_iteration`, will perform n more
|
||||
`training_step()` calls until the minimum time has been
|
||||
consumed. Set to 0 or None for no minimum time.
|
||||
min_train_timesteps_per_iteration: Minimum training timesteps to accumulate
|
||||
within a single `train()` call. This value does not affect learning,
|
||||
only the number of times `Trainer.step_attempt()` is called by
|
||||
`Trauber.train()`. If - after one `step_attempt()`, the training
|
||||
only the number of times `Trainer.training_step()` is called by
|
||||
`Trainer.train()`. If - after one such step attempt, the training
|
||||
timestep count has not been reached, will perform n more
|
||||
`step_attempt()` calls until the minimum timesteps have been executed.
|
||||
Set to 0 for no minimum timesteps.
|
||||
min_sample_timesteps_per_reporting: Minimum env sampling timesteps to
|
||||
`training_step()` calls until the minimum timesteps have been
|
||||
executed. Set to 0 or None for no minimum timesteps.
|
||||
min_sample_timesteps_per_iteration: Minimum env sampling timesteps to
|
||||
accumulate within a single `train()` call. This value does not affect
|
||||
learning, only the number of times `Trainer.step_attempt()` is called by
|
||||
`Trauber.train()`. If - after one `step_attempt()`, the env sampling
|
||||
timestep count has not been reached, will perform n more
|
||||
`step_attempt()` calls until the minimum timesteps have been executed.
|
||||
Set to 0 for no minimum timesteps.
|
||||
learning, only the number of times `Trainer.training_step()` is
|
||||
called by `Trainer.train()`. If - after one such step attempt, the env
|
||||
sampling timestep count has not been reached, will perform n more
|
||||
`training_step()` calls until the minimum timesteps have been
|
||||
executed. Set to 0 or None for no minimum timesteps.
|
||||
|
||||
Returns:
|
||||
This updated TrainerConfig object.
|
||||
|
@ -1138,12 +1144,12 @@ class TrainerConfig:
|
|||
)
|
||||
if metrics_num_episodes_for_smoothing is not None:
|
||||
self.metrics_num_episodes_for_smoothing = metrics_num_episodes_for_smoothing
|
||||
if min_time_s_per_reporting is not None:
|
||||
self.min_time_s_per_reporting = min_time_s_per_reporting
|
||||
if min_train_timesteps_per_reporting is not None:
|
||||
self.min_train_timesteps_per_reporting = min_train_timesteps_per_reporting
|
||||
if min_sample_timesteps_per_reporting is not None:
|
||||
self.min_sample_timesteps_per_reporting = min_sample_timesteps_per_reporting
|
||||
if min_time_s_per_iteration is not None:
|
||||
self.min_time_s_per_iteration = min_time_s_per_iteration
|
||||
if min_train_timesteps_per_iteration is not None:
|
||||
self.min_train_timesteps_per_iteration = min_train_timesteps_per_iteration
|
||||
if min_sample_timesteps_per_iteration is not None:
|
||||
self.min_sample_timesteps_per_iteration = min_sample_timesteps_per_iteration
|
||||
|
||||
return self
|
||||
|
||||
|
|
|
@ -73,7 +73,7 @@ class A2CConfig(A3CConfig):
|
|||
# Override some of A3CConfig's default values with A2C-specific values.
|
||||
self.rollout_fragment_length = 20
|
||||
self.sample_async = False
|
||||
self.min_time_s_per_reporting = 10
|
||||
self.min_time_s_per_iteration = 10
|
||||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
|
||||
|
@ -145,12 +145,12 @@ class A2C(A3C):
|
|||
self._microbatches_counts = self._num_microbatches = 0
|
||||
|
||||
@override(A3C)
|
||||
def training_iteration(self) -> ResultDict:
|
||||
def training_step(self) -> ResultDict:
|
||||
# W/o microbatching: Identical to Trainer's default implementation.
|
||||
# Only difference to a default Trainer being the value function loss term
|
||||
# and its value computations alongside each action.
|
||||
if self.config["microbatch_size"] is None:
|
||||
return Trainer.training_iteration(self)
|
||||
return Trainer.training_step(self)
|
||||
|
||||
# In microbatch mode, we want to compute gradients on experience
|
||||
# microbatches, average a number of these microbatches, and then
|
||||
|
|
|
@ -39,7 +39,7 @@ class TestA2C(unittest.TestCase):
|
|||
config = (
|
||||
a2c.A2CConfig()
|
||||
.environment(env="CartPole-v0")
|
||||
.reporting(min_time_s_per_reporting=0)
|
||||
.reporting(min_time_s_per_iteration=0)
|
||||
)
|
||||
|
||||
for _ in framework_iterator(config):
|
||||
|
@ -54,7 +54,7 @@ class TestA2C(unittest.TestCase):
|
|||
config = (
|
||||
a2c.A2CConfig()
|
||||
.environment(env="CartPole-v0")
|
||||
.reporting(min_time_s_per_reporting=0)
|
||||
.reporting(min_time_s_per_iteration=0)
|
||||
.training(microbatch_size=10)
|
||||
)
|
||||
|
||||
|
|
|
@ -85,7 +85,7 @@ class A3CConfig(TrainerConfig):
|
|||
# This causes not every call to `training_iteration` to be reported,
|
||||
# but to wait until n seconds have passed and then to summarize the
|
||||
# thus far collected results.
|
||||
self.min_time_s_per_reporting = 5
|
||||
self.min_time_s_per_iteration = 5
|
||||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
|
||||
|
@ -190,7 +190,7 @@ class A3C(Trainer):
|
|||
|
||||
return A3CEagerTFPolicy
|
||||
|
||||
def training_iteration(self) -> ResultDict:
|
||||
def training_step(self) -> ResultDict:
|
||||
# Shortcut.
|
||||
local_worker = self.workers.local_worker()
|
||||
|
||||
|
|
|
@ -61,7 +61,7 @@ class TestA3C(unittest.TestCase):
|
|||
# 0 metrics reporting delay, this makes sure timestep,
|
||||
# which entropy coeff depends on, is updated after each worker rollout.
|
||||
config.reporting(
|
||||
min_time_s_per_reporting=0, min_sample_timesteps_per_reporting=20
|
||||
min_time_s_per_iteration=0, min_sample_timesteps_per_iteration=20
|
||||
)
|
||||
|
||||
def _step_n_times(trainer, n: int):
|
||||
|
|
|
@ -140,7 +140,7 @@ class AlphaStarConfig(appo.APPOConfig):
|
|||
# Override some of APPOConfig's default values with AlphaStar-specific
|
||||
# values.
|
||||
self.vtrace_drop_last_ts = False
|
||||
self.min_time_s_per_reporting = 2
|
||||
self.min_time_s_per_iteration = 2
|
||||
self._disable_execution_plan_api = True
|
||||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
|
@ -415,7 +415,7 @@ class AlphaStar(appo.APPO):
|
|||
return result
|
||||
|
||||
@override(Trainer)
|
||||
def training_iteration(self) -> ResultDict:
|
||||
def training_step(self) -> ResultDict:
|
||||
# Trigger asynchronous rollouts on all RolloutWorkers.
|
||||
# - Rollout results are sent directly to correct replay buffer
|
||||
# shards, instead of here (to the driver).
|
||||
|
|
|
@ -319,7 +319,7 @@ class AlphaZero(Trainer):
|
|||
return AlphaZeroPolicyWrapperClass
|
||||
|
||||
@override(Trainer)
|
||||
def training_iteration(self) -> ResultDict:
|
||||
def training_step(self) -> ResultDict:
|
||||
"""TODO:
|
||||
|
||||
Returns:
|
||||
|
|
|
@ -66,8 +66,8 @@ class ApexDDPGConfig(DDPGConfig):
|
|||
self.exploration_config = {"type": "PerWorkerOrnsteinUhlenbeckNoise"}
|
||||
self.num_gpus = 0
|
||||
self.num_workers = 32
|
||||
self.min_sample_timesteps_per_reporting = 25000
|
||||
self.min_time_s_per_reporting = 30
|
||||
self.min_sample_timesteps_per_iteration = 25000
|
||||
self.min_time_s_per_iteration = 30
|
||||
self.train_batch_size = 512
|
||||
self.rollout_fragment_length = 50
|
||||
self.replay_buffer_config = {
|
||||
|
@ -182,9 +182,9 @@ class ApexDDPG(DDPG, ApexDQN):
|
|||
return ApexDQN.setup(self, config)
|
||||
|
||||
@override(DDPG)
|
||||
def training_iteration(self) -> ResultDict:
|
||||
def training_step(self) -> ResultDict:
|
||||
"""Use APEX-DQN's training iteration function."""
|
||||
return ApexDQN.training_iteration(self)
|
||||
return ApexDQN.training_step(self)
|
||||
|
||||
@override(Trainer)
|
||||
def on_worker_failures(
|
||||
|
|
|
@ -23,7 +23,7 @@ class TestApexDDPG(unittest.TestCase):
|
|||
config = (
|
||||
apex_ddpg.ApexDDPGConfig()
|
||||
.rollouts(num_rollout_workers=2)
|
||||
.reporting(min_sample_timesteps_per_reporting=100)
|
||||
.reporting(min_sample_timesteps_per_iteration=100)
|
||||
.training(
|
||||
replay_buffer_config={"learning_starts": 0},
|
||||
optimizer={"num_replay_buffer_shards": 1},
|
||||
|
|
|
@ -52,7 +52,6 @@ from ray.rllib.utils.typing import (
|
|||
ResultDict,
|
||||
PartialTrainerConfigDict,
|
||||
)
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.utils.placement_groups import PlacementGroupFactory
|
||||
from ray.util.ml_utils.dict import merge_dicts
|
||||
|
||||
|
@ -194,8 +193,8 @@ class ApexDQNConfig(DQNConfig):
|
|||
self.num_gpus = 1
|
||||
|
||||
# .reporting()
|
||||
self.min_time_s_per_reporting = 30
|
||||
self.min_sample_timesteps_per_reporting = 25000
|
||||
self.min_time_s_per_iteration = 30
|
||||
self.min_sample_timesteps_per_iteration = 25000
|
||||
|
||||
# fmt: on
|
||||
# __sphinx_doc_end__
|
||||
|
@ -351,7 +350,7 @@ class ApexDQNConfig(DQNConfig):
|
|||
|
||||
|
||||
class ApexDQN(DQN):
|
||||
@override(Trainable)
|
||||
@override(Trainer)
|
||||
def setup(self, config: PartialTrainerConfigDict):
|
||||
super().setup(config)
|
||||
|
||||
|
@ -434,8 +433,8 @@ class ApexDQN(DQN):
|
|||
# Call DQN's validation method.
|
||||
super().validate_config(config)
|
||||
|
||||
@override(Trainable)
|
||||
def training_iteration(self) -> ResultDict:
|
||||
@override(DQN)
|
||||
def training_step(self) -> ResultDict:
|
||||
num_samples_ready_dict = self.get_samples_and_store_to_replay_buffers()
|
||||
worker_samples_collected = defaultdict(int)
|
||||
|
||||
|
@ -656,9 +655,9 @@ class ApexDQN(DQN):
|
|||
self._sampling_actor_manager.add_workers(new_workers)
|
||||
|
||||
@override(Trainer)
|
||||
def _compile_step_results(self, *, step_ctx, step_attempt_results=None):
|
||||
result = super()._compile_step_results(
|
||||
step_ctx=step_ctx, step_attempt_results=step_attempt_results
|
||||
def _compile_iteration_results(self, *, step_ctx, iteration_results=None):
|
||||
result = super()._compile_iteration_results(
|
||||
step_ctx=step_ctx, iteration_results=iteration_results
|
||||
)
|
||||
replay_stats = ray.get(
|
||||
self._replay_actors[0].stats.remote(self.config["optimizer"].get("debug"))
|
||||
|
@ -680,7 +679,7 @@ class ApexDQN(DQN):
|
|||
return result
|
||||
|
||||
@classmethod
|
||||
@override(Trainable)
|
||||
@override(Trainer)
|
||||
def default_resource_request(cls, config):
|
||||
cf = dict(cls.get_default_config(), **config)
|
||||
|
||||
|
|
|
@ -34,8 +34,8 @@ class TestApexDQN(unittest.TestCase):
|
|||
},
|
||||
)
|
||||
.reporting(
|
||||
min_sample_timesteps_per_reporting=100,
|
||||
min_time_s_per_reporting=1,
|
||||
min_sample_timesteps_per_iteration=100,
|
||||
min_time_s_per_iteration=1,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -61,8 +61,8 @@ class TestApexDQN(unittest.TestCase):
|
|||
},
|
||||
)
|
||||
.reporting(
|
||||
min_sample_timesteps_per_reporting=100,
|
||||
min_time_s_per_reporting=1,
|
||||
min_sample_timesteps_per_iteration=100,
|
||||
min_time_s_per_iteration=1,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -123,10 +123,10 @@ class TestApexDQN(unittest.TestCase):
|
|||
lr_schedule=[[0, 0.2], [100, 0.001]],
|
||||
)
|
||||
.reporting(
|
||||
min_sample_timesteps_per_reporting=10,
|
||||
min_sample_timesteps_per_iteration=10,
|
||||
# 0 metrics reporting delay, this makes sure timestep,
|
||||
# which lr depends on, is updated after each worker rollout.
|
||||
min_time_s_per_reporting=0,
|
||||
min_time_s_per_iteration=0,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -84,7 +84,7 @@ class APPOConfig(ImpalaConfig):
|
|||
# Override some of ImpalaConfig's default values with APPO-specific values.
|
||||
self.rollout_fragment_length = 50
|
||||
self.train_batch_size = 500
|
||||
self.min_time_s_per_reporting = 10
|
||||
self.min_time_s_per_iteration = 10
|
||||
self.num_workers = 2
|
||||
self.num_gpus = 0
|
||||
self.num_multi_gpu_tower_stacks = 1
|
||||
|
@ -228,8 +228,8 @@ class APPO(Impala):
|
|||
self.workers.local_worker().foreach_policy_to_train(update)
|
||||
|
||||
@override(Impala)
|
||||
def training_iteration(self) -> ResultDict:
|
||||
train_results = super().training_iteration()
|
||||
def training_step(self) -> ResultDict:
|
||||
train_results = super().training_step()
|
||||
|
||||
# Update KL, target network periodically.
|
||||
self.after_train_step(train_results)
|
||||
|
|
|
@ -108,10 +108,10 @@ class TestAPPO(unittest.TestCase):
|
|||
)
|
||||
)
|
||||
|
||||
config.min_sample_timesteps_per_reporting = 20
|
||||
config.min_sample_timesteps_per_iteration = 20
|
||||
# 0 metrics reporting delay, this makes sure timestep,
|
||||
# which entropy coeff depends on, is updated after each worker rollout.
|
||||
config.min_time_s_per_reporting = 0
|
||||
config.min_time_s_per_iteration = 0
|
||||
|
||||
def _step_n_times(trainer, n: int):
|
||||
"""Step trainer n times.
|
||||
|
|
|
@ -397,7 +397,7 @@ class ARS(Trainer):
|
|||
return self.policy
|
||||
|
||||
@override(Trainer)
|
||||
def step_attempt(self):
|
||||
def step(self):
|
||||
config = self.config
|
||||
|
||||
theta = self.policy.get_flat_weights()
|
||||
|
|
|
@ -33,7 +33,7 @@ class BanditConfig(TrainerConfig):
|
|||
# Make sure, a `train()` call performs at least 100 env sampling
|
||||
# timesteps, before reporting results. Not setting this (default is 0)
|
||||
# would significantly slow down the Bandit Trainer.
|
||||
self.min_sample_timesteps_per_reporting = 100
|
||||
self.min_sample_timesteps_per_iteration = 100
|
||||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
|
||||
|
|
|
@ -70,8 +70,8 @@ class CQLConfig(SACConfig):
|
|||
self.off_policy_estimation_methods = {}
|
||||
|
||||
# .reporting()
|
||||
self.min_sample_timesteps_per_reporting = 0
|
||||
self.min_train_timesteps_per_reporting = 100
|
||||
self.min_sample_timesteps_per_iteration = 0
|
||||
self.min_train_timesteps_per_iteration = 100
|
||||
# fmt: on
|
||||
# __sphinx_doc_end__
|
||||
|
||||
|
@ -176,10 +176,10 @@ class CQL(SAC):
|
|||
if config.get("timesteps_per_iteration", DEPRECATED_VALUE) != DEPRECATED_VALUE:
|
||||
deprecation_warning(
|
||||
old="timesteps_per_iteration",
|
||||
new="min_train_timesteps_per_reporting",
|
||||
new="min_train_timesteps_per_iteration",
|
||||
error=False,
|
||||
)
|
||||
config["min_train_timesteps_per_reporting"] = config[
|
||||
config["min_train_timesteps_per_iteration"] = config[
|
||||
"timesteps_per_iteration"
|
||||
]
|
||||
config["timesteps_per_iteration"] = DEPRECATED_VALUE
|
||||
|
@ -213,7 +213,7 @@ class CQL(SAC):
|
|||
return CQLTFPolicy
|
||||
|
||||
@override(SAC)
|
||||
def training_iteration(self) -> ResultDict:
|
||||
def training_step(self) -> ResultDict:
|
||||
|
||||
# Sample training batch from replay buffer.
|
||||
train_batch = sample_min_n_steps_from_buffer(
|
||||
|
|
|
@ -208,7 +208,7 @@ class CRR(Trainer):
|
|||
raise ValueError("Non-torch frameworks are not supported yet!")
|
||||
|
||||
@override(Trainer)
|
||||
def training_iteration(self) -> ResultDict:
|
||||
def training_step(self) -> ResultDict:
|
||||
|
||||
total_transitions = len(self.local_replay_buffer)
|
||||
bsize = self.config["train_batch_size"]
|
||||
|
|
|
@ -148,8 +148,8 @@ class TestDDPG(unittest.TestCase):
|
|||
config.actor_hiddens = [10]
|
||||
config.critic_hiddens = [10]
|
||||
# Make sure, timing differences do not affect trainer.train().
|
||||
config.min_time_s_per_reporting = 0
|
||||
config.min_sample_timesteps_per_reporting = 100
|
||||
config.min_time_s_per_iteration = 0
|
||||
config.min_sample_timesteps_per_iteration = 100
|
||||
|
||||
map_ = {
|
||||
# Normal net.
|
||||
|
|
|
@ -282,7 +282,7 @@ class DDPPO(PPO):
|
|||
)
|
||||
|
||||
@override(PPO)
|
||||
def training_iteration(self) -> ResultDict:
|
||||
def training_step(self) -> ResultDict:
|
||||
# Shortcut.
|
||||
first_worker = self.workers.remote_workers()[0]
|
||||
|
||||
|
|
|
@ -41,7 +41,6 @@ from ray.rllib.utils.metrics import (
|
|||
from ray.rllib.utils.deprecation import (
|
||||
Deprecated,
|
||||
)
|
||||
from ray.rllib.utils.annotations import ExperimentalAPI
|
||||
from ray.rllib.utils.metrics import SYNCH_WORKER_WEIGHTS_TIMER
|
||||
from ray.rllib.execution.common import (
|
||||
LAST_TARGET_UPDATE_TS,
|
||||
|
@ -333,8 +332,8 @@ class DQN(SimpleQ):
|
|||
else:
|
||||
return DQNTFPolicy
|
||||
|
||||
@ExperimentalAPI
|
||||
def training_iteration(self) -> ResultDict:
|
||||
@override(SimpleQ)
|
||||
def training_step(self) -> ResultDict:
|
||||
"""DQN training iteration function.
|
||||
|
||||
Each training iteration, we:
|
||||
|
|
|
@ -377,7 +377,7 @@ class Dreamer(Trainer):
|
|||
return rollouts
|
||||
|
||||
@override(Trainer)
|
||||
def training_iteration(self) -> ResultDict:
|
||||
def training_step(self) -> ResultDict:
|
||||
local_worker = self.workers.local_worker()
|
||||
|
||||
# Number of sub-iterations for Dreamer
|
||||
|
@ -412,8 +412,8 @@ class Dreamer(Trainer):
|
|||
|
||||
return fetches
|
||||
|
||||
def _compile_step_results(self, *args, **kwargs):
|
||||
results = super()._compile_step_results(*args, **kwargs)
|
||||
def _compile_iteration_results(self, *args, **kwargs):
|
||||
results = super()._compile_iteration_results(*args, **kwargs)
|
||||
results["timesteps_total"] = self._counters[STEPS_SAMPLED_COUNTER]
|
||||
return results
|
||||
|
||||
|
|
|
@ -403,7 +403,7 @@ class ES(Trainer):
|
|||
return self.policy
|
||||
|
||||
@override(Trainer)
|
||||
def step_attempt(self):
|
||||
def step(self):
|
||||
config = self.config
|
||||
|
||||
theta = self.policy.get_flat_weights()
|
||||
|
|
|
@ -133,7 +133,7 @@ class ImpalaConfig(TrainerConfig):
|
|||
self.num_workers = 2
|
||||
self.num_gpus = 1
|
||||
self.lr = 0.0005
|
||||
self.min_time_s_per_reporting = 10
|
||||
self.min_time_s_per_iteration = 10
|
||||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
|
||||
|
@ -608,7 +608,7 @@ class Impala(Trainer):
|
|||
self.workers_that_need_updates = set()
|
||||
|
||||
@override(Trainer)
|
||||
def training_iteration(self) -> ResultDict:
|
||||
def training_step(self) -> ResultDict:
|
||||
unprocessed_sample_batches = self.get_samples_from_workers()
|
||||
|
||||
self.workers_that_need_updates |= unprocessed_sample_batches.keys()
|
||||
|
@ -901,9 +901,9 @@ class Impala(Trainer):
|
|||
self._sampling_actor_manager.add_workers(new_workers)
|
||||
|
||||
@override(Trainer)
|
||||
def _compile_step_results(self, *, step_ctx, step_attempt_results=None):
|
||||
result = super()._compile_step_results(
|
||||
step_ctx=step_ctx, step_attempt_results=step_attempt_results
|
||||
def _compile_iteration_results(self, *, step_ctx, iteration_results=None):
|
||||
result = super()._compile_iteration_results(
|
||||
step_ctx=step_ctx, iteration_results=iteration_results
|
||||
)
|
||||
result = self._learner_thread.add_learner_metrics(
|
||||
result, overwrite_learner_info=False
|
||||
|
|
|
@ -101,7 +101,7 @@ class MADDPGConfig(TrainerConfig):
|
|||
self.rollout_fragment_length = 100
|
||||
self.train_batch_size = 1024
|
||||
self.num_workers = 1
|
||||
self.min_time_s_per_reporting = 0
|
||||
self.min_time_s_per_iteration = 0
|
||||
# fmt: on
|
||||
# __sphinx_doc_end__
|
||||
|
||||
|
|
|
@ -251,7 +251,7 @@ class MARWIL(Trainer):
|
|||
return MARWILTF2Policy
|
||||
|
||||
@override(Trainer)
|
||||
def training_iteration(self) -> ResultDict:
|
||||
def training_step(self) -> ResultDict:
|
||||
# Collect SampleBatches from sample workers.
|
||||
batch = synchronous_parallel_sample(worker_set=self.workers)
|
||||
batch = batch.as_multi_agent()
|
||||
|
|
|
@ -380,7 +380,7 @@ class PPO(Trainer):
|
|||
return PPOTF2Policy
|
||||
|
||||
@ExperimentalAPI
|
||||
def training_iteration(self) -> ResultDict:
|
||||
def training_step(self) -> ResultDict:
|
||||
# Collect SampleBatches from sample workers until we have a full batch.
|
||||
if self._by_agent_steps:
|
||||
train_batch = synchronous_parallel_sample(
|
||||
|
|
|
@ -103,8 +103,8 @@ class QMixConfig(SimpleQConfig):
|
|||
self.batch_mode = "complete_episodes"
|
||||
|
||||
# .reporting()
|
||||
self.min_time_s_per_reporting = 1
|
||||
self.min_sample_timesteps_per_reporting = 1000
|
||||
self.min_time_s_per_iteration = 1
|
||||
self.min_sample_timesteps_per_iteration = 1000
|
||||
|
||||
# .exploration()
|
||||
self.exploration_config = {
|
||||
|
@ -226,7 +226,7 @@ class QMix(SimpleQ):
|
|||
return QMixTorchPolicy
|
||||
|
||||
@override(SimpleQ)
|
||||
def training_iteration(self) -> ResultDict:
|
||||
def training_step(self) -> ResultDict:
|
||||
"""QMIX training iteration function.
|
||||
|
||||
- Sample n MultiAgentBatches from n workers synchronously.
|
||||
|
|
|
@ -92,8 +92,8 @@ class SACConfig(TrainerConfig):
|
|||
self.train_batch_size = 256
|
||||
|
||||
# .reporting()
|
||||
self.min_time_s_per_reporting = 1
|
||||
self.min_sample_timesteps_per_reporting = 100
|
||||
self.min_time_s_per_iteration = 1
|
||||
self.min_sample_timesteps_per_iteration = 100
|
||||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
|
||||
|
|
|
@ -176,7 +176,7 @@ class TestSAC(unittest.TestCase):
|
|||
)
|
||||
.rollouts(num_rollout_workers=0)
|
||||
.reporting(
|
||||
min_time_s_per_reporting=0,
|
||||
min_time_s_per_iteration=0,
|
||||
)
|
||||
.environment(
|
||||
env_config={"simplex_actions": True},
|
||||
|
|
|
@ -147,8 +147,8 @@ class SimpleQConfig(TrainerConfig):
|
|||
self.evaluation_config = {"explore": False}
|
||||
|
||||
# `reporting()`
|
||||
self.min_time_s_per_reporting = 1
|
||||
self.min_sample_timesteps_per_reporting = 1000
|
||||
self.min_time_s_per_iteration = 1
|
||||
self.min_sample_timesteps_per_iteration = 1000
|
||||
|
||||
# Deprecated.
|
||||
self.buffer_size = DEPRECATED_VALUE
|
||||
|
@ -312,7 +312,7 @@ class SimpleQ(Trainer):
|
|||
|
||||
@ExperimentalAPI
|
||||
@override(Trainer)
|
||||
def training_iteration(self) -> ResultDict:
|
||||
def training_step(self) -> ResultDict:
|
||||
"""Simple Q training iteration function.
|
||||
|
||||
Simple Q consists of the following steps:
|
||||
|
|
|
@ -113,8 +113,8 @@ class SlateQConfig(TrainerConfig):
|
|||
self.rollout_fragment_length = 4
|
||||
self.train_batch_size = 32
|
||||
self.lr = 0.00025
|
||||
self.min_sample_timesteps_per_reporting = 1000
|
||||
self.min_time_s_per_reporting = 1
|
||||
self.min_sample_timesteps_per_iteration = 1000
|
||||
self.min_time_s_per_iteration = 1
|
||||
self.compress_observations = False
|
||||
self._disable_preprocessor_api = True
|
||||
# __sphinx_doc_end__
|
||||
|
|
|
@ -51,7 +51,7 @@ if __name__ == "__main__":
|
|||
}
|
||||
|
||||
# Actual env steps per `train()` call will be
|
||||
# 10 * `min_sample_timesteps_per_reporting` (100 by default) = 1,000
|
||||
# 10 * `min_sample_timesteps_per_iteration` (100 by default) = 1,000
|
||||
training_iterations = 10
|
||||
|
||||
print("Running training for %s time steps" % training_iterations)
|
||||
|
|
|
@ -63,7 +63,7 @@ if __name__ == "__main__":
|
|||
}
|
||||
|
||||
# Actual env timesteps per `train()` call will be
|
||||
# 10 * min_sample_timesteps_per_reporting (100 by default) = 1,000.
|
||||
# 10 * min_sample_timesteps_per_iteration (100 by default) = 1,000.
|
||||
training_iterations = 10
|
||||
|
||||
print("Running training for %s time steps" % training_iterations)
|
||||
|
|
|
@ -42,7 +42,7 @@ if __name__ == "__main__":
|
|||
}
|
||||
|
||||
# Actual env timesteps per `train()` call will be
|
||||
# 10 * min_sample_timesteps_per_reporting (100 by default) = 1,000
|
||||
# 10 * min_sample_timesteps_per_iteration (100 by default) = 1,000
|
||||
training_iterations = 5000
|
||||
|
||||
print("Running training for %s time steps" % training_iterations)
|
||||
|
|
|
@ -34,7 +34,7 @@ if __name__ == "__main__":
|
|||
"n_step": 3,
|
||||
"lr": 0.0001,
|
||||
"target_network_update_freq": 50000,
|
||||
"min_sample_timesteps_per_reporting": 25000,
|
||||
"min_sample_timesteps_per_iteration": 25000,
|
||||
# Method specific.
|
||||
"multiagent": {
|
||||
# We only have one policy (calling it "shared").
|
||||
|
|
|
@ -60,7 +60,7 @@ if __name__ == "__main__":
|
|||
}
|
||||
config["train_batch_size"] = 256
|
||||
config["target_network_update_freq"] = 1
|
||||
config["min_train_timesteps_per_reporting"] = 1000
|
||||
config["min_train_timesteps_per_iteration"] = 1000
|
||||
data_file = "/path/to/my/json_file.json"
|
||||
print("data_file={} exists={}".format(data_file, os.path.isfile(data_file)))
|
||||
config["input"] = [data_file]
|
||||
|
|
|
@ -69,7 +69,7 @@ class RandomParametricTrainer(Trainer):
|
|||
return RandomParametricPolicy
|
||||
|
||||
@override(Trainer)
|
||||
def training_iteration(self) -> ResultDict:
|
||||
def training_step(self) -> ResultDict:
|
||||
# Perform rollouts (only for collecting metrics later).
|
||||
synchronous_parallel_sample(worker_set=self.workers)
|
||||
|
||||
|
|
|
@ -181,7 +181,7 @@ if __name__ == "__main__":
|
|||
config.update(
|
||||
{
|
||||
"replay_buffer_config": {"learning_starts": 100},
|
||||
"min_sample_timesteps_per_reporting": 200,
|
||||
"min_sample_timesteps_per_iteration": 200,
|
||||
"n_step": 3,
|
||||
"rollout_fragment_length": 4,
|
||||
"train_batch_size": 8,
|
||||
|
|
|
@ -85,7 +85,7 @@ if __name__ == "__main__":
|
|||
config["lambda"] = 0.95
|
||||
config["log_level"] = "WARN"
|
||||
config["lr"] = 0.001
|
||||
config["min_time_s_per_reporting"] = 5
|
||||
config["min_time_s_per_iteration"] = 5
|
||||
config["num_gpus"] = int(os.environ.get("RLLIB_NUM_GPUS", "0"))
|
||||
config["num_workers"] = args.num_workers
|
||||
config["rollout_fragment_length"] = 200
|
||||
|
|
|
@ -88,7 +88,7 @@ class MyTrainer(Trainer):
|
|||
)
|
||||
|
||||
@override(Trainer)
|
||||
def training_iteration(self) -> ResultDict:
|
||||
def training_step(self) -> ResultDict:
|
||||
# Generate common experiences, collect batch for PPO, store every (DQN) batch
|
||||
# into replay buffer.
|
||||
ppo_batches = []
|
||||
|
|
|
@ -50,13 +50,13 @@ def StandardMetricsReporting(
|
|||
output_op = (
|
||||
train_op.filter(
|
||||
OncePerTimestepsElapsed(
|
||||
config["min_train_timesteps_per_reporting"] or 0
|
||||
config["min_train_timesteps_per_iteration"] or 0
|
||||
if by_steps_trained
|
||||
else config["min_sample_timesteps_per_reporting"] or 0,
|
||||
else config["min_sample_timesteps_per_iteration"] or 0,
|
||||
by_steps_trained=by_steps_trained,
|
||||
)
|
||||
)
|
||||
.filter(OncePerTimeInterval(config["min_time_s_per_reporting"]))
|
||||
.filter(OncePerTimeInterval(config["min_time_s_per_iteration"]))
|
||||
.for_each(
|
||||
CollectMetrics(
|
||||
workers,
|
||||
|
|
|
@ -24,7 +24,7 @@ CONFIGS = {
|
|||
"explore": False,
|
||||
"observation_filter": "MeanStdFilter",
|
||||
"num_workers": 2,
|
||||
"min_time_s_per_reporting": 1,
|
||||
"min_time_s_per_iteration": 1,
|
||||
"optimizer": {
|
||||
"num_replay_buffer_shards": 1,
|
||||
},
|
||||
|
@ -38,7 +38,7 @@ CONFIGS = {
|
|||
},
|
||||
"DDPG": {
|
||||
"explore": False,
|
||||
"min_sample_timesteps_per_reporting": 100,
|
||||
"min_sample_timesteps_per_iteration": 100,
|
||||
},
|
||||
"DQN": {
|
||||
"explore": False,
|
||||
|
|
|
@ -115,8 +115,8 @@ class TestEagerSupportOffPolicy(unittest.TestCase):
|
|||
"num_workers": 2,
|
||||
"replay_buffer_config": {"learning_starts": 0},
|
||||
"num_gpus": 0,
|
||||
"min_time_s_per_reporting": 1,
|
||||
"min_sample_timesteps_per_reporting": 100,
|
||||
"min_time_s_per_iteration": 1,
|
||||
"min_sample_timesteps_per_iteration": 100,
|
||||
"optimizer": {
|
||||
"num_replay_buffer_shards": 1,
|
||||
},
|
||||
|
|
|
@ -112,8 +112,8 @@ class TestExecution(unittest.TestCase):
|
|||
a,
|
||||
workers,
|
||||
{
|
||||
"min_time_s_per_reporting": 2.5,
|
||||
"min_sample_timesteps_per_reporting": 0,
|
||||
"min_time_s_per_iteration": 2.5,
|
||||
"min_sample_timesteps_per_iteration": 0,
|
||||
"metrics_num_episodes_for_smoothing": 10,
|
||||
"metrics_episode_collection_timeout_s": 10,
|
||||
"keep_per_episode_custom_metrics": False,
|
||||
|
|
|
@ -20,7 +20,7 @@ CONFIGS = {
|
|||
"explore": False,
|
||||
"observation_filter": "MeanStdFilter",
|
||||
"num_workers": 2,
|
||||
"min_time_s_per_reporting": 1,
|
||||
"min_time_s_per_iteration": 1,
|
||||
"optimizer": {
|
||||
"num_replay_buffer_shards": 1,
|
||||
},
|
||||
|
@ -34,7 +34,7 @@ CONFIGS = {
|
|||
},
|
||||
"DDPG": {
|
||||
"explore": False,
|
||||
"min_sample_timesteps_per_reporting": 100,
|
||||
"min_sample_timesteps_per_iteration": 100,
|
||||
},
|
||||
"DQN": {
|
||||
"explore": False,
|
||||
|
|
|
@ -35,8 +35,8 @@ class TestReproducibility(unittest.TestCase):
|
|||
register_env("PickLargest", env_creator)
|
||||
config = {
|
||||
"seed": 666 if trial in [0, 1] else 999,
|
||||
"min_time_s_per_reporting": 0,
|
||||
"min_sample_timesteps_per_reporting": 100,
|
||||
"min_time_s_per_iteration": 0,
|
||||
"min_sample_timesteps_per_iteration": 100,
|
||||
"framework": fw,
|
||||
}
|
||||
agent = DQN(config=config, env="PickLargest")
|
||||
|
|
|
@ -35,8 +35,8 @@ def evaluate_test(algo, env="CartPole-v0", test_episode_rollout=False):
|
|||
"--checkpoint-freq=1 ".format(rllib_dir, tmp_dir, algo)
|
||||
+ "--config='{"
|
||||
+ '"num_workers": 1, "num_gpus": 0{}{}'.format(fw_, extra_config)
|
||||
+ ', "min_sample_timesteps_per_reporting": 5,'
|
||||
'"min_time_s_per_reporting": 0.1, '
|
||||
+ ', "min_sample_timesteps_per_iteration": 5,'
|
||||
'"min_time_s_per_iteration": 0.1, '
|
||||
'"model": {"fcnet_hiddens": [10]}'
|
||||
"}' --stop='{\"training_iteration\": 1}'" + " --env={}".format(env)
|
||||
)
|
||||
|
|
|
@ -93,13 +93,13 @@ class TestSupportedMultiAgentOffPolicy(unittest.TestCase):
|
|||
"APEX",
|
||||
{
|
||||
"num_workers": 2,
|
||||
"min_sample_timesteps_per_reporting": 100,
|
||||
"min_sample_timesteps_per_iteration": 100,
|
||||
"num_gpus": 0,
|
||||
"replay_buffer_config": {
|
||||
"capacity": 1000,
|
||||
"learning_starts": 10,
|
||||
},
|
||||
"min_time_s_per_reporting": 1,
|
||||
"min_time_s_per_iteration": 1,
|
||||
"target_network_update_freq": 100,
|
||||
"optimizer": {
|
||||
"num_replay_buffer_shards": 1,
|
||||
|
@ -112,13 +112,13 @@ class TestSupportedMultiAgentOffPolicy(unittest.TestCase):
|
|||
"APEX_DDPG",
|
||||
{
|
||||
"num_workers": 2,
|
||||
"min_sample_timesteps_per_reporting": 100,
|
||||
"min_sample_timesteps_per_iteration": 100,
|
||||
"replay_buffer_config": {
|
||||
"capacity": 1000,
|
||||
"learning_starts": 10,
|
||||
},
|
||||
"num_gpus": 0,
|
||||
"min_time_s_per_reporting": 1,
|
||||
"min_time_s_per_iteration": 1,
|
||||
"target_network_update_freq": 100,
|
||||
"use_state_preprocessor": True,
|
||||
},
|
||||
|
@ -128,7 +128,7 @@ class TestSupportedMultiAgentOffPolicy(unittest.TestCase):
|
|||
check_support_multiagent(
|
||||
"DDPG",
|
||||
{
|
||||
"min_sample_timesteps_per_reporting": 1,
|
||||
"min_sample_timesteps_per_iteration": 1,
|
||||
"replay_buffer_config": {
|
||||
"capacity": 1000,
|
||||
"learning_starts": 500,
|
||||
|
@ -141,7 +141,7 @@ class TestSupportedMultiAgentOffPolicy(unittest.TestCase):
|
|||
check_support_multiagent(
|
||||
"DQN",
|
||||
{
|
||||
"min_sample_timesteps_per_reporting": 1,
|
||||
"min_sample_timesteps_per_iteration": 1,
|
||||
"replay_buffer_config": {
|
||||
"capacity": 1000,
|
||||
},
|
||||
|
|
|
@ -181,7 +181,7 @@ class TestSupportedSpacesOffPolicy(unittest.TestCase):
|
|||
"DDPG",
|
||||
{
|
||||
"exploration_config": {"ou_base_scale": 100.0},
|
||||
"min_sample_timesteps_per_reporting": 1,
|
||||
"min_sample_timesteps_per_iteration": 1,
|
||||
"replay_buffer_config": {
|
||||
"capacity": 1000,
|
||||
},
|
||||
|
@ -192,7 +192,7 @@ class TestSupportedSpacesOffPolicy(unittest.TestCase):
|
|||
|
||||
def test_dqn(self):
|
||||
config = {
|
||||
"min_sample_timesteps_per_reporting": 1,
|
||||
"min_sample_timesteps_per_iteration": 1,
|
||||
"replay_buffer_config": {
|
||||
"capacity": 1000,
|
||||
},
|
||||
|
|
|
@ -27,4 +27,4 @@ apex-breakoutnoframeskip-v4:
|
|||
rollout_fragment_length: 20
|
||||
train_batch_size: 512
|
||||
target_network_update_freq: 50000
|
||||
min_sample_timesteps_per_reporting: 25000
|
||||
min_sample_timesteps_per_iteration: 25000
|
|
@ -25,7 +25,7 @@ cartpole-apex-dqn-training-itr:
|
|||
|
||||
num_gpus: 0
|
||||
|
||||
min_time_s_per_reporting: 5
|
||||
min_time_s_per_iteration: 5
|
||||
target_network_update_freq: 500
|
||||
min_sample_timesteps_per_reporting: 1000
|
||||
min_sample_timesteps_per_iteration: 1000
|
||||
training_intensity: 4
|
|
@ -99,7 +99,7 @@ apex:
|
|||
rollout_fragment_length: 20
|
||||
train_batch_size: 512
|
||||
target_network_update_freq: 50000
|
||||
min_sample_timesteps_per_reporting: 25000
|
||||
min_sample_timesteps_per_iteration: 25000
|
||||
atari-a2c:
|
||||
env: BreakoutNoFrameskip-v4
|
||||
run: A2C
|
||||
|
@ -142,4 +142,4 @@ atari-basic-dqn:
|
|||
epsilon_timesteps: 200000
|
||||
final_epsilon: 0.01
|
||||
num_gpus: 0.2
|
||||
min_sample_timesteps_per_reporting: 10000
|
||||
min_sample_timesteps_per_iteration: 10000
|
||||
|
|
|
@ -31,7 +31,7 @@ halfcheetah_bc:
|
|||
learning_starts: 10
|
||||
train_batch_size: 256
|
||||
target_network_update_freq: 0
|
||||
min_train_timesteps_per_reporting: 1000
|
||||
min_train_timesteps_per_iteration: 1000
|
||||
optimization:
|
||||
actor_learning_rate: 0.0001
|
||||
critic_learning_rate: 0.0003
|
||||
|
|
|
@ -33,7 +33,7 @@ halfcheetah_cql:
|
|||
learning_starts: 256
|
||||
train_batch_size: 256
|
||||
target_network_update_freq: 0
|
||||
min_train_timesteps_per_reporting: 1000
|
||||
min_train_timesteps_per_iteration: 1000
|
||||
optimization:
|
||||
actor_learning_rate: 0.0001
|
||||
critic_learning_rate: 0.0003
|
||||
|
|
|
@ -31,7 +31,7 @@ hopper_bc:
|
|||
learning_starts: 10
|
||||
train_batch_size: 256
|
||||
target_network_update_freq: 0
|
||||
min_train_timesteps_per_reporting: 1000
|
||||
min_train_timesteps_per_iteration: 1000
|
||||
optimization:
|
||||
actor_learning_rate: 0.0001
|
||||
critic_learning_rate: 0.0003
|
||||
|
|
|
@ -31,7 +31,7 @@ hopper_cql:
|
|||
learning_starts: 10
|
||||
train_batch_size: 256
|
||||
target_network_update_freq: 0
|
||||
min_train_timesteps_per_reporting: 1000
|
||||
min_train_timesteps_per_iteration: 1000
|
||||
optimization:
|
||||
actor_learning_rate: 0.0001
|
||||
critic_learning_rate: 0.0003
|
||||
|
|
|
@ -25,7 +25,7 @@ halfcheetah-ddpg:
|
|||
ou_theta: 0.15
|
||||
ou_sigma: 0.2
|
||||
|
||||
min_sample_timesteps_per_reporting: 1000
|
||||
min_sample_timesteps_per_iteration: 1000
|
||||
target_network_update_freq: 0
|
||||
tau: 0.001
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ ddpg-halfcheetahbulletenv-v0:
|
|||
ou_base_scale: 0.1
|
||||
ou_theta: 0.15
|
||||
ou_sigma: 0.2
|
||||
min_sample_timesteps_per_reporting: 1000
|
||||
min_sample_timesteps_per_iteration: 1000
|
||||
target_network_update_freq: 0
|
||||
tau: 0.001
|
||||
replay_buffer_config:
|
||||
|
|
|
@ -22,7 +22,7 @@ ddpg-hopperbulletenv-v0:
|
|||
ou_base_scale: 0.1
|
||||
ou_theta: 0.15
|
||||
ou_sigma: 0.2
|
||||
min_sample_timesteps_per_reporting: 1000
|
||||
min_sample_timesteps_per_iteration: 1000
|
||||
target_network_update_freq: 0
|
||||
tau: 0.001
|
||||
replay_buffer_config:
|
||||
|
|
|
@ -25,7 +25,7 @@ mountaincarcontinuous-ddpg:
|
|||
ou_theta: 0.15
|
||||
ou_sigma: 0.2
|
||||
|
||||
min_sample_timesteps_per_reporting: 1000
|
||||
min_sample_timesteps_per_iteration: 1000
|
||||
|
||||
target_network_update_freq: 0
|
||||
tau: 0.01
|
||||
|
|
|
@ -15,7 +15,7 @@ pendulum-ddpg-fake-gpus:
|
|||
gamma: 0.99
|
||||
exploration_config:
|
||||
final_scale: 0.02
|
||||
min_sample_timesteps_per_reporting: 600
|
||||
min_sample_timesteps_per_iteration: 600
|
||||
replay_buffer_config:
|
||||
type: MultiAgentPrioritizedReplayBuffer
|
||||
capacity: 10000
|
||||
|
|
|
@ -28,7 +28,7 @@ pendulum-ddpg:
|
|||
ou_theta: 0.15
|
||||
ou_sigma: 0.2
|
||||
|
||||
min_sample_timesteps_per_reporting: 600
|
||||
min_sample_timesteps_per_iteration: 600
|
||||
target_network_update_freq: 0
|
||||
tau: 0.001
|
||||
|
||||
|
|
|
@ -26,4 +26,4 @@ atari-dist-dqn:
|
|||
epsilon_timesteps: 200000
|
||||
final_epsilon: 0.01
|
||||
num_gpus: 0.2
|
||||
min_sample_timesteps_per_reporting: 10000
|
||||
min_sample_timesteps_per_iteration: 10000
|
||||
|
|
|
@ -30,4 +30,4 @@ atari-basic-dqn:
|
|||
epsilon_timesteps: 200000
|
||||
final_epsilon: 0.01
|
||||
num_gpus: 0.2
|
||||
min_sample_timesteps_per_reporting: 10000
|
||||
min_sample_timesteps_per_iteration: 10000
|
||||
|
|
|
@ -30,4 +30,4 @@ dueling-ddqn:
|
|||
epsilon_timesteps: 200000
|
||||
final_epsilon: 0.01
|
||||
num_gpus: 0.2
|
||||
min_sample_timesteps_per_reporting: 10000
|
||||
min_sample_timesteps_per_iteration: 10000
|
||||
|
|
|
@ -40,7 +40,7 @@ atari-sac-tf-and-torch:
|
|||
prioritized_replay_beta: 0.4
|
||||
prioritized_replay_eps: 1e-6
|
||||
train_batch_size: 64
|
||||
min_sample_timesteps_per_reporting: 4
|
||||
min_sample_timesteps_per_iteration: 4
|
||||
# Paper uses 20k random timesteps, which is not exactly the same, but
|
||||
# seems to work nevertheless. We use 100k here for the longer Atari
|
||||
# runs (DQN style: filling up the buffer a bit before learning).
|
||||
|
|
|
@ -17,7 +17,7 @@ cartpole-sac:
|
|||
learning_starts: 256
|
||||
initial_alpha: 0.2
|
||||
clip_actions: false
|
||||
min_sample_timesteps_per_reporting: 1000
|
||||
min_sample_timesteps_per_iteration: 1000
|
||||
optimization:
|
||||
actor_learning_rate: 0.005
|
||||
critic_learning_rate: 0.005
|
||||
|
|
|
@ -21,7 +21,7 @@ halfcheetah-pybullet-sac:
|
|||
rollout_fragment_length: 1
|
||||
train_batch_size: 256
|
||||
target_network_update_freq: 1
|
||||
min_sample_timesteps_per_reporting: 1000
|
||||
min_sample_timesteps_per_iteration: 1000
|
||||
replay_buffer_config:
|
||||
type: MultiAgentPrioritizedReplayBuffer
|
||||
learning_starts: 10000
|
||||
|
|
|
@ -22,7 +22,7 @@ halfcheetah_sac:
|
|||
rollout_fragment_length: 1
|
||||
train_batch_size: 256
|
||||
target_network_update_freq: 1
|
||||
min_sample_timesteps_per_reporting: 1000
|
||||
min_sample_timesteps_per_iteration: 1000
|
||||
replay_buffer_config:
|
||||
type: MultiAgentPrioritizedReplayBuffer
|
||||
learning_starts: 10000
|
||||
|
|
|
@ -28,7 +28,7 @@ mspacman-sac-tf:
|
|||
n_step: 1
|
||||
rollout_fragment_length: 1
|
||||
train_batch_size: 64
|
||||
min_sample_timesteps_per_reporting: 4
|
||||
min_sample_timesteps_per_iteration: 4
|
||||
# Paper uses 20k random timesteps, which is not exactly the same, but
|
||||
# seems to work nevertheless.
|
||||
replay_buffer_config:
|
||||
|
|
|
@ -23,7 +23,7 @@ pendulum-sac-fake-gpus:
|
|||
rollout_fragment_length: 1
|
||||
train_batch_size: 256
|
||||
target_network_update_freq: 1
|
||||
min_sample_timesteps_per_reporting: 1000
|
||||
min_sample_timesteps_per_iteration: 1000
|
||||
replay_buffer_config:
|
||||
type: MultiAgentPrioritizedReplayBuffer
|
||||
learning_starts: 256
|
||||
|
|
|
@ -24,7 +24,7 @@ pendulum-sac:
|
|||
rollout_fragment_length: 1
|
||||
train_batch_size: 256
|
||||
target_network_update_freq: 1
|
||||
min_sample_timesteps_per_reporting: 1000
|
||||
min_sample_timesteps_per_iteration: 1000
|
||||
replay_buffer_config:
|
||||
type: MultiAgentPrioritizedReplayBuffer
|
||||
learning_starts: 256
|
||||
|
|
|
@ -32,7 +32,7 @@ transformed-actions-pendulum-sac-dummy-torch:
|
|||
rollout_fragment_length: 1
|
||||
train_batch_size: 256
|
||||
target_network_update_freq: 1
|
||||
min_sample_timesteps_per_reporting: 1000
|
||||
min_sample_timesteps_per_iteration: 1000
|
||||
replay_buffer_config:
|
||||
type: MultiAgentPrioritizedReplayBuffer
|
||||
learning_starts: 256
|
||||
|
|
Loading…
Add table
Reference in a new issue