From 87fe033f7bd6e3f721cad19e7be041b97fb43269 Mon Sep 17 00:00:00 2001 From: Jun Gong Date: Wed, 2 Feb 2022 01:20:37 -0800 Subject: [PATCH] [RLlib] Request CPU resources in `Trainer.default_resource_request()` if using dataset input. (#21948) --- rllib/agents/trainer.py | 6 ++- rllib/evaluation/worker_set.py | 41 +-------------------- rllib/offline/__init__.py | 5 ++- rllib/offline/dataset_reader.py | 65 ++++++++++++++++++++++++++++++++- rllib/offline/resource.py | 16 ++++++++ 5 files changed, 91 insertions(+), 42 deletions(-) create mode 100644 rllib/offline/resource.py diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index 76344db6f..27e171740 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -55,6 +55,7 @@ from ray.rllib.execution.train_ops import ( multi_gpu_train_one_step, ) from ray.rllib.models import MODEL_DEFAULTS +from ray.rllib.offline import get_offline_io_resource_bundles from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch from ray.rllib.utils import deep_update, FilterManager, merge_dicts @@ -2071,7 +2072,10 @@ class Trainer(Trainable): ] if cf["evaluation_interval"] else [] - ), + ) + + + # In case our I/O reader/writer requires conmpute resources. + get_offline_io_resource_bundles(cf), strategy=config.get("placement_strategy", "PACK"), ) diff --git a/rllib/evaluation/worker_set.py b/rllib/evaluation/worker_set.py index ef0102348..9999a4eed 100644 --- a/rllib/evaluation/worker_set.py +++ b/rllib/evaluation/worker_set.py @@ -5,7 +5,6 @@ from types import FunctionType from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union import ray -from ray import data from ray.actor import ActorHandle from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.env.base_env import BaseEnv @@ -19,6 +18,7 @@ from ray.rllib.offline import ( D4RLReader, DatasetReader, DatasetWriter, + get_dataset_and_shards, ) from ray.rllib.policy.policy import Policy, PolicySpec from ray.rllib.utils import merge_dicts @@ -106,7 +106,7 @@ class WorkerSet: if trainer_config["input"] == "dataset": # Create the set of dataset readers to be shared by all the # rollout workers. - self._ds, self._ds_shards = self._get_dataset_and_shards( + self._ds, self._ds_shards = get_dataset_and_shards( trainer_config, num_workers, local_worker ) else: @@ -438,43 +438,6 @@ class WorkerSet: workers._remote_workers = remote_workers or [] return workers - def _get_dataset_and_shards( - self, config: TrainerConfigDict, num_workers: int, local_worker: bool - ) -> (ray.data.dataset.Dataset, List[ray.data.dataset.Dataset]): - assert config["input"] == "dataset" - assert ( - "input_config" in config - ), "Must specify input_config dict if using Dataset input." - - input_config = config["input_config"] - if not input_config.get("format", None) or not input_config.get("path", None): - raise ValueError( - "Must specify format and path via input_config key" - " when using Ray dataset input." - ) - - format = input_config["format"] - path = input_config["path"] - if format == "json": - dataset = data.read_json(path) - elif format == "parquet": - dataset = data.read_parquet(path) - else: - raise ValueError("Un-supported Ray dataset format: ", format) - - # Local worker will be responsible for sampling. - if local_worker and num_workers == 0: - # Dataset is the only shard we need. - return dataset, [dataset] - # Remote workers are responsible for sampling: - else: - # Each remote worker gets 1 shard. - # The first None shard is for the local worker, which - # shouldn't be doing rollout work anyways. - return dataset, [None] + dataset.repartition( - num_blocks=num_workers, shuffle=False - ).split(num_workers) - def _make_worker( self, *, diff --git a/rllib/offline/__init__.py b/rllib/offline/__init__.py index c6de093ac..33f917d66 100644 --- a/rllib/offline/__init__.py +++ b/rllib/offline/__init__.py @@ -1,5 +1,5 @@ from ray.rllib.offline.d4rl_reader import D4RLReader -from ray.rllib.offline.dataset_reader import DatasetReader +from ray.rllib.offline.dataset_reader import DatasetReader, get_dataset_and_shards from ray.rllib.offline.dataset_writer import DatasetWriter from ray.rllib.offline.io_context import IOContext from ray.rllib.offline.input_reader import InputReader @@ -7,6 +7,7 @@ from ray.rllib.offline.mixed_input import MixedInput from ray.rllib.offline.json_reader import JsonReader from ray.rllib.offline.json_writer import JsonWriter from ray.rllib.offline.output_writer import OutputWriter, NoopOutput +from ray.rllib.offline.resource import get_offline_io_resource_bundles from ray.rllib.offline.shuffled_input import ShuffledInput __all__ = [ @@ -21,4 +22,6 @@ __all__ = [ "D4RLReader", "DatasetReader", "DatasetWriter", + "get_dataset_and_shards", + "get_offline_io_resource_bundles", ] diff --git a/rllib/offline/dataset_reader.py b/rllib/offline/dataset_reader.py index 463192346..864ffb296 100644 --- a/rllib/offline/dataset_reader.py +++ b/rllib/offline/dataset_reader.py @@ -1,14 +1,74 @@ import logging +import math import ray.data from ray.rllib.offline.input_reader import InputReader from ray.rllib.offline.io_context import IOContext from ray.rllib.offline.json_reader import from_json_data from ray.rllib.utils.annotations import override, PublicAPI -from ray.rllib.utils.typing import SampleBatchType +from ray.rllib.utils.typing import SampleBatchType, TrainerConfigDict +from typing import List logger = logging.getLogger(__name__) +DEFAULT_NUM_CPUS_PER_TASK = 0.5 + + +def get_resource_bundles(config: TrainerConfigDict): + input_config = config.get("input_config", {}) + parallelism = input_config.get("parallelism", config.get("num_workers", 1)) + cpus_per_task = input_config.get( + "num_cpus_per_read_task", DEFAULT_NUM_CPUS_PER_TASK + ) + return [{"CPU": math.ceil(parallelism * cpus_per_task)}] + + +def get_dataset_and_shards( + config: TrainerConfigDict, num_workers: int, local_worker: bool +) -> (ray.data.dataset.Dataset, List[ray.data.dataset.Dataset]): + assert config["input"] == "dataset" + assert ( + "input_config" in config + ), "Must specify input_config dict if using Dataset input." + + input_config = config["input_config"] + if not input_config.get("format", None) or not input_config.get("path", None): + raise ValueError( + "Must specify format and path via input_config key" + " when using Ray dataset input." + ) + + parallelism = input_config.get("parallelism", num_workers) + cpus_per_task = input_config.get( + "num_cpus_per_read_task", DEFAULT_NUM_CPUS_PER_TASK + ) + + format = input_config["format"] + path = input_config["path"] + if format == "json": + dataset = ray.data.read_json( + path, parallelism=parallelism, ray_remote_args={"num_cpus": cpus_per_task} + ) + elif format == "parquet": + dataset = ray.data.read_parquet( + path, parallelism=parallelism, ray_remote_args={"num_cpus": cpus_per_task} + ) + else: + raise ValueError("Un-supported Ray dataset format: ", format) + + # Local worker will be responsible for sampling. + if local_worker and num_workers == 0: + # Dataset is the only shard we need. + return dataset, [dataset] + # Remote workers are responsible for sampling: + else: + # Each remote worker gets 1 shard. + # The first None shard is for the local worker, which + # shouldn't be doing rollout work anyways. + return dataset, [None] + dataset.repartition( + num_blocks=num_workers, shuffle=False + ).split(num_workers) + @PublicAPI class DatasetReader(InputReader): @@ -20,6 +80,9 @@ class DatasetReader(InputReader): "input_config"={ "format": "json", "path": "/tmp/sample_batches/", + # By default, parallelism=num_workers. + "parallelism": 3, + "num_cpus_per_read_task": 0.5, } } diff --git a/rllib/offline/resource.py b/rllib/offline/resource.py new file mode 100644 index 000000000..d176395e4 --- /dev/null +++ b/rllib/offline/resource.py @@ -0,0 +1,16 @@ +from ray.rllib.offline.dataset_reader import ( + get_resource_bundles as dataset_reader_get_resource_bundles, +) +from ray.rllib.utils.typing import PartialTrainerConfigDict +from typing import Dict, List + + +def get_offline_io_resource_bundles( + config: PartialTrainerConfigDict, +) -> List[Dict[str, float]]: + # DatasetReader is the only offline I/O component today that + # requires compute resources. + if config["input"] == "dataset": + return dataset_reader_get_resource_bundles(config["input_config"]) + else: + return []