mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
122 lines
4.1 KiB
Python
122 lines
4.1 KiB
Python
import logging
|
|
import math
|
|
|
|
import ray.data
|
|
from ray.rllib.offline.input_reader import InputReader
|
|
from ray.rllib.offline.io_context import IOContext
|
|
from ray.rllib.offline.json_reader import from_json_data
|
|
from ray.rllib.utils.annotations import override, PublicAPI
|
|
from ray.rllib.utils.typing import SampleBatchType, TrainerConfigDict
|
|
from typing import List
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
DEFAULT_NUM_CPUS_PER_TASK = 0.5
|
|
|
|
|
|
def get_resource_bundles(config: TrainerConfigDict):
|
|
input_config = config.get("input_config", {})
|
|
parallelism = input_config.get("parallelism", config.get("num_workers", 1))
|
|
cpus_per_task = input_config.get(
|
|
"num_cpus_per_read_task", DEFAULT_NUM_CPUS_PER_TASK
|
|
)
|
|
return [{"CPU": math.ceil(parallelism * cpus_per_task)}]
|
|
|
|
|
|
def get_dataset_and_shards(
|
|
config: TrainerConfigDict, num_workers: int, local_worker: bool
|
|
) -> (ray.data.dataset.Dataset, List[ray.data.dataset.Dataset]):
|
|
assert config["input"] == "dataset"
|
|
assert (
|
|
"input_config" in config
|
|
), "Must specify input_config dict if using Dataset input."
|
|
|
|
input_config = config["input_config"]
|
|
if not input_config.get("format", None) or not input_config.get("path", None):
|
|
raise ValueError(
|
|
"Must specify format and path via input_config key"
|
|
" when using Ray dataset input."
|
|
)
|
|
|
|
parallelism = input_config.get("parallelism", num_workers or 1)
|
|
cpus_per_task = input_config.get(
|
|
"num_cpus_per_read_task", DEFAULT_NUM_CPUS_PER_TASK
|
|
)
|
|
|
|
format = input_config["format"]
|
|
path = input_config["path"]
|
|
if format == "json":
|
|
dataset = ray.data.read_json(
|
|
path, parallelism=parallelism, ray_remote_args={"num_cpus": cpus_per_task}
|
|
)
|
|
elif format == "parquet":
|
|
dataset = ray.data.read_parquet(
|
|
path, parallelism=parallelism, ray_remote_args={"num_cpus": cpus_per_task}
|
|
)
|
|
else:
|
|
raise ValueError("Un-supported Ray dataset format: ", format)
|
|
|
|
# Local worker will be responsible for sampling.
|
|
if local_worker and num_workers == 0:
|
|
# Dataset is the only shard we need.
|
|
return dataset, [dataset]
|
|
# Remote workers are responsible for sampling:
|
|
else:
|
|
# Each remote worker gets 1 shard.
|
|
# The first None shard is for the local worker, which
|
|
# shouldn't be doing rollout work anyways.
|
|
return dataset, [None] + dataset.repartition(
|
|
num_blocks=num_workers, shuffle=False
|
|
).split(num_workers)
|
|
|
|
|
|
@PublicAPI
|
|
class DatasetReader(InputReader):
|
|
"""Reader object that loads data from Ray Dataset.
|
|
|
|
Examples:
|
|
config = {
|
|
"input": "dataset",
|
|
"input_config": {
|
|
"format": "json",
|
|
# A single data file, a directory, or anything
|
|
# that ray.data.dataset recognizes.
|
|
"path": "/tmp/sample_batches/",
|
|
# By default, parallelism=num_workers.
|
|
"parallelism": 3,
|
|
# Dataset allocates 0.5 CPU for each reader by default.
|
|
# Adjust this value based on the size of your offline dataset.
|
|
"num_cpus_per_read_task": 0.5,
|
|
}
|
|
}
|
|
"""
|
|
|
|
@PublicAPI
|
|
def __init__(self, ioctx: IOContext, ds: ray.data.Dataset):
|
|
"""Initializes a DatasetReader instance.
|
|
|
|
Args:
|
|
ds: Ray dataset to sample from.
|
|
"""
|
|
self._ioctx = ioctx
|
|
self._dataset = ds
|
|
# We allow the creation of a non-functioning None DatasetReader.
|
|
# It's useful for example for a non-rollout local worker.
|
|
if ds:
|
|
print(
|
|
"DatasetReader ", ioctx.worker_index, " has ", ds.count(), " samples."
|
|
)
|
|
self._iter = self._dataset.repeat().iter_rows()
|
|
else:
|
|
self._iter = None
|
|
|
|
@override(InputReader)
|
|
def next(self) -> SampleBatchType:
|
|
# next() should not get called on None DatasetReader.
|
|
assert self._iter is not None
|
|
|
|
d = next(self._iter).as_pydict()
|
|
# Columns like obs are compressed when written by DatasetWriter.
|
|
d = from_json_data(d, self._ioctx.worker)
|
|
|
|
return d
|