[RLlib] Request CPU resources in Trainer.default_resource_request() if using dataset input. (#21948)

This commit is contained in:
Jun Gong 2022-02-02 01:20:37 -08:00 committed by GitHub
parent a55258eb9c
commit 87fe033f7b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 91 additions and 42 deletions

View file

@ -55,6 +55,7 @@ from ray.rllib.execution.train_ops import (
multi_gpu_train_one_step, multi_gpu_train_one_step,
) )
from ray.rllib.models import MODEL_DEFAULTS from ray.rllib.models import MODEL_DEFAULTS
from ray.rllib.offline import get_offline_io_resource_bundles
from ray.rllib.policy.policy import Policy from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
from ray.rllib.utils import deep_update, FilterManager, merge_dicts from ray.rllib.utils import deep_update, FilterManager, merge_dicts
@ -2071,7 +2072,10 @@ class Trainer(Trainable):
] ]
if cf["evaluation_interval"] if cf["evaluation_interval"]
else [] else []
), )
+
# In case our I/O reader/writer requires conmpute resources.
get_offline_io_resource_bundles(cf),
strategy=config.get("placement_strategy", "PACK"), strategy=config.get("placement_strategy", "PACK"),
) )

View file

@ -5,7 +5,6 @@ from types import FunctionType
from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
import ray import ray
from ray import data
from ray.actor import ActorHandle from ray.actor import ActorHandle
from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.env.base_env import BaseEnv from ray.rllib.env.base_env import BaseEnv
@ -19,6 +18,7 @@ from ray.rllib.offline import (
D4RLReader, D4RLReader,
DatasetReader, DatasetReader,
DatasetWriter, DatasetWriter,
get_dataset_and_shards,
) )
from ray.rllib.policy.policy import Policy, PolicySpec from ray.rllib.policy.policy import Policy, PolicySpec
from ray.rllib.utils import merge_dicts from ray.rllib.utils import merge_dicts
@ -106,7 +106,7 @@ class WorkerSet:
if trainer_config["input"] == "dataset": if trainer_config["input"] == "dataset":
# Create the set of dataset readers to be shared by all the # Create the set of dataset readers to be shared by all the
# rollout workers. # rollout workers.
self._ds, self._ds_shards = self._get_dataset_and_shards( self._ds, self._ds_shards = get_dataset_and_shards(
trainer_config, num_workers, local_worker trainer_config, num_workers, local_worker
) )
else: else:
@ -438,43 +438,6 @@ class WorkerSet:
workers._remote_workers = remote_workers or [] workers._remote_workers = remote_workers or []
return workers return workers
def _get_dataset_and_shards(
self, config: TrainerConfigDict, num_workers: int, local_worker: bool
) -> (ray.data.dataset.Dataset, List[ray.data.dataset.Dataset]):
assert config["input"] == "dataset"
assert (
"input_config" in config
), "Must specify input_config dict if using Dataset input."
input_config = config["input_config"]
if not input_config.get("format", None) or not input_config.get("path", None):
raise ValueError(
"Must specify format and path via input_config key"
" when using Ray dataset input."
)
format = input_config["format"]
path = input_config["path"]
if format == "json":
dataset = data.read_json(path)
elif format == "parquet":
dataset = data.read_parquet(path)
else:
raise ValueError("Un-supported Ray dataset format: ", format)
# Local worker will be responsible for sampling.
if local_worker and num_workers == 0:
# Dataset is the only shard we need.
return dataset, [dataset]
# Remote workers are responsible for sampling:
else:
# Each remote worker gets 1 shard.
# The first None shard is for the local worker, which
# shouldn't be doing rollout work anyways.
return dataset, [None] + dataset.repartition(
num_blocks=num_workers, shuffle=False
).split(num_workers)
def _make_worker( def _make_worker(
self, self,
*, *,

View file

@ -1,5 +1,5 @@
from ray.rllib.offline.d4rl_reader import D4RLReader from ray.rllib.offline.d4rl_reader import D4RLReader
from ray.rllib.offline.dataset_reader import DatasetReader from ray.rllib.offline.dataset_reader import DatasetReader, get_dataset_and_shards
from ray.rllib.offline.dataset_writer import DatasetWriter from ray.rllib.offline.dataset_writer import DatasetWriter
from ray.rllib.offline.io_context import IOContext from ray.rllib.offline.io_context import IOContext
from ray.rllib.offline.input_reader import InputReader from ray.rllib.offline.input_reader import InputReader
@ -7,6 +7,7 @@ from ray.rllib.offline.mixed_input import MixedInput
from ray.rllib.offline.json_reader import JsonReader from ray.rllib.offline.json_reader import JsonReader
from ray.rllib.offline.json_writer import JsonWriter from ray.rllib.offline.json_writer import JsonWriter
from ray.rllib.offline.output_writer import OutputWriter, NoopOutput from ray.rllib.offline.output_writer import OutputWriter, NoopOutput
from ray.rllib.offline.resource import get_offline_io_resource_bundles
from ray.rllib.offline.shuffled_input import ShuffledInput from ray.rllib.offline.shuffled_input import ShuffledInput
__all__ = [ __all__ = [
@ -21,4 +22,6 @@ __all__ = [
"D4RLReader", "D4RLReader",
"DatasetReader", "DatasetReader",
"DatasetWriter", "DatasetWriter",
"get_dataset_and_shards",
"get_offline_io_resource_bundles",
] ]

View file

@ -1,14 +1,74 @@
import logging import logging
import math
import ray.data import ray.data
from ray.rllib.offline.input_reader import InputReader from ray.rllib.offline.input_reader import InputReader
from ray.rllib.offline.io_context import IOContext from ray.rllib.offline.io_context import IOContext
from ray.rllib.offline.json_reader import from_json_data from ray.rllib.offline.json_reader import from_json_data
from ray.rllib.utils.annotations import override, PublicAPI from ray.rllib.utils.annotations import override, PublicAPI
from ray.rllib.utils.typing import SampleBatchType from ray.rllib.utils.typing import SampleBatchType, TrainerConfigDict
from typing import List
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_NUM_CPUS_PER_TASK = 0.5
def get_resource_bundles(config: TrainerConfigDict):
input_config = config.get("input_config", {})
parallelism = input_config.get("parallelism", config.get("num_workers", 1))
cpus_per_task = input_config.get(
"num_cpus_per_read_task", DEFAULT_NUM_CPUS_PER_TASK
)
return [{"CPU": math.ceil(parallelism * cpus_per_task)}]
def get_dataset_and_shards(
config: TrainerConfigDict, num_workers: int, local_worker: bool
) -> (ray.data.dataset.Dataset, List[ray.data.dataset.Dataset]):
assert config["input"] == "dataset"
assert (
"input_config" in config
), "Must specify input_config dict if using Dataset input."
input_config = config["input_config"]
if not input_config.get("format", None) or not input_config.get("path", None):
raise ValueError(
"Must specify format and path via input_config key"
" when using Ray dataset input."
)
parallelism = input_config.get("parallelism", num_workers)
cpus_per_task = input_config.get(
"num_cpus_per_read_task", DEFAULT_NUM_CPUS_PER_TASK
)
format = input_config["format"]
path = input_config["path"]
if format == "json":
dataset = ray.data.read_json(
path, parallelism=parallelism, ray_remote_args={"num_cpus": cpus_per_task}
)
elif format == "parquet":
dataset = ray.data.read_parquet(
path, parallelism=parallelism, ray_remote_args={"num_cpus": cpus_per_task}
)
else:
raise ValueError("Un-supported Ray dataset format: ", format)
# Local worker will be responsible for sampling.
if local_worker and num_workers == 0:
# Dataset is the only shard we need.
return dataset, [dataset]
# Remote workers are responsible for sampling:
else:
# Each remote worker gets 1 shard.
# The first None shard is for the local worker, which
# shouldn't be doing rollout work anyways.
return dataset, [None] + dataset.repartition(
num_blocks=num_workers, shuffle=False
).split(num_workers)
@PublicAPI @PublicAPI
class DatasetReader(InputReader): class DatasetReader(InputReader):
@ -20,6 +80,9 @@ class DatasetReader(InputReader):
"input_config"={ "input_config"={
"format": "json", "format": "json",
"path": "/tmp/sample_batches/", "path": "/tmp/sample_batches/",
# By default, parallelism=num_workers.
"parallelism": 3,
"num_cpus_per_read_task": 0.5,
} }
} }

16
rllib/offline/resource.py Normal file
View file

@ -0,0 +1,16 @@
from ray.rllib.offline.dataset_reader import (
get_resource_bundles as dataset_reader_get_resource_bundles,
)
from ray.rllib.utils.typing import PartialTrainerConfigDict
from typing import Dict, List
def get_offline_io_resource_bundles(
config: PartialTrainerConfigDict,
) -> List[Dict[str, float]]:
# DatasetReader is the only offline I/O component today that
# requires compute resources.
if config["input"] == "dataset":
return dataset_reader_get_resource_bundles(config["input_config"])
else:
return []