[RLLib]: Make IOContext optional for DatasetReader (#26694)

This commit is contained in:
Rohan Potdar 2022-07-21 13:05:00 -07:00 committed by GitHub
parent 51ecc04ccb
commit 2b13ac85f9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 19 additions and 11 deletions

View file

@ -588,7 +588,7 @@ class WorkerSet:
# Input dataset shards should have already been prepared.
# We just need to take the proper shard here.
input_creator = lambda ioctx: DatasetReader(
ioctx, self._ds_shards[worker_index]
self._ds_shards[worker_index], ioctx
)
# Dict: Mix of different input methods with different ratios.
elif isinstance(config["input"], dict):

View file

@ -2,7 +2,7 @@ import logging
import math
from pathlib import Path
import re
from typing import List, Tuple
from typing import List, Tuple, Optional
import zipfile
import ray.data
@ -205,24 +205,23 @@ class DatasetReader(InputReader):
"""
@PublicAPI
def __init__(self, ioctx: IOContext, ds: ray.data.Dataset):
def __init__(self, ds: ray.data.Dataset, ioctx: Optional[IOContext] = None):
"""Initializes a DatasetReader instance.
Args:
ds: Ray dataset to sample from.
"""
self._ioctx = ioctx
self._ioctx = ioctx or IOContext()
self._default_policy = self.policy_map = None
self._dataset = ds
self.count = None if not self._dataset else self._dataset.count()
# do this to disable the ray data stdout logging
ray.data.set_progress_bars(enabled=False)
# the number of rows to return per call to next()
if self._ioctx:
self.batch_size = self._ioctx.config.get("train_batch_size", 1)
num_workers = self._ioctx.config.get("num_workers", 0)
seed = self._ioctx.config.get("seed", None)
# the number of steps to return per call to next()
self.batch_size = self._ioctx.config.get("train_batch_size", 1)
num_workers = self._ioctx.config.get("num_workers", 0)
seed = self._ioctx.config.get("seed", None)
if num_workers:
self.batch_size = max(math.ceil(self.batch_size / num_workers), 1)
# We allow the creation of a non-functioning None DatasetReader.
@ -258,7 +257,7 @@ class DatasetReader(InputReader):
return ret
def _postprocess_if_needed(self, batch: SampleBatchType) -> SampleBatchType:
if not self._ioctx or not self._ioctx.config.get("postprocess_inputs"):
if not self._ioctx.config.get("postprocess_inputs"):
return batch
if isinstance(batch, SampleBatch):

View file

@ -36,7 +36,7 @@ class TestDatasetReader(unittest.TestCase):
)
ioctx = IOContext(config={"train_batch_size": 1200}, worker_index=0)
reader = DatasetReader(ioctx, dataset)
reader = DatasetReader(dataset, ioctx)
assert len(reader.next()) >= 1200
def test_dataset_shard_with_only_local(self):
@ -132,6 +132,15 @@ class TestDatasetReader(unittest.TestCase):
with self.assertRaises(ValueError):
get_dataset_and_shards(config)
def test_default_ioctx(self):
# Test DatasetReader without passing in IOContext
input_config = {"format": "json", "paths": self.dset_path}
dataset, _ = get_dataset_and_shards(
{"input": "dataset", "input_config": input_config}
)
reader = DatasetReader(dataset)
reader.next()
class TestUnzipIfNeeded(unittest.TestCase):
@classmethod