mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00

* Remove all __future__ imports from RLlib. * Remove (object) again from tf_run_builder.py::TFRunBuilder. * Fix 2xLINT warnings. * Fix broken appo_policy import (must be appo_tf_policy) * Remove future imports from all other ray files (not just RLlib). * Remove future imports from all other ray files (not just RLlib). * Remove future import blocks that contain `unicode_literals` as well. Revert appo_tf_policy.py to appo_policy.py (belongs to another PR). * Add two empty lines before Schedule class. * Put back __future__ imports into determine_tests_to_run.py. Fails otherwise on a py2/print related error.
43 lines
1.3 KiB
Python
43 lines
1.3 KiB
Python
import numpy as np
|
|
|
|
from ray.rllib.offline.input_reader import InputReader
|
|
from ray.rllib.offline.json_reader import JsonReader
|
|
from ray.rllib.utils.annotations import override, DeveloperAPI
|
|
|
|
|
|
@DeveloperAPI
|
|
class MixedInput(InputReader):
|
|
"""Mixes input from a number of other input sources.
|
|
|
|
Examples:
|
|
>>> MixedInput({
|
|
"sampler": 0.4,
|
|
"/tmp/experiences/*.json": 0.4,
|
|
"s3://bucket/expert.json": 0.2,
|
|
}, ioctx)
|
|
"""
|
|
|
|
@DeveloperAPI
|
|
def __init__(self, dist, ioctx):
|
|
"""Initialize a MixedInput.
|
|
|
|
Arguments:
|
|
dist (dict): dict mapping JSONReader paths or "sampler" to
|
|
probabilities. The probabilities must sum to 1.0.
|
|
ioctx (IOContext): current IO context object.
|
|
"""
|
|
if sum(dist.values()) != 1.0:
|
|
raise ValueError("Values must sum to 1.0: {}".format(dist))
|
|
self.choices = []
|
|
self.p = []
|
|
for k, v in dist.items():
|
|
if k == "sampler":
|
|
self.choices.append(ioctx.default_sampler_input())
|
|
else:
|
|
self.choices.append(JsonReader(k))
|
|
self.p.append(v)
|
|
|
|
@override(InputReader)
|
|
def next(self):
|
|
source = np.random.choice(self.choices, p=self.p)
|
|
return source.next()
|