mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
123 lines
4.3 KiB
Python
123 lines
4.3 KiB
Python
from datetime import datetime
|
|
import json
|
|
import logging
|
|
import numpy as np
|
|
import os
|
|
from six.moves.urllib.parse import urlparse
|
|
import time
|
|
|
|
try:
|
|
from smart_open import smart_open
|
|
except ImportError:
|
|
smart_open = None
|
|
|
|
from ray.rllib.policy.sample_batch import MultiAgentBatch
|
|
from ray.rllib.offline.io_context import IOContext
|
|
from ray.rllib.offline.output_writer import OutputWriter
|
|
from ray.rllib.utils.annotations import override, PublicAPI
|
|
from ray.rllib.utils.compression import pack, compression_supported
|
|
from ray.rllib.utils.types import FileType, SampleBatchType
|
|
from typing import Any, List
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@PublicAPI
|
|
class JsonWriter(OutputWriter):
|
|
"""Writer object that saves experiences in JSON file chunks."""
|
|
|
|
@PublicAPI
|
|
def __init__(self,
|
|
path: str,
|
|
ioctx: IOContext = None,
|
|
max_file_size: int = 64 * 1024 * 1024,
|
|
compress_columns: List[str] = frozenset(["obs", "new_obs"])):
|
|
"""Initialize a JsonWriter.
|
|
|
|
Arguments:
|
|
path (str): a path/URI of the output directory to save files in.
|
|
ioctx (IOContext): current IO context object.
|
|
max_file_size (int): max size of single files before rolling over.
|
|
compress_columns (list): list of sample batch columns to compress.
|
|
"""
|
|
|
|
self.ioctx = ioctx or IOContext()
|
|
self.max_file_size = max_file_size
|
|
self.compress_columns = compress_columns
|
|
if urlparse(path).scheme not in ["", "c"]:
|
|
self.path_is_uri = True
|
|
else:
|
|
path = os.path.abspath(os.path.expanduser(path))
|
|
# Try to create local dirs if they don't exist
|
|
try:
|
|
os.makedirs(path)
|
|
except OSError:
|
|
pass # already exists
|
|
assert os.path.exists(path), "Failed to create {}".format(path)
|
|
self.path_is_uri = False
|
|
self.path = path
|
|
self.file_index = 0
|
|
self.bytes_written = 0
|
|
self.cur_file = None
|
|
|
|
@override(OutputWriter)
|
|
def write(self, sample_batch: SampleBatchType):
|
|
start = time.time()
|
|
data = _to_json(sample_batch, self.compress_columns)
|
|
f = self._get_file()
|
|
f.write(data)
|
|
f.write("\n")
|
|
if hasattr(f, "flush"): # legacy smart_open impls
|
|
f.flush()
|
|
self.bytes_written += len(data)
|
|
logger.debug("Wrote {} bytes to {} in {}s".format(
|
|
len(data), f,
|
|
time.time() - start))
|
|
|
|
def _get_file(self) -> FileType:
|
|
if not self.cur_file or self.bytes_written >= self.max_file_size:
|
|
if self.cur_file:
|
|
self.cur_file.close()
|
|
timestr = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
|
|
path = os.path.join(
|
|
self.path, "output-{}_worker-{}_{}.json".format(
|
|
timestr, self.ioctx.worker_index, self.file_index))
|
|
if self.path_is_uri:
|
|
if smart_open is None:
|
|
raise ValueError(
|
|
"You must install the `smart_open` module to write "
|
|
"to URIs like {}".format(path))
|
|
self.cur_file = smart_open(path, "w")
|
|
else:
|
|
self.cur_file = open(path, "w")
|
|
self.file_index += 1
|
|
self.bytes_written = 0
|
|
logger.info("Writing to new output file {}".format(self.cur_file))
|
|
return self.cur_file
|
|
|
|
|
|
def _to_jsonable(v, compress: bool) -> Any:
|
|
if compress and compression_supported():
|
|
return str(pack(v))
|
|
elif isinstance(v, np.ndarray):
|
|
return v.tolist()
|
|
return v
|
|
|
|
|
|
def _to_json(batch: SampleBatchType, compress_columns: List[str]) -> str:
|
|
out = {}
|
|
if isinstance(batch, MultiAgentBatch):
|
|
out["type"] = "MultiAgentBatch"
|
|
out["count"] = batch.count
|
|
policy_batches = {}
|
|
for policy_id, sub_batch in batch.policy_batches.items():
|
|
policy_batches[policy_id] = {}
|
|
for k, v in sub_batch.data.items():
|
|
policy_batches[policy_id][k] = _to_jsonable(
|
|
v, compress=k in compress_columns)
|
|
out["policy_batches"] = policy_batches
|
|
else:
|
|
out["type"] = "SampleBatch"
|
|
for k, v in batch.data.items():
|
|
out[k] = _to_jsonable(v, compress=k in compress_columns)
|
|
return json.dumps(out)
|