mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[RLlib] Request CPU resources in Trainer.default_resource_request()
if using dataset input. (#21948)
This commit is contained in:
parent
a55258eb9c
commit
87fe033f7b
5 changed files with 91 additions and 42 deletions
|
@ -55,6 +55,7 @@ from ray.rllib.execution.train_ops import (
|
||||||
multi_gpu_train_one_step,
|
multi_gpu_train_one_step,
|
||||||
)
|
)
|
||||||
from ray.rllib.models import MODEL_DEFAULTS
|
from ray.rllib.models import MODEL_DEFAULTS
|
||||||
|
from ray.rllib.offline import get_offline_io_resource_bundles
|
||||||
from ray.rllib.policy.policy import Policy
|
from ray.rllib.policy.policy import Policy
|
||||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
|
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
|
||||||
from ray.rllib.utils import deep_update, FilterManager, merge_dicts
|
from ray.rllib.utils import deep_update, FilterManager, merge_dicts
|
||||||
|
@ -2071,7 +2072,10 @@ class Trainer(Trainable):
|
||||||
]
|
]
|
||||||
if cf["evaluation_interval"]
|
if cf["evaluation_interval"]
|
||||||
else []
|
else []
|
||||||
),
|
)
|
||||||
|
+
|
||||||
|
# In case our I/O reader/writer requires conmpute resources.
|
||||||
|
get_offline_io_resource_bundles(cf),
|
||||||
strategy=config.get("placement_strategy", "PACK"),
|
strategy=config.get("placement_strategy", "PACK"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,6 @@ from types import FunctionType
|
||||||
from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray import data
|
|
||||||
from ray.actor import ActorHandle
|
from ray.actor import ActorHandle
|
||||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||||
from ray.rllib.env.base_env import BaseEnv
|
from ray.rllib.env.base_env import BaseEnv
|
||||||
|
@ -19,6 +18,7 @@ from ray.rllib.offline import (
|
||||||
D4RLReader,
|
D4RLReader,
|
||||||
DatasetReader,
|
DatasetReader,
|
||||||
DatasetWriter,
|
DatasetWriter,
|
||||||
|
get_dataset_and_shards,
|
||||||
)
|
)
|
||||||
from ray.rllib.policy.policy import Policy, PolicySpec
|
from ray.rllib.policy.policy import Policy, PolicySpec
|
||||||
from ray.rllib.utils import merge_dicts
|
from ray.rllib.utils import merge_dicts
|
||||||
|
@ -106,7 +106,7 @@ class WorkerSet:
|
||||||
if trainer_config["input"] == "dataset":
|
if trainer_config["input"] == "dataset":
|
||||||
# Create the set of dataset readers to be shared by all the
|
# Create the set of dataset readers to be shared by all the
|
||||||
# rollout workers.
|
# rollout workers.
|
||||||
self._ds, self._ds_shards = self._get_dataset_and_shards(
|
self._ds, self._ds_shards = get_dataset_and_shards(
|
||||||
trainer_config, num_workers, local_worker
|
trainer_config, num_workers, local_worker
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -438,43 +438,6 @@ class WorkerSet:
|
||||||
workers._remote_workers = remote_workers or []
|
workers._remote_workers = remote_workers or []
|
||||||
return workers
|
return workers
|
||||||
|
|
||||||
def _get_dataset_and_shards(
|
|
||||||
self, config: TrainerConfigDict, num_workers: int, local_worker: bool
|
|
||||||
) -> (ray.data.dataset.Dataset, List[ray.data.dataset.Dataset]):
|
|
||||||
assert config["input"] == "dataset"
|
|
||||||
assert (
|
|
||||||
"input_config" in config
|
|
||||||
), "Must specify input_config dict if using Dataset input."
|
|
||||||
|
|
||||||
input_config = config["input_config"]
|
|
||||||
if not input_config.get("format", None) or not input_config.get("path", None):
|
|
||||||
raise ValueError(
|
|
||||||
"Must specify format and path via input_config key"
|
|
||||||
" when using Ray dataset input."
|
|
||||||
)
|
|
||||||
|
|
||||||
format = input_config["format"]
|
|
||||||
path = input_config["path"]
|
|
||||||
if format == "json":
|
|
||||||
dataset = data.read_json(path)
|
|
||||||
elif format == "parquet":
|
|
||||||
dataset = data.read_parquet(path)
|
|
||||||
else:
|
|
||||||
raise ValueError("Un-supported Ray dataset format: ", format)
|
|
||||||
|
|
||||||
# Local worker will be responsible for sampling.
|
|
||||||
if local_worker and num_workers == 0:
|
|
||||||
# Dataset is the only shard we need.
|
|
||||||
return dataset, [dataset]
|
|
||||||
# Remote workers are responsible for sampling:
|
|
||||||
else:
|
|
||||||
# Each remote worker gets 1 shard.
|
|
||||||
# The first None shard is for the local worker, which
|
|
||||||
# shouldn't be doing rollout work anyways.
|
|
||||||
return dataset, [None] + dataset.repartition(
|
|
||||||
num_blocks=num_workers, shuffle=False
|
|
||||||
).split(num_workers)
|
|
||||||
|
|
||||||
def _make_worker(
|
def _make_worker(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from ray.rllib.offline.d4rl_reader import D4RLReader
|
from ray.rllib.offline.d4rl_reader import D4RLReader
|
||||||
from ray.rllib.offline.dataset_reader import DatasetReader
|
from ray.rllib.offline.dataset_reader import DatasetReader, get_dataset_and_shards
|
||||||
from ray.rllib.offline.dataset_writer import DatasetWriter
|
from ray.rllib.offline.dataset_writer import DatasetWriter
|
||||||
from ray.rllib.offline.io_context import IOContext
|
from ray.rllib.offline.io_context import IOContext
|
||||||
from ray.rllib.offline.input_reader import InputReader
|
from ray.rllib.offline.input_reader import InputReader
|
||||||
|
@ -7,6 +7,7 @@ from ray.rllib.offline.mixed_input import MixedInput
|
||||||
from ray.rllib.offline.json_reader import JsonReader
|
from ray.rllib.offline.json_reader import JsonReader
|
||||||
from ray.rllib.offline.json_writer import JsonWriter
|
from ray.rllib.offline.json_writer import JsonWriter
|
||||||
from ray.rllib.offline.output_writer import OutputWriter, NoopOutput
|
from ray.rllib.offline.output_writer import OutputWriter, NoopOutput
|
||||||
|
from ray.rllib.offline.resource import get_offline_io_resource_bundles
|
||||||
from ray.rllib.offline.shuffled_input import ShuffledInput
|
from ray.rllib.offline.shuffled_input import ShuffledInput
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
@ -21,4 +22,6 @@ __all__ = [
|
||||||
"D4RLReader",
|
"D4RLReader",
|
||||||
"DatasetReader",
|
"DatasetReader",
|
||||||
"DatasetWriter",
|
"DatasetWriter",
|
||||||
|
"get_dataset_and_shards",
|
||||||
|
"get_offline_io_resource_bundles",
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,14 +1,74 @@
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
|
|
||||||
import ray.data
|
import ray.data
|
||||||
from ray.rllib.offline.input_reader import InputReader
|
from ray.rllib.offline.input_reader import InputReader
|
||||||
from ray.rllib.offline.io_context import IOContext
|
from ray.rllib.offline.io_context import IOContext
|
||||||
from ray.rllib.offline.json_reader import from_json_data
|
from ray.rllib.offline.json_reader import from_json_data
|
||||||
from ray.rllib.utils.annotations import override, PublicAPI
|
from ray.rllib.utils.annotations import override, PublicAPI
|
||||||
from ray.rllib.utils.typing import SampleBatchType
|
from ray.rllib.utils.typing import SampleBatchType, TrainerConfigDict
|
||||||
|
from typing import List
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_NUM_CPUS_PER_TASK = 0.5
|
||||||
|
|
||||||
|
|
||||||
|
def get_resource_bundles(config: TrainerConfigDict):
|
||||||
|
input_config = config.get("input_config", {})
|
||||||
|
parallelism = input_config.get("parallelism", config.get("num_workers", 1))
|
||||||
|
cpus_per_task = input_config.get(
|
||||||
|
"num_cpus_per_read_task", DEFAULT_NUM_CPUS_PER_TASK
|
||||||
|
)
|
||||||
|
return [{"CPU": math.ceil(parallelism * cpus_per_task)}]
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataset_and_shards(
|
||||||
|
config: TrainerConfigDict, num_workers: int, local_worker: bool
|
||||||
|
) -> (ray.data.dataset.Dataset, List[ray.data.dataset.Dataset]):
|
||||||
|
assert config["input"] == "dataset"
|
||||||
|
assert (
|
||||||
|
"input_config" in config
|
||||||
|
), "Must specify input_config dict if using Dataset input."
|
||||||
|
|
||||||
|
input_config = config["input_config"]
|
||||||
|
if not input_config.get("format", None) or not input_config.get("path", None):
|
||||||
|
raise ValueError(
|
||||||
|
"Must specify format and path via input_config key"
|
||||||
|
" when using Ray dataset input."
|
||||||
|
)
|
||||||
|
|
||||||
|
parallelism = input_config.get("parallelism", num_workers)
|
||||||
|
cpus_per_task = input_config.get(
|
||||||
|
"num_cpus_per_read_task", DEFAULT_NUM_CPUS_PER_TASK
|
||||||
|
)
|
||||||
|
|
||||||
|
format = input_config["format"]
|
||||||
|
path = input_config["path"]
|
||||||
|
if format == "json":
|
||||||
|
dataset = ray.data.read_json(
|
||||||
|
path, parallelism=parallelism, ray_remote_args={"num_cpus": cpus_per_task}
|
||||||
|
)
|
||||||
|
elif format == "parquet":
|
||||||
|
dataset = ray.data.read_parquet(
|
||||||
|
path, parallelism=parallelism, ray_remote_args={"num_cpus": cpus_per_task}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Un-supported Ray dataset format: ", format)
|
||||||
|
|
||||||
|
# Local worker will be responsible for sampling.
|
||||||
|
if local_worker and num_workers == 0:
|
||||||
|
# Dataset is the only shard we need.
|
||||||
|
return dataset, [dataset]
|
||||||
|
# Remote workers are responsible for sampling:
|
||||||
|
else:
|
||||||
|
# Each remote worker gets 1 shard.
|
||||||
|
# The first None shard is for the local worker, which
|
||||||
|
# shouldn't be doing rollout work anyways.
|
||||||
|
return dataset, [None] + dataset.repartition(
|
||||||
|
num_blocks=num_workers, shuffle=False
|
||||||
|
).split(num_workers)
|
||||||
|
|
||||||
|
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
class DatasetReader(InputReader):
|
class DatasetReader(InputReader):
|
||||||
|
@ -20,6 +80,9 @@ class DatasetReader(InputReader):
|
||||||
"input_config"={
|
"input_config"={
|
||||||
"format": "json",
|
"format": "json",
|
||||||
"path": "/tmp/sample_batches/",
|
"path": "/tmp/sample_batches/",
|
||||||
|
# By default, parallelism=num_workers.
|
||||||
|
"parallelism": 3,
|
||||||
|
"num_cpus_per_read_task": 0.5,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
16
rllib/offline/resource.py
Normal file
16
rllib/offline/resource.py
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
from ray.rllib.offline.dataset_reader import (
|
||||||
|
get_resource_bundles as dataset_reader_get_resource_bundles,
|
||||||
|
)
|
||||||
|
from ray.rllib.utils.typing import PartialTrainerConfigDict
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
|
||||||
|
def get_offline_io_resource_bundles(
|
||||||
|
config: PartialTrainerConfigDict,
|
||||||
|
) -> List[Dict[str, float]]:
|
||||||
|
# DatasetReader is the only offline I/O component today that
|
||||||
|
# requires compute resources.
|
||||||
|
if config["input"] == "dataset":
|
||||||
|
return dataset_reader_get_resource_bundles(config["input_config"])
|
||||||
|
else:
|
||||||
|
return []
|
Loading…
Add table
Reference in a new issue