ray/rllib/offline/json_writer.py

126 lines
4.4 KiB
Python
Raw Normal View History

from datetime import datetime
import json
import logging
import numpy as np
import os
from six.moves.urllib.parse import urlparse
import time
try:
from smart_open import smart_open
except ImportError:
smart_open = None
from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.offline.io_context import IOContext
from ray.rllib.offline.output_writer import OutputWriter
from ray.rllib.utils.annotations import override, PublicAPI
from ray.rllib.utils.compression import pack, compression_supported
from ray.rllib.utils.typing import FileType, SampleBatchType
from typing import Any, List
logger = logging.getLogger(__name__)
WINDOWS_DRIVES = [chr(i) for i in range(ord("c"), ord("z") + 1)]
@PublicAPI
class JsonWriter(OutputWriter):
"""Writer object that saves experiences in JSON file chunks."""
@PublicAPI
def __init__(self,
path: str,
ioctx: IOContext = None,
max_file_size: int = 64 * 1024 * 1024,
compress_columns: List[str] = frozenset(["obs", "new_obs"])):
"""Initialize a JsonWriter.
2020-09-20 11:27:02 +02:00
Args:
path (str): a path/URI of the output directory to save files in.
ioctx (IOContext): current IO context object.
max_file_size (int): max size of single files before rolling over.
compress_columns (list): list of sample batch columns to compress.
"""
self.ioctx = ioctx or IOContext()
self.max_file_size = max_file_size
self.compress_columns = compress_columns
if urlparse(path).scheme not in [""] + WINDOWS_DRIVES:
self.path_is_uri = True
else:
path = os.path.abspath(os.path.expanduser(path))
# Try to create local dirs if they don't exist
try:
os.makedirs(path)
except OSError:
pass # already exists
assert os.path.exists(path), "Failed to create {}".format(path)
self.path_is_uri = False
self.path = path
self.file_index = 0
self.bytes_written = 0
self.cur_file = None
@override(OutputWriter)
def write(self, sample_batch: SampleBatchType):
start = time.time()
data = _to_json(sample_batch, self.compress_columns)
f = self._get_file()
f.write(data)
f.write("\n")
if hasattr(f, "flush"): # legacy smart_open impls
f.flush()
self.bytes_written += len(data)
logger.debug("Wrote {} bytes to {} in {}s".format(
len(data), f,
time.time() - start))
def _get_file(self) -> FileType:
if not self.cur_file or self.bytes_written >= self.max_file_size:
if self.cur_file:
self.cur_file.close()
timestr = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
path = os.path.join(
self.path, "output-{}_worker-{}_{}.json".format(
timestr, self.ioctx.worker_index, self.file_index))
if self.path_is_uri:
if smart_open is None:
raise ValueError(
"You must install the `smart_open` module to write "
"to URIs like {}".format(path))
self.cur_file = smart_open(path, "w")
else:
self.cur_file = open(path, "w")
self.file_index += 1
self.bytes_written = 0
logger.info("Writing to new output file {}".format(self.cur_file))
return self.cur_file
def _to_jsonable(v, compress: bool) -> Any:
if compress and compression_supported():
return str(pack(v))
elif isinstance(v, np.ndarray):
return v.tolist()
return v
def _to_json(batch: SampleBatchType, compress_columns: List[str]) -> str:
out = {}
if isinstance(batch, MultiAgentBatch):
out["type"] = "MultiAgentBatch"
out["count"] = batch.count
policy_batches = {}
for policy_id, sub_batch in batch.policy_batches.items():
policy_batches[policy_id] = {}
for k, v in sub_batch.data.items():
policy_batches[policy_id][k] = _to_jsonable(
v, compress=k in compress_columns)
out["policy_batches"] = policy_batches
else:
out["type"] = "SampleBatch"
for k, v in batch.items():
out[k] = _to_jsonable(v, compress=k in compress_columns)
return json.dumps(out)