mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[RLlib, Offline] Make the dataset and json readers batchable (#26055)
Make the dataset and json readers batchable.
This commit is contained in:
parent
5043cc1a82
commit
1f9282a496
5 changed files with 138 additions and 19 deletions
15
rllib/BUILD
15
rllib/BUILD
|
@ -1647,6 +1647,21 @@ py_test(
|
||||||
srcs = ["offline/estimators/tests/test_ope.py"]
|
srcs = ["offline/estimators/tests/test_ope.py"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "test_json_reader",
|
||||||
|
tags = ["team:rllib", "offline"],
|
||||||
|
size = "small",
|
||||||
|
srcs = ["offline/tests/test_json_reader.py"]
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "test_dataset_reader",
|
||||||
|
tags = ["team:rllib", "offline"],
|
||||||
|
size = "small",
|
||||||
|
srcs = ["offline/tests/test_dataset_reader.py"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------------------------------
|
# --------------------------------------------------------------------
|
||||||
# Policies
|
# Policies
|
||||||
# rllib/policy/
|
# rllib/policy/
|
||||||
|
|
|
@ -5,10 +5,12 @@ import ray.data
|
||||||
from ray.rllib.offline.input_reader import InputReader
|
from ray.rllib.offline.input_reader import InputReader
|
||||||
from ray.rllib.offline.io_context import IOContext
|
from ray.rllib.offline.io_context import IOContext
|
||||||
from ray.rllib.offline.json_reader import from_json_data
|
from ray.rllib.offline.json_reader import from_json_data
|
||||||
|
from ray.rllib.policy.sample_batch import concat_samples
|
||||||
from ray.rllib.utils.annotations import override, PublicAPI
|
from ray.rllib.utils.annotations import override, PublicAPI
|
||||||
from ray.rllib.utils.typing import SampleBatchType, AlgorithmConfigDict
|
from ray.rllib.utils.typing import SampleBatchType, AlgorithmConfigDict
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
DEFAULT_NUM_CPUS_PER_TASK = 0.5
|
DEFAULT_NUM_CPUS_PER_TASK = 0.5
|
||||||
|
@ -113,6 +115,12 @@ class DatasetReader(InputReader):
|
||||||
"""
|
"""
|
||||||
self._ioctx = ioctx
|
self._ioctx = ioctx
|
||||||
self._dataset = ds
|
self._dataset = ds
|
||||||
|
# the number of rows to return per call to next()
|
||||||
|
if self._ioctx:
|
||||||
|
self.batch_size = ioctx.config.get("train_batch_size", 1)
|
||||||
|
num_workers = ioctx.config.get("num_workers", 0)
|
||||||
|
if num_workers:
|
||||||
|
self.batch_size = max(math.ceil(self.batch_size / num_workers), 1)
|
||||||
# We allow the creation of a non-functioning None DatasetReader.
|
# We allow the creation of a non-functioning None DatasetReader.
|
||||||
# It's useful for example for a non-rollout local worker.
|
# It's useful for example for a non-rollout local worker.
|
||||||
if ds:
|
if ds:
|
||||||
|
@ -127,9 +135,13 @@ class DatasetReader(InputReader):
|
||||||
def next(self) -> SampleBatchType:
|
def next(self) -> SampleBatchType:
|
||||||
# next() should not get called on None DatasetReader.
|
# next() should not get called on None DatasetReader.
|
||||||
assert self._iter is not None
|
assert self._iter is not None
|
||||||
|
ret = []
|
||||||
d = next(self._iter).as_pydict()
|
count = 0
|
||||||
# Columns like obs are compressed when written by DatasetWriter.
|
while count < self.batch_size:
|
||||||
d = from_json_data(d, self._ioctx.worker)
|
d = next(self._iter).as_pydict()
|
||||||
|
# Columns like obs are compressed when written by DatasetWriter.
|
||||||
return d
|
d = from_json_data(d, self._ioctx.worker)
|
||||||
|
count += d.count
|
||||||
|
ret.append(d)
|
||||||
|
ret = concat_samples(ret)
|
||||||
|
return ret
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
import glob
|
import glob
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -23,6 +25,7 @@ from ray.rllib.policy.sample_batch import (
|
||||||
DEFAULT_POLICY_ID,
|
DEFAULT_POLICY_ID,
|
||||||
MultiAgentBatch,
|
MultiAgentBatch,
|
||||||
SampleBatch,
|
SampleBatch,
|
||||||
|
concat_samples,
|
||||||
)
|
)
|
||||||
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
|
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
|
||||||
from ray.rllib.utils.compression import unpack_if_needed
|
from ray.rllib.utils.compression import unpack_if_needed
|
||||||
|
@ -145,6 +148,13 @@ class JsonReader(InputReader):
|
||||||
|
|
||||||
self.ioctx = ioctx or IOContext()
|
self.ioctx = ioctx or IOContext()
|
||||||
self.default_policy = self.policy_map = None
|
self.default_policy = self.policy_map = None
|
||||||
|
self.batch_size = 1
|
||||||
|
if self.ioctx:
|
||||||
|
self.batch_size = self.ioctx.config.get("train_batch_size", 1)
|
||||||
|
num_workers = self.ioctx.config.get("num_workers", 0)
|
||||||
|
if num_workers:
|
||||||
|
self.batch_size = max(math.ceil(self.batch_size / num_workers), 1)
|
||||||
|
|
||||||
if self.ioctx.worker is not None:
|
if self.ioctx.worker is not None:
|
||||||
self.policy_map = self.ioctx.worker.policy_map
|
self.policy_map = self.ioctx.worker.policy_map
|
||||||
self.default_policy = self.policy_map.get(DEFAULT_POLICY_ID)
|
self.default_policy = self.policy_map.get(DEFAULT_POLICY_ID)
|
||||||
|
@ -180,20 +190,26 @@ class JsonReader(InputReader):
|
||||||
|
|
||||||
@override(InputReader)
|
@override(InputReader)
|
||||||
def next(self) -> SampleBatchType:
|
def next(self) -> SampleBatchType:
|
||||||
batch = self._try_parse(self._next_line())
|
ret = []
|
||||||
tries = 0
|
count = 0
|
||||||
while not batch and tries < 100:
|
while count < self.batch_size:
|
||||||
tries += 1
|
|
||||||
logger.debug("Skipping empty line in {}".format(self.cur_file))
|
|
||||||
batch = self._try_parse(self._next_line())
|
batch = self._try_parse(self._next_line())
|
||||||
if not batch:
|
tries = 0
|
||||||
raise ValueError(
|
while not batch and tries < 100:
|
||||||
"Failed to read valid experience batch from file: {}".format(
|
tries += 1
|
||||||
self.cur_file
|
logger.debug("Skipping empty line in {}".format(self.cur_file))
|
||||||
|
batch = self._try_parse(self._next_line())
|
||||||
|
if not batch:
|
||||||
|
raise ValueError(
|
||||||
|
"Failed to read valid experience batch from file: {}".format(
|
||||||
|
self.cur_file
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
batch = self._postprocess_if_needed(batch)
|
||||||
|
count += batch.count
|
||||||
return self._postprocess_if_needed(batch)
|
ret.append(batch)
|
||||||
|
ret = concat_samples(ret)
|
||||||
|
return ret
|
||||||
|
|
||||||
def read_all_files(self) -> SampleBatchType:
|
def read_all_files(self) -> SampleBatchType:
|
||||||
"""Reads through all files and yields one SampleBatchType per line.
|
"""Reads through all files and yields one SampleBatchType per line.
|
||||||
|
@ -223,7 +239,7 @@ class JsonReader(InputReader):
|
||||||
out = []
|
out = []
|
||||||
for sub_batch in batch.split_by_episode():
|
for sub_batch in batch.split_by_episode():
|
||||||
out.append(self.default_policy.postprocess_trajectory(sub_batch))
|
out.append(self.default_policy.postprocess_trajectory(sub_batch))
|
||||||
return SampleBatch.concat_samples(out)
|
return concat_samples(out)
|
||||||
else:
|
else:
|
||||||
# TODO(ekl) this is trickier since the alignments between agent
|
# TODO(ekl) this is trickier since the alignments between agent
|
||||||
# trajectories in the episode are not available any more.
|
# trajectories in the episode are not available any more.
|
||||||
|
|
40
rllib/offline/tests/test_dataset_reader.py
Normal file
40
rllib/offline/tests/test_dataset_reader.py
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import ray
|
||||||
|
from ray.rllib.offline import IOContext
|
||||||
|
from ray.rllib.offline.dataset_reader import DatasetReader, get_dataset_and_shards
|
||||||
|
|
||||||
|
|
||||||
|
class TestDatasetReader(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls) -> None:
|
||||||
|
ray.init()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls) -> None:
|
||||||
|
ray.shutdown()
|
||||||
|
|
||||||
|
def test_itr_batches(self):
|
||||||
|
"""Test that the json reader iterates over batches of rows correctly."""
|
||||||
|
rllib_dir = Path(__file__).parent.parent.parent.parent
|
||||||
|
print("rllib dir={}".format(rllib_dir))
|
||||||
|
data_file = os.path.join(rllib_dir, "rllib/tests/data/pendulum/large.json")
|
||||||
|
print("data_file={} exists={}".format(data_file, os.path.isfile(data_file)))
|
||||||
|
input_config = {"format": "json", "path": data_file}
|
||||||
|
dataset, _ = get_dataset_and_shards(
|
||||||
|
{"input": "dataset", "input_config": input_config}, 0, True
|
||||||
|
)
|
||||||
|
|
||||||
|
ioctx = IOContext(config={"train_batch_size": 1200}, worker_index=0)
|
||||||
|
reader = DatasetReader(ioctx, dataset)
|
||||||
|
assert len(reader.next()) == 1200
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
sys.exit(pytest.main(["-v", __file__]))
|
36
rllib/offline/tests/test_json_reader.py
Normal file
36
rllib/offline/tests/test_json_reader.py
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import ray
|
||||||
|
from ray.rllib.offline import IOContext
|
||||||
|
from ray.rllib.offline.json_reader import JsonReader
|
||||||
|
|
||||||
|
|
||||||
|
class TestJsonReader(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls) -> None:
|
||||||
|
ray.init()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls) -> None:
|
||||||
|
ray.shutdown()
|
||||||
|
|
||||||
|
def test_itr_batches(self):
|
||||||
|
"""Test that the json reader iterates over batches of rows correctly."""
|
||||||
|
rllib_dir = Path(__file__).parent.parent.parent.parent
|
||||||
|
print("rllib dir={}".format(rllib_dir))
|
||||||
|
data_file = os.path.join(rllib_dir, "rllib/tests/data/pendulum/large.json")
|
||||||
|
print("data_file={} exists={}".format(data_file, os.path.isfile(data_file)))
|
||||||
|
|
||||||
|
ioctx = IOContext(config={"train_batch_size": 1200}, worker_index=0)
|
||||||
|
reader = JsonReader([data_file], ioctx)
|
||||||
|
assert len(reader.next()) == 1200
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
sys.exit(pytest.main(["-v", __file__]))
|
Loading…
Add table
Reference in a new issue