ray/rllib/offline/dataset_writer.py
Kai Fricke 262d6121bb
[rllib] Fix error messages and example for dataset writer (#23419)
Currently the error message and example refer to a field type that is actually format.
2022-03-28 19:53:12 +01:00

82 lines
2.7 KiB
Python

import logging
import os
import time
from ray import data
from ray.rllib.offline.io_context import IOContext
from ray.rllib.offline.json_writer import _to_json_dict
from ray.rllib.offline.output_writer import OutputWriter
from ray.rllib.utils.annotations import override, PublicAPI
from ray.rllib.utils.typing import SampleBatchType
from typing import Dict, List
logger = logging.getLogger(__name__)
@PublicAPI
class DatasetWriter(OutputWriter):
"""Writer object that saves experiences using Datasets."""
@PublicAPI
def __init__(
self,
ioctx: IOContext = None,
compress_columns: List[str] = frozenset(["obs", "new_obs"]),
):
"""Initializes a DatasetWriter instance.
Examples:
config = {
"output": "dataset",
"output_config": {
"format": "json",
"path": "/tmp/test_samples/",
"max_num_samples_per_file": 100000,
}
}
Args:
ioctx: current IO context object.
compress_columns: list of sample batch columns to compress.
"""
self.ioctx = ioctx or IOContext()
output_config: Dict = ioctx.output_config
assert (
"format" in output_config
), "output_config.format must be specified when using Dataset output."
assert (
"path" in output_config
), "output_config.path must be specified when using Dataset output."
self.format = output_config["format"]
self.path = os.path.abspath(os.path.expanduser(output_config["path"]))
self.max_num_samples_per_file = (
output_config["max_num_samples_per_file"]
if "max_num_samples_per_file" in output_config
else 100000
)
self.compress_columns = compress_columns
self.samples = []
@override(OutputWriter)
def write(self, sample_batch: SampleBatchType):
start = time.time()
# Make sure columns like obs are compressed and writable.
d = _to_json_dict(sample_batch, self.compress_columns)
self.samples.append(d)
# Todo: We should flush at the end of sampling even if this
# condition was not reached.
if len(self.samples) >= self.max_num_samples_per_file:
ds = data.from_items(self.samples).repartition(num_blocks=1, shuffle=False)
if self.format == "json":
ds.write_json(self.path, try_create_dir=True)
elif self.format == "parquet":
ds.write_parquet(self.path, try_create_dir=True)
else:
raise ValueError("Unknown output type: ", self.format)
self.samples = []
logger.debug("Wrote dataset in {}s".format(time.time() - start))