ray/rllib/offline/dataset_writer.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

83 lines
2.7 KiB
Python
Raw Normal View History

import logging
import os
import time
from ray import data
from ray.rllib.offline.io_context import IOContext
from ray.rllib.offline.json_writer import _to_json_dict
from ray.rllib.offline.output_writer import OutputWriter
from ray.rllib.utils.annotations import override, PublicAPI
from ray.rllib.utils.typing import SampleBatchType
from typing import Dict, List
logger = logging.getLogger(__name__)
@PublicAPI
class DatasetWriter(OutputWriter):
"""Writer object that saves experiences using Datasets."""
@PublicAPI
def __init__(
self,
ioctx: IOContext = None,
compress_columns: List[str] = frozenset(["obs", "new_obs"]),
):
"""Initializes a DatasetWriter instance.
Examples:
config = {
"output": "dataset",
"output_config": {
"format": "json",
"path": "/tmp/test_samples/",
"max_num_samples_per_file": 100000,
}
}
Args:
ioctx: current IO context object.
compress_columns: list of sample batch columns to compress.
"""
self.ioctx = ioctx or IOContext()
output_config: Dict = ioctx.output_config
assert (
"format" in output_config
), "output_config.format must be specified when using Dataset output."
assert (
"path" in output_config
), "output_config.path must be specified when using Dataset output."
self.format = output_config["format"]
self.path = os.path.abspath(os.path.expanduser(output_config["path"]))
self.max_num_samples_per_file = (
output_config["max_num_samples_per_file"]
if "max_num_samples_per_file" in output_config
else 100000
)
self.compress_columns = compress_columns
self.samples = []
@override(OutputWriter)
def write(self, sample_batch: SampleBatchType):
start = time.time()
# Make sure columns like obs are compressed and writable.
d = _to_json_dict(sample_batch, self.compress_columns)
self.samples.append(d)
# Todo: We should flush at the end of sampling even if this
# condition was not reached.
if len(self.samples) >= self.max_num_samples_per_file:
ds = data.from_items(self.samples).repartition(num_blocks=1, shuffle=False)
if self.format == "json":
ds.write_json(self.path, try_create_dir=True)
elif self.format == "parquet":
ds.write_parquet(self.path, try_create_dir=True)
else:
raise ValueError("Unknown output type: ", self.format)
self.samples = []
logger.debug("Wrote dataset in {}s".format(time.time() - start))