[RLlib, Offline] Make the dataset and json readers batchable (#26055)

Make the dataset and json readers batchable.
This commit is contained in:
Avnish Narayan 2022-06-29 14:52:40 -04:00 committed by GitHub
parent 5043cc1a82
commit 1f9282a496
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 138 additions and 19 deletions

View file

@ -1647,6 +1647,21 @@ py_test(
srcs = ["offline/estimators/tests/test_ope.py"] srcs = ["offline/estimators/tests/test_ope.py"]
) )
py_test(
name = "test_json_reader",
tags = ["team:rllib", "offline"],
size = "small",
srcs = ["offline/tests/test_json_reader.py"]
)
py_test(
name = "test_dataset_reader",
tags = ["team:rllib", "offline"],
size = "small",
srcs = ["offline/tests/test_dataset_reader.py"]
)
# -------------------------------------------------------------------- # --------------------------------------------------------------------
# Policies # Policies
# rllib/policy/ # rllib/policy/

View file

@ -5,10 +5,12 @@ import ray.data
from ray.rllib.offline.input_reader import InputReader from ray.rllib.offline.input_reader import InputReader
from ray.rllib.offline.io_context import IOContext from ray.rllib.offline.io_context import IOContext
from ray.rllib.offline.json_reader import from_json_data from ray.rllib.offline.json_reader import from_json_data
from ray.rllib.policy.sample_batch import concat_samples
from ray.rllib.utils.annotations import override, PublicAPI from ray.rllib.utils.annotations import override, PublicAPI
from ray.rllib.utils.typing import SampleBatchType, AlgorithmConfigDict from ray.rllib.utils.typing import SampleBatchType, AlgorithmConfigDict
from typing import List from typing import List
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_NUM_CPUS_PER_TASK = 0.5 DEFAULT_NUM_CPUS_PER_TASK = 0.5
@ -113,6 +115,12 @@ class DatasetReader(InputReader):
""" """
self._ioctx = ioctx self._ioctx = ioctx
self._dataset = ds self._dataset = ds
# the number of rows to return per call to next()
if self._ioctx:
self.batch_size = ioctx.config.get("train_batch_size", 1)
num_workers = ioctx.config.get("num_workers", 0)
if num_workers:
self.batch_size = max(math.ceil(self.batch_size / num_workers), 1)
# We allow the creation of a non-functioning None DatasetReader. # We allow the creation of a non-functioning None DatasetReader.
# It's useful for example for a non-rollout local worker. # It's useful for example for a non-rollout local worker.
if ds: if ds:
@ -127,9 +135,13 @@ class DatasetReader(InputReader):
def next(self) -> SampleBatchType: def next(self) -> SampleBatchType:
# next() should not get called on None DatasetReader. # next() should not get called on None DatasetReader.
assert self._iter is not None assert self._iter is not None
ret = []
d = next(self._iter).as_pydict() count = 0
# Columns like obs are compressed when written by DatasetWriter. while count < self.batch_size:
d = from_json_data(d, self._ioctx.worker) d = next(self._iter).as_pydict()
# Columns like obs are compressed when written by DatasetWriter.
return d d = from_json_data(d, self._ioctx.worker)
count += d.count
ret.append(d)
ret = concat_samples(ret)
return ret

View file

@ -1,6 +1,8 @@
import glob import glob
import json import json
import logging import logging
import math
import numpy as np import numpy as np
import os import os
from pathlib import Path from pathlib import Path
@ -23,6 +25,7 @@ from ray.rllib.policy.sample_batch import (
DEFAULT_POLICY_ID, DEFAULT_POLICY_ID,
MultiAgentBatch, MultiAgentBatch,
SampleBatch, SampleBatch,
concat_samples,
) )
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
from ray.rllib.utils.compression import unpack_if_needed from ray.rllib.utils.compression import unpack_if_needed
@ -145,6 +148,13 @@ class JsonReader(InputReader):
self.ioctx = ioctx or IOContext() self.ioctx = ioctx or IOContext()
self.default_policy = self.policy_map = None self.default_policy = self.policy_map = None
self.batch_size = 1
if self.ioctx:
self.batch_size = self.ioctx.config.get("train_batch_size", 1)
num_workers = self.ioctx.config.get("num_workers", 0)
if num_workers:
self.batch_size = max(math.ceil(self.batch_size / num_workers), 1)
if self.ioctx.worker is not None: if self.ioctx.worker is not None:
self.policy_map = self.ioctx.worker.policy_map self.policy_map = self.ioctx.worker.policy_map
self.default_policy = self.policy_map.get(DEFAULT_POLICY_ID) self.default_policy = self.policy_map.get(DEFAULT_POLICY_ID)
@ -180,20 +190,26 @@ class JsonReader(InputReader):
@override(InputReader) @override(InputReader)
def next(self) -> SampleBatchType: def next(self) -> SampleBatchType:
batch = self._try_parse(self._next_line()) ret = []
tries = 0 count = 0
while not batch and tries < 100: while count < self.batch_size:
tries += 1
logger.debug("Skipping empty line in {}".format(self.cur_file))
batch = self._try_parse(self._next_line()) batch = self._try_parse(self._next_line())
if not batch: tries = 0
raise ValueError( while not batch and tries < 100:
"Failed to read valid experience batch from file: {}".format( tries += 1
self.cur_file logger.debug("Skipping empty line in {}".format(self.cur_file))
batch = self._try_parse(self._next_line())
if not batch:
raise ValueError(
"Failed to read valid experience batch from file: {}".format(
self.cur_file
)
) )
) batch = self._postprocess_if_needed(batch)
count += batch.count
return self._postprocess_if_needed(batch) ret.append(batch)
ret = concat_samples(ret)
return ret
def read_all_files(self) -> SampleBatchType: def read_all_files(self) -> SampleBatchType:
"""Reads through all files and yields one SampleBatchType per line. """Reads through all files and yields one SampleBatchType per line.
@ -223,7 +239,7 @@ class JsonReader(InputReader):
out = [] out = []
for sub_batch in batch.split_by_episode(): for sub_batch in batch.split_by_episode():
out.append(self.default_policy.postprocess_trajectory(sub_batch)) out.append(self.default_policy.postprocess_trajectory(sub_batch))
return SampleBatch.concat_samples(out) return concat_samples(out)
else: else:
# TODO(ekl) this is trickier since the alignments between agent # TODO(ekl) this is trickier since the alignments between agent
# trajectories in the episode are not available any more. # trajectories in the episode are not available any more.

View file

@ -0,0 +1,40 @@
import os
import unittest
from pathlib import Path
import ray
from ray.rllib.offline import IOContext
from ray.rllib.offline.dataset_reader import DatasetReader, get_dataset_and_shards
class TestDatasetReader(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
ray.init()
@classmethod
def tearDownClass(cls) -> None:
ray.shutdown()
def test_itr_batches(self):
"""Test that the json reader iterates over batches of rows correctly."""
rllib_dir = Path(__file__).parent.parent.parent.parent
print("rllib dir={}".format(rllib_dir))
data_file = os.path.join(rllib_dir, "rllib/tests/data/pendulum/large.json")
print("data_file={} exists={}".format(data_file, os.path.isfile(data_file)))
input_config = {"format": "json", "path": data_file}
dataset, _ = get_dataset_and_shards(
{"input": "dataset", "input_config": input_config}, 0, True
)
ioctx = IOContext(config={"train_batch_size": 1200}, worker_index=0)
reader = DatasetReader(ioctx, dataset)
assert len(reader.next()) == 1200
if __name__ == "__main__":
import sys
import pytest
sys.exit(pytest.main(["-v", __file__]))

View file

@ -0,0 +1,36 @@
import os
import unittest
from pathlib import Path
import ray
from ray.rllib.offline import IOContext
from ray.rllib.offline.json_reader import JsonReader
class TestJsonReader(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
ray.init()
@classmethod
def tearDownClass(cls) -> None:
ray.shutdown()
def test_itr_batches(self):
"""Test that the json reader iterates over batches of rows correctly."""
rllib_dir = Path(__file__).parent.parent.parent.parent
print("rllib dir={}".format(rllib_dir))
data_file = os.path.join(rllib_dir, "rllib/tests/data/pendulum/large.json")
print("data_file={} exists={}".format(data_file, os.path.isfile(data_file)))
ioctx = IOContext(config={"train_batch_size": 1200}, worker_index=0)
reader = JsonReader([data_file], ioctx)
assert len(reader.next()) == 1200
if __name__ == "__main__":
import sys
import pytest
sys.exit(pytest.main(["-v", __file__]))