[RLlib] Remove execution_plan API code no longer needed. (#24501)

This commit is contained in:
Sven Mika 2022-05-06 12:29:53 +02:00 committed by GitHub
parent f891a2b6f1
commit f54557073e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 8 additions and 807 deletions

View file

@ -1706,13 +1706,6 @@ py_test(
srcs = ["tests/test_env_with_subprocess.py"]
)
py_test(
name = "tests/test_exec_api",
tags = ["team:ml", "tests_dir", "tests_dir_E"],
size = "medium",
srcs = ["tests/test_exec_api.py"]
)
py_test(
name = "tests/test_execution",
tags = ["team:ml", "tests_dir", "tests_dir_E"],
@ -2855,7 +2848,7 @@ py_test(
py_test(
name = "examples/rollout_worker_custom_workflow",
tags = ["team:ml", "examples", "examples_R"],
size = "small",
size = "medium",
srcs = ["examples/rollout_worker_custom_workflow.py"],
args = ["--num-cpus=4"]
)

View file

@ -4,24 +4,13 @@ from typing import Optional
from ray.rllib.agents.a3c.a3c import A3CConfig, A3CTrainer
from ray.rllib.agents.trainer import Trainer
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.common import (
STEPS_TRAINED_COUNTER,
STEPS_TRAINED_THIS_ITER_COUNTER,
)
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.rollout_ops import (
ParallelRollouts,
ConcatBatches,
synchronous_parallel_sample,
)
from ray.rllib.execution.train_ops import (
ComputeGradients,
AverageGradients,
ApplyGradients,
MultiGPUTrainOneStep,
TrainOneStep,
)
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated
@ -37,7 +26,6 @@ from ray.rllib.utils.typing import (
ResultDict,
TrainerConfigDict,
)
from ray.util.iter import LocalIterator
logger = logging.getLogger(__name__)
@ -228,60 +216,6 @@ class A2CTrainer(A3CTrainer):
return train_results
@staticmethod
@override(Trainer)
def execution_plan(
workers: WorkerSet, config: TrainerConfigDict, **kwargs
) -> LocalIterator[dict]:
assert (
len(kwargs) == 0
), "A2C execution_plan does NOT take any additional parameters"
rollouts = ParallelRollouts(workers, mode="bulk_sync")
if config["microbatch_size"]:
num_microbatches = math.ceil(
config["train_batch_size"] / config["microbatch_size"]
)
# In microbatch mode, we want to compute gradients on experience
# microbatches, average a number of these microbatches, and then
# apply the averaged gradient in one SGD step. This conserves GPU
# memory, allowing for extremely large experience batches to be
# used.
train_op = (
rollouts.combine(
ConcatBatches(
min_batch_size=config["microbatch_size"],
count_steps_by=config["multiagent"]["count_steps_by"],
)
)
.for_each(ComputeGradients(workers)) # (grads, info)
.batch(num_microbatches) # List[(grads, info)]
.for_each(AverageGradients()) # (avg_grads, info)
.for_each(ApplyGradients(workers))
)
else:
# In normal mode, we execute one SGD step per each train batch.
if config["simple_optimizer"]:
train_step_op = TrainOneStep(workers)
else:
train_step_op = MultiGPUTrainOneStep(
workers=workers,
sgd_minibatch_size=config["train_batch_size"],
num_sgd_iter=1,
num_gpus=config["num_gpus"],
_fake_gpus=config["_fake_gpus"],
)
train_op = rollouts.combine(
ConcatBatches(
min_batch_size=config["train_batch_size"],
count_steps_by=config["multiagent"]["count_steps_by"],
)
).for_each(train_step_op)
return StandardMetricsReporting(train_op, workers, config)
# Deprecated: Use ray.rllib.agents.a3c.A2CConfig instead!
class _deprecated_default_config(dict):

View file

@ -6,11 +6,7 @@ from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy
from ray.rllib.agents.trainer import Trainer
from ray.rllib.agents.trainer_config import TrainerConfig
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.parallel_requests import asynchronous_parallel_requests
from ray.rllib.execution.rollout_ops import AsyncGradients
from ray.rllib.execution.train_ops import ApplyGradients
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated
@ -25,7 +21,6 @@ from ray.rllib.utils.metrics import (
)
from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder
from ray.rllib.utils.typing import ResultDict, TrainerConfigDict
from ray.util.iter import LocalIterator
logger = logging.getLogger(__name__)
@ -248,25 +243,6 @@ class A3CTrainer(Trainer):
return learner_info_builder.finalize()
@staticmethod
@override(Trainer)
def execution_plan(
workers: WorkerSet, config: TrainerConfigDict, **kwargs
) -> LocalIterator[dict]:
assert (
len(kwargs) == 0
), "A3C execution_plan does NOT take any additional parameters"
# For A3C, compute policy gradients remotely on the rollout workers.
grads = AsyncGradients(workers)
# Apply the gradients as they arrive. We set update_all to False so
# that only the worker sending the gradient is updated with new
# weights.
train_op = grads.for_each(ApplyGradients(workers, update_all=False))
return StandardMetricsReporting(train_op, workers, config)
# Deprecated: Use ray.rllib.agents.a3c.A3CConfig instead!
class _deprecated_default_config(dict):

View file

@ -5,14 +5,9 @@ from typing import Type
from ray.rllib.agents.cql.cql_tf_policy import CQLTFPolicy
from ray.rllib.agents.cql.cql_torch_policy import CQLTorchPolicy
from ray.rllib.agents.sac.sac import SACTrainer, DEFAULT_CONFIG as SAC_CONFIG
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.replay_ops import Replay
from ray.rllib.execution.train_ops import (
multi_gpu_train_one_step,
MultiGPUTrainOneStep,
train_one_step,
TrainOneStep,
UpdateTargetNetwork,
)
from ray.rllib.offline.shuffled_input import ShuffledInput
from ray.rllib.policy.policy import Policy
@ -29,7 +24,6 @@ from ray.rllib.utils.metrics import (
TARGET_NET_UPDATE_TIMER,
SYNCH_WORKER_WEIGHTS_TIMER,
)
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
from ray.rllib.utils.replay_buffers.utils import update_priorities_in_replay_buffer
from ray.rllib.utils.typing import ResultDict, TrainerConfigDict
@ -223,63 +217,3 @@ class CQLTrainer(SACTrainer):
# Return all collected metrics for the iteration.
return train_results
@staticmethod
@override(SACTrainer)
def execution_plan(workers, config, **kwargs):
assert (
"local_replay_buffer" in kwargs
), "CQL execution plan requires a local replay buffer."
local_replay_buffer = kwargs["local_replay_buffer"]
def update_prio(item):
samples, info_dict = item
if config.get("prioritized_replay"):
prio_dict = {}
for policy_id, info in info_dict.items():
# TODO(sven): This is currently structured differently for
# torch/tf. Clean up these results/info dicts across
# policies (note: fixing this in torch_policy.py will
# break e.g. DDPPO!).
td_error = info.get(
"td_error", info[LEARNER_STATS_KEY].get("td_error")
)
samples.policy_batches[policy_id].set_get_interceptor(None)
prio_dict[policy_id] = (
samples.policy_batches[policy_id].get("batch_indexes"),
td_error,
)
local_replay_buffer.update_priorities(prio_dict)
return info_dict
# (2) Read and train on experiences from the replay buffer. Every batch
# returned from the LocalReplay() iterator is passed to TrainOneStep to
# take a SGD step, and then we decide whether to update the target
# network.
post_fn = config.get("before_learn_on_batch") or (lambda b, *a: b)
if config["simple_optimizer"]:
train_step_op = TrainOneStep(workers)
else:
train_step_op = MultiGPUTrainOneStep(
workers=workers,
sgd_minibatch_size=config["train_batch_size"],
num_sgd_iter=1,
num_gpus=config["num_gpus"],
_fake_gpus=config["_fake_gpus"],
)
train_op = (
Replay(local_buffer=local_replay_buffer)
.for_each(lambda x: post_fn(x, workers, config))
.for_each(train_step_op)
.for_each(update_prio)
.for_each(
UpdateTargetNetwork(workers, config["target_network_update_freq"])
)
)
return StandardMetricsReporting(
train_op, workers, config, by_steps_trained=True
)

View file

@ -172,9 +172,6 @@ DEFAULT_CONFIG = with_common_config({
# timestep count has not been reached, will perform n more `step_attempt()` calls
# until the minimum timesteps have been executed. Set to 0 for no minimum timesteps.
"min_sample_timesteps_per_reporting": 1000,
# Experimental flag.
"_disable_execution_plan_api": True,
})
# __sphinx_doc_end__
# fmt: on
@ -197,6 +194,7 @@ class DDPGTrainer(SimpleQTrainer):
@override(SimpleQTrainer)
def validate_config(self, config: TrainerConfigDict) -> None:
# Call super's validation method.
super().validate_config(config)

View file

@ -61,7 +61,6 @@ TD3_DEFAULT_CONFIG = DDPGTrainer.merge_trainer_configs(
"type": "MultiAgentReplayBuffer",
"capacity": 1000000,
},
"_disable_execution_plan_api": True,
},
)

View file

@ -19,25 +19,16 @@ from ray.rllib.agents.dqn.simple_q import (
DEFAULT_CONFIG as SIMPLEQ_DEFAULT_CONFIG,
)
from ray.rllib.agents.trainer import Trainer
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.concurrency_ops import Concurrently
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
from ray.rllib.execution.rollout_ops import (
ParallelRollouts,
synchronous_parallel_sample,
)
from ray.rllib.execution.train_ops import (
TrainOneStep,
UpdateTargetNetwork,
MultiGPUTrainOneStep,
train_one_step,
multi_gpu_train_one_step,
)
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.replay_buffers.utils import update_priorities_in_replay_buffer
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
from ray.rllib.utils.typing import (
ResultDict,
TrainerConfigDict,
@ -46,11 +37,6 @@ from ray.rllib.utils.metrics import (
NUM_ENV_STEPS_SAMPLED,
NUM_AGENT_STEPS_SAMPLED,
)
from ray.util.iter import LocalIterator
from ray.rllib.utils.replay_buffers import MultiAgentPrioritizedReplayBuffer
from ray.rllib.execution.buffers.multi_agent_replay_buffer import (
MultiAgentReplayBuffer as LegacyMultiAgentReplayBuffer,
)
from ray.rllib.utils.deprecation import (
Deprecated,
DEPRECATED_VALUE,
@ -280,103 +266,6 @@ class DQNTrainer(SimpleQTrainer):
# Return all collected metrics for the iteration.
return train_results
@staticmethod
@override(SimpleQTrainer)
def execution_plan(
workers: WorkerSet, config: TrainerConfigDict, **kwargs
) -> LocalIterator[dict]:
assert (
"local_replay_buffer" in kwargs
), "DQN's execution plan requires a local replay buffer."
# Assign to Trainer, so we can store the MultiAgentReplayBuffer's
# data when we save checkpoints.
local_replay_buffer = kwargs["local_replay_buffer"]
rollouts = ParallelRollouts(workers, mode="bulk_sync")
# We execute the following steps concurrently:
# (1) Generate rollouts and store them in our local replay buffer.
# Calling next() on store_op drives this.
store_op = rollouts.for_each(
StoreToReplayBuffer(local_buffer=local_replay_buffer)
)
def update_prio(item):
samples, info_dict = item
prio_dict = {}
for policy_id, info in info_dict.items():
# TODO(sven): This is currently structured differently for
# torch/tf. Clean up these results/info dicts across
# policies (note: fixing this in torch_policy.py will
# break e.g. DDPPO!).
td_error = info.get("td_error", info[LEARNER_STATS_KEY].get("td_error"))
samples.policy_batches[policy_id].set_get_interceptor(None)
batch_indices = samples.policy_batches[policy_id].get("batch_indexes")
# In case the buffer stores sequences, TD-error could
# already be calculated per sequence chunk.
if len(batch_indices) != len(td_error):
T = local_replay_buffer.replay_sequence_length
assert (
len(batch_indices) > len(td_error)
and len(batch_indices) % T == 0
)
batch_indices = batch_indices.reshape([-1, T])[:, 0]
assert len(batch_indices) == len(td_error)
prio_dict[policy_id] = (batch_indices, td_error)
local_replay_buffer.update_priorities(prio_dict)
return info_dict
# (2) Read and train on experiences from the replay buffer. Every batch
# returned from the LocalReplay() iterator is passed to TrainOneStep to
# take a SGD step, and then we decide whether to update the target
# network.
post_fn = config.get("before_learn_on_batch") or (lambda b, *a: b)
if config["simple_optimizer"]:
train_step_op = TrainOneStep(workers)
else:
train_step_op = MultiGPUTrainOneStep(
workers=workers,
sgd_minibatch_size=config["train_batch_size"],
num_sgd_iter=1,
num_gpus=config["num_gpus"],
_fake_gpus=config["_fake_gpus"],
)
if (
type(local_replay_buffer) is LegacyMultiAgentReplayBuffer
and config["replay_buffer_config"].get("prioritized_replay_alpha", 0.0)
> 0.0
) or isinstance(local_replay_buffer, MultiAgentPrioritizedReplayBuffer):
update_prio_fn = update_prio
else:
def update_prio_fn(x):
return x
replay_op = (
Replay(local_buffer=local_replay_buffer)
.for_each(lambda x: post_fn(x, workers, config))
.for_each(train_step_op)
.for_each(update_prio_fn)
.for_each(
UpdateTargetNetwork(workers, config["target_network_update_freq"])
)
)
# Alternate deterministically between (1) and (2).
# Only return the output of (2) since training metrics are not
# available until (2) runs.
train_op = Concurrently(
[store_op, replay_op],
mode="round_robin",
output_indexes=[1],
round_robin_weights=calculate_rr_weights(config),
)
return StandardMetricsReporting(train_op, workers, config)
@Deprecated(
new="Sub-class directly from `DQNTrainer` and override its methods", error=False

View file

@ -16,23 +16,14 @@ from ray.rllib.agents.dqn.simple_q_tf_policy import SimpleQTFPolicy
from ray.rllib.agents.dqn.simple_q_torch_policy import SimpleQTorchPolicy
from ray.rllib.agents.trainer import Trainer, with_common_config
from ray.rllib.utils.metrics import SYNCH_WORKER_WEIGHTS_TIMER
from ray.rllib.execution.concurrency_ops import Concurrently
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
from ray.rllib.utils.replay_buffers.utils import validate_buffer_config
from ray.rllib.execution.rollout_ops import (
ParallelRollouts,
synchronous_parallel_sample,
)
from ray.rllib.execution.train_ops import (
TrainOneStep,
MultiGPUTrainOneStep,
train_one_step,
multi_gpu_train_one_step,
)
from ray.rllib.execution.train_ops import (
UpdateTargetNetwork,
)
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import ExperimentalAPI
from ray.rllib.utils.annotations import override
@ -267,46 +258,3 @@ class SimpleQTrainer(Trainer):
# Return all collected metrics for the iteration.
return train_results
@staticmethod
@override(Trainer)
def execution_plan(workers, config, **kwargs):
assert (
"local_replay_buffer" in kwargs
), "GenericOffPolicy execution plan requires a local replay buffer."
local_replay_buffer = kwargs["local_replay_buffer"]
rollouts = ParallelRollouts(workers, mode="bulk_sync")
# (1) Generate rollouts and store them in our local replay buffer.
store_op = rollouts.for_each(
StoreToReplayBuffer(local_buffer=local_replay_buffer)
)
if config["simple_optimizer"]:
train_step_op = TrainOneStep(workers)
else:
train_step_op = MultiGPUTrainOneStep(
workers=workers,
sgd_minibatch_size=config["train_batch_size"],
num_sgd_iter=1,
num_gpus=config["num_gpus"],
_fake_gpus=config["_fake_gpus"],
)
# (2) Read and train on experiences from the replay buffer.
replay_op = (
Replay(local_buffer=local_replay_buffer)
.for_each(train_step_op)
.for_each(
UpdateTargetNetwork(workers, config["target_network_update_freq"])
)
)
# Alternate deterministically between (1) and (2).
train_op = Concurrently(
[store_op, replay_op], mode="round_robin", output_indexes=[1]
)
return StandardMetricsReporting(train_op, workers, config)

View file

@ -2,19 +2,12 @@ from typing import Type
from ray.rllib.agents.trainer import Trainer, with_common_config
from ray.rllib.agents.marwil.marwil_tf_policy import MARWILTFPolicy
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.buffers.multi_agent_replay_buffer import MultiAgentReplayBuffer
from ray.rllib.execution.concurrency_ops import Concurrently
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
from ray.rllib.execution.rollout_ops import (
ConcatBatches,
ParallelRollouts,
synchronous_parallel_sample,
)
from ray.rllib.execution.train_ops import (
multi_gpu_train_one_step,
TrainOneStep,
train_one_step,
)
from ray.rllib.policy.policy import Policy
@ -31,7 +24,6 @@ from ray.rllib.utils.typing import (
ResultDict,
TrainerConfigDict,
)
from ray.util.iter import LocalIterator
# fmt: off
# __sphinx_doc_begin__
@ -171,39 +163,3 @@ class MARWILTrainer(Trainer):
self.workers.local_worker().set_global_vars(global_vars)
return train_results
@staticmethod
@override(Trainer)
def execution_plan(
workers: WorkerSet, config: TrainerConfigDict, **kwargs
) -> LocalIterator[dict]:
assert (
len(kwargs) == 0
), "Marwill execution_plan does NOT take any additional parameters"
rollouts = ParallelRollouts(workers, mode="bulk_sync")
replay_buffer = MultiAgentReplayBuffer(
learning_starts=config["learning_starts"],
capacity=config["replay_buffer_size"],
replay_batch_size=config["train_batch_size"],
replay_sequence_length=1,
)
store_op = rollouts.for_each(StoreToReplayBuffer(local_buffer=replay_buffer))
replay_op = (
Replay(local_buffer=replay_buffer)
.combine(
ConcatBatches(
min_batch_size=config["train_batch_size"],
count_steps_by=config["multiagent"]["count_steps_by"],
)
)
.for_each(TrainOneStep(workers))
)
train_op = Concurrently(
[store_op, replay_op], mode="round_robin", output_indexes=[1]
)
return StandardMetricsReporting(train_op, workers, config)

View file

@ -25,18 +25,10 @@ from ray.rllib.agents.ppo.ppo import DEFAULT_CONFIG as PPO_DEFAULT_CONFIG, PPOTr
from ray.rllib.agents.trainer import Trainer
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.common import (
STEPS_SAMPLED_COUNTER,
STEPS_TRAINED_COUNTER,
STEPS_TRAINED_THIS_ITER_COUNTER,
_get_shared_metrics,
_get_global_vars,
)
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.parallel_requests import asynchronous_parallel_requests
from ray.rllib.execution.rollout_ops import ParallelRollouts
from ray.rllib.evaluation.rollout_worker import get_global_worker
from ray.rllib.utils.annotations import override
from ray.rllib.utils.metrics import (
LEARN_ON_BATCH_TIMER,
@ -46,7 +38,7 @@ from ray.rllib.utils.metrics import (
NUM_ENV_STEPS_TRAINED,
SAMPLE_TIMER,
)
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, LearnerInfoBuilder
from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder
from ray.rllib.utils.sgd import do_minibatch_sgd
from ray.rllib.utils.typing import (
EnvType,
@ -55,7 +47,6 @@ from ray.rllib.utils.typing import (
TrainerConfigDict,
)
from ray.tune.logger import Logger
from ray.util.iter import LocalIterator
logger = logging.getLogger(__name__)
@ -321,141 +312,3 @@ class DDPPOTrainer(PPOTrainer):
"sample_time": sample_time,
"learn_on_batch_time": learn_on_batch_time,
}
@staticmethod
@override(PPOTrainer)
def execution_plan(
workers: WorkerSet, config: TrainerConfigDict, **kwargs
) -> LocalIterator[dict]:
"""Execution plan of the DD-PPO algorithm. Defines the distributed dataflow.
Args:
workers (WorkerSet): The WorkerSet for training the Polic(y/ies)
of the Trainer.
config (TrainerConfigDict): The trainer's configuration dict.
Returns:
LocalIterator[dict]: The Policy class to use with PGTrainer.
If None, use `get_default_policy_class()` provided by Trainer.
"""
assert (
len(kwargs) == 0
), "DDPPO execution_plan does NOT take any additional parameters"
rollouts = ParallelRollouts(workers, mode="raw")
# Setup the distributed processes.
ip = ray.get(workers.remote_workers()[0].get_node_ip.remote())
port = ray.get(workers.remote_workers()[0].find_free_port.remote())
address = "tcp://{ip}:{port}".format(ip=ip, port=port)
logger.info("Creating torch process group with leader {}".format(address))
# Get setup tasks in order to throw errors on failure.
ray.get(
[
worker.setup_torch_data_parallel.remote(
url=address,
world_rank=i,
world_size=len(workers.remote_workers()),
backend=config["torch_distributed_backend"],
)
for i, worker in enumerate(workers.remote_workers())
]
)
logger.info("Torch process group init completed")
# This function is applied remotely on each rollout worker.
def train_torch_distributed_allreduce(batch):
expected_batch_size = (
config["rollout_fragment_length"] * config["num_envs_per_worker"]
)
this_worker = get_global_worker()
assert batch.count == expected_batch_size, (
"Batch size possibly out of sync between workers, expected:",
expected_batch_size,
"got:",
batch.count,
)
logger.info(
"Executing distributed minibatch SGD "
"with epoch size {}, minibatch size {}".format(
batch.count, config["sgd_minibatch_size"]
)
)
info = do_minibatch_sgd(
batch,
this_worker.policy_map,
this_worker,
config["num_sgd_iter"],
config["sgd_minibatch_size"],
["advantages"],
)
return info, batch.count
# Broadcast the local set of global vars.
def update_worker_global_vars(item):
global_vars = _get_global_vars()
for w in workers.remote_workers():
w.set_global_vars.remote(global_vars)
return item
# Have to manually record stats since we are using "raw" rollouts mode.
class RecordStats:
def _on_fetch_start(self):
self.fetch_start_time = time.perf_counter()
def __call__(self, items):
assert len(items) == config["num_workers"]
for item in items:
info, count = item
metrics = _get_shared_metrics()
metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = count
metrics.counters[STEPS_SAMPLED_COUNTER] += count
metrics.counters[STEPS_TRAINED_COUNTER] += count
metrics.info[LEARNER_INFO] = info
# Since SGD happens remotely, the time delay between fetch and
# completion is approximately the SGD step time.
metrics.timers[LEARN_ON_BATCH_TIMER].push(
time.perf_counter() - self.fetch_start_time
)
train_op = (
rollouts.for_each(train_torch_distributed_allreduce) # allreduce
.batch_across_shards() # List[(grad_info, count)]
.for_each(RecordStats())
)
train_op = train_op.for_each(update_worker_global_vars)
# Sync down the weights. As with the sync up, this is not really
# needed unless the user is reading the local weights.
if config["keep_local_weights_in_sync"]:
def download_weights(item):
workers.local_worker().set_weights(
ray.get(workers.remote_workers()[0].get_weights.remote())
)
return item
train_op = train_op.for_each(download_weights)
# In debug mode, check the allreduce successfully synced the weights.
if logger.isEnabledFor(logging.DEBUG):
def check_sync(item):
weights = ray.get(
[w.get_weights.remote() for w in workers.remote_workers()]
)
sums = []
for w in weights:
acc = 0
for p in w.values():
for k, v in p.items():
acc += v.sum()
sums.append(float(acc))
logger.debug("The worker weight sums are {}".format(sums))
assert len(set(sums)) == 1, sums
train_op = train_op.for_each(check_sync)
return StandardMetricsReporting(train_op, workers, config)

View file

@ -16,29 +16,20 @@ from ray.util.debug import log_once
from ray.rllib.agents.ppo.ppo_tf_policy import PPOTFPolicy
from ray.rllib.agents.trainer import Trainer
from ray.rllib.agents.trainer_config import TrainerConfig
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.rollout_ops import (
ParallelRollouts,
ConcatBatches,
StandardizeFields,
standardize_fields,
SelectExperiences,
)
from ray.rllib.execution.train_ops import (
TrainOneStep,
MultiGPUTrainOneStep,
train_one_step,
multi_gpu_train_one_step,
)
from ray.rllib.utils.annotations import ExperimentalAPI
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, LEARNER_STATS_KEY
from ray.rllib.utils.typing import TrainerConfigDict, ResultDict
from ray.util.iter import LocalIterator
from ray.rllib.execution.rollout_ops import synchronous_parallel_sample
from ray.rllib.utils.metrics import (
NUM_AGENT_STEPS_SAMPLED,
@ -456,59 +447,6 @@ class PPOTrainer(Trainer):
return train_results
@staticmethod
@override(Trainer)
def execution_plan(
workers: WorkerSet, config: TrainerConfigDict, **kwargs
) -> LocalIterator[dict]:
assert (
len(kwargs) == 0
), "PPO execution_plan does NOT take any additional parameters"
rollouts = ParallelRollouts(workers, mode="bulk_sync")
# Collect batches for the trainable policies.
rollouts = rollouts.for_each(
SelectExperiences(local_worker=workers.local_worker())
)
# Concatenate the SampleBatches into one.
rollouts = rollouts.combine(
ConcatBatches(
min_batch_size=config["train_batch_size"],
count_steps_by=config["multiagent"]["count_steps_by"],
)
)
# Standardize advantages.
rollouts = rollouts.for_each(StandardizeFields(["advantages"]))
# Perform one training step on the combined + standardized batch.
if config["simple_optimizer"]:
train_op = rollouts.for_each(
TrainOneStep(
workers,
num_sgd_iter=config["num_sgd_iter"],
sgd_minibatch_size=config["sgd_minibatch_size"],
)
)
else:
train_op = rollouts.for_each(
MultiGPUTrainOneStep(
workers=workers,
sgd_minibatch_size=config["sgd_minibatch_size"],
num_sgd_iter=config["num_sgd_iter"],
num_gpus=config["num_gpus"],
_fake_gpus=config["_fake_gpus"],
)
)
# Update KL after each round of training.
train_op = train_op.for_each(lambda t: t[1]).for_each(UpdateKL(workers))
# Warn about bad reward scales and return training metrics.
return StandardMetricsReporting(train_op, workers, config).for_each(
lambda result: warn_about_bad_reward_scales(config, result)
)
# Deprecated: Use ray.rllib.agents.ppo.PPOConfig instead!
class _deprecated_default_config(dict):

View file

@ -3,24 +3,12 @@ from typing import Type
from ray.rllib.agents.trainer import with_common_config
from ray.rllib.agents.dqn.simple_q import SimpleQTrainer
from ray.rllib.agents.qmix.qmix_policy import QMixTorchPolicy
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.concurrency_ops import Concurrently
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.replay_ops import (
SimpleReplayBuffer,
Replay,
StoreToReplayBuffer,
)
from ray.rllib.execution.rollout_ops import (
ConcatBatches,
ParallelRollouts,
synchronous_parallel_sample,
)
from ray.rllib.execution.train_ops import (
multi_gpu_train_one_step,
train_one_step,
TrainOneStep,
UpdateTargetNetwork,
)
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override
@ -34,7 +22,6 @@ from ray.rllib.utils.metrics import (
)
from ray.rllib.utils.replay_buffers.utils import sample_min_n_steps_from_buffer
from ray.rllib.utils.typing import ResultDict, TrainerConfigDict
from ray.util.iter import LocalIterator
# fmt: off
# __sphinx_doc_begin__
@ -232,37 +219,3 @@ class QMixTrainer(SimpleQTrainer):
# Return all collected metrics for the iteration.
return train_results
@staticmethod
@override(SimpleQTrainer)
def execution_plan(
workers: WorkerSet, config: TrainerConfigDict, **kwargs
) -> LocalIterator[dict]:
assert (
len(kwargs) == 0
), "QMIX execution_plan does NOT take any additional parameters"
rollouts = ParallelRollouts(workers, mode="bulk_sync")
replay_buffer = SimpleReplayBuffer(config["buffer_size"])
store_op = rollouts.for_each(StoreToReplayBuffer(local_buffer=replay_buffer))
train_op = (
Replay(local_buffer=replay_buffer)
.combine(
ConcatBatches(
min_batch_size=config["train_batch_size"],
count_steps_by=config["multiagent"]["count_steps_by"],
)
)
.for_each(TrainOneStep(workers))
.for_each(
UpdateTargetNetwork(workers, config["target_network_update_freq"])
)
)
merged_op = Concurrently(
[store_op, train_op], mode="round_robin", output_indexes=[1]
)
return StandardMetricsReporting(merged_op, workers, config)

View file

@ -19,21 +19,10 @@ from ray.rllib.agents.dqn.dqn import DQNTrainer
from ray.rllib.agents.slateq.slateq_tf_policy import SlateQTFPolicy
from ray.rllib.agents.slateq.slateq_torch_policy import SlateQTorchPolicy
from ray.rllib.agents.trainer import with_common_config
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.concurrency_ops import Concurrently
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
from ray.rllib.execution.rollout_ops import ParallelRollouts
from ray.rllib.execution.train_ops import (
MultiGPUTrainOneStep,
TrainOneStep,
UpdateTargetNetwork,
)
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
from ray.rllib.utils.typing import TrainerConfigDict
from ray.util.iter import LocalIterator
from ray.rllib.utils.replay_buffers.utils import validate_buffer_config
logger = logging.getLogger(__name__)
@ -182,55 +171,3 @@ class SlateQTrainer(DQNTrainer):
return SlateQTorchPolicy
else:
return SlateQTFPolicy
@staticmethod
@override(DQNTrainer)
def execution_plan(
workers: WorkerSet, config: TrainerConfigDict, **kwargs
) -> LocalIterator[dict]:
assert (
"local_replay_buffer" in kwargs
), "SlateQ execution plan requires a local replay buffer."
rollouts = ParallelRollouts(workers, mode="bulk_sync")
# We execute the following steps concurrently:
# (1) Generate rollouts and store them in our local replay buffer.
# Calling next() on store_op drives this.
store_op = rollouts.for_each(
StoreToReplayBuffer(local_buffer=kwargs["local_replay_buffer"])
)
if config["simple_optimizer"]:
train_step_op = TrainOneStep(workers)
else:
train_step_op = MultiGPUTrainOneStep(
workers=workers,
sgd_minibatch_size=config["train_batch_size"],
num_sgd_iter=1,
num_gpus=config["num_gpus"],
_fake_gpus=config["_fake_gpus"],
)
# (2) Read and train on experiences from the replay buffer. Every batch
# returned from the LocalReplay() iterator is passed to TrainOneStep to
# take a SGD step.
replay_op = (
Replay(local_buffer=kwargs["local_replay_buffer"])
.for_each(train_step_op)
.for_each(
UpdateTargetNetwork(workers, config["target_network_update_freq"])
)
)
# Alternate deterministically between (1) and (2). Only return the
# output of (2) since training metrics are not available until (2)
# runs.
train_op = Concurrently(
[store_op, replay_op],
mode="round_robin",
output_indexes=[1],
round_robin_weights=calculate_round_robin_weights(config),
)
return StandardMetricsReporting(train_op, workers, config)

View file

@ -40,20 +40,15 @@ from ray.rllib.evaluation.metrics import (
)
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.buffers.multi_agent_replay_buffer import (
MultiAgentReplayBuffer as Legacy_MultiAgentReplayBuffer,
)
from ray.rllib.utils.replay_buffers import MultiAgentReplayBuffer
from ray.rllib.execution.common import WORKER_UPDATE_TIMER
from ray.rllib.execution.rollout_ops import (
ConcatBatches,
ParallelRollouts,
synchronous_parallel_sample,
)
from ray.rllib.execution.train_ops import (
TrainOneStep,
MultiGPUTrainOneStep,
train_one_step,
multi_gpu_train_one_step,
)
@ -909,41 +904,14 @@ class Trainer(Trainable):
return train_results
@DeveloperAPI
@staticmethod
def execution_plan(workers, config, **kwargs):
# Collects experiences in parallel from multiple RolloutWorker actors.
rollouts = ParallelRollouts(workers, mode="bulk_sync")
# Combine experiences batches until we hit `train_batch_size` in size.
# Then, train the policy on those experiences and update the workers.
train_op = rollouts.combine(
ConcatBatches(
min_batch_size=config["train_batch_size"],
count_steps_by=config["multiagent"]["count_steps_by"],
raise NotImplementedError(
"It is not longer recommended to use Trainer's `execution_plan` method/API."
" Set `_disable_execution_plan_api=True` in your config and override the "
"`Trainer.training_iteration()` method with your algo's custom "
"execution logic."
)
)
if config.get("simple_optimizer") is True:
train_op = train_op.for_each(TrainOneStep(workers))
else:
train_op = train_op.for_each(
MultiGPUTrainOneStep(
workers=workers,
sgd_minibatch_size=config.get(
"sgd_minibatch_size", config["train_batch_size"]
),
num_sgd_iter=config.get("num_sgd_iter", 1),
num_gpus=config["num_gpus"],
_fake_gpus=config["_fake_gpus"],
)
)
# Add on the standard episode reward, etc. metrics reporting. This
# returns a LocalIterator[metrics_dict] representing metrics for each
# train step.
return StandardMetricsReporting(train_op, workers, config)
@PublicAPI
def compute_single_action(

View file

@ -119,11 +119,6 @@ DEFAULT_CONFIG = with_common_config({
"num_workers": 1,
# Prevent iterations from going lower than this time span
"min_time_s_per_reporting": 0,
# Experimental flag.
# If True, the execution plan API will not be used. Instead,
# a Trainer's `training_iteration` method will be called as-is each
# training iteration.
"_disable_execution_plan_api": False,
})
# __sphinx_doc_end__
# fmt: on

View file

@ -1,70 +0,0 @@
import unittest
import ray
from ray.rllib.agents.a3c import A2CTrainer
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, STEPS_TRAINED_COUNTER
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
from ray.rllib.utils.test_utils import framework_iterator
class TestDistributedExecution(unittest.TestCase):
"""General tests for the distributed execution API."""
@classmethod
def setUpClass(cls):
ray.init(num_cpus=4)
@classmethod
def tearDownClass(cls):
ray.shutdown()
def test_exec_plan_stats(ray_start_regular):
for fw in framework_iterator(frameworks=("torch", "tf")):
trainer = A2CTrainer(
env="CartPole-v0",
config={
"min_time_s_per_reporting": 0,
"framework": fw,
"_disable_execution_plan_api": False,
},
)
result = trainer.train()
assert isinstance(result, dict)
assert "info" in result
assert LEARNER_INFO in result["info"]
assert STEPS_SAMPLED_COUNTER in result["info"]
assert STEPS_TRAINED_COUNTER in result["info"]
assert "timers" in result
assert "learn_time_ms" in result["timers"]
assert "learn_throughput" in result["timers"]
assert "sample_time_ms" in result["timers"]
assert "sample_throughput" in result["timers"]
assert "update_time_ms" in result["timers"]
def test_exec_plan_save_restore(ray_start_regular):
for fw in framework_iterator(frameworks=("torch", "tf")):
trainer = A2CTrainer(
env="CartPole-v0",
config={
"min_time_s_per_reporting": 0,
"framework": fw,
"_disable_execution_plan_api": False,
},
)
res1 = trainer.train()
checkpoint = trainer.save()
for _ in range(2):
res2 = trainer.train()
assert res2["timesteps_total"] > res1["timesteps_total"], (res1, res2)
trainer.restore(checkpoint)
# Should restore the timesteps counter to the same as res2.
res3 = trainer.train()
assert res3["timesteps_total"] < res2["timesteps_total"], (res2, res3)
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))