[RLlib, Offline] Make the dataset and json readers batchable (#26055)

Make the dataset and json readers batchable.
This commit is contained in:
Avnish Narayan 2022-06-29 14:52:40 -04:00 committed by GitHub
parent 5043cc1a82
commit 1f9282a496
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 138 additions and 19 deletions

View file

@ -1647,6 +1647,21 @@ py_test(
srcs = ["offline/estimators/tests/test_ope.py"]
)
py_test(
name = "test_json_reader",
tags = ["team:rllib", "offline"],
size = "small",
srcs = ["offline/tests/test_json_reader.py"]
)
py_test(
name = "test_dataset_reader",
tags = ["team:rllib", "offline"],
size = "small",
srcs = ["offline/tests/test_dataset_reader.py"]
)
# --------------------------------------------------------------------
# Policies
# rllib/policy/

View file

@ -5,10 +5,12 @@ import ray.data
from ray.rllib.offline.input_reader import InputReader
from ray.rllib.offline.io_context import IOContext
from ray.rllib.offline.json_reader import from_json_data
from ray.rllib.policy.sample_batch import concat_samples
from ray.rllib.utils.annotations import override, PublicAPI
from ray.rllib.utils.typing import SampleBatchType, AlgorithmConfigDict
from typing import List
logger = logging.getLogger(__name__)
DEFAULT_NUM_CPUS_PER_TASK = 0.5
@ -113,6 +115,12 @@ class DatasetReader(InputReader):
"""
self._ioctx = ioctx
self._dataset = ds
# the number of rows to return per call to next()
if self._ioctx:
self.batch_size = ioctx.config.get("train_batch_size", 1)
num_workers = ioctx.config.get("num_workers", 0)
if num_workers:
self.batch_size = max(math.ceil(self.batch_size / num_workers), 1)
# We allow the creation of a non-functioning None DatasetReader.
# It's useful for example for a non-rollout local worker.
if ds:
@ -127,9 +135,13 @@ class DatasetReader(InputReader):
def next(self) -> SampleBatchType:
# next() should not get called on None DatasetReader.
assert self._iter is not None
d = next(self._iter).as_pydict()
# Columns like obs are compressed when written by DatasetWriter.
d = from_json_data(d, self._ioctx.worker)
return d
ret = []
count = 0
while count < self.batch_size:
d = next(self._iter).as_pydict()
# Columns like obs are compressed when written by DatasetWriter.
d = from_json_data(d, self._ioctx.worker)
count += d.count
ret.append(d)
ret = concat_samples(ret)
return ret

View file

@ -1,6 +1,8 @@
import glob
import json
import logging
import math
import numpy as np
import os
from pathlib import Path
@ -23,6 +25,7 @@ from ray.rllib.policy.sample_batch import (
DEFAULT_POLICY_ID,
MultiAgentBatch,
SampleBatch,
concat_samples,
)
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
from ray.rllib.utils.compression import unpack_if_needed
@ -145,6 +148,13 @@ class JsonReader(InputReader):
self.ioctx = ioctx or IOContext()
self.default_policy = self.policy_map = None
self.batch_size = 1
if self.ioctx:
self.batch_size = self.ioctx.config.get("train_batch_size", 1)
num_workers = self.ioctx.config.get("num_workers", 0)
if num_workers:
self.batch_size = max(math.ceil(self.batch_size / num_workers), 1)
if self.ioctx.worker is not None:
self.policy_map = self.ioctx.worker.policy_map
self.default_policy = self.policy_map.get(DEFAULT_POLICY_ID)
@ -180,20 +190,26 @@ class JsonReader(InputReader):
@override(InputReader)
def next(self) -> SampleBatchType:
batch = self._try_parse(self._next_line())
tries = 0
while not batch and tries < 100:
tries += 1
logger.debug("Skipping empty line in {}".format(self.cur_file))
ret = []
count = 0
while count < self.batch_size:
batch = self._try_parse(self._next_line())
if not batch:
raise ValueError(
"Failed to read valid experience batch from file: {}".format(
self.cur_file
tries = 0
while not batch and tries < 100:
tries += 1
logger.debug("Skipping empty line in {}".format(self.cur_file))
batch = self._try_parse(self._next_line())
if not batch:
raise ValueError(
"Failed to read valid experience batch from file: {}".format(
self.cur_file
)
)
)
return self._postprocess_if_needed(batch)
batch = self._postprocess_if_needed(batch)
count += batch.count
ret.append(batch)
ret = concat_samples(ret)
return ret
def read_all_files(self) -> SampleBatchType:
"""Reads through all files and yields one SampleBatchType per line.
@ -223,7 +239,7 @@ class JsonReader(InputReader):
out = []
for sub_batch in batch.split_by_episode():
out.append(self.default_policy.postprocess_trajectory(sub_batch))
return SampleBatch.concat_samples(out)
return concat_samples(out)
else:
# TODO(ekl) this is trickier since the alignments between agent
# trajectories in the episode are not available any more.

View file

@ -0,0 +1,40 @@
import os
import unittest
from pathlib import Path
import ray
from ray.rllib.offline import IOContext
from ray.rllib.offline.dataset_reader import DatasetReader, get_dataset_and_shards
class TestDatasetReader(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
ray.init()
@classmethod
def tearDownClass(cls) -> None:
ray.shutdown()
def test_itr_batches(self):
"""Test that the json reader iterates over batches of rows correctly."""
rllib_dir = Path(__file__).parent.parent.parent.parent
print("rllib dir={}".format(rllib_dir))
data_file = os.path.join(rllib_dir, "rllib/tests/data/pendulum/large.json")
print("data_file={} exists={}".format(data_file, os.path.isfile(data_file)))
input_config = {"format": "json", "path": data_file}
dataset, _ = get_dataset_and_shards(
{"input": "dataset", "input_config": input_config}, 0, True
)
ioctx = IOContext(config={"train_batch_size": 1200}, worker_index=0)
reader = DatasetReader(ioctx, dataset)
assert len(reader.next()) == 1200
if __name__ == "__main__":
import sys
import pytest
sys.exit(pytest.main(["-v", __file__]))

View file

@ -0,0 +1,36 @@
import os
import unittest
from pathlib import Path
import ray
from ray.rllib.offline import IOContext
from ray.rllib.offline.json_reader import JsonReader
class TestJsonReader(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
ray.init()
@classmethod
def tearDownClass(cls) -> None:
ray.shutdown()
def test_itr_batches(self):
"""Test that the json reader iterates over batches of rows correctly."""
rllib_dir = Path(__file__).parent.parent.parent.parent
print("rllib dir={}".format(rllib_dir))
data_file = os.path.join(rllib_dir, "rllib/tests/data/pendulum/large.json")
print("data_file={} exists={}".format(data_file, os.path.isfile(data_file)))
ioctx = IOContext(config={"train_batch_size": 1200}, worker_index=0)
reader = JsonReader([data_file], ioctx)
assert len(reader.next()) == 1200
if __name__ == "__main__":
import sys
import pytest
sys.exit(pytest.main(["-v", __file__]))