ray/rllib/offline/dataset_reader.py

209 lines
7.9 KiB
Python

import logging
import math
from pathlib import Path
import re
from typing import List, Tuple
import zipfile
import ray.data
from ray.rllib.offline.input_reader import InputReader
from ray.rllib.offline.io_context import IOContext
from ray.rllib.offline.json_reader import from_json_data
from ray.rllib.policy.sample_batch import concat_samples, SampleBatch, DEFAULT_POLICY_ID
from ray.rllib.utils.annotations import override, PublicAPI
from ray.rllib.utils.typing import SampleBatchType, AlgorithmConfigDict
logger = logging.getLogger(__name__)
DEFAULT_NUM_CPUS_PER_TASK = 0.5
def _get_resource_bundles(config: AlgorithmConfigDict):
input_config = config.get("input_config", {})
parallelism = input_config.get("parallelism", config.get("num_workers", 1))
cpus_per_task = input_config.get(
"num_cpus_per_read_task", DEFAULT_NUM_CPUS_PER_TASK
)
return [{"CPU": math.ceil(parallelism * cpus_per_task)}]
def _unzip_if_needed(paths: List[str], format: str):
"""If a path in paths is a zip file, unzip it and use path of the unzipped file"""
ret = []
for path in paths:
fpath = Path(path).absolute()
if not fpath.exists():
fpath = Path(__file__).parent.parent / path
if not fpath.exists():
raise FileNotFoundError(f"File not found: {path}")
if re.search("\\.zip$", str(fpath)):
with zipfile.ZipFile(str(fpath), "r") as zip_ref:
zip_ref.extractall(str(fpath.parent))
fpath = re.sub("\\.zip$", f".{format}", str(fpath))
fpath = str(fpath)
ret.append(fpath)
return ret
@PublicAPI
def get_dataset_and_shards(
config: AlgorithmConfigDict, num_workers: int, local_worker: bool
) -> Tuple[ray.data.dataset.Dataset, List[ray.data.dataset.Dataset]]:
assert config["input"] == "dataset", (
"Must specify input as dataset if" " calling `get_dataset_and_shards`"
)
assert (
"input_config" in config
), "Must specify input_config dict if using Dataset input."
input_config = config["input_config"]
format = input_config.get("format")
assert format in ("json", "parquet"), (
"Offline input data format must be " "parquet " "or json"
)
paths = input_config.get("paths")
loader_fn = input_config.get("loader_fn")
if loader_fn and (format or paths):
raise ValueError(
"When using a `loader_fn`, you cannot specify a `format` or `path`."
)
if not (format and paths) and not loader_fn:
raise ValueError(
"Must specify format and path, or a loader_fn via input_config key"
" when using Ray dataset input."
)
if not isinstance(paths, (list, str)):
raise ValueError("Paths must be a list of path strings or a path string")
if isinstance(paths, str):
paths = [paths]
paths = _unzip_if_needed(paths, format)
parallelism = input_config.get("parallelism", num_workers or 1)
cpus_per_task = input_config.get(
"num_cpus_per_read_task", DEFAULT_NUM_CPUS_PER_TASK
)
assert loader_fn or (format and paths), (
f"If using a loader_fn: {loader_fn} that constructs a dataset, "
"format: {format} and paths: {paths} must be specified. If format and "
"paths are specified, a loader_fn must not be specified."
)
if loader_fn:
dataset = loader_fn()
elif format == "json":
dataset = ray.data.read_json(
paths, parallelism=parallelism, ray_remote_args={"num_cpus": cpus_per_task}
)
elif format == "parquet":
dataset = ray.data.read_parquet(
paths, parallelism=parallelism, ray_remote_args={"num_cpus": cpus_per_task}
)
else:
raise ValueError("Un-supported Ray dataset format: ", format)
# Local worker will be responsible for sampling.
if local_worker and num_workers == 0:
# Dataset is the only shard we need.
return dataset, [dataset]
# Remote workers are responsible for sampling:
else:
# Each remote worker gets 1 shard.
# The first None shard is for the local worker, which
# shouldn't be doing rollout work anyways.
return dataset, [None] + dataset.repartition(
num_blocks=num_workers, shuffle=False
).split(num_workers)
@PublicAPI
class DatasetReader(InputReader):
"""Reader object that loads data from Ray Dataset.
Examples:
config = {
"input": "dataset",
"input_config": {
"format": "json",
# A single data file, a directory, or anything
# that ray.data.dataset recognizes.
"paths": "/tmp/sample_batches/",
# By default, parallelism=num_workers.
"parallelism": 3,
# Dataset allocates 0.5 CPU for each reader by default.
# Adjust this value based on the size of your offline dataset.
"num_cpus_per_read_task": 0.5,
}
}
"""
@PublicAPI
def __init__(self, ioctx: IOContext, ds: ray.data.Dataset):
"""Initializes a DatasetReader instance.
Args:
ds: Ray dataset to sample from.
"""
self._ioctx = ioctx
self._default_policy = self.policy_map = None
self._dataset = ds
self.count = None if not self._dataset else self._dataset.count()
# do this to disable the ray data stdout logging
ray.data.set_progress_bars(enabled=False)
# the number of rows to return per call to next()
if self._ioctx:
self.batch_size = self._ioctx.config.get("train_batch_size", 1)
num_workers = self._ioctx.config.get("num_workers", 0)
seed = self._ioctx.config.get("seed", None)
if num_workers:
self.batch_size = max(math.ceil(self.batch_size / num_workers), 1)
# We allow the creation of a non-functioning None DatasetReader.
# It's useful for example for a non-rollout local worker.
if ds:
if self._ioctx.worker is not None:
self._policy_map = self._ioctx.worker.policy_map
self._default_policy = self._policy_map.get(DEFAULT_POLICY_ID)
self._dataset.random_shuffle(seed=seed)
print(
f"DatasetReader {self._ioctx.worker_index} has {ds.count()}, samples."
)
# TODO: @avnishn make this call seeded.
# calling random_shuffle_each_window shuffles the dataset after
# each time the whole dataset has been read.
self._iter = self._dataset.repeat().random_shuffle_each_window().iter_rows()
else:
self._iter = None
@override(InputReader)
def next(self) -> SampleBatchType:
# next() should not get called on None DatasetReader.
assert self._iter is not None
ret = []
count = 0
while count < self.batch_size:
d = next(self._iter).as_pydict()
# Columns like obs are compressed when written by DatasetWriter.
d = from_json_data(d, self._ioctx.worker)
count += d.count
ret.append(self._postprocess_if_needed(d))
ret = concat_samples(ret)
return ret
def _postprocess_if_needed(self, batch: SampleBatchType) -> SampleBatchType:
if not self._ioctx or not self._ioctx.config.get("postprocess_inputs"):
return batch
if isinstance(batch, SampleBatch):
out = []
for sub_batch in batch.split_by_episode():
out.append(self._default_policy.postprocess_trajectory(sub_batch))
return SampleBatch.concat_samples(out)
else:
# TODO(ekl) this is trickier since the alignments between agent
# trajectories in the episode are not available any more.
raise NotImplementedError(
"Postprocessing of multi-agent data not implemented yet."
)