[RLLib]: Make IOContext optional for DatasetReader (#26694)

This commit is contained in:
Rohan Potdar 2022-07-21 13:05:00 -07:00 committed by GitHub
parent 51ecc04ccb
commit 2b13ac85f9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 19 additions and 11 deletions

View file

@ -588,7 +588,7 @@ class WorkerSet:
# Input dataset shards should have already been prepared. # Input dataset shards should have already been prepared.
# We just need to take the proper shard here. # We just need to take the proper shard here.
input_creator = lambda ioctx: DatasetReader( input_creator = lambda ioctx: DatasetReader(
ioctx, self._ds_shards[worker_index] self._ds_shards[worker_index], ioctx
) )
# Dict: Mix of different input methods with different ratios. # Dict: Mix of different input methods with different ratios.
elif isinstance(config["input"], dict): elif isinstance(config["input"], dict):

View file

@ -2,7 +2,7 @@ import logging
import math import math
from pathlib import Path from pathlib import Path
import re import re
from typing import List, Tuple from typing import List, Tuple, Optional
import zipfile import zipfile
import ray.data import ray.data
@ -205,21 +205,20 @@ class DatasetReader(InputReader):
""" """
@PublicAPI @PublicAPI
def __init__(self, ioctx: IOContext, ds: ray.data.Dataset): def __init__(self, ds: ray.data.Dataset, ioctx: Optional[IOContext] = None):
"""Initializes a DatasetReader instance. """Initializes a DatasetReader instance.
Args: Args:
ds: Ray dataset to sample from. ds: Ray dataset to sample from.
""" """
self._ioctx = ioctx self._ioctx = ioctx or IOContext()
self._default_policy = self.policy_map = None self._default_policy = self.policy_map = None
self._dataset = ds self._dataset = ds
self.count = None if not self._dataset else self._dataset.count() self.count = None if not self._dataset else self._dataset.count()
# do this to disable the ray data stdout logging # do this to disable the ray data stdout logging
ray.data.set_progress_bars(enabled=False) ray.data.set_progress_bars(enabled=False)
# the number of rows to return per call to next() # the number of steps to return per call to next()
if self._ioctx:
self.batch_size = self._ioctx.config.get("train_batch_size", 1) self.batch_size = self._ioctx.config.get("train_batch_size", 1)
num_workers = self._ioctx.config.get("num_workers", 0) num_workers = self._ioctx.config.get("num_workers", 0)
seed = self._ioctx.config.get("seed", None) seed = self._ioctx.config.get("seed", None)
@ -258,7 +257,7 @@ class DatasetReader(InputReader):
return ret return ret
def _postprocess_if_needed(self, batch: SampleBatchType) -> SampleBatchType: def _postprocess_if_needed(self, batch: SampleBatchType) -> SampleBatchType:
if not self._ioctx or not self._ioctx.config.get("postprocess_inputs"): if not self._ioctx.config.get("postprocess_inputs"):
return batch return batch
if isinstance(batch, SampleBatch): if isinstance(batch, SampleBatch):

View file

@ -36,7 +36,7 @@ class TestDatasetReader(unittest.TestCase):
) )
ioctx = IOContext(config={"train_batch_size": 1200}, worker_index=0) ioctx = IOContext(config={"train_batch_size": 1200}, worker_index=0)
reader = DatasetReader(ioctx, dataset) reader = DatasetReader(dataset, ioctx)
assert len(reader.next()) >= 1200 assert len(reader.next()) >= 1200
def test_dataset_shard_with_only_local(self): def test_dataset_shard_with_only_local(self):
@ -132,6 +132,15 @@ class TestDatasetReader(unittest.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
get_dataset_and_shards(config) get_dataset_and_shards(config)
def test_default_ioctx(self):
# Test DatasetReader without passing in IOContext
input_config = {"format": "json", "paths": self.dset_path}
dataset, _ = get_dataset_and_shards(
{"input": "dataset", "input_config": input_config}
)
reader = DatasetReader(dataset)
reader.next()
class TestUnzipIfNeeded(unittest.TestCase): class TestUnzipIfNeeded(unittest.TestCase):
@classmethod @classmethod