ray/rllib/offline/mixed_input.py
Sven 60d4d5e1aa Remove future imports (#6724)
* Remove all __future__ imports from RLlib.

* Remove (object) again from tf_run_builder.py::TFRunBuilder.

* Fix 2xLINT warnings.

* Fix broken appo_policy import (must be appo_tf_policy)

* Remove future imports from all other ray files (not just RLlib).

* Remove future imports from all other ray files (not just RLlib).

* Remove future import blocks that contain `unicode_literals` as well.
Revert appo_tf_policy.py to appo_policy.py (belongs to another PR).

* Add two empty lines before Schedule class.

* Put back __future__ imports into determine_tests_to_run.py. Fails otherwise on a py2/print related error.
2020-01-09 00:15:48 -08:00

43 lines
1.3 KiB
Python

import numpy as np
from ray.rllib.offline.input_reader import InputReader
from ray.rllib.offline.json_reader import JsonReader
from ray.rllib.utils.annotations import override, DeveloperAPI
@DeveloperAPI
class MixedInput(InputReader):
"""Mixes input from a number of other input sources.
Examples:
>>> MixedInput({
"sampler": 0.4,
"/tmp/experiences/*.json": 0.4,
"s3://bucket/expert.json": 0.2,
}, ioctx)
"""
@DeveloperAPI
def __init__(self, dist, ioctx):
"""Initialize a MixedInput.
Arguments:
dist (dict): dict mapping JSONReader paths or "sampler" to
probabilities. The probabilities must sum to 1.0.
ioctx (IOContext): current IO context object.
"""
if sum(dist.values()) != 1.0:
raise ValueError("Values must sum to 1.0: {}".format(dist))
self.choices = []
self.p = []
for k, v in dist.items():
if k == "sampler":
self.choices.append(ioctx.default_sampler_input())
else:
self.choices.append(JsonReader(k))
self.p.append(v)
@override(InputReader)
def next(self):
source = np.random.choice(self.choices, p=self.p)
return source.next()