mirror of
https://github.com/vale981/ray
synced 2025-03-04 17:41:43 -05:00
82 lines
2.7 KiB
Python
82 lines
2.7 KiB
Python
import logging
|
|
import os
|
|
import time
|
|
|
|
from ray import data
|
|
from ray.rllib.offline.io_context import IOContext
|
|
from ray.rllib.offline.json_writer import _to_json_dict
|
|
from ray.rllib.offline.output_writer import OutputWriter
|
|
from ray.rllib.utils.annotations import override, PublicAPI
|
|
from ray.rllib.utils.typing import SampleBatchType
|
|
from typing import Dict, List
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@PublicAPI
|
|
class DatasetWriter(OutputWriter):
|
|
"""Writer object that saves experiences using Datasets."""
|
|
|
|
@PublicAPI
|
|
def __init__(
|
|
self,
|
|
ioctx: IOContext = None,
|
|
compress_columns: List[str] = frozenset(["obs", "new_obs"]),
|
|
):
|
|
"""Initializes a DatasetWriter instance.
|
|
|
|
Examples:
|
|
config = {
|
|
"output": "dataset",
|
|
"output_config": {
|
|
"format": "json",
|
|
"path": "/tmp/test_samples/",
|
|
"max_num_samples_per_file": 100000,
|
|
}
|
|
}
|
|
|
|
Args:
|
|
ioctx: current IO context object.
|
|
compress_columns: list of sample batch columns to compress.
|
|
"""
|
|
self.ioctx = ioctx or IOContext()
|
|
|
|
output_config: Dict = ioctx.output_config
|
|
assert (
|
|
"format" in output_config
|
|
), "output_config.format must be specified when using Dataset output."
|
|
assert (
|
|
"path" in output_config
|
|
), "output_config.path must be specified when using Dataset output."
|
|
|
|
self.format = output_config["format"]
|
|
self.path = os.path.abspath(os.path.expanduser(output_config["path"]))
|
|
self.max_num_samples_per_file = (
|
|
output_config["max_num_samples_per_file"]
|
|
if "max_num_samples_per_file" in output_config
|
|
else 100000
|
|
)
|
|
self.compress_columns = compress_columns
|
|
|
|
self.samples = []
|
|
|
|
@override(OutputWriter)
|
|
def write(self, sample_batch: SampleBatchType):
|
|
start = time.time()
|
|
|
|
# Make sure columns like obs are compressed and writable.
|
|
d = _to_json_dict(sample_batch, self.compress_columns)
|
|
self.samples.append(d)
|
|
|
|
# Todo: We should flush at the end of sampling even if this
|
|
# condition was not reached.
|
|
if len(self.samples) >= self.max_num_samples_per_file:
|
|
ds = data.from_items(self.samples).repartition(num_blocks=1, shuffle=False)
|
|
if self.format == "json":
|
|
ds.write_json(self.path, try_create_dir=True)
|
|
elif self.format == "parquet":
|
|
ds.write_parquet(self.path, try_create_dir=True)
|
|
else:
|
|
raise ValueError("Unknown output type: ", self.format)
|
|
self.samples = []
|
|
logger.debug("Wrote dataset in {}s".format(time.time() - start))
|