ray/rllib/offline/dataset_reader.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

210 lines
7.9 KiB
Python
Raw Normal View History

import logging
import math
from pathlib import Path
import re
from typing import List, Tuple
import zipfile
import ray.data
from ray.rllib.offline.input_reader import InputReader
from ray.rllib.offline.io_context import IOContext
from ray.rllib.offline.json_reader import from_json_data
from ray.rllib.policy.sample_batch import concat_samples, SampleBatch, DEFAULT_POLICY_ID
from ray.rllib.utils.annotations import override, PublicAPI
from ray.rllib.utils.typing import SampleBatchType, AlgorithmConfigDict
logger = logging.getLogger(__name__)
DEFAULT_NUM_CPUS_PER_TASK = 0.5
def _get_resource_bundles(config: AlgorithmConfigDict):
input_config = config.get("input_config", {})
parallelism = input_config.get("parallelism", config.get("num_workers", 1))
cpus_per_task = input_config.get(
"num_cpus_per_read_task", DEFAULT_NUM_CPUS_PER_TASK
)
return [{"CPU": math.ceil(parallelism * cpus_per_task)}]
def _unzip_if_needed(paths: List[str], format: str):
"""If a path in paths is a zip file, unzip it and use path of the unzipped file"""
ret = []
for path in paths:
fpath = Path(path).absolute()
if not fpath.exists():
fpath = Path(__file__).parent.parent / path
if not fpath.exists():
raise FileNotFoundError(f"File not found: {path}")
if re.search("\\.zip$", str(fpath)):
with zipfile.ZipFile(str(fpath), "r") as zip_ref:
zip_ref.extractall(str(fpath.parent))
fpath = re.sub("\\.zip$", f".{format}", str(fpath))
fpath = str(fpath)
ret.append(fpath)
return ret
@PublicAPI
def get_dataset_and_shards(
config: AlgorithmConfigDict, num_workers: int, local_worker: bool
) -> Tuple[ray.data.dataset.Dataset, List[ray.data.dataset.Dataset]]:
assert config["input"] == "dataset", (
"Must specify input as dataset if" " calling `get_dataset_and_shards`"
)
assert (
"input_config" in config
), "Must specify input_config dict if using Dataset input."
input_config = config["input_config"]
format = input_config.get("format")
assert format in ("json", "parquet"), (
"Offline input data format must be " "parquet " "or json"
)
paths = input_config.get("paths")
loader_fn = input_config.get("loader_fn")
if loader_fn and (format or paths):
raise ValueError(
"When using a `loader_fn`, you cannot specify a `format` or `path`."
)
if not (format and paths) and not loader_fn:
raise ValueError(
"Must specify format and path, or a loader_fn via input_config key"
" when using Ray dataset input."
)
if not isinstance(paths, (list, str)):
raise ValueError("Paths must be a list of path strings or a path string")
if isinstance(paths, str):
paths = [paths]
paths = _unzip_if_needed(paths, format)
parallelism = input_config.get("parallelism", num_workers or 1)
cpus_per_task = input_config.get(
"num_cpus_per_read_task", DEFAULT_NUM_CPUS_PER_TASK
)
assert loader_fn or (format and paths), (
f"If using a loader_fn: {loader_fn} that constructs a dataset, "
"format: {format} and paths: {paths} must be specified. If format and "
"paths are specified, a loader_fn must not be specified."
)
if loader_fn:
dataset = loader_fn()
elif format == "json":
dataset = ray.data.read_json(
paths, parallelism=parallelism, ray_remote_args={"num_cpus": cpus_per_task}
)
elif format == "parquet":
dataset = ray.data.read_parquet(
paths, parallelism=parallelism, ray_remote_args={"num_cpus": cpus_per_task}
)
else:
raise ValueError("Un-supported Ray dataset format: ", format)
# Local worker will be responsible for sampling.
if local_worker and num_workers == 0:
# Dataset is the only shard we need.
return dataset, [dataset]
# Remote workers are responsible for sampling:
else:
# Each remote worker gets 1 shard.
# The first None shard is for the local worker, which
# shouldn't be doing rollout work anyways.
return dataset, [None] + dataset.repartition(
num_blocks=num_workers, shuffle=False
).split(num_workers)
@PublicAPI
class DatasetReader(InputReader):
"""Reader object that loads data from Ray Dataset.
Examples:
config = {
"input": "dataset",
"input_config": {
"format": "json",
# A single data file, a directory, or anything
# that ray.data.dataset recognizes.
"paths": "/tmp/sample_batches/",
# By default, parallelism=num_workers.
"parallelism": 3,
# Dataset allocates 0.5 CPU for each reader by default.
# Adjust this value based on the size of your offline dataset.
"num_cpus_per_read_task": 0.5,
}
}
"""
@PublicAPI
def __init__(self, ioctx: IOContext, ds: ray.data.Dataset):
"""Initializes a DatasetReader instance.
Args:
ds: Ray dataset to sample from.
"""
self._ioctx = ioctx
self._default_policy = self.policy_map = None
self._dataset = ds
self.count = None if not self._dataset else self._dataset.count()
# do this to disable the ray data stdout logging
ray.data.set_progress_bars(enabled=False)
# the number of rows to return per call to next()
if self._ioctx:
self.batch_size = self._ioctx.config.get("train_batch_size", 1)
num_workers = self._ioctx.config.get("num_workers", 0)
seed = self._ioctx.config.get("seed", None)
if num_workers:
self.batch_size = max(math.ceil(self.batch_size / num_workers), 1)
# We allow the creation of a non-functioning None DatasetReader.
# It's useful for example for a non-rollout local worker.
if ds:
if self._ioctx.worker is not None:
self._policy_map = self._ioctx.worker.policy_map
self._default_policy = self._policy_map.get(DEFAULT_POLICY_ID)
self._dataset.random_shuffle(seed=seed)
print(
f"DatasetReader {self._ioctx.worker_index} has {ds.count()}, samples."
)
# TODO: @avnishn make this call seeded.
# calling random_shuffle_each_window shuffles the dataset after
# each time the whole dataset has been read.
self._iter = self._dataset.repeat().random_shuffle_each_window().iter_rows()
else:
self._iter = None
@override(InputReader)
def next(self) -> SampleBatchType:
# next() should not get called on None DatasetReader.
assert self._iter is not None
ret = []
count = 0
while count < self.batch_size:
d = next(self._iter).as_pydict()
# Columns like obs are compressed when written by DatasetWriter.
d = from_json_data(d, self._ioctx.worker)
count += d.count
ret.append(self._postprocess_if_needed(d))
ret = concat_samples(ret)
return ret
def _postprocess_if_needed(self, batch: SampleBatchType) -> SampleBatchType:
if not self._ioctx or not self._ioctx.config.get("postprocess_inputs"):
return batch
if isinstance(batch, SampleBatch):
out = []
for sub_batch in batch.split_by_episode():
out.append(self._default_policy.postprocess_trajectory(sub_batch))
return SampleBatch.concat_samples(out)
else:
# TODO(ekl) this is trickier since the alignments between agent
# trajectories in the episode are not available any more.
raise NotImplementedError(
"Postprocessing of multi-agent data not implemented yet."
)