mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
52 lines
1.6 KiB
Python
52 lines
1.6 KiB
Python
import logging
|
|
import gym
|
|
|
|
from ray.rllib.offline.input_reader import InputReader
|
|
from ray.rllib.offline.io_context import IOContext
|
|
from ray.rllib.policy.sample_batch import SampleBatch
|
|
from ray.rllib.utils.annotations import override, PublicAPI
|
|
from ray.rllib.utils.typing import SampleBatchType
|
|
from typing import Dict
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@PublicAPI
|
|
class D4RLReader(InputReader):
|
|
"""Reader object that loads the dataset from the D4RL dataset."""
|
|
|
|
@PublicAPI
|
|
def __init__(self, inputs: str, ioctx: IOContext = None):
|
|
"""Initialize a D4RLReader.
|
|
|
|
Args:
|
|
inputs (str): String corresponding to D4RL environment name
|
|
ioctx (IOContext): Current IO context object.
|
|
"""
|
|
import d4rl
|
|
self.env = gym.make(inputs)
|
|
self.dataset = convert_to_batch(d4rl.qlearning_dataset(self.env))
|
|
assert self.dataset.count >= 1
|
|
self.dataset.shuffle()
|
|
self.counter = 0
|
|
|
|
@override(InputReader)
|
|
def next(self) -> SampleBatchType:
|
|
if self.counter >= self.dataset.count:
|
|
self.counter = 0
|
|
self.dataset.shuffle()
|
|
|
|
self.counter += 1
|
|
return self.dataset.slice(start=self.counter, end=self.counter + 1)
|
|
|
|
|
|
def convert_to_batch(dataset: Dict) -> SampleBatchType:
|
|
# Converts D4RL dataset to SampleBatch
|
|
d = {}
|
|
d[SampleBatch.OBS] = dataset["observations"]
|
|
d[SampleBatch.ACTIONS] = dataset["actions"]
|
|
d[SampleBatch.NEXT_OBS] = dataset["next_observations"]
|
|
d[SampleBatch.REWARDS] = dataset["rewards"]
|
|
d[SampleBatch.DONES] = dataset["terminals"]
|
|
|
|
return SampleBatch(d)
|