mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Remove execution_plan
API code no longer needed. (#24501)
This commit is contained in:
parent
f891a2b6f1
commit
f54557073e
16 changed files with 8 additions and 807 deletions
|
@ -1706,13 +1706,6 @@ py_test(
|
|||
srcs = ["tests/test_env_with_subprocess.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tests/test_exec_api",
|
||||
tags = ["team:ml", "tests_dir", "tests_dir_E"],
|
||||
size = "medium",
|
||||
srcs = ["tests/test_exec_api.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tests/test_execution",
|
||||
tags = ["team:ml", "tests_dir", "tests_dir_E"],
|
||||
|
@ -2855,7 +2848,7 @@ py_test(
|
|||
py_test(
|
||||
name = "examples/rollout_worker_custom_workflow",
|
||||
tags = ["team:ml", "examples", "examples_R"],
|
||||
size = "small",
|
||||
size = "medium",
|
||||
srcs = ["examples/rollout_worker_custom_workflow.py"],
|
||||
args = ["--num-cpus=4"]
|
||||
)
|
||||
|
|
|
@ -4,24 +4,13 @@ from typing import Optional
|
|||
|
||||
from ray.rllib.agents.a3c.a3c import A3CConfig, A3CTrainer
|
||||
from ray.rllib.agents.trainer import Trainer
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.execution.common import (
|
||||
STEPS_TRAINED_COUNTER,
|
||||
STEPS_TRAINED_THIS_ITER_COUNTER,
|
||||
)
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.execution.rollout_ops import (
|
||||
ParallelRollouts,
|
||||
ConcatBatches,
|
||||
synchronous_parallel_sample,
|
||||
)
|
||||
from ray.rllib.execution.train_ops import (
|
||||
ComputeGradients,
|
||||
AverageGradients,
|
||||
ApplyGradients,
|
||||
MultiGPUTrainOneStep,
|
||||
TrainOneStep,
|
||||
)
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.deprecation import Deprecated
|
||||
|
@ -37,7 +26,6 @@ from ray.rllib.utils.typing import (
|
|||
ResultDict,
|
||||
TrainerConfigDict,
|
||||
)
|
||||
from ray.util.iter import LocalIterator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -228,60 +216,6 @@ class A2CTrainer(A3CTrainer):
|
|||
|
||||
return train_results
|
||||
|
||||
@staticmethod
|
||||
@override(Trainer)
|
||||
def execution_plan(
|
||||
workers: WorkerSet, config: TrainerConfigDict, **kwargs
|
||||
) -> LocalIterator[dict]:
|
||||
assert (
|
||||
len(kwargs) == 0
|
||||
), "A2C execution_plan does NOT take any additional parameters"
|
||||
|
||||
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
||||
|
||||
if config["microbatch_size"]:
|
||||
num_microbatches = math.ceil(
|
||||
config["train_batch_size"] / config["microbatch_size"]
|
||||
)
|
||||
# In microbatch mode, we want to compute gradients on experience
|
||||
# microbatches, average a number of these microbatches, and then
|
||||
# apply the averaged gradient in one SGD step. This conserves GPU
|
||||
# memory, allowing for extremely large experience batches to be
|
||||
# used.
|
||||
train_op = (
|
||||
rollouts.combine(
|
||||
ConcatBatches(
|
||||
min_batch_size=config["microbatch_size"],
|
||||
count_steps_by=config["multiagent"]["count_steps_by"],
|
||||
)
|
||||
)
|
||||
.for_each(ComputeGradients(workers)) # (grads, info)
|
||||
.batch(num_microbatches) # List[(grads, info)]
|
||||
.for_each(AverageGradients()) # (avg_grads, info)
|
||||
.for_each(ApplyGradients(workers))
|
||||
)
|
||||
else:
|
||||
# In normal mode, we execute one SGD step per each train batch.
|
||||
if config["simple_optimizer"]:
|
||||
train_step_op = TrainOneStep(workers)
|
||||
else:
|
||||
train_step_op = MultiGPUTrainOneStep(
|
||||
workers=workers,
|
||||
sgd_minibatch_size=config["train_batch_size"],
|
||||
num_sgd_iter=1,
|
||||
num_gpus=config["num_gpus"],
|
||||
_fake_gpus=config["_fake_gpus"],
|
||||
)
|
||||
|
||||
train_op = rollouts.combine(
|
||||
ConcatBatches(
|
||||
min_batch_size=config["train_batch_size"],
|
||||
count_steps_by=config["multiagent"]["count_steps_by"],
|
||||
)
|
||||
).for_each(train_step_op)
|
||||
|
||||
return StandardMetricsReporting(train_op, workers, config)
|
||||
|
||||
|
||||
# Deprecated: Use ray.rllib.agents.a3c.A2CConfig instead!
|
||||
class _deprecated_default_config(dict):
|
||||
|
|
|
@ -6,11 +6,7 @@ from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy
|
|||
from ray.rllib.agents.trainer import Trainer
|
||||
from ray.rllib.agents.trainer_config import TrainerConfig
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.execution.parallel_requests import asynchronous_parallel_requests
|
||||
from ray.rllib.execution.rollout_ops import AsyncGradients
|
||||
from ray.rllib.execution.train_ops import ApplyGradients
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.deprecation import Deprecated
|
||||
|
@ -25,7 +21,6 @@ from ray.rllib.utils.metrics import (
|
|||
)
|
||||
from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder
|
||||
from ray.rllib.utils.typing import ResultDict, TrainerConfigDict
|
||||
from ray.util.iter import LocalIterator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -248,25 +243,6 @@ class A3CTrainer(Trainer):
|
|||
|
||||
return learner_info_builder.finalize()
|
||||
|
||||
@staticmethod
|
||||
@override(Trainer)
|
||||
def execution_plan(
|
||||
workers: WorkerSet, config: TrainerConfigDict, **kwargs
|
||||
) -> LocalIterator[dict]:
|
||||
assert (
|
||||
len(kwargs) == 0
|
||||
), "A3C execution_plan does NOT take any additional parameters"
|
||||
|
||||
# For A3C, compute policy gradients remotely on the rollout workers.
|
||||
grads = AsyncGradients(workers)
|
||||
|
||||
# Apply the gradients as they arrive. We set update_all to False so
|
||||
# that only the worker sending the gradient is updated with new
|
||||
# weights.
|
||||
train_op = grads.for_each(ApplyGradients(workers, update_all=False))
|
||||
|
||||
return StandardMetricsReporting(train_op, workers, config)
|
||||
|
||||
|
||||
# Deprecated: Use ray.rllib.agents.a3c.A3CConfig instead!
|
||||
class _deprecated_default_config(dict):
|
||||
|
|
|
@ -5,14 +5,9 @@ from typing import Type
|
|||
from ray.rllib.agents.cql.cql_tf_policy import CQLTFPolicy
|
||||
from ray.rllib.agents.cql.cql_torch_policy import CQLTorchPolicy
|
||||
from ray.rllib.agents.sac.sac import SACTrainer, DEFAULT_CONFIG as SAC_CONFIG
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.execution.replay_ops import Replay
|
||||
from ray.rllib.execution.train_ops import (
|
||||
multi_gpu_train_one_step,
|
||||
MultiGPUTrainOneStep,
|
||||
train_one_step,
|
||||
TrainOneStep,
|
||||
UpdateTargetNetwork,
|
||||
)
|
||||
from ray.rllib.offline.shuffled_input import ShuffledInput
|
||||
from ray.rllib.policy.policy import Policy
|
||||
|
@ -29,7 +24,6 @@ from ray.rllib.utils.metrics import (
|
|||
TARGET_NET_UPDATE_TIMER,
|
||||
SYNCH_WORKER_WEIGHTS_TIMER,
|
||||
)
|
||||
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
|
||||
from ray.rllib.utils.replay_buffers.utils import update_priorities_in_replay_buffer
|
||||
from ray.rllib.utils.typing import ResultDict, TrainerConfigDict
|
||||
|
||||
|
@ -223,63 +217,3 @@ class CQLTrainer(SACTrainer):
|
|||
|
||||
# Return all collected metrics for the iteration.
|
||||
return train_results
|
||||
|
||||
@staticmethod
|
||||
@override(SACTrainer)
|
||||
def execution_plan(workers, config, **kwargs):
|
||||
assert (
|
||||
"local_replay_buffer" in kwargs
|
||||
), "CQL execution plan requires a local replay buffer."
|
||||
|
||||
local_replay_buffer = kwargs["local_replay_buffer"]
|
||||
|
||||
def update_prio(item):
|
||||
samples, info_dict = item
|
||||
if config.get("prioritized_replay"):
|
||||
prio_dict = {}
|
||||
for policy_id, info in info_dict.items():
|
||||
# TODO(sven): This is currently structured differently for
|
||||
# torch/tf. Clean up these results/info dicts across
|
||||
# policies (note: fixing this in torch_policy.py will
|
||||
# break e.g. DDPPO!).
|
||||
td_error = info.get(
|
||||
"td_error", info[LEARNER_STATS_KEY].get("td_error")
|
||||
)
|
||||
samples.policy_batches[policy_id].set_get_interceptor(None)
|
||||
prio_dict[policy_id] = (
|
||||
samples.policy_batches[policy_id].get("batch_indexes"),
|
||||
td_error,
|
||||
)
|
||||
local_replay_buffer.update_priorities(prio_dict)
|
||||
return info_dict
|
||||
|
||||
# (2) Read and train on experiences from the replay buffer. Every batch
|
||||
# returned from the LocalReplay() iterator is passed to TrainOneStep to
|
||||
# take a SGD step, and then we decide whether to update the target
|
||||
# network.
|
||||
post_fn = config.get("before_learn_on_batch") or (lambda b, *a: b)
|
||||
|
||||
if config["simple_optimizer"]:
|
||||
train_step_op = TrainOneStep(workers)
|
||||
else:
|
||||
train_step_op = MultiGPUTrainOneStep(
|
||||
workers=workers,
|
||||
sgd_minibatch_size=config["train_batch_size"],
|
||||
num_sgd_iter=1,
|
||||
num_gpus=config["num_gpus"],
|
||||
_fake_gpus=config["_fake_gpus"],
|
||||
)
|
||||
|
||||
train_op = (
|
||||
Replay(local_buffer=local_replay_buffer)
|
||||
.for_each(lambda x: post_fn(x, workers, config))
|
||||
.for_each(train_step_op)
|
||||
.for_each(update_prio)
|
||||
.for_each(
|
||||
UpdateTargetNetwork(workers, config["target_network_update_freq"])
|
||||
)
|
||||
)
|
||||
|
||||
return StandardMetricsReporting(
|
||||
train_op, workers, config, by_steps_trained=True
|
||||
)
|
||||
|
|
|
@ -172,9 +172,6 @@ DEFAULT_CONFIG = with_common_config({
|
|||
# timestep count has not been reached, will perform n more `step_attempt()` calls
|
||||
# until the minimum timesteps have been executed. Set to 0 for no minimum timesteps.
|
||||
"min_sample_timesteps_per_reporting": 1000,
|
||||
|
||||
# Experimental flag.
|
||||
"_disable_execution_plan_api": True,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
|
@ -197,6 +194,7 @@ class DDPGTrainer(SimpleQTrainer):
|
|||
|
||||
@override(SimpleQTrainer)
|
||||
def validate_config(self, config: TrainerConfigDict) -> None:
|
||||
|
||||
# Call super's validation method.
|
||||
super().validate_config(config)
|
||||
|
||||
|
|
|
@ -61,7 +61,6 @@ TD3_DEFAULT_CONFIG = DDPGTrainer.merge_trainer_configs(
|
|||
"type": "MultiAgentReplayBuffer",
|
||||
"capacity": 1000000,
|
||||
},
|
||||
"_disable_execution_plan_api": True,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
@ -19,25 +19,16 @@ from ray.rllib.agents.dqn.simple_q import (
|
|||
DEFAULT_CONFIG as SIMPLEQ_DEFAULT_CONFIG,
|
||||
)
|
||||
from ray.rllib.agents.trainer import Trainer
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.execution.concurrency_ops import Concurrently
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
|
||||
from ray.rllib.execution.rollout_ops import (
|
||||
ParallelRollouts,
|
||||
synchronous_parallel_sample,
|
||||
)
|
||||
from ray.rllib.execution.train_ops import (
|
||||
TrainOneStep,
|
||||
UpdateTargetNetwork,
|
||||
MultiGPUTrainOneStep,
|
||||
train_one_step,
|
||||
multi_gpu_train_one_step,
|
||||
)
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.replay_buffers.utils import update_priorities_in_replay_buffer
|
||||
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
|
||||
from ray.rllib.utils.typing import (
|
||||
ResultDict,
|
||||
TrainerConfigDict,
|
||||
|
@ -46,11 +37,6 @@ from ray.rllib.utils.metrics import (
|
|||
NUM_ENV_STEPS_SAMPLED,
|
||||
NUM_AGENT_STEPS_SAMPLED,
|
||||
)
|
||||
from ray.util.iter import LocalIterator
|
||||
from ray.rllib.utils.replay_buffers import MultiAgentPrioritizedReplayBuffer
|
||||
from ray.rllib.execution.buffers.multi_agent_replay_buffer import (
|
||||
MultiAgentReplayBuffer as LegacyMultiAgentReplayBuffer,
|
||||
)
|
||||
from ray.rllib.utils.deprecation import (
|
||||
Deprecated,
|
||||
DEPRECATED_VALUE,
|
||||
|
@ -280,103 +266,6 @@ class DQNTrainer(SimpleQTrainer):
|
|||
# Return all collected metrics for the iteration.
|
||||
return train_results
|
||||
|
||||
@staticmethod
|
||||
@override(SimpleQTrainer)
|
||||
def execution_plan(
|
||||
workers: WorkerSet, config: TrainerConfigDict, **kwargs
|
||||
) -> LocalIterator[dict]:
|
||||
assert (
|
||||
"local_replay_buffer" in kwargs
|
||||
), "DQN's execution plan requires a local replay buffer."
|
||||
|
||||
# Assign to Trainer, so we can store the MultiAgentReplayBuffer's
|
||||
# data when we save checkpoints.
|
||||
local_replay_buffer = kwargs["local_replay_buffer"]
|
||||
|
||||
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
||||
|
||||
# We execute the following steps concurrently:
|
||||
# (1) Generate rollouts and store them in our local replay buffer.
|
||||
# Calling next() on store_op drives this.
|
||||
store_op = rollouts.for_each(
|
||||
StoreToReplayBuffer(local_buffer=local_replay_buffer)
|
||||
)
|
||||
|
||||
def update_prio(item):
|
||||
samples, info_dict = item
|
||||
prio_dict = {}
|
||||
for policy_id, info in info_dict.items():
|
||||
# TODO(sven): This is currently structured differently for
|
||||
# torch/tf. Clean up these results/info dicts across
|
||||
# policies (note: fixing this in torch_policy.py will
|
||||
# break e.g. DDPPO!).
|
||||
td_error = info.get("td_error", info[LEARNER_STATS_KEY].get("td_error"))
|
||||
samples.policy_batches[policy_id].set_get_interceptor(None)
|
||||
batch_indices = samples.policy_batches[policy_id].get("batch_indexes")
|
||||
# In case the buffer stores sequences, TD-error could
|
||||
# already be calculated per sequence chunk.
|
||||
if len(batch_indices) != len(td_error):
|
||||
T = local_replay_buffer.replay_sequence_length
|
||||
assert (
|
||||
len(batch_indices) > len(td_error)
|
||||
and len(batch_indices) % T == 0
|
||||
)
|
||||
batch_indices = batch_indices.reshape([-1, T])[:, 0]
|
||||
assert len(batch_indices) == len(td_error)
|
||||
prio_dict[policy_id] = (batch_indices, td_error)
|
||||
local_replay_buffer.update_priorities(prio_dict)
|
||||
return info_dict
|
||||
|
||||
# (2) Read and train on experiences from the replay buffer. Every batch
|
||||
# returned from the LocalReplay() iterator is passed to TrainOneStep to
|
||||
# take a SGD step, and then we decide whether to update the target
|
||||
# network.
|
||||
post_fn = config.get("before_learn_on_batch") or (lambda b, *a: b)
|
||||
|
||||
if config["simple_optimizer"]:
|
||||
train_step_op = TrainOneStep(workers)
|
||||
else:
|
||||
train_step_op = MultiGPUTrainOneStep(
|
||||
workers=workers,
|
||||
sgd_minibatch_size=config["train_batch_size"],
|
||||
num_sgd_iter=1,
|
||||
num_gpus=config["num_gpus"],
|
||||
_fake_gpus=config["_fake_gpus"],
|
||||
)
|
||||
|
||||
if (
|
||||
type(local_replay_buffer) is LegacyMultiAgentReplayBuffer
|
||||
and config["replay_buffer_config"].get("prioritized_replay_alpha", 0.0)
|
||||
> 0.0
|
||||
) or isinstance(local_replay_buffer, MultiAgentPrioritizedReplayBuffer):
|
||||
update_prio_fn = update_prio
|
||||
else:
|
||||
|
||||
def update_prio_fn(x):
|
||||
return x
|
||||
|
||||
replay_op = (
|
||||
Replay(local_buffer=local_replay_buffer)
|
||||
.for_each(lambda x: post_fn(x, workers, config))
|
||||
.for_each(train_step_op)
|
||||
.for_each(update_prio_fn)
|
||||
.for_each(
|
||||
UpdateTargetNetwork(workers, config["target_network_update_freq"])
|
||||
)
|
||||
)
|
||||
|
||||
# Alternate deterministically between (1) and (2).
|
||||
# Only return the output of (2) since training metrics are not
|
||||
# available until (2) runs.
|
||||
train_op = Concurrently(
|
||||
[store_op, replay_op],
|
||||
mode="round_robin",
|
||||
output_indexes=[1],
|
||||
round_robin_weights=calculate_rr_weights(config),
|
||||
)
|
||||
|
||||
return StandardMetricsReporting(train_op, workers, config)
|
||||
|
||||
|
||||
@Deprecated(
|
||||
new="Sub-class directly from `DQNTrainer` and override its methods", error=False
|
||||
|
|
|
@ -16,23 +16,14 @@ from ray.rllib.agents.dqn.simple_q_tf_policy import SimpleQTFPolicy
|
|||
from ray.rllib.agents.dqn.simple_q_torch_policy import SimpleQTorchPolicy
|
||||
from ray.rllib.agents.trainer import Trainer, with_common_config
|
||||
from ray.rllib.utils.metrics import SYNCH_WORKER_WEIGHTS_TIMER
|
||||
from ray.rllib.execution.concurrency_ops import Concurrently
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
|
||||
from ray.rllib.utils.replay_buffers.utils import validate_buffer_config
|
||||
from ray.rllib.execution.rollout_ops import (
|
||||
ParallelRollouts,
|
||||
synchronous_parallel_sample,
|
||||
)
|
||||
from ray.rllib.execution.train_ops import (
|
||||
TrainOneStep,
|
||||
MultiGPUTrainOneStep,
|
||||
train_one_step,
|
||||
multi_gpu_train_one_step,
|
||||
)
|
||||
from ray.rllib.execution.train_ops import (
|
||||
UpdateTargetNetwork,
|
||||
)
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils.annotations import ExperimentalAPI
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
@ -267,46 +258,3 @@ class SimpleQTrainer(Trainer):
|
|||
|
||||
# Return all collected metrics for the iteration.
|
||||
return train_results
|
||||
|
||||
@staticmethod
|
||||
@override(Trainer)
|
||||
def execution_plan(workers, config, **kwargs):
|
||||
assert (
|
||||
"local_replay_buffer" in kwargs
|
||||
), "GenericOffPolicy execution plan requires a local replay buffer."
|
||||
|
||||
local_replay_buffer = kwargs["local_replay_buffer"]
|
||||
|
||||
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
||||
|
||||
# (1) Generate rollouts and store them in our local replay buffer.
|
||||
store_op = rollouts.for_each(
|
||||
StoreToReplayBuffer(local_buffer=local_replay_buffer)
|
||||
)
|
||||
|
||||
if config["simple_optimizer"]:
|
||||
train_step_op = TrainOneStep(workers)
|
||||
else:
|
||||
train_step_op = MultiGPUTrainOneStep(
|
||||
workers=workers,
|
||||
sgd_minibatch_size=config["train_batch_size"],
|
||||
num_sgd_iter=1,
|
||||
num_gpus=config["num_gpus"],
|
||||
_fake_gpus=config["_fake_gpus"],
|
||||
)
|
||||
|
||||
# (2) Read and train on experiences from the replay buffer.
|
||||
replay_op = (
|
||||
Replay(local_buffer=local_replay_buffer)
|
||||
.for_each(train_step_op)
|
||||
.for_each(
|
||||
UpdateTargetNetwork(workers, config["target_network_update_freq"])
|
||||
)
|
||||
)
|
||||
|
||||
# Alternate deterministically between (1) and (2).
|
||||
train_op = Concurrently(
|
||||
[store_op, replay_op], mode="round_robin", output_indexes=[1]
|
||||
)
|
||||
|
||||
return StandardMetricsReporting(train_op, workers, config)
|
||||
|
|
|
@ -2,19 +2,12 @@ from typing import Type
|
|||
|
||||
from ray.rllib.agents.trainer import Trainer, with_common_config
|
||||
from ray.rllib.agents.marwil.marwil_tf_policy import MARWILTFPolicy
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.execution.buffers.multi_agent_replay_buffer import MultiAgentReplayBuffer
|
||||
from ray.rllib.execution.concurrency_ops import Concurrently
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
|
||||
from ray.rllib.execution.rollout_ops import (
|
||||
ConcatBatches,
|
||||
ParallelRollouts,
|
||||
synchronous_parallel_sample,
|
||||
)
|
||||
from ray.rllib.execution.train_ops import (
|
||||
multi_gpu_train_one_step,
|
||||
TrainOneStep,
|
||||
train_one_step,
|
||||
)
|
||||
from ray.rllib.policy.policy import Policy
|
||||
|
@ -31,7 +24,6 @@ from ray.rllib.utils.typing import (
|
|||
ResultDict,
|
||||
TrainerConfigDict,
|
||||
)
|
||||
from ray.util.iter import LocalIterator
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
|
@ -171,39 +163,3 @@ class MARWILTrainer(Trainer):
|
|||
self.workers.local_worker().set_global_vars(global_vars)
|
||||
|
||||
return train_results
|
||||
|
||||
@staticmethod
|
||||
@override(Trainer)
|
||||
def execution_plan(
|
||||
workers: WorkerSet, config: TrainerConfigDict, **kwargs
|
||||
) -> LocalIterator[dict]:
|
||||
assert (
|
||||
len(kwargs) == 0
|
||||
), "Marwill execution_plan does NOT take any additional parameters"
|
||||
|
||||
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
||||
replay_buffer = MultiAgentReplayBuffer(
|
||||
learning_starts=config["learning_starts"],
|
||||
capacity=config["replay_buffer_size"],
|
||||
replay_batch_size=config["train_batch_size"],
|
||||
replay_sequence_length=1,
|
||||
)
|
||||
|
||||
store_op = rollouts.for_each(StoreToReplayBuffer(local_buffer=replay_buffer))
|
||||
|
||||
replay_op = (
|
||||
Replay(local_buffer=replay_buffer)
|
||||
.combine(
|
||||
ConcatBatches(
|
||||
min_batch_size=config["train_batch_size"],
|
||||
count_steps_by=config["multiagent"]["count_steps_by"],
|
||||
)
|
||||
)
|
||||
.for_each(TrainOneStep(workers))
|
||||
)
|
||||
|
||||
train_op = Concurrently(
|
||||
[store_op, replay_op], mode="round_robin", output_indexes=[1]
|
||||
)
|
||||
|
||||
return StandardMetricsReporting(train_op, workers, config)
|
||||
|
|
|
@ -25,18 +25,10 @@ from ray.rllib.agents.ppo.ppo import DEFAULT_CONFIG as PPO_DEFAULT_CONFIG, PPOTr
|
|||
from ray.rllib.agents.trainer import Trainer
|
||||
from ray.rllib.evaluation.postprocessing import Postprocessing
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.execution.common import (
|
||||
STEPS_SAMPLED_COUNTER,
|
||||
STEPS_TRAINED_COUNTER,
|
||||
STEPS_TRAINED_THIS_ITER_COUNTER,
|
||||
_get_shared_metrics,
|
||||
_get_global_vars,
|
||||
)
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.execution.parallel_requests import asynchronous_parallel_requests
|
||||
from ray.rllib.execution.rollout_ops import ParallelRollouts
|
||||
from ray.rllib.evaluation.rollout_worker import get_global_worker
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.metrics import (
|
||||
LEARN_ON_BATCH_TIMER,
|
||||
|
@ -46,7 +38,7 @@ from ray.rllib.utils.metrics import (
|
|||
NUM_ENV_STEPS_TRAINED,
|
||||
SAMPLE_TIMER,
|
||||
)
|
||||
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, LearnerInfoBuilder
|
||||
from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder
|
||||
from ray.rllib.utils.sgd import do_minibatch_sgd
|
||||
from ray.rllib.utils.typing import (
|
||||
EnvType,
|
||||
|
@ -55,7 +47,6 @@ from ray.rllib.utils.typing import (
|
|||
TrainerConfigDict,
|
||||
)
|
||||
from ray.tune.logger import Logger
|
||||
from ray.util.iter import LocalIterator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -321,141 +312,3 @@ class DDPPOTrainer(PPOTrainer):
|
|||
"sample_time": sample_time,
|
||||
"learn_on_batch_time": learn_on_batch_time,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
@override(PPOTrainer)
|
||||
def execution_plan(
|
||||
workers: WorkerSet, config: TrainerConfigDict, **kwargs
|
||||
) -> LocalIterator[dict]:
|
||||
"""Execution plan of the DD-PPO algorithm. Defines the distributed dataflow.
|
||||
|
||||
Args:
|
||||
workers (WorkerSet): The WorkerSet for training the Polic(y/ies)
|
||||
of the Trainer.
|
||||
config (TrainerConfigDict): The trainer's configuration dict.
|
||||
|
||||
Returns:
|
||||
LocalIterator[dict]: The Policy class to use with PGTrainer.
|
||||
If None, use `get_default_policy_class()` provided by Trainer.
|
||||
"""
|
||||
assert (
|
||||
len(kwargs) == 0
|
||||
), "DDPPO execution_plan does NOT take any additional parameters"
|
||||
|
||||
rollouts = ParallelRollouts(workers, mode="raw")
|
||||
|
||||
# Setup the distributed processes.
|
||||
ip = ray.get(workers.remote_workers()[0].get_node_ip.remote())
|
||||
port = ray.get(workers.remote_workers()[0].find_free_port.remote())
|
||||
address = "tcp://{ip}:{port}".format(ip=ip, port=port)
|
||||
logger.info("Creating torch process group with leader {}".format(address))
|
||||
|
||||
# Get setup tasks in order to throw errors on failure.
|
||||
ray.get(
|
||||
[
|
||||
worker.setup_torch_data_parallel.remote(
|
||||
url=address,
|
||||
world_rank=i,
|
||||
world_size=len(workers.remote_workers()),
|
||||
backend=config["torch_distributed_backend"],
|
||||
)
|
||||
for i, worker in enumerate(workers.remote_workers())
|
||||
]
|
||||
)
|
||||
logger.info("Torch process group init completed")
|
||||
|
||||
# This function is applied remotely on each rollout worker.
|
||||
def train_torch_distributed_allreduce(batch):
|
||||
expected_batch_size = (
|
||||
config["rollout_fragment_length"] * config["num_envs_per_worker"]
|
||||
)
|
||||
this_worker = get_global_worker()
|
||||
assert batch.count == expected_batch_size, (
|
||||
"Batch size possibly out of sync between workers, expected:",
|
||||
expected_batch_size,
|
||||
"got:",
|
||||
batch.count,
|
||||
)
|
||||
logger.info(
|
||||
"Executing distributed minibatch SGD "
|
||||
"with epoch size {}, minibatch size {}".format(
|
||||
batch.count, config["sgd_minibatch_size"]
|
||||
)
|
||||
)
|
||||
info = do_minibatch_sgd(
|
||||
batch,
|
||||
this_worker.policy_map,
|
||||
this_worker,
|
||||
config["num_sgd_iter"],
|
||||
config["sgd_minibatch_size"],
|
||||
["advantages"],
|
||||
)
|
||||
return info, batch.count
|
||||
|
||||
# Broadcast the local set of global vars.
|
||||
def update_worker_global_vars(item):
|
||||
global_vars = _get_global_vars()
|
||||
for w in workers.remote_workers():
|
||||
w.set_global_vars.remote(global_vars)
|
||||
return item
|
||||
|
||||
# Have to manually record stats since we are using "raw" rollouts mode.
|
||||
class RecordStats:
|
||||
def _on_fetch_start(self):
|
||||
self.fetch_start_time = time.perf_counter()
|
||||
|
||||
def __call__(self, items):
|
||||
assert len(items) == config["num_workers"]
|
||||
for item in items:
|
||||
info, count = item
|
||||
metrics = _get_shared_metrics()
|
||||
metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = count
|
||||
metrics.counters[STEPS_SAMPLED_COUNTER] += count
|
||||
metrics.counters[STEPS_TRAINED_COUNTER] += count
|
||||
metrics.info[LEARNER_INFO] = info
|
||||
# Since SGD happens remotely, the time delay between fetch and
|
||||
# completion is approximately the SGD step time.
|
||||
metrics.timers[LEARN_ON_BATCH_TIMER].push(
|
||||
time.perf_counter() - self.fetch_start_time
|
||||
)
|
||||
|
||||
train_op = (
|
||||
rollouts.for_each(train_torch_distributed_allreduce) # allreduce
|
||||
.batch_across_shards() # List[(grad_info, count)]
|
||||
.for_each(RecordStats())
|
||||
)
|
||||
|
||||
train_op = train_op.for_each(update_worker_global_vars)
|
||||
|
||||
# Sync down the weights. As with the sync up, this is not really
|
||||
# needed unless the user is reading the local weights.
|
||||
if config["keep_local_weights_in_sync"]:
|
||||
|
||||
def download_weights(item):
|
||||
workers.local_worker().set_weights(
|
||||
ray.get(workers.remote_workers()[0].get_weights.remote())
|
||||
)
|
||||
return item
|
||||
|
||||
train_op = train_op.for_each(download_weights)
|
||||
|
||||
# In debug mode, check the allreduce successfully synced the weights.
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
|
||||
def check_sync(item):
|
||||
weights = ray.get(
|
||||
[w.get_weights.remote() for w in workers.remote_workers()]
|
||||
)
|
||||
sums = []
|
||||
for w in weights:
|
||||
acc = 0
|
||||
for p in w.values():
|
||||
for k, v in p.items():
|
||||
acc += v.sum()
|
||||
sums.append(float(acc))
|
||||
logger.debug("The worker weight sums are {}".format(sums))
|
||||
assert len(set(sums)) == 1, sums
|
||||
|
||||
train_op = train_op.for_each(check_sync)
|
||||
|
||||
return StandardMetricsReporting(train_op, workers, config)
|
||||
|
|
|
@ -16,29 +16,20 @@ from ray.util.debug import log_once
|
|||
from ray.rllib.agents.ppo.ppo_tf_policy import PPOTFPolicy
|
||||
from ray.rllib.agents.trainer import Trainer
|
||||
from ray.rllib.agents.trainer_config import TrainerConfig
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.execution.rollout_ops import (
|
||||
ParallelRollouts,
|
||||
ConcatBatches,
|
||||
StandardizeFields,
|
||||
standardize_fields,
|
||||
SelectExperiences,
|
||||
)
|
||||
from ray.rllib.execution.train_ops import (
|
||||
TrainOneStep,
|
||||
MultiGPUTrainOneStep,
|
||||
train_one_step,
|
||||
multi_gpu_train_one_step,
|
||||
)
|
||||
from ray.rllib.utils.annotations import ExperimentalAPI
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.deprecation import Deprecated
|
||||
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, LEARNER_STATS_KEY
|
||||
from ray.rllib.utils.typing import TrainerConfigDict, ResultDict
|
||||
from ray.util.iter import LocalIterator
|
||||
from ray.rllib.execution.rollout_ops import synchronous_parallel_sample
|
||||
from ray.rllib.utils.metrics import (
|
||||
NUM_AGENT_STEPS_SAMPLED,
|
||||
|
@ -456,59 +447,6 @@ class PPOTrainer(Trainer):
|
|||
|
||||
return train_results
|
||||
|
||||
@staticmethod
|
||||
@override(Trainer)
|
||||
def execution_plan(
|
||||
workers: WorkerSet, config: TrainerConfigDict, **kwargs
|
||||
) -> LocalIterator[dict]:
|
||||
assert (
|
||||
len(kwargs) == 0
|
||||
), "PPO execution_plan does NOT take any additional parameters"
|
||||
|
||||
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
||||
|
||||
# Collect batches for the trainable policies.
|
||||
rollouts = rollouts.for_each(
|
||||
SelectExperiences(local_worker=workers.local_worker())
|
||||
)
|
||||
# Concatenate the SampleBatches into one.
|
||||
rollouts = rollouts.combine(
|
||||
ConcatBatches(
|
||||
min_batch_size=config["train_batch_size"],
|
||||
count_steps_by=config["multiagent"]["count_steps_by"],
|
||||
)
|
||||
)
|
||||
# Standardize advantages.
|
||||
rollouts = rollouts.for_each(StandardizeFields(["advantages"]))
|
||||
|
||||
# Perform one training step on the combined + standardized batch.
|
||||
if config["simple_optimizer"]:
|
||||
train_op = rollouts.for_each(
|
||||
TrainOneStep(
|
||||
workers,
|
||||
num_sgd_iter=config["num_sgd_iter"],
|
||||
sgd_minibatch_size=config["sgd_minibatch_size"],
|
||||
)
|
||||
)
|
||||
else:
|
||||
train_op = rollouts.for_each(
|
||||
MultiGPUTrainOneStep(
|
||||
workers=workers,
|
||||
sgd_minibatch_size=config["sgd_minibatch_size"],
|
||||
num_sgd_iter=config["num_sgd_iter"],
|
||||
num_gpus=config["num_gpus"],
|
||||
_fake_gpus=config["_fake_gpus"],
|
||||
)
|
||||
)
|
||||
|
||||
# Update KL after each round of training.
|
||||
train_op = train_op.for_each(lambda t: t[1]).for_each(UpdateKL(workers))
|
||||
|
||||
# Warn about bad reward scales and return training metrics.
|
||||
return StandardMetricsReporting(train_op, workers, config).for_each(
|
||||
lambda result: warn_about_bad_reward_scales(config, result)
|
||||
)
|
||||
|
||||
|
||||
# Deprecated: Use ray.rllib.agents.ppo.PPOConfig instead!
|
||||
class _deprecated_default_config(dict):
|
||||
|
|
|
@ -3,24 +3,12 @@ from typing import Type
|
|||
from ray.rllib.agents.trainer import with_common_config
|
||||
from ray.rllib.agents.dqn.simple_q import SimpleQTrainer
|
||||
from ray.rllib.agents.qmix.qmix_policy import QMixTorchPolicy
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.execution.concurrency_ops import Concurrently
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.execution.replay_ops import (
|
||||
SimpleReplayBuffer,
|
||||
Replay,
|
||||
StoreToReplayBuffer,
|
||||
)
|
||||
from ray.rllib.execution.rollout_ops import (
|
||||
ConcatBatches,
|
||||
ParallelRollouts,
|
||||
synchronous_parallel_sample,
|
||||
)
|
||||
from ray.rllib.execution.train_ops import (
|
||||
multi_gpu_train_one_step,
|
||||
train_one_step,
|
||||
TrainOneStep,
|
||||
UpdateTargetNetwork,
|
||||
)
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
@ -34,7 +22,6 @@ from ray.rllib.utils.metrics import (
|
|||
)
|
||||
from ray.rllib.utils.replay_buffers.utils import sample_min_n_steps_from_buffer
|
||||
from ray.rllib.utils.typing import ResultDict, TrainerConfigDict
|
||||
from ray.util.iter import LocalIterator
|
||||
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
|
@ -232,37 +219,3 @@ class QMixTrainer(SimpleQTrainer):
|
|||
|
||||
# Return all collected metrics for the iteration.
|
||||
return train_results
|
||||
|
||||
@staticmethod
|
||||
@override(SimpleQTrainer)
|
||||
def execution_plan(
|
||||
workers: WorkerSet, config: TrainerConfigDict, **kwargs
|
||||
) -> LocalIterator[dict]:
|
||||
assert (
|
||||
len(kwargs) == 0
|
||||
), "QMIX execution_plan does NOT take any additional parameters"
|
||||
|
||||
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
||||
replay_buffer = SimpleReplayBuffer(config["buffer_size"])
|
||||
|
||||
store_op = rollouts.for_each(StoreToReplayBuffer(local_buffer=replay_buffer))
|
||||
|
||||
train_op = (
|
||||
Replay(local_buffer=replay_buffer)
|
||||
.combine(
|
||||
ConcatBatches(
|
||||
min_batch_size=config["train_batch_size"],
|
||||
count_steps_by=config["multiagent"]["count_steps_by"],
|
||||
)
|
||||
)
|
||||
.for_each(TrainOneStep(workers))
|
||||
.for_each(
|
||||
UpdateTargetNetwork(workers, config["target_network_update_freq"])
|
||||
)
|
||||
)
|
||||
|
||||
merged_op = Concurrently(
|
||||
[store_op, train_op], mode="round_robin", output_indexes=[1]
|
||||
)
|
||||
|
||||
return StandardMetricsReporting(merged_op, workers, config)
|
||||
|
|
|
@ -19,21 +19,10 @@ from ray.rllib.agents.dqn.dqn import DQNTrainer
|
|||
from ray.rllib.agents.slateq.slateq_tf_policy import SlateQTFPolicy
|
||||
from ray.rllib.agents.slateq.slateq_torch_policy import SlateQTorchPolicy
|
||||
from ray.rllib.agents.trainer import with_common_config
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.execution.concurrency_ops import Concurrently
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.execution.replay_ops import Replay, StoreToReplayBuffer
|
||||
from ray.rllib.execution.rollout_ops import ParallelRollouts
|
||||
from ray.rllib.execution.train_ops import (
|
||||
MultiGPUTrainOneStep,
|
||||
TrainOneStep,
|
||||
UpdateTargetNetwork,
|
||||
)
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
from ray.util.iter import LocalIterator
|
||||
from ray.rllib.utils.replay_buffers.utils import validate_buffer_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -182,55 +171,3 @@ class SlateQTrainer(DQNTrainer):
|
|||
return SlateQTorchPolicy
|
||||
else:
|
||||
return SlateQTFPolicy
|
||||
|
||||
@staticmethod
|
||||
@override(DQNTrainer)
|
||||
def execution_plan(
|
||||
workers: WorkerSet, config: TrainerConfigDict, **kwargs
|
||||
) -> LocalIterator[dict]:
|
||||
assert (
|
||||
"local_replay_buffer" in kwargs
|
||||
), "SlateQ execution plan requires a local replay buffer."
|
||||
|
||||
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
||||
|
||||
# We execute the following steps concurrently:
|
||||
# (1) Generate rollouts and store them in our local replay buffer.
|
||||
# Calling next() on store_op drives this.
|
||||
store_op = rollouts.for_each(
|
||||
StoreToReplayBuffer(local_buffer=kwargs["local_replay_buffer"])
|
||||
)
|
||||
|
||||
if config["simple_optimizer"]:
|
||||
train_step_op = TrainOneStep(workers)
|
||||
else:
|
||||
train_step_op = MultiGPUTrainOneStep(
|
||||
workers=workers,
|
||||
sgd_minibatch_size=config["train_batch_size"],
|
||||
num_sgd_iter=1,
|
||||
num_gpus=config["num_gpus"],
|
||||
_fake_gpus=config["_fake_gpus"],
|
||||
)
|
||||
|
||||
# (2) Read and train on experiences from the replay buffer. Every batch
|
||||
# returned from the LocalReplay() iterator is passed to TrainOneStep to
|
||||
# take a SGD step.
|
||||
replay_op = (
|
||||
Replay(local_buffer=kwargs["local_replay_buffer"])
|
||||
.for_each(train_step_op)
|
||||
.for_each(
|
||||
UpdateTargetNetwork(workers, config["target_network_update_freq"])
|
||||
)
|
||||
)
|
||||
|
||||
# Alternate deterministically between (1) and (2). Only return the
|
||||
# output of (2) since training metrics are not available until (2)
|
||||
# runs.
|
||||
train_op = Concurrently(
|
||||
[store_op, replay_op],
|
||||
mode="round_robin",
|
||||
output_indexes=[1],
|
||||
round_robin_weights=calculate_round_robin_weights(config),
|
||||
)
|
||||
|
||||
return StandardMetricsReporting(train_op, workers, config)
|
||||
|
|
|
@ -40,20 +40,15 @@ from ray.rllib.evaluation.metrics import (
|
|||
)
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.execution.buffers.multi_agent_replay_buffer import (
|
||||
MultiAgentReplayBuffer as Legacy_MultiAgentReplayBuffer,
|
||||
)
|
||||
from ray.rllib.utils.replay_buffers import MultiAgentReplayBuffer
|
||||
from ray.rllib.execution.common import WORKER_UPDATE_TIMER
|
||||
from ray.rllib.execution.rollout_ops import (
|
||||
ConcatBatches,
|
||||
ParallelRollouts,
|
||||
synchronous_parallel_sample,
|
||||
)
|
||||
from ray.rllib.execution.train_ops import (
|
||||
TrainOneStep,
|
||||
MultiGPUTrainOneStep,
|
||||
train_one_step,
|
||||
multi_gpu_train_one_step,
|
||||
)
|
||||
|
@ -909,41 +904,14 @@ class Trainer(Trainable):
|
|||
|
||||
return train_results
|
||||
|
||||
@DeveloperAPI
|
||||
@staticmethod
|
||||
def execution_plan(workers, config, **kwargs):
|
||||
|
||||
# Collects experiences in parallel from multiple RolloutWorker actors.
|
||||
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
||||
|
||||
# Combine experiences batches until we hit `train_batch_size` in size.
|
||||
# Then, train the policy on those experiences and update the workers.
|
||||
train_op = rollouts.combine(
|
||||
ConcatBatches(
|
||||
min_batch_size=config["train_batch_size"],
|
||||
count_steps_by=config["multiagent"]["count_steps_by"],
|
||||
raise NotImplementedError(
|
||||
"It is not longer recommended to use Trainer's `execution_plan` method/API."
|
||||
" Set `_disable_execution_plan_api=True` in your config and override the "
|
||||
"`Trainer.training_iteration()` method with your algo's custom "
|
||||
"execution logic."
|
||||
)
|
||||
)
|
||||
|
||||
if config.get("simple_optimizer") is True:
|
||||
train_op = train_op.for_each(TrainOneStep(workers))
|
||||
else:
|
||||
train_op = train_op.for_each(
|
||||
MultiGPUTrainOneStep(
|
||||
workers=workers,
|
||||
sgd_minibatch_size=config.get(
|
||||
"sgd_minibatch_size", config["train_batch_size"]
|
||||
),
|
||||
num_sgd_iter=config.get("num_sgd_iter", 1),
|
||||
num_gpus=config["num_gpus"],
|
||||
_fake_gpus=config["_fake_gpus"],
|
||||
)
|
||||
)
|
||||
|
||||
# Add on the standard episode reward, etc. metrics reporting. This
|
||||
# returns a LocalIterator[metrics_dict] representing metrics for each
|
||||
# train step.
|
||||
return StandardMetricsReporting(train_op, workers, config)
|
||||
|
||||
@PublicAPI
|
||||
def compute_single_action(
|
||||
|
|
|
@ -119,11 +119,6 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"num_workers": 1,
|
||||
# Prevent iterations from going lower than this time span
|
||||
"min_time_s_per_reporting": 0,
|
||||
# Experimental flag.
|
||||
# If True, the execution plan API will not be used. Instead,
|
||||
# a Trainer's `training_iteration` method will be called as-is each
|
||||
# training iteration.
|
||||
"_disable_execution_plan_api": False,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# fmt: on
|
||||
|
|
|
@ -1,70 +0,0 @@
|
|||
import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.a3c import A2CTrainer
|
||||
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, STEPS_TRAINED_COUNTER
|
||||
from ray.rllib.utils.metrics.learner_info import LEARNER_INFO
|
||||
from ray.rllib.utils.test_utils import framework_iterator
|
||||
|
||||
|
||||
class TestDistributedExecution(unittest.TestCase):
|
||||
"""General tests for the distributed execution API."""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
ray.init(num_cpus=4)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
ray.shutdown()
|
||||
|
||||
def test_exec_plan_stats(ray_start_regular):
|
||||
for fw in framework_iterator(frameworks=("torch", "tf")):
|
||||
trainer = A2CTrainer(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"min_time_s_per_reporting": 0,
|
||||
"framework": fw,
|
||||
"_disable_execution_plan_api": False,
|
||||
},
|
||||
)
|
||||
result = trainer.train()
|
||||
assert isinstance(result, dict)
|
||||
assert "info" in result
|
||||
assert LEARNER_INFO in result["info"]
|
||||
assert STEPS_SAMPLED_COUNTER in result["info"]
|
||||
assert STEPS_TRAINED_COUNTER in result["info"]
|
||||
assert "timers" in result
|
||||
assert "learn_time_ms" in result["timers"]
|
||||
assert "learn_throughput" in result["timers"]
|
||||
assert "sample_time_ms" in result["timers"]
|
||||
assert "sample_throughput" in result["timers"]
|
||||
assert "update_time_ms" in result["timers"]
|
||||
|
||||
def test_exec_plan_save_restore(ray_start_regular):
|
||||
for fw in framework_iterator(frameworks=("torch", "tf")):
|
||||
trainer = A2CTrainer(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"min_time_s_per_reporting": 0,
|
||||
"framework": fw,
|
||||
"_disable_execution_plan_api": False,
|
||||
},
|
||||
)
|
||||
res1 = trainer.train()
|
||||
checkpoint = trainer.save()
|
||||
for _ in range(2):
|
||||
res2 = trainer.train()
|
||||
assert res2["timesteps_total"] > res1["timesteps_total"], (res1, res2)
|
||||
trainer.restore(checkpoint)
|
||||
|
||||
# Should restore the timesteps counter to the same as res2.
|
||||
res3 = trainer.train()
|
||||
assert res3["timesteps_total"] < res2["timesteps_total"], (res2, res3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
Loading…
Add table
Reference in a new issue