mirror of
https://github.com/vale981/ray
synced 2025-03-04 17:41:43 -05:00
404 lines
15 KiB
Python
404 lines
15 KiB
Python
import glob
|
|
import json
|
|
import logging
|
|
import math
|
|
|
|
import numpy as np
|
|
import os
|
|
from pathlib import Path
|
|
import random
|
|
import re
|
|
import tree # pip install dm_tree
|
|
from typing import List, Optional, TYPE_CHECKING, Union
|
|
from urllib.parse import urlparse
|
|
import zipfile
|
|
|
|
try:
|
|
from smart_open import smart_open
|
|
except ImportError:
|
|
smart_open = None
|
|
|
|
from ray.rllib.offline.input_reader import InputReader
|
|
from ray.rllib.offline.io_context import IOContext
|
|
from ray.rllib.policy.policy import Policy
|
|
from ray.rllib.policy.sample_batch import (
|
|
DEFAULT_POLICY_ID,
|
|
MultiAgentBatch,
|
|
SampleBatch,
|
|
concat_samples,
|
|
)
|
|
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
|
|
from ray.rllib.utils.compression import unpack_if_needed
|
|
from ray.rllib.utils.spaces.space_utils import clip_action, normalize_action
|
|
from ray.rllib.utils.typing import Any, FileType, SampleBatchType
|
|
|
|
if TYPE_CHECKING:
|
|
from ray.rllib.evaluation import RolloutWorker
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
WINDOWS_DRIVES = [chr(i) for i in range(ord("c"), ord("z") + 1)]
|
|
|
|
|
|
def _adjust_obs_actions_for_policy(json_data: dict, policy: Policy) -> dict:
|
|
"""Handle nested action/observation spaces for policies.
|
|
|
|
Translates nested lists/dicts from the json into proper
|
|
np.ndarrays, according to the (nested) observation- and action-
|
|
spaces of the given policy.
|
|
|
|
Providing nested lists w/o this preprocessing step would
|
|
confuse a SampleBatch constructor.
|
|
"""
|
|
for k, v in json_data.items():
|
|
data_col = (
|
|
policy.view_requirements[k].data_col
|
|
if k in policy.view_requirements
|
|
else ""
|
|
)
|
|
if policy.config.get("_disable_action_flattening") and (
|
|
k == SampleBatch.ACTIONS
|
|
or data_col == SampleBatch.ACTIONS
|
|
or k == SampleBatch.PREV_ACTIONS
|
|
or data_col == SampleBatch.PREV_ACTIONS
|
|
):
|
|
json_data[k] = tree.map_structure_up_to(
|
|
policy.action_space_struct,
|
|
lambda comp: np.array(comp),
|
|
json_data[k],
|
|
check_types=False,
|
|
)
|
|
elif policy.config.get("_disable_preprocessor_api") and (
|
|
k == SampleBatch.OBS
|
|
or data_col == SampleBatch.OBS
|
|
or k == SampleBatch.NEXT_OBS
|
|
or data_col == SampleBatch.NEXT_OBS
|
|
):
|
|
json_data[k] = tree.map_structure_up_to(
|
|
policy.observation_space_struct,
|
|
lambda comp: np.array(comp),
|
|
json_data[k],
|
|
check_types=False,
|
|
)
|
|
return json_data
|
|
|
|
|
|
@DeveloperAPI
|
|
def postprocess_actions(batch: SampleBatchType, ioctx: IOContext) -> SampleBatchType:
|
|
# Clip actions (from any values into env's bounds), if necessary.
|
|
cfg = ioctx.config
|
|
# TODO(jungong) : we should not clip_action in input reader.
|
|
# Use connector to handle this.
|
|
if cfg.get("clip_actions"):
|
|
if ioctx.worker is None:
|
|
raise ValueError(
|
|
"clip_actions is True but cannot clip actions since no workers exist"
|
|
)
|
|
|
|
if isinstance(batch, SampleBatch):
|
|
default_policy = ioctx.worker.policy_map.get(DEFAULT_POLICY_ID)
|
|
batch[SampleBatch.ACTIONS] = clip_action(
|
|
batch[SampleBatch.ACTIONS], default_policy.action_space_struct
|
|
)
|
|
else:
|
|
for pid, b in batch.policy_batches.items():
|
|
b[SampleBatch.ACTIONS] = clip_action(
|
|
b[SampleBatch.ACTIONS],
|
|
ioctx.worker.policy_map[pid].action_space_struct,
|
|
)
|
|
# Re-normalize actions (from env's bounds to zero-centered), if
|
|
# necessary.
|
|
if (
|
|
cfg.get("actions_in_input_normalized") is False
|
|
and cfg.get("normalize_actions") is True
|
|
):
|
|
if ioctx.worker is None:
|
|
raise ValueError(
|
|
"actions_in_input_normalized is False but"
|
|
"cannot normalize actions since no workers exist"
|
|
)
|
|
|
|
# If we have a complex action space and actions were flattened
|
|
# and we have to normalize -> Error.
|
|
error_msg = (
|
|
"Normalization of offline actions that are flattened is not "
|
|
"supported! Make sure that you record actions into offline "
|
|
"file with the `_disable_action_flattening=True` flag OR "
|
|
"as already normalized (between -1.0 and 1.0) values. "
|
|
"Also, when reading already normalized action values from "
|
|
"offline files, make sure to set "
|
|
"`actions_in_input_normalized=True` so that RLlib will not "
|
|
"perform normalization on top."
|
|
)
|
|
|
|
if isinstance(batch, SampleBatch):
|
|
pol = ioctx.worker.policy_map.get(DEFAULT_POLICY_ID)
|
|
if isinstance(
|
|
pol.action_space_struct, (tuple, dict)
|
|
) and not pol.config.get("_disable_action_flattening"):
|
|
raise ValueError(error_msg)
|
|
batch[SampleBatch.ACTIONS] = normalize_action(
|
|
batch[SampleBatch.ACTIONS], pol.action_space_struct
|
|
)
|
|
else:
|
|
for pid, b in batch.policy_batches.items():
|
|
pol = ioctx.worker.policy_map[pid]
|
|
if isinstance(
|
|
pol.action_space_struct, (tuple, dict)
|
|
) and not pol.config.get("_disable_action_flattening"):
|
|
raise ValueError(error_msg)
|
|
b[SampleBatch.ACTIONS] = normalize_action(
|
|
b[SampleBatch.ACTIONS],
|
|
ioctx.worker.policy_map[pid].action_space_struct,
|
|
)
|
|
|
|
return batch
|
|
|
|
|
|
@DeveloperAPI
|
|
def from_json_data(json_data: Any, worker: Optional["RolloutWorker"]):
|
|
# Try to infer the SampleBatchType (SampleBatch or MultiAgentBatch).
|
|
if "type" in json_data:
|
|
data_type = json_data.pop("type")
|
|
else:
|
|
raise ValueError("JSON record missing 'type' field")
|
|
|
|
if data_type == "SampleBatch":
|
|
if worker is not None and len(worker.policy_map) != 1:
|
|
raise ValueError(
|
|
"Found single-agent SampleBatch in input file, but our "
|
|
"PolicyMap contains more than 1 policy!"
|
|
)
|
|
for k, v in json_data.items():
|
|
json_data[k] = unpack_if_needed(v)
|
|
if worker is not None:
|
|
policy = next(iter(worker.policy_map.values()))
|
|
json_data = _adjust_obs_actions_for_policy(json_data, policy)
|
|
return SampleBatch(json_data)
|
|
elif data_type == "MultiAgentBatch":
|
|
policy_batches = {}
|
|
for policy_id, policy_batch in json_data["policy_batches"].items():
|
|
inner = {}
|
|
for k, v in policy_batch.items():
|
|
inner[k] = unpack_if_needed(v)
|
|
if worker is not None:
|
|
policy = worker.policy_map[policy_id]
|
|
inner = _adjust_obs_actions_for_policy(inner, policy)
|
|
policy_batches[policy_id] = SampleBatch(inner)
|
|
return MultiAgentBatch(policy_batches, json_data["count"])
|
|
else:
|
|
raise ValueError(
|
|
"Type field must be one of ['SampleBatch', 'MultiAgentBatch']", data_type
|
|
)
|
|
|
|
|
|
# TODO(jungong) : use DatasetReader to back JsonReader, so we reduce
|
|
# codebase complexity without losing existing functionality.
|
|
@PublicAPI
|
|
class JsonReader(InputReader):
|
|
"""Reader object that loads experiences from JSON file chunks.
|
|
|
|
The input files will be read from in random order.
|
|
"""
|
|
|
|
@PublicAPI
|
|
def __init__(
|
|
self, inputs: Union[str, List[str]], ioctx: Optional[IOContext] = None
|
|
):
|
|
"""Initializes a JsonReader instance.
|
|
|
|
Args:
|
|
inputs: Either a glob expression for files, e.g. `/tmp/**/*.json`,
|
|
or a list of single file paths or URIs, e.g.,
|
|
["s3://bucket/file.json", "s3://bucket/file2.json"].
|
|
ioctx: Current IO context object or None.
|
|
"""
|
|
logger.info(
|
|
"You are using JSONReader. It is recommended to use "
|
|
+ "DatasetReader instead for better sharding support."
|
|
)
|
|
|
|
self.ioctx = ioctx or IOContext()
|
|
self.default_policy = self.policy_map = None
|
|
self.batch_size = 1
|
|
if self.ioctx:
|
|
self.batch_size = self.ioctx.config.get("train_batch_size", 1)
|
|
num_workers = self.ioctx.config.get("num_workers", 0)
|
|
if num_workers:
|
|
self.batch_size = max(math.ceil(self.batch_size / num_workers), 1)
|
|
|
|
if self.ioctx.worker is not None:
|
|
self.policy_map = self.ioctx.worker.policy_map
|
|
self.default_policy = self.policy_map.get(DEFAULT_POLICY_ID)
|
|
|
|
if isinstance(inputs, str):
|
|
inputs = os.path.abspath(os.path.expanduser(inputs))
|
|
if os.path.isdir(inputs):
|
|
inputs = [os.path.join(inputs, "*.json"), os.path.join(inputs, "*.zip")]
|
|
logger.warning(f"Treating input directory as glob patterns: {inputs}")
|
|
else:
|
|
inputs = [inputs]
|
|
|
|
if any(urlparse(i).scheme not in [""] + WINDOWS_DRIVES for i in inputs):
|
|
raise ValueError(
|
|
"Don't know how to glob over `{}`, ".format(inputs)
|
|
+ "please specify a list of files to read instead."
|
|
)
|
|
else:
|
|
self.files = []
|
|
for i in inputs:
|
|
self.files.extend(glob.glob(i))
|
|
elif isinstance(inputs, (list, tuple)):
|
|
self.files = list(inputs)
|
|
else:
|
|
raise ValueError(
|
|
"type of inputs must be list or str, not {}".format(inputs)
|
|
)
|
|
if self.files:
|
|
logger.info("Found {} input files.".format(len(self.files)))
|
|
else:
|
|
raise ValueError("No files found matching {}".format(inputs))
|
|
self.cur_file = None
|
|
|
|
@override(InputReader)
|
|
def next(self) -> SampleBatchType:
|
|
ret = []
|
|
count = 0
|
|
while count < self.batch_size:
|
|
batch = self._try_parse(self._next_line())
|
|
tries = 0
|
|
while not batch and tries < 100:
|
|
tries += 1
|
|
logger.debug("Skipping empty line in {}".format(self.cur_file))
|
|
batch = self._try_parse(self._next_line())
|
|
if not batch:
|
|
raise ValueError(
|
|
"Failed to read valid experience batch from file: {}".format(
|
|
self.cur_file
|
|
)
|
|
)
|
|
batch = self._postprocess_if_needed(batch)
|
|
count += batch.count
|
|
ret.append(batch)
|
|
ret = concat_samples(ret)
|
|
return ret
|
|
|
|
def read_all_files(self) -> SampleBatchType:
|
|
"""Reads through all files and yields one SampleBatchType per line.
|
|
|
|
When reaching the end of the last file, will start from the beginning
|
|
again.
|
|
|
|
Yields:
|
|
One SampleBatch or MultiAgentBatch per line in all input files.
|
|
"""
|
|
for path in self.files:
|
|
file = self._try_open_file(path)
|
|
while True:
|
|
line = file.readline()
|
|
if not line:
|
|
break
|
|
batch = self._try_parse(line)
|
|
if batch is None:
|
|
break
|
|
yield batch
|
|
|
|
def _postprocess_if_needed(self, batch: SampleBatchType) -> SampleBatchType:
|
|
if not self.ioctx.config.get("postprocess_inputs"):
|
|
return batch
|
|
|
|
if isinstance(batch, SampleBatch):
|
|
out = []
|
|
for sub_batch in batch.split_by_episode():
|
|
out.append(self.default_policy.postprocess_trajectory(sub_batch))
|
|
return concat_samples(out)
|
|
else:
|
|
# TODO(ekl) this is trickier since the alignments between agent
|
|
# trajectories in the episode are not available any more.
|
|
raise NotImplementedError(
|
|
"Postprocessing of multi-agent data not implemented yet."
|
|
)
|
|
|
|
def _try_open_file(self, path):
|
|
if urlparse(path).scheme not in [""] + WINDOWS_DRIVES:
|
|
if smart_open is None:
|
|
raise ValueError(
|
|
"You must install the `smart_open` module to read "
|
|
"from URIs like {}".format(path)
|
|
)
|
|
ctx = smart_open
|
|
else:
|
|
# Allow shortcut for home directory ("~/" -> env[HOME]).
|
|
if path.startswith("~/"):
|
|
path = os.path.join(os.environ.get("HOME", ""), path[2:])
|
|
|
|
# If path doesn't exist, try to interpret is as relative to the
|
|
# rllib directory (located ../../ from this very module).
|
|
path_orig = path
|
|
if not os.path.exists(path):
|
|
path = os.path.join(Path(__file__).parent.parent, path)
|
|
if not os.path.exists(path):
|
|
raise FileNotFoundError(f"Offline file {path_orig} not found!")
|
|
|
|
# Unzip files, if necessary and re-point to extracted json file.
|
|
if re.search("\\.zip$", path):
|
|
with zipfile.ZipFile(path, "r") as zip_ref:
|
|
zip_ref.extractall(Path(path).parent)
|
|
path = re.sub("\\.zip$", ".json", path)
|
|
assert os.path.exists(path)
|
|
ctx = open
|
|
file = ctx(path, "r")
|
|
return file
|
|
|
|
def _try_parse(self, line: str) -> Optional[SampleBatchType]:
|
|
line = line.strip()
|
|
if not line:
|
|
return None
|
|
try:
|
|
batch = self._from_json(line)
|
|
except Exception:
|
|
logger.exception(
|
|
"Ignoring corrupt json record in {}: {}".format(self.cur_file, line)
|
|
)
|
|
return None
|
|
|
|
batch = postprocess_actions(batch, self.ioctx)
|
|
|
|
return batch
|
|
|
|
def _next_line(self) -> str:
|
|
if not self.cur_file:
|
|
self.cur_file = self._next_file()
|
|
line = self.cur_file.readline()
|
|
tries = 0
|
|
while not line and tries < 100:
|
|
tries += 1
|
|
if hasattr(self.cur_file, "close"): # legacy smart_open impls
|
|
self.cur_file.close()
|
|
self.cur_file = self._next_file()
|
|
line = self.cur_file.readline()
|
|
if not line:
|
|
logger.debug("Ignoring empty file {}".format(self.cur_file))
|
|
if not line:
|
|
raise ValueError(
|
|
"Failed to read next line from files: {}".format(self.files)
|
|
)
|
|
return line
|
|
|
|
def _next_file(self) -> FileType:
|
|
# If this is the first time, we open a file, make sure all workers
|
|
# start with a different one if possible.
|
|
if self.cur_file is None and self.ioctx.worker is not None:
|
|
idx = self.ioctx.worker.worker_index
|
|
total = self.ioctx.worker.num_workers or 1
|
|
path = self.files[round((len(self.files) - 1) * (idx / total))]
|
|
# After the first file, pick all others randomly.
|
|
else:
|
|
path = random.choice(self.files)
|
|
return self._try_open_file(path)
|
|
|
|
def _from_json(self, data: str) -> SampleBatchType:
|
|
if isinstance(data, bytes): # smart_open S3 doesn't respect "r"
|
|
data = data.decode("utf-8")
|
|
json_data = json.loads(data)
|
|
return from_json_data(json_data, self.ioctx.worker)
|