mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib + Tune] Add placement group support to RLlib. (#14289)
This commit is contained in:
parent
8000258333
commit
6cd0cd3bd9
10 changed files with 302 additions and 55 deletions
|
@ -331,6 +331,8 @@ Recurrent Replay Distributed DQN (R2D2)
|
|||
---------------------------------------
|
||||
|pytorch| |tensorflow|
|
||||
`[paper] <https://openreview.net/pdf?id=r1lyTjAqYX>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/agents/dqn/r2d2.py>`__
|
||||
R2D2 allows running the DQN algorithm with an RNN model (e.g. an LSTM). Sequences of a fixed length
|
||||
are stored in the replay buffer and takens from there for RNN-based learning updates.
|
||||
R2D2 can be scaled by increasing the number of workers. All of the DQN improvements evaluated in `Rainbow <https://arxiv.org/abs/1710.02298>`__ are available, though not all are enabled by default.
|
||||
|
||||
Tuned examples: `CartPole-v0 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/dqn/cartpole-r2d2.yaml>`__
|
||||
|
|
|
@ -314,7 +314,7 @@ Now let's look at each PPO policy definition:
|
|||
|
||||
return stats_fetches
|
||||
|
||||
``extra_actions_fetches_fn``: This function defines extra outputs that will be recorded when generating actions with the policy. For example, this enables saving the raw policy logits in the experience batch, which e.g. means it can be referenced in the PPO loss function via ``batch[BEHAVIOUR_LOGITS]``. Other values such as the current value prediction can also be emitted for debugging or optimization purposes:
|
||||
``extra_action_out_fn``: This function defines extra outputs that will be recorded when generating actions with the policy. For example, this enables saving the raw policy logits in the experience batch, which e.g. means it can be referenced in the PPO loss function via ``batch[BEHAVIOUR_LOGITS]``. Other values such as the current value prediction can also be emitted for debugging or optimization purposes:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
|
|
@ -1,28 +1,29 @@
|
|||
import sys
|
||||
from contextlib import redirect_stdout, redirect_stderr
|
||||
from datetime import datetime
|
||||
|
||||
import copy
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import os
|
||||
import ray.cloudpickle as pickle
|
||||
import platform
|
||||
|
||||
from ray.tune.utils.trainable import TrainableUtil
|
||||
from ray.tune.utils.util import Tee
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from typing import Any, Dict, Union
|
||||
import uuid
|
||||
|
||||
import ray
|
||||
from ray.util.debug import log_once
|
||||
import ray.cloudpickle as pickle
|
||||
from ray.tune.resources import Resources
|
||||
from ray.tune.result import (
|
||||
DEFAULT_RESULTS_DIR, SHOULD_CHECKPOINT, TIME_THIS_ITER_S,
|
||||
TIMESTEPS_THIS_ITER, DONE, TIMESTEPS_TOTAL, EPISODES_THIS_ITER,
|
||||
EPISODES_TOTAL, TRAINING_ITERATION, RESULT_DUPLICATE, TRIAL_INFO,
|
||||
STDOUT_FILE, STDERR_FILE)
|
||||
from ray.tune.utils import UtilMonitor
|
||||
from ray.tune.utils.placement_groups import PlacementGroupFactory
|
||||
from ray.tune.utils.trainable import TrainableUtil
|
||||
from ray.tune.utils.util import Tee
|
||||
from ray.util.debug import log_once
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -106,7 +107,8 @@ class Trainable:
|
|||
self._monitor = UtilMonitor(start=log_sys_usage)
|
||||
|
||||
@classmethod
|
||||
def default_resource_request(cls, config):
|
||||
def default_resource_request(cls, config: Dict[str, Any]) -> \
|
||||
Union[Resources, PlacementGroupFactory]:
|
||||
"""Provides a static resource requirement for the given configuration.
|
||||
|
||||
This can be overridden by sub-classes to set the correct trial resource
|
||||
|
@ -122,8 +124,12 @@ class Trainable:
|
|||
extra_cpu=config["workers"],
|
||||
extra_gpu=int(config["use_gpu"]) * config["workers"])
|
||||
|
||||
Args:
|
||||
config[Dict[str, Any]]: The Trainable's config dict.
|
||||
|
||||
Returns:
|
||||
Resources: A Resources object consumed by Tune for queueing.
|
||||
Union[Resources, PlacementGroupFactory]: A Resources object or
|
||||
PlacementGroupFactory consumed by Tune for queueing.
|
||||
"""
|
||||
return None
|
||||
|
||||
|
|
|
@ -224,14 +224,26 @@ class Trial:
|
|||
if trainable_cls:
|
||||
default_resources = trainable_cls.default_resource_request(
|
||||
self.config)
|
||||
|
||||
# If Trainable returns resources, do not allow manual overrid via
|
||||
# `resources_per_trial` by the user.
|
||||
if default_resources:
|
||||
if resources:
|
||||
if resources or placement_group_factory:
|
||||
raise ValueError(
|
||||
"Resources for {} have been automatically set to {} "
|
||||
"by its `default_resource_request()` method. Please "
|
||||
"clear the `resources_per_trial` option.".format(
|
||||
trainable_cls, default_resources))
|
||||
resources = default_resources
|
||||
|
||||
# New way: Trainable returns a PlacementGroupFactory object.
|
||||
if isinstance(default_resources, PlacementGroupFactory):
|
||||
placement_group_factory = default_resources
|
||||
resources = None
|
||||
# Set placement group factory to None for backwards
|
||||
# compatibility.
|
||||
else:
|
||||
placement_group_factory = None
|
||||
resources = default_resources
|
||||
self.location = Location()
|
||||
|
||||
self.resources = resources or Resources(cpu=1, gpu=0)
|
||||
|
|
|
@ -1429,6 +1429,13 @@ py_test(
|
|||
srcs = ["tests/test_pettingzoo_env.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tests/test_placement_groups",
|
||||
tags = ["tests_dir", "tests_dir_P"],
|
||||
size = "medium",
|
||||
srcs = ["tests/test_placement_groups.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tests/test_reproducibility",
|
||||
tags = ["tests_dir", "tests_dir_R"],
|
||||
|
|
|
@ -15,7 +15,7 @@ from ray.rllib.execution.concurrency_ops import Concurrently, Enqueue, Dequeue
|
|||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.resources import Resources
|
||||
from ray.tune.utils.placement_groups import PlacementGroupFactory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -68,8 +68,11 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"max_sample_requests_in_flight_per_worker": 2,
|
||||
# max number of workers to broadcast one set of weights to
|
||||
"broadcast_interval": 1,
|
||||
# use intermediate actors for multi-level aggregation. This can make sense
|
||||
# if ingesting >2GB/s of samples, or if the data requires decompression.
|
||||
# Use n (`num_aggregation_workers`) extra Actors for multi-level
|
||||
# aggregation of the data produced by the m RolloutWorkers
|
||||
# (`num_workers`). Note that n should be much smaller than m.
|
||||
# This can make sense if ingesting >2GB/s of samples, or if
|
||||
# the data requires decompression.
|
||||
"num_aggregation_workers": 0,
|
||||
|
||||
# Learning params.
|
||||
|
@ -101,17 +104,40 @@ class OverrideDefaultResourceRequest:
|
|||
def default_resource_request(cls, config):
|
||||
cf = dict(cls._default_config, **config)
|
||||
Trainer._validate_config(cf)
|
||||
return Resources(
|
||||
cpu=cf["num_cpus_for_driver"],
|
||||
gpu=cf["num_gpus"],
|
||||
memory=cf["memory"],
|
||||
object_store_memory=cf["object_store_memory"],
|
||||
extra_cpu=cf["num_cpus_per_worker"] * cf["num_workers"] +
|
||||
cf["num_aggregation_workers"],
|
||||
extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"],
|
||||
extra_memory=cf["memory_per_worker"] * cf["num_workers"],
|
||||
extra_object_store_memory=cf["object_store_memory_per_worker"] *
|
||||
cf["num_workers"])
|
||||
|
||||
eval_config = cf["evaluation_config"]
|
||||
|
||||
# Return PlacementGroupFactory containing all needed resources
|
||||
# (already properly defined as device bundles).
|
||||
return PlacementGroupFactory(
|
||||
bundles=[{
|
||||
# Driver + Aggregation Workers:
|
||||
# Force to be on same node to maximize data bandwidth
|
||||
# between aggregation workers and the learner (driver).
|
||||
# Aggregation workers tree-aggregate experiences collected
|
||||
# from RolloutWorkers (n rollout workers map to m
|
||||
# aggregation workers, where m < n) and always use 1 CPU
|
||||
# each.
|
||||
"CPU": cf["num_cpus_for_driver"] +
|
||||
cf["num_aggregation_workers"],
|
||||
"GPU": cf["num_gpus"]
|
||||
}] + [
|
||||
{
|
||||
# RolloutWorkers.
|
||||
"CPU": cf["num_cpus_per_worker"],
|
||||
"GPU": cf["num_gpus_per_worker"],
|
||||
} for _ in range(cf["num_workers"])
|
||||
] + ([
|
||||
{
|
||||
# Evaluation workers (+1 b/c of the additional local
|
||||
# worker)
|
||||
"CPU": eval_config.get("num_cpus_per_worker",
|
||||
cf["num_cpus_per_worker"]),
|
||||
"GPU": eval_config.get("num_gpus_per_worker",
|
||||
cf["num_gpus_per_worker"]),
|
||||
} for _ in range(cf["evaluation_num_workers"] + 1)
|
||||
] if cf["evaluation_interval"] else []),
|
||||
strategy=config.get("placement_strategy", "PACK"))
|
||||
|
||||
|
||||
def make_learner_thread(local_worker, config):
|
||||
|
@ -172,6 +198,17 @@ def validate_config(config):
|
|||
raise ValueError(
|
||||
"Must use `batch_mode`=truncate_episodes if `vtrace` is True.")
|
||||
|
||||
# Check whether worker to aggregation-worker ratio makes sense.
|
||||
if config["num_aggregation_workers"] > config["num_workers"]:
|
||||
raise ValueError(
|
||||
"`num_aggregation_workers` must be smaller than or equal "
|
||||
"`num_workers`! Aggregation makes no sense otherwise.")
|
||||
elif config["num_aggregation_workers"] > \
|
||||
config["num_workers"] / 2:
|
||||
logger.warning(
|
||||
"`num_aggregation_workers` should be significantly smaller than"
|
||||
"`num_workers`! Try setting it to 0.5*`num_workers` or less.")
|
||||
|
||||
|
||||
# Update worker weights as they finish generating experiences.
|
||||
class BroadcastUpdateLearnerWeights:
|
||||
|
|
|
@ -43,7 +43,7 @@ class TestIMPALA(unittest.TestCase):
|
|||
local_cfg["model"]["use_lstm"] = True
|
||||
local_cfg["model"]["lstm_use_prev_action"] = True
|
||||
local_cfg["model"]["lstm_use_prev_reward"] = True
|
||||
local_cfg["num_aggregation_workers"] = 2
|
||||
local_cfg["num_aggregation_workers"] = 1
|
||||
trainer = impala.ImpalaTrainer(config=local_cfg, env=env)
|
||||
for i in range(num_iterations):
|
||||
print(trainer.train())
|
||||
|
|
|
@ -56,7 +56,7 @@ class SimpleEnv(Env):
|
|||
class TestSAC(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
ray.init(local_mode=True)
|
||||
ray.init()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls) -> None:
|
||||
|
|
|
@ -30,12 +30,13 @@ from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE
|
|||
from ray.rllib.utils.from_config import from_config
|
||||
from ray.rllib.utils.typing import TrainerConfigDict, \
|
||||
PartialTrainerConfigDict, EnvInfoDict, ResultDict, EnvType, PolicyID
|
||||
from ray.tune.logger import Logger, UnifiedLogger
|
||||
from ray.tune.registry import ENV_CREATOR, register_env, _global_registry
|
||||
from ray.tune.resources import Resources
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.trial import ExportFormat
|
||||
from ray.tune.resources import Resources
|
||||
from ray.tune.logger import Logger, UnifiedLogger
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
from ray.tune.utils.placement_groups import PlacementGroupFactory
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
|
@ -92,10 +93,6 @@ COMMON_CONFIG: TrainerConfigDict = {
|
|||
"batch_mode": "truncate_episodes",
|
||||
|
||||
# === Settings for the Trainer process ===
|
||||
# Number of GPUs to allocate to the trainer process. Note that not all
|
||||
# algorithms can take advantage of trainer GPUs. This can be fractional
|
||||
# (e.g., 0.3 GPUs).
|
||||
"num_gpus": 0,
|
||||
# Training batch size, if applicable. Should be >= rollout_fragment_length.
|
||||
# Samples batches will be concatenated together to a batch of this size,
|
||||
# which is then passed to SGD.
|
||||
|
@ -304,7 +301,11 @@ COMMON_CONFIG: TrainerConfigDict = {
|
|||
# The extra python environments need to set for worker processes.
|
||||
"extra_python_environs_for_worker": {},
|
||||
|
||||
# === Advanced Resource Settings ===
|
||||
# === Resource Settings ===
|
||||
# Number of GPUs to allocate to the trainer process. Note that not all
|
||||
# algorithms can take advantage of trainer GPUs. This can be fractional
|
||||
# (e.g., 0.3 GPUs).
|
||||
"num_gpus": 0,
|
||||
# Number of CPUs to allocate per worker.
|
||||
"num_cpus_per_worker": 1,
|
||||
# Number of GPUs to allocate per worker. This can be fractional. This is
|
||||
|
@ -316,11 +317,21 @@ COMMON_CONFIG: TrainerConfigDict = {
|
|||
# Number of CPUs to allocate for the trainer. Note: this only takes effect
|
||||
# when running in Tune. Otherwise, the trainer runs in the main program.
|
||||
"num_cpus_for_driver": 1,
|
||||
# Deprecated.
|
||||
"memory": 0,
|
||||
"object_store_memory": 0,
|
||||
"memory_per_worker": 0,
|
||||
"object_store_memory_per_worker": 0,
|
||||
# The strategy for the placement group factory returned by
|
||||
# `Trainer.default_resource_request()`. A PlacementGroup defines, which
|
||||
# devices (resources) should always be co-located on the same node.
|
||||
# For example, a Trainer with 2 rollout workers, running with
|
||||
# num_gpus=1 will request a placement group with the bundles:
|
||||
# [{"gpu": 1, "cpu": 1}, {"cpu": 1}, {"cpu": 1}], where the first bundle is
|
||||
# for the driver and the other 2 bundles are for the two workers.
|
||||
# These bundles can now be "placed" on the same or different
|
||||
# nodes depending on the value of `placement_strategy`:
|
||||
# "PACK": Packs bundles into as few nodes as possible.
|
||||
# "SPREAD": Places bundles across distinct nodes as even as possible.
|
||||
# "STRICT_PACK": Packs bundles into one node. The group is not allowed
|
||||
# to span multiple nodes.
|
||||
# "STRICT_SPREAD": Packs bundles across distinct nodes.
|
||||
"placement_strategy": "PACK",
|
||||
|
||||
# === Offline Datasets ===
|
||||
# Specify how to generate experiences:
|
||||
|
@ -393,6 +404,12 @@ COMMON_CONFIG: TrainerConfigDict = {
|
|||
# Define logger-specific configuration to be used inside Logger
|
||||
# Default value None allows overwriting with nested dicts
|
||||
"logger_config": None,
|
||||
|
||||
# Deprecated keys.
|
||||
"memory": 0,
|
||||
"object_store_memory": 0,
|
||||
"memory_per_worker": 0,
|
||||
"object_store_memory_per_worker": 0,
|
||||
}
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
@ -492,21 +509,37 @@ class Trainer(Trainable):
|
|||
@classmethod
|
||||
@override(Trainable)
|
||||
def default_resource_request(
|
||||
cls, config: PartialTrainerConfigDict) -> Resources:
|
||||
cls, config: PartialTrainerConfigDict) -> \
|
||||
Union[Resources, PlacementGroupFactory]:
|
||||
cf = dict(cls._default_config, **config)
|
||||
Trainer._validate_config(cf)
|
||||
num_workers = cf["num_workers"] + cf["evaluation_num_workers"]
|
||||
# TODO(ekl): add custom resources here once tune supports them
|
||||
return Resources(
|
||||
cpu=cf["num_cpus_for_driver"],
|
||||
gpu=cf["num_gpus"],
|
||||
memory=cf["memory"],
|
||||
object_store_memory=cf["object_store_memory"],
|
||||
extra_cpu=cf["num_cpus_per_worker"] * num_workers,
|
||||
extra_gpu=cf["num_gpus_per_worker"] * num_workers,
|
||||
extra_memory=cf["memory_per_worker"] * num_workers,
|
||||
extra_object_store_memory=cf["object_store_memory_per_worker"] *
|
||||
num_workers)
|
||||
|
||||
eval_config = cf["evaluation_config"]
|
||||
|
||||
# Return PlacementGroupFactory containing all needed resources
|
||||
# (already properly defined as device bundles).
|
||||
return PlacementGroupFactory(
|
||||
bundles=[{
|
||||
# Driver.
|
||||
"CPU": cf["num_cpus_for_driver"],
|
||||
"GPU": cf["num_gpus"]
|
||||
}] + [
|
||||
{
|
||||
# RolloutWorkers.
|
||||
"CPU": cf["num_cpus_per_worker"],
|
||||
"GPU": cf["num_gpus_per_worker"]
|
||||
} for _ in range(cf["num_workers"])
|
||||
] + ([
|
||||
{
|
||||
# Evaluation workers (+1 b/c of the additional local
|
||||
# worker)
|
||||
"CPU": eval_config.get("num_cpus_per_worker",
|
||||
cf["num_cpus_per_worker"]),
|
||||
"GPU": eval_config.get("num_gpus_per_worker",
|
||||
cf["num_gpus_per_worker"]),
|
||||
} for _ in range(cf["evaluation_num_workers"] + 1)
|
||||
] if cf["evaluation_interval"] else []),
|
||||
strategy=config.get("placement_strategy", "PACK"))
|
||||
|
||||
@override(Trainable)
|
||||
@PublicAPI
|
||||
|
@ -1092,6 +1125,15 @@ class Trainer(Trainable):
|
|||
if model_config is None:
|
||||
config["model"] = model_config = {}
|
||||
|
||||
if config.get("memory", 0) != 0:
|
||||
deprecation_warning(old="memory")
|
||||
if config.get("object_store_memory", 0) != 0:
|
||||
deprecation_warning(old="object_store_memory")
|
||||
if config.get("memory_per_worker", 0) != 0:
|
||||
deprecation_warning(old="memory_per_worker")
|
||||
if config.get("object_store_memory_per_worker", 0) != 0:
|
||||
deprecation_warning(old="object_store_memory_per_worker")
|
||||
|
||||
if not config.get("_use_trajectory_view_api"):
|
||||
traj_view_framestacks = model_config.get("num_framestacks", "auto")
|
||||
if model_config.get("_time_major"):
|
||||
|
|
141
rllib/tests/test_placement_groups.py
Normal file
141
rllib/tests/test_placement_groups.py
Normal file
|
@ -0,0 +1,141 @@
|
|||
import unittest
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.tune import Callback
|
||||
from ray.rllib.agents.pg import PGTrainer, DEFAULT_CONFIG
|
||||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.utils.placement_groups import PlacementGroupFactory
|
||||
from ray.util import placement_group_table
|
||||
|
||||
trial_executor = None
|
||||
|
||||
|
||||
class _TestCallback(Callback):
|
||||
def on_step_end(self, iteration, trials, **info):
|
||||
num_finished = len([
|
||||
t for t in trials
|
||||
if t.status == Trial.TERMINATED or t.status == Trial.ERROR
|
||||
])
|
||||
num_running = len([t for t in trials if t.status == Trial.RUNNING])
|
||||
|
||||
num_staging = sum(
|
||||
len(s) for s in trial_executor._pg_manager._staging.values())
|
||||
num_ready = sum(
|
||||
len(s) for s in trial_executor._pg_manager._ready.values())
|
||||
num_in_use = len(trial_executor._pg_manager._in_use_pgs)
|
||||
num_cached = len(trial_executor._pg_manager._cached_pgs)
|
||||
|
||||
total_num_tracked = num_staging + num_ready + \
|
||||
num_in_use + num_cached
|
||||
|
||||
num_non_removed_pgs = len([
|
||||
p for pid, p in placement_group_table().items()
|
||||
if p["state"] != "REMOVED"
|
||||
])
|
||||
num_removal_scheduled_pgs = len(
|
||||
trial_executor._pg_manager._pgs_for_removal)
|
||||
|
||||
# All 3 trials (3 different learning rates) should be scheduled.
|
||||
assert 3 == min(3, len(trials))
|
||||
# Cannot run more than 2 at a time
|
||||
# (due to different resource restrictions in the test cases).
|
||||
assert num_running <= 2
|
||||
# The number of placement groups should decrease
|
||||
# when trials finish.
|
||||
assert max(3, len(trials)) - num_finished == total_num_tracked
|
||||
# The number of actual placement groups should match this.
|
||||
assert max(3, len(trials)) - num_finished == \
|
||||
num_non_removed_pgs - num_removal_scheduled_pgs
|
||||
|
||||
|
||||
class TestPlacementGroups(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
ray.init(num_cpus=6)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
ray.shutdown()
|
||||
|
||||
def test_overriding_default_resource_request(self):
|
||||
config = DEFAULT_CONFIG.copy()
|
||||
config["model"]["fcnet_hiddens"] = [10]
|
||||
config["num_workers"] = 2
|
||||
# 3 Trials: Can only run 2 at a time (num_cpus=6; needed: 3).
|
||||
config["lr"] = tune.grid_search([0.1, 0.01, 0.001])
|
||||
config["env"] = "CartPole-v0"
|
||||
config["framework"] = "tf"
|
||||
|
||||
class DefaultResourceRequest:
|
||||
@classmethod
|
||||
def default_resource_request(cls, config):
|
||||
head_bundle = {"CPU": 1, "GPU": 0}
|
||||
child_bundle = {"CPU": 1}
|
||||
return PlacementGroupFactory(
|
||||
[head_bundle, child_bundle, child_bundle],
|
||||
strategy=config["placement_strategy"])
|
||||
|
||||
# Create a trainer with an overridden default_resource_request
|
||||
# method that returns a PlacementGroupFactory.
|
||||
MyTrainer = PGTrainer.with_updates(mixins=[DefaultResourceRequest])
|
||||
tune.register_trainable("my_trainable", MyTrainer)
|
||||
|
||||
global trial_executor
|
||||
trial_executor = RayTrialExecutor(reuse_actors=False)
|
||||
|
||||
tune.run(
|
||||
"my_trainable",
|
||||
config=config,
|
||||
stop={"training_iteration": 2},
|
||||
trial_executor=trial_executor,
|
||||
callbacks=[_TestCallback()],
|
||||
verbose=2,
|
||||
)
|
||||
|
||||
def test_default_resource_request(self):
|
||||
config = DEFAULT_CONFIG.copy()
|
||||
config["model"]["fcnet_hiddens"] = [10]
|
||||
config["num_workers"] = 2
|
||||
config["num_cpus_per_worker"] = 2
|
||||
# 3 Trials: Can only run 1 at a time (num_cpus=6; needed: 5).
|
||||
config["lr"] = tune.grid_search([0.1, 0.01, 0.001])
|
||||
config["env"] = "CartPole-v0"
|
||||
config["framework"] = "torch"
|
||||
config["placement_strategy"] = "SPREAD"
|
||||
|
||||
global trial_executor
|
||||
trial_executor = RayTrialExecutor(reuse_actors=False)
|
||||
|
||||
tune.run(
|
||||
"PG",
|
||||
config=config,
|
||||
stop={"training_iteration": 2},
|
||||
trial_executor=trial_executor,
|
||||
callbacks=[_TestCallback()],
|
||||
verbose=2,
|
||||
)
|
||||
|
||||
def test_default_resource_request_plus_manual_leads_to_error(self):
|
||||
config = DEFAULT_CONFIG.copy()
|
||||
config["model"]["fcnet_hiddens"] = [10]
|
||||
config["num_workers"] = 0
|
||||
config["env"] = "CartPole-v0"
|
||||
|
||||
try:
|
||||
tune.run(
|
||||
"PG",
|
||||
config=config,
|
||||
stop={"training_iteration": 2},
|
||||
resources_per_trial=PlacementGroupFactory([{
|
||||
"CPU": 1
|
||||
}]),
|
||||
verbose=2,
|
||||
)
|
||||
except ValueError as e:
|
||||
assert "have been automatically set to" in e.args[0]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
Loading…
Add table
Reference in a new issue