mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[RLLib]: Make IOContext optional for DatasetReader (#26694)
This commit is contained in:
parent
51ecc04ccb
commit
2b13ac85f9
3 changed files with 19 additions and 11 deletions
|
@ -588,7 +588,7 @@ class WorkerSet:
|
|||
# Input dataset shards should have already been prepared.
|
||||
# We just need to take the proper shard here.
|
||||
input_creator = lambda ioctx: DatasetReader(
|
||||
ioctx, self._ds_shards[worker_index]
|
||||
self._ds_shards[worker_index], ioctx
|
||||
)
|
||||
# Dict: Mix of different input methods with different ratios.
|
||||
elif isinstance(config["input"], dict):
|
||||
|
|
|
@ -2,7 +2,7 @@ import logging
|
|||
import math
|
||||
from pathlib import Path
|
||||
import re
|
||||
from typing import List, Tuple
|
||||
from typing import List, Tuple, Optional
|
||||
import zipfile
|
||||
|
||||
import ray.data
|
||||
|
@ -205,24 +205,23 @@ class DatasetReader(InputReader):
|
|||
"""
|
||||
|
||||
@PublicAPI
|
||||
def __init__(self, ioctx: IOContext, ds: ray.data.Dataset):
|
||||
def __init__(self, ds: ray.data.Dataset, ioctx: Optional[IOContext] = None):
|
||||
"""Initializes a DatasetReader instance.
|
||||
|
||||
Args:
|
||||
ds: Ray dataset to sample from.
|
||||
"""
|
||||
self._ioctx = ioctx
|
||||
self._ioctx = ioctx or IOContext()
|
||||
self._default_policy = self.policy_map = None
|
||||
self._dataset = ds
|
||||
self.count = None if not self._dataset else self._dataset.count()
|
||||
# do this to disable the ray data stdout logging
|
||||
ray.data.set_progress_bars(enabled=False)
|
||||
|
||||
# the number of rows to return per call to next()
|
||||
if self._ioctx:
|
||||
self.batch_size = self._ioctx.config.get("train_batch_size", 1)
|
||||
num_workers = self._ioctx.config.get("num_workers", 0)
|
||||
seed = self._ioctx.config.get("seed", None)
|
||||
# the number of steps to return per call to next()
|
||||
self.batch_size = self._ioctx.config.get("train_batch_size", 1)
|
||||
num_workers = self._ioctx.config.get("num_workers", 0)
|
||||
seed = self._ioctx.config.get("seed", None)
|
||||
if num_workers:
|
||||
self.batch_size = max(math.ceil(self.batch_size / num_workers), 1)
|
||||
# We allow the creation of a non-functioning None DatasetReader.
|
||||
|
@ -258,7 +257,7 @@ class DatasetReader(InputReader):
|
|||
return ret
|
||||
|
||||
def _postprocess_if_needed(self, batch: SampleBatchType) -> SampleBatchType:
|
||||
if not self._ioctx or not self._ioctx.config.get("postprocess_inputs"):
|
||||
if not self._ioctx.config.get("postprocess_inputs"):
|
||||
return batch
|
||||
|
||||
if isinstance(batch, SampleBatch):
|
||||
|
|
|
@ -36,7 +36,7 @@ class TestDatasetReader(unittest.TestCase):
|
|||
)
|
||||
|
||||
ioctx = IOContext(config={"train_batch_size": 1200}, worker_index=0)
|
||||
reader = DatasetReader(ioctx, dataset)
|
||||
reader = DatasetReader(dataset, ioctx)
|
||||
assert len(reader.next()) >= 1200
|
||||
|
||||
def test_dataset_shard_with_only_local(self):
|
||||
|
@ -132,6 +132,15 @@ class TestDatasetReader(unittest.TestCase):
|
|||
with self.assertRaises(ValueError):
|
||||
get_dataset_and_shards(config)
|
||||
|
||||
def test_default_ioctx(self):
|
||||
# Test DatasetReader without passing in IOContext
|
||||
input_config = {"format": "json", "paths": self.dset_path}
|
||||
dataset, _ = get_dataset_and_shards(
|
||||
{"input": "dataset", "input_config": input_config}
|
||||
)
|
||||
reader = DatasetReader(dataset)
|
||||
reader.next()
|
||||
|
||||
|
||||
class TestUnzipIfNeeded(unittest.TestCase):
|
||||
@classmethod
|
||||
|
|
Loading…
Add table
Reference in a new issue