From 2b13ac85f979fabae334870b0f269c79ec9f8a6c Mon Sep 17 00:00:00 2001 From: Rohan Potdar <66227218+Rohan138@users.noreply.github.com> Date: Thu, 21 Jul 2022 13:05:00 -0700 Subject: [PATCH] [RLLib]: Make IOContext optional for DatasetReader (#26694) --- rllib/evaluation/worker_set.py | 2 +- rllib/offline/dataset_reader.py | 17 ++++++++--------- rllib/offline/tests/test_dataset_reader.py | 11 ++++++++++- 3 files changed, 19 insertions(+), 11 deletions(-) diff --git a/rllib/evaluation/worker_set.py b/rllib/evaluation/worker_set.py index e7e07ec24..a37d411aa 100644 --- a/rllib/evaluation/worker_set.py +++ b/rllib/evaluation/worker_set.py @@ -588,7 +588,7 @@ class WorkerSet: # Input dataset shards should have already been prepared. # We just need to take the proper shard here. input_creator = lambda ioctx: DatasetReader( - ioctx, self._ds_shards[worker_index] + self._ds_shards[worker_index], ioctx ) # Dict: Mix of different input methods with different ratios. elif isinstance(config["input"], dict): diff --git a/rllib/offline/dataset_reader.py b/rllib/offline/dataset_reader.py index d0190661a..edf06a845 100644 --- a/rllib/offline/dataset_reader.py +++ b/rllib/offline/dataset_reader.py @@ -2,7 +2,7 @@ import logging import math from pathlib import Path import re -from typing import List, Tuple +from typing import List, Tuple, Optional import zipfile import ray.data @@ -205,24 +205,23 @@ class DatasetReader(InputReader): """ @PublicAPI - def __init__(self, ioctx: IOContext, ds: ray.data.Dataset): + def __init__(self, ds: ray.data.Dataset, ioctx: Optional[IOContext] = None): """Initializes a DatasetReader instance. Args: ds: Ray dataset to sample from. """ - self._ioctx = ioctx + self._ioctx = ioctx or IOContext() self._default_policy = self.policy_map = None self._dataset = ds self.count = None if not self._dataset else self._dataset.count() # do this to disable the ray data stdout logging ray.data.set_progress_bars(enabled=False) - # the number of rows to return per call to next() - if self._ioctx: - self.batch_size = self._ioctx.config.get("train_batch_size", 1) - num_workers = self._ioctx.config.get("num_workers", 0) - seed = self._ioctx.config.get("seed", None) + # the number of steps to return per call to next() + self.batch_size = self._ioctx.config.get("train_batch_size", 1) + num_workers = self._ioctx.config.get("num_workers", 0) + seed = self._ioctx.config.get("seed", None) if num_workers: self.batch_size = max(math.ceil(self.batch_size / num_workers), 1) # We allow the creation of a non-functioning None DatasetReader. @@ -258,7 +257,7 @@ class DatasetReader(InputReader): return ret def _postprocess_if_needed(self, batch: SampleBatchType) -> SampleBatchType: - if not self._ioctx or not self._ioctx.config.get("postprocess_inputs"): + if not self._ioctx.config.get("postprocess_inputs"): return batch if isinstance(batch, SampleBatch): diff --git a/rllib/offline/tests/test_dataset_reader.py b/rllib/offline/tests/test_dataset_reader.py index ef2f00b7e..77ec567df 100644 --- a/rllib/offline/tests/test_dataset_reader.py +++ b/rllib/offline/tests/test_dataset_reader.py @@ -36,7 +36,7 @@ class TestDatasetReader(unittest.TestCase): ) ioctx = IOContext(config={"train_batch_size": 1200}, worker_index=0) - reader = DatasetReader(ioctx, dataset) + reader = DatasetReader(dataset, ioctx) assert len(reader.next()) >= 1200 def test_dataset_shard_with_only_local(self): @@ -132,6 +132,15 @@ class TestDatasetReader(unittest.TestCase): with self.assertRaises(ValueError): get_dataset_and_shards(config) + def test_default_ioctx(self): + # Test DatasetReader without passing in IOContext + input_config = {"format": "json", "paths": self.dset_path} + dataset, _ = get_dataset_and_shards( + {"input": "dataset", "input_config": input_config} + ) + reader = DatasetReader(dataset) + reader.next() + class TestUnzipIfNeeded(unittest.TestCase): @classmethod