mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
276 lines
11 KiB
Python
276 lines
11 KiB
Python
import logging
|
|
import math
|
|
from pathlib import Path
|
|
import re
|
|
from typing import List, Tuple, Optional
|
|
import zipfile
|
|
|
|
import ray.data
|
|
from ray.rllib.offline.input_reader import InputReader
|
|
from ray.rllib.offline.io_context import IOContext
|
|
from ray.rllib.offline.json_reader import from_json_data
|
|
from ray.rllib.policy.sample_batch import concat_samples, SampleBatch, DEFAULT_POLICY_ID
|
|
from ray.rllib.utils.annotations import override, PublicAPI
|
|
from ray.rllib.utils.typing import SampleBatchType, AlgorithmConfigDict
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
DEFAULT_NUM_CPUS_PER_TASK = 0.5
|
|
|
|
|
|
# TODO: @avnishn what is the use of this function anymore?
|
|
def _get_resource_bundles(config: AlgorithmConfigDict):
|
|
input_config = config.get("input_config", {})
|
|
parallelism = input_config.get("parallelism", config.get("num_workers", 1))
|
|
cpus_per_task = input_config.get(
|
|
"num_cpus_per_read_task", DEFAULT_NUM_CPUS_PER_TASK
|
|
)
|
|
return [{"CPU": math.ceil(parallelism * cpus_per_task)}]
|
|
|
|
|
|
def _unzip_this_path(fpath: Path, extract_path: str):
|
|
with zipfile.ZipFile(str(fpath), "r") as zip_ref:
|
|
zip_ref.extractall(extract_path)
|
|
|
|
|
|
def _unzip_if_needed(paths: List[str], format: str):
|
|
"""If a path in paths is a zip file, unzip it and use path of the unzipped file"""
|
|
ret_paths = []
|
|
for path in paths:
|
|
if re.search("\\.zip$", str(path)):
|
|
# TODO: We need to add unzip support for s3
|
|
if str(path).startswith("s3://"):
|
|
raise ValueError(
|
|
"unzip_if_needed currently does not support remote paths from s3"
|
|
)
|
|
extract_path = "./"
|
|
try:
|
|
_unzip_this_path(str(path), extract_path)
|
|
except FileNotFoundError:
|
|
# intrepreted as a relative path to rllib folder
|
|
try:
|
|
# TODO: remove this later when we replace all tests with s3 paths
|
|
_unzip_this_path(Path(__file__).parent.parent / path, extract_path)
|
|
except FileNotFoundError:
|
|
raise FileNotFoundError(f"File not found: {path}")
|
|
|
|
unzipped_path = str(
|
|
Path(extract_path).absolute() / f"{Path(path).stem}.{format}"
|
|
)
|
|
ret_paths.append(unzipped_path)
|
|
else:
|
|
# TODO: We can get rid of this logic when we replace all tests with s3 paths
|
|
if str(path).startswith("s3://"):
|
|
ret_paths.append(path)
|
|
else:
|
|
if not Path(path).exists():
|
|
relative_path = str(Path(__file__).parent.parent / path)
|
|
if not Path(relative_path).exists():
|
|
raise FileNotFoundError(f"File not found: {path}")
|
|
path = relative_path
|
|
ret_paths.append(path)
|
|
return ret_paths
|
|
|
|
|
|
@PublicAPI
|
|
def get_dataset_and_shards(
|
|
config: AlgorithmConfigDict, num_workers: int = 0
|
|
) -> Tuple[ray.data.dataset.Dataset, List[ray.data.dataset.Dataset]]:
|
|
"""Returns a dataset and a list of shards.
|
|
|
|
This function uses algorithm configs to create a dataset and a list of shards.
|
|
The following config keys are used to create the dataset:
|
|
input: The input type should be "dataset".
|
|
input_config: A dict containing the following key and values:
|
|
`format`: str, speciifies the format of the input data. This will be the
|
|
format that ray dataset supports. See ray.data.dataset.Dataset for
|
|
supported formats. Only "parquet" or "json" are supported for now.
|
|
`paths`: str, a single string or a list of strings. Each string is a path
|
|
to a file or a directory holding the dataset. It can be either a local path
|
|
or a remote path (e.g. to an s3 bucket).
|
|
`loader_fn`: Callable[None, ray.data.dataset.Dataset], Instead of
|
|
specifying paths and format, you can specify a function to load the dataset.
|
|
`parallelism`: int, The number of tasks to use for loading the dataset.
|
|
If not specified, it will be set to the number of workers.
|
|
`num_cpus_per_read_task`: float, The number of CPUs to use for each read
|
|
task. If not specified, it will be set to 0.5.
|
|
|
|
Args:
|
|
config: The config dict for the algorithm.
|
|
num_workers: The number of shards to create for remote workers.
|
|
|
|
Returns:
|
|
dataset: The dataset object.
|
|
shards: A list of dataset shards. For num_workers > 0 the first returned
|
|
shared would be a dummy None shard for local_worker.
|
|
"""
|
|
# check input and input config keys
|
|
assert config["input"] == "dataset", (
|
|
f"Must specify input as dataset if"
|
|
f" calling `get_dataset_and_shards`. Got {config['input']}"
|
|
)
|
|
assert (
|
|
"input_config" in config
|
|
), "Must specify input_config dict if using Dataset input."
|
|
|
|
# check input config format
|
|
input_config = config["input_config"]
|
|
format = input_config.get("format")
|
|
|
|
supported_fmts = ["json", "parquet"]
|
|
if format is not None and format not in supported_fmts:
|
|
raise ValueError(
|
|
f"Unsupported format {format}. Supported formats are {supported_fmts}"
|
|
)
|
|
|
|
# check paths and loader_fn since only one of them is required.
|
|
paths = input_config.get("paths")
|
|
loader_fn = input_config.get("loader_fn")
|
|
if loader_fn and (format or paths):
|
|
raise ValueError(
|
|
"When using a `loader_fn`, you cannot specify a `format` or `path`."
|
|
)
|
|
|
|
# check if at least loader_fn or format + path is specified.
|
|
if not (format and paths) and not loader_fn:
|
|
raise ValueError(
|
|
f"If using a loader_fn: {loader_fn} that constructs a dataset, "
|
|
"neither format: {format} and paths: {paths} must not be specified. If "
|
|
"format and paths are specified, a loader_fn must not be specified."
|
|
)
|
|
|
|
# check paths to be a str or list[str] if not None
|
|
if paths is not None:
|
|
if isinstance(paths, str):
|
|
paths = [paths]
|
|
elif isinstance(paths, list):
|
|
assert isinstance(paths[0], str), "Paths must be a list of path strings."
|
|
else:
|
|
raise ValueError("Paths must be a path string or a list of path strings.")
|
|
paths = _unzip_if_needed(paths, format)
|
|
|
|
parallelism = input_config.get("parallelism", num_workers or 1)
|
|
cpus_per_task = input_config.get(
|
|
"num_cpus_per_read_task", DEFAULT_NUM_CPUS_PER_TASK
|
|
)
|
|
|
|
if loader_fn:
|
|
dataset = loader_fn()
|
|
elif format == "json":
|
|
dataset = ray.data.read_json(
|
|
paths, parallelism=parallelism, ray_remote_args={"num_cpus": cpus_per_task}
|
|
)
|
|
elif format == "parquet":
|
|
dataset = ray.data.read_parquet(
|
|
paths, parallelism=parallelism, ray_remote_args={"num_cpus": cpus_per_task}
|
|
)
|
|
else:
|
|
raise ValueError("Un-supported Ray dataset format: ", format)
|
|
|
|
# Local worker will be responsible for sampling.
|
|
if num_workers == 0:
|
|
# Dataset is the only shard we need.
|
|
return dataset, [dataset]
|
|
# Remote workers are responsible for sampling:
|
|
else:
|
|
# Each remote worker gets 1 shard.
|
|
remote_shards = dataset.repartition(
|
|
num_blocks=num_workers, shuffle=False
|
|
).split(num_workers)
|
|
|
|
# The first None shard is for the local worker, which
|
|
# shouldn't be doing rollout work anyways.
|
|
return dataset, [None] + remote_shards
|
|
|
|
|
|
@PublicAPI
|
|
class DatasetReader(InputReader):
|
|
"""Reader object that loads data from Ray Dataset.
|
|
|
|
Examples:
|
|
config = {
|
|
"input": "dataset",
|
|
"input_config": {
|
|
"format": "json",
|
|
# A single data file, a directory, or anything
|
|
# that ray.data.dataset recognizes.
|
|
"paths": "/tmp/sample_batches/",
|
|
# By default, parallelism=num_workers.
|
|
"parallelism": 3,
|
|
# Dataset allocates 0.5 CPU for each reader by default.
|
|
# Adjust this value based on the size of your offline dataset.
|
|
"num_cpus_per_read_task": 0.5,
|
|
}
|
|
}
|
|
"""
|
|
|
|
@PublicAPI
|
|
def __init__(self, ds: ray.data.Dataset, ioctx: Optional[IOContext] = None):
|
|
"""Initializes a DatasetReader instance.
|
|
|
|
Args:
|
|
ds: Ray dataset to sample from.
|
|
"""
|
|
self._ioctx = ioctx or IOContext()
|
|
self._default_policy = self.policy_map = None
|
|
self._dataset = ds
|
|
self.count = None if not self._dataset else self._dataset.count()
|
|
# do this to disable the ray data stdout logging
|
|
ray.data.set_progress_bars(enabled=False)
|
|
|
|
# the number of steps to return per call to next()
|
|
self.batch_size = self._ioctx.config.get("train_batch_size", 1)
|
|
num_workers = self._ioctx.config.get("num_workers", 0)
|
|
seed = self._ioctx.config.get("seed", None)
|
|
if num_workers:
|
|
self.batch_size = max(math.ceil(self.batch_size / num_workers), 1)
|
|
# We allow the creation of a non-functioning None DatasetReader.
|
|
# It's useful for example for a non-rollout local worker.
|
|
if ds:
|
|
if self._ioctx.worker is not None:
|
|
self._policy_map = self._ioctx.worker.policy_map
|
|
self._default_policy = self._policy_map.get(DEFAULT_POLICY_ID)
|
|
self._dataset.random_shuffle(seed=seed)
|
|
print(
|
|
f"DatasetReader {self._ioctx.worker_index} has {ds.count()}, samples."
|
|
)
|
|
# TODO: @avnishn make this call seeded.
|
|
# calling random_shuffle_each_window shuffles the dataset after
|
|
# each time the whole dataset has been read.
|
|
self._iter = self._dataset.repeat().random_shuffle_each_window().iter_rows()
|
|
else:
|
|
self._iter = None
|
|
|
|
@override(InputReader)
|
|
def next(self) -> SampleBatchType:
|
|
# next() should not get called on None DatasetReader.
|
|
assert self._iter is not None
|
|
ret = []
|
|
count = 0
|
|
while count < self.batch_size:
|
|
d = next(self._iter).as_pydict()
|
|
# Columns like obs are compressed when written by DatasetWriter.
|
|
d = from_json_data(d, self._ioctx.worker)
|
|
count += d.count
|
|
ret.append(self._postprocess_if_needed(d))
|
|
ret = concat_samples(ret)
|
|
return ret
|
|
|
|
def _postprocess_if_needed(self, batch: SampleBatchType) -> SampleBatchType:
|
|
if not self._ioctx.config.get("postprocess_inputs"):
|
|
return batch
|
|
|
|
if isinstance(batch, SampleBatch):
|
|
out = []
|
|
for sub_batch in batch.split_by_episode():
|
|
if self._default_policy is not None:
|
|
out.append(self._default_policy.postprocess_trajectory(sub_batch))
|
|
else:
|
|
out.append(sub_batch)
|
|
return concat_samples(out)
|
|
else:
|
|
# TODO(ekl) this is trickier since the alignments between agent
|
|
# trajectories in the episode are not available any more.
|
|
raise NotImplementedError(
|
|
"Postprocessing of multi-agent data not implemented yet."
|
|
)
|