ray/rllib/offline/mixed_input.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

57 lines
2.1 KiB
Python
Raw Normal View History

from types import FunctionType
from typing import Dict
import numpy as np
from ray.rllib.offline.input_reader import InputReader
from ray.rllib.offline.io_context import IOContext
from ray.rllib.offline.json_reader import JsonReader
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.typing import SampleBatchType
from ray.tune.registry import registry_get_input, registry_contains_input
@DeveloperAPI
class MixedInput(InputReader):
"""Mixes input from a number of other input sources.
Examples:
>>> from ray.rllib.offline.io_context import IOContext
>>> from ray.rllib.offline.mixed_input import MixedInput
>>> ioctx = IOContext(...) # doctest: +SKIP
>>> MixedInput({ # doctest: +SKIP
... "sampler": 0.4, # doctest: +SKIP
... "/tmp/experiences/*.json": 0.4, # doctest: +SKIP
... "s3://bucket/expert.json": 0.2, # doctest: +SKIP
... }, ioctx) # doctest: +SKIP
"""
@DeveloperAPI
def __init__(self, dist: Dict[JsonReader, float], ioctx: IOContext):
"""Initialize a MixedInput.
2020-09-20 11:27:02 +02:00
Args:
dist: dict mapping JSONReader paths or "sampler" to
probabilities. The probabilities must sum to 1.0.
ioctx: current IO context object.
"""
if sum(dist.values()) != 1.0:
raise ValueError("Values must sum to 1.0: {}".format(dist))
self.choices = []
self.p = []
for k, v in dist.items():
if k == "sampler":
self.choices.append(ioctx.default_sampler_input())
elif isinstance(k, FunctionType):
self.choices.append(k(ioctx))
elif isinstance(k, str) and registry_contains_input(k):
input_creator = registry_get_input(k)
self.choices.append(input_creator(ioctx))
else:
self.choices.append(JsonReader(k, ioctx))
self.p.append(v)
@override(InputReader)
def next(self) -> SampleBatchType:
source = np.random.choice(self.choices, p=self.p)
return source.next()