mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00

Removes all ML related code from `ray.util` Removes: - `ray.util.xgboost` - `ray.util.lightgbm` - `ray.util.horovod` - `ray.util.ray_lightning` Moves `ray.util.ml_utils` to other locations Closes #23900 Signed-off-by: Amog Kamsetty <amogkamsetty@yahoo.com> Signed-off-by: Kai Fricke <kai@anyscale.com> Co-authored-by: Kai Fricke <kai@anyscale.com>
144 lines
4.9 KiB
Python
144 lines
4.9 KiB
Python
from datetime import datetime
|
|
import json
|
|
import logging
|
|
import numpy as np
|
|
import os
|
|
from six.moves.urllib.parse import urlparse
|
|
import time
|
|
|
|
try:
|
|
from smart_open import smart_open
|
|
except ImportError:
|
|
smart_open = None
|
|
|
|
from ray.air._internal.json import SafeFallbackEncoder
|
|
from ray.rllib.policy.sample_batch import MultiAgentBatch
|
|
from ray.rllib.offline.io_context import IOContext
|
|
from ray.rllib.offline.output_writer import OutputWriter
|
|
from ray.rllib.utils.annotations import override, PublicAPI
|
|
from ray.rllib.utils.compression import pack, compression_supported
|
|
from ray.rllib.utils.typing import FileType, SampleBatchType
|
|
from typing import Any, Dict, List
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
WINDOWS_DRIVES = [chr(i) for i in range(ord("c"), ord("z") + 1)]
|
|
|
|
|
|
# TODO(jungong) : use DatasetWriter to back JsonWriter, so we reduce
|
|
# codebase complexity without losing existing functionality.
|
|
@PublicAPI
|
|
class JsonWriter(OutputWriter):
|
|
"""Writer object that saves experiences in JSON file chunks."""
|
|
|
|
@PublicAPI
|
|
def __init__(
|
|
self,
|
|
path: str,
|
|
ioctx: IOContext = None,
|
|
max_file_size: int = 64 * 1024 * 1024,
|
|
compress_columns: List[str] = frozenset(["obs", "new_obs"]),
|
|
):
|
|
"""Initializes a JsonWriter instance.
|
|
|
|
Args:
|
|
path: a path/URI of the output directory to save files in.
|
|
ioctx: current IO context object.
|
|
max_file_size: max size of single files before rolling over.
|
|
compress_columns: list of sample batch columns to compress.
|
|
"""
|
|
logger.info(
|
|
"You are using JSONWriter. It is recommended to use "
|
|
+ "DatasetWriter instead."
|
|
)
|
|
|
|
self.ioctx = ioctx or IOContext()
|
|
self.max_file_size = max_file_size
|
|
self.compress_columns = compress_columns
|
|
if urlparse(path).scheme not in [""] + WINDOWS_DRIVES:
|
|
self.path_is_uri = True
|
|
else:
|
|
path = os.path.abspath(os.path.expanduser(path))
|
|
# Try to create local dirs if they don't exist
|
|
try:
|
|
os.makedirs(path)
|
|
except OSError:
|
|
pass # already exists
|
|
assert os.path.exists(path), "Failed to create {}".format(path)
|
|
self.path_is_uri = False
|
|
self.path = path
|
|
self.file_index = 0
|
|
self.bytes_written = 0
|
|
self.cur_file = None
|
|
|
|
@override(OutputWriter)
|
|
def write(self, sample_batch: SampleBatchType):
|
|
start = time.time()
|
|
data = _to_json(sample_batch, self.compress_columns)
|
|
f = self._get_file()
|
|
f.write(data)
|
|
f.write("\n")
|
|
if hasattr(f, "flush"): # legacy smart_open impls
|
|
f.flush()
|
|
self.bytes_written += len(data)
|
|
logger.debug(
|
|
"Wrote {} bytes to {} in {}s".format(len(data), f, time.time() - start)
|
|
)
|
|
|
|
def _get_file(self) -> FileType:
|
|
if not self.cur_file or self.bytes_written >= self.max_file_size:
|
|
if self.cur_file:
|
|
self.cur_file.close()
|
|
timestr = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
|
|
path = os.path.join(
|
|
self.path,
|
|
"output-{}_worker-{}_{}.json".format(
|
|
timestr, self.ioctx.worker_index, self.file_index
|
|
),
|
|
)
|
|
if self.path_is_uri:
|
|
if smart_open is None:
|
|
raise ValueError(
|
|
"You must install the `smart_open` module to write "
|
|
"to URIs like {}".format(path)
|
|
)
|
|
self.cur_file = smart_open(path, "w")
|
|
else:
|
|
self.cur_file = open(path, "w")
|
|
self.file_index += 1
|
|
self.bytes_written = 0
|
|
logger.info("Writing to new output file {}".format(self.cur_file))
|
|
return self.cur_file
|
|
|
|
|
|
def _to_jsonable(v, compress: bool) -> Any:
|
|
if compress and compression_supported():
|
|
return str(pack(v))
|
|
elif isinstance(v, np.ndarray):
|
|
return v.tolist()
|
|
return v
|
|
|
|
|
|
def _to_json_dict(batch: SampleBatchType, compress_columns: List[str]) -> Dict:
|
|
out = {}
|
|
if isinstance(batch, MultiAgentBatch):
|
|
out["type"] = "MultiAgentBatch"
|
|
out["count"] = batch.count
|
|
policy_batches = {}
|
|
for policy_id, sub_batch in batch.policy_batches.items():
|
|
policy_batches[policy_id] = {}
|
|
for k, v in sub_batch.items():
|
|
policy_batches[policy_id][k] = _to_jsonable(
|
|
v, compress=k in compress_columns
|
|
)
|
|
out["policy_batches"] = policy_batches
|
|
else:
|
|
out["type"] = "SampleBatch"
|
|
for k, v in batch.items():
|
|
out[k] = _to_jsonable(v, compress=k in compress_columns)
|
|
return out
|
|
|
|
|
|
def _to_json(batch: SampleBatchType, compress_columns: List[str]) -> str:
|
|
out = _to_json_dict(batch, compress_columns)
|
|
return json.dumps(out, cls=SafeFallbackEncoder)
|