[rllib] Basic Offline Data IO API (#3473)

This commit is contained in:
Eric Liang 2018-12-12 13:57:48 -08:00 committed by GitHub
parent cc8f7db246
commit 32473cf22e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 801 additions and 48 deletions

View file

@ -5,7 +5,7 @@ FROM ray-project/deploy
# This updates numpy to 1.14 and mutes errors from other libraries
RUN conda install -y numpy
RUN apt-get install -y zlib1g-dev
RUN pip install gym[atari] opencv-python==3.2.0.8 tensorflow lz4 keras pytest-timeout
RUN pip install gym[atari] opencv-python==3.2.0.8 tensorflow lz4 keras pytest-timeout smart_open
RUN pip install -U h5py # Mutes FutureWarnings
RUN pip install --upgrade git+git://github.com/hyperopt/hyperopt.git
RUN conda install pytorch-cpu torchvision-cpu -c pytorch

View file

@ -10,8 +10,10 @@ import pickle
import six
import tempfile
import tensorflow as tf
from types import FunctionType
import ray
from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter
from ray.rllib.models import MODEL_DEFAULTS
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
@ -122,11 +124,45 @@ COMMON_CONFIG = {
"intra_op_parallelism_threads": 8,
"inter_op_parallelism_threads": 8,
},
# Whether to LZ4 compress observations
# Whether to LZ4 compress individual observations
"compress_observations": False,
# Drop metric batches from unresponsive workers after this many seconds
"collect_metrics_timeout": 180,
# === Offline Data Input / Output ===
# Specify how to generate experiences:
# - "sampler": generate experiences via online simulation (default)
# - a local directory or file glob expression (e.g., "/tmp/*.json")
# - a list of individual file paths/URIs (e.g., ["/tmp/1.json",
# "s3://bucket/2.json"])
# - a dict with string keys and sampling probabilities as values (e.g.,
# {"sampler": 0.4, "/tmp/*.json": 0.4, "s3://bucket/expert.json": 0.2}).
# - a function that returns a rllib.offline.InputReader
"input": "sampler",
# Specify how to evaluate the current policy. This only makes sense to set
# when the input is not already generating simulation data:
# - None: don't evaluate the policy. The episode reward and other
# metrics will be NaN if using offline data.
# - "simulation": run the environment in the background, but use
# this data for evaluation only and not for learning.
# - "counterfactual": use counterfactual policy evaluation to estimate
# performance (this option is not implemented yet).
"input_evaluation": None,
# Specify where experiences should be saved:
# - None: don't save any experiences
# - "logdir" to save to the agent log dir
# - a path/URI to save to a custom output directory (e.g., "s3://bucket/")
# - a function that returns a rllib.offline.OutputWriter
"output": None,
# What sample batch columns to LZ4 compress in the output data.
"output_compress_columns": ["obs", "new_obs"],
# Max output file size before rolling over to a new file.
"output_max_file_size": 64 * 1024 * 1024,
# Whether to run postprocess_trajectory() on the trajectory fragments from
# offline inputs. Whether this makes sense is algorithm-specific.
# TODO(ekl) implement this and multi-agent batch handling
# "postprocess_inputs": False,
# === Multiagent ===
"multiagent": {
# Map from policy ids to tuples of (policy_graph_cls, obs_space,
@ -179,7 +215,6 @@ class Agent(Trainable):
"""
config = config or {}
Agent._validate_config(config)
# Vars to synchronize to evaluators on each train call
self.global_vars = {"timestep": 0}
@ -267,6 +302,7 @@ class Agent(Trainable):
self._allow_unknown_configs,
self._allow_unknown_subkeys)
self.config = merged_config
Agent._validate_config(self.config)
if self.config.get("log_level"):
logging.getLogger("ray.rllib").setLevel(self.config["log_level"])
@ -391,6 +427,32 @@ class Agent(Trainable):
self.config) for i in range(count)
]
@classmethod
def resource_help(cls, config):
return ("\n\nYou can adjust the resource requests of RLlib agents by "
"setting `num_workers` and other configs. See the "
"DEFAULT_CONFIG defined by each agent for more info.\n\n"
"The config of this agent is: {}".format(config))
@staticmethod
def _validate_config(config):
if "gpu" in config:
raise ValueError(
"The `gpu` config is deprecated, please use `num_gpus=0|1` "
"instead.")
if "gpu_fraction" in config:
raise ValueError(
"The `gpu_fraction` config is deprecated, please use "
"`num_gpus=<fraction>` instead.")
if "use_gpu_for_workers" in config:
raise ValueError(
"The `use_gpu_for_workers` config is deprecated, please use "
"`num_gpus_per_worker=1` instead.")
if (config["input"] == "sampler"
and config["input_evaluation"] is not None):
raise ValueError(
"`input_evaluation` should not be set when input=sampler")
def _make_evaluator(self, cls, env_creator, policy_graph, worker_index,
config):
def session_creator():
@ -399,6 +461,32 @@ class Agent(Trainable):
return tf.Session(
config=tf.ConfigProto(**config["tf_session_args"]))
if isinstance(config["input"], FunctionType):
input_creator = config["input"]
elif config["input"] == "sampler":
input_creator = (lambda ioctx: ioctx.default_sampler_input())
elif isinstance(config["input"], dict):
input_creator = (lambda ioctx: MixedInput(ioctx, config["input"]))
else:
input_creator = (lambda ioctx: JsonReader(ioctx, config["input"]))
if isinstance(config["output"], FunctionType):
output_creator = config["output"]
elif config["output"] is None:
output_creator = (lambda ioctx: NoopOutput())
elif config["output"] == "logdir":
output_creator = (lambda ioctx: JsonWriter(
ioctx,
ioctx.log_dir,
max_file_size=config["output_max_file_size"],
compress_columns=config["output_compress_columns"]))
else:
output_creator = (lambda ioctx: JsonWriter(
ioctx,
config["output"],
max_file_size=config["output_max_file_size"],
compress_columns=config["output_compress_columns"]))
return cls(
env_creator,
self.config["multiagent"]["policy_graphs"] or policy_graph,
@ -421,30 +509,12 @@ class Agent(Trainable):
policy_config=config,
worker_index=worker_index,
monitor_path=self.logdir if config["monitor"] else None,
log_dir=self.logdir,
log_level=config["log_level"],
callbacks=config["callbacks"])
@classmethod
def resource_help(cls, config):
return ("\n\nYou can adjust the resource requests of RLlib agents by "
"setting `num_workers` and other configs. See the "
"DEFAULT_CONFIG defined by each agent for more info.\n\n"
"The config of this agent is: {}".format(config))
@staticmethod
def _validate_config(config):
if "gpu" in config:
raise ValueError(
"The `gpu` config is deprecated, please use `num_gpus=0|1` "
"instead.")
if "gpu_fraction" in config:
raise ValueError(
"The `gpu_fraction` config is deprecated, please use "
"`num_gpus=<fraction>` instead.")
if "use_gpu_for_workers" in config:
raise ValueError(
"The `use_gpu_for_workers` config is deprecated, please use "
"`num_gpus_per_worker=1` instead.")
callbacks=config["callbacks"],
input_creator=input_creator,
input_evaluation_method=config["input_evaluation"],
output_creator=output_creator)
def __getstate__(self):
state = {}

View file

@ -36,9 +36,11 @@ class PGAgent(Agent):
self.env_creator, self._policy_graph)
self.remote_evaluators = self.make_remote_evaluators(
self.env_creator, self._policy_graph, self.config["num_workers"])
self.optimizer = SyncSamplesOptimizer(self.local_evaluator,
self.remote_evaluators,
self.config["optimizer"])
optimizer_config = dict(
self.config["optimizer"],
**{"train_batch_size": self.config["train_batch_size"]})
self.optimizer = SyncSamplesOptimizer(
self.local_evaluator, self.remote_evaluators, optimizer_config)
@override(Agent)
def _train(self):

View file

@ -18,6 +18,7 @@ from ray.rllib.evaluation.sample_batch import MultiAgentBatch, \
from ray.rllib.evaluation.sampler import AsyncSampler, SyncSampler
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
from ray.rllib.offline import NoopOutput, IOContext, OutputWriter, InputReader
from ray.rllib.models import ModelCatalog
from ray.rllib.models.preprocessors import NoPreprocessor
from ray.rllib.utils import merge_dicts
@ -108,8 +109,12 @@ class PolicyEvaluator(EvaluatorInterface):
policy_config=None,
worker_index=0,
monitor_path=None,
log_dir=None,
log_level=None,
callbacks=None):
callbacks=None,
input_creator=lambda ioctx: ioctx.default_sampler_input(),
input_evaluation_method=None,
output_creator=lambda ioctx: NoopOutput()):
"""Initialize a policy evaluator.
Arguments:
@ -170,8 +175,22 @@ class PolicyEvaluator(EvaluatorInterface):
through EnvContext so that envs can be configured per worker.
monitor_path (str): Write out episode stats and videos to this
directory if specified.
log_dir (str): Directory where logs can be placed.
log_level (str): Set the root log level on creation.
callbacks (dict): Dict of custom debug callbacks.
input_creator (func): Function that returns an InputReader object
for loading previous generated experiences.
input_evaluation_method (str): How to evaluate the current policy.
This only applies when the input is reading offline data.
Options are:
- None: don't evaluate the policy. The episode reward and
other metrics will be NaN.
- "simulation": run the environment in the background, but
use this data for evaluation only and never for learning.
- "counterfactual": use counterfactual policy evaluation to
estimate performance.
output_creator (func): Function that returns an OutputWriter object
for saving generated experiences.
"""
if log_level:
@ -279,6 +298,20 @@ class PolicyEvaluator(EvaluatorInterface):
else:
raise ValueError("Unsupported batch mode: {}".format(
self.batch_mode))
if input_evaluation_method == "simulation":
logger.warn(
"Requested 'simulation' input evaluation method: "
"will discard all sampler outputs and keep only metrics.")
sample_async = True
elif input_evaluation_method == "counterfactual":
raise NotImplementedError
elif input_evaluation_method is None:
pass
else:
raise ValueError("Unknown evaluation method: {}".format(
input_evaluation_method))
if sample_async:
self.sampler = AsyncSampler(
self.async_env,
@ -292,7 +325,8 @@ class PolicyEvaluator(EvaluatorInterface):
horizon=episode_horizon,
pack=pack_episodes,
tf_sess=self.tf_sess,
clip_actions=clip_actions)
clip_actions=clip_actions,
blackhole_outputs=input_evaluation_method == "simulation")
self.sampler.start()
else:
self.sampler = SyncSampler(
@ -309,6 +343,12 @@ class PolicyEvaluator(EvaluatorInterface):
tf_sess=self.tf_sess,
clip_actions=clip_actions)
self.io_context = IOContext(log_dir, policy_config, worker_index, self)
self.input_reader = input_creator(self.io_context)
assert isinstance(self.input_reader, InputReader), self.input_reader
self.output_writer = output_creator(self.io_context)
assert isinstance(self.output_writer, OutputWriter), self.output_writer
logger.debug("Created evaluator with env {} ({}), policies {}".format(
self.async_env, self.env, self.policy_map))
@ -320,7 +360,7 @@ class PolicyEvaluator(EvaluatorInterface):
SampleBatch|MultiAgentBatch from evaluating the current policies.
"""
batches = [self.sampler.get_data()]
batches = [self.input_reader.next()]
steps_so_far = batches[0].count
# In truncate_episodes mode, never pull more than 1 batch per env.
@ -332,10 +372,9 @@ class PolicyEvaluator(EvaluatorInterface):
while steps_so_far < self.sample_batch_size and len(
batches) < max_batches:
batch = self.sampler.get_data()
batch = self.input_reader.next()
steps_so_far += batch.count
batches.append(batch)
batches.extend(self.sampler.get_extra_batches())
batch = batches[0].concat_samples(batches)
if self.callbacks.get("on_sample_end"):
@ -353,6 +392,7 @@ class PolicyEvaluator(EvaluatorInterface):
batch["obs"] = [pack(o) for o in batch["obs"]]
batch["new_obs"] = [pack(o) for o in batch["new_obs"]]
self.output_writer.write(batch)
return batch
@ray.method(num_return_vals=2)
@ -531,6 +571,10 @@ class PolicyEvaluator(EvaluatorInterface):
policy_map[name] = cls(obs_space, act_space, merged_conf)
return policy_map, preprocessors
def __del__(self):
if isinstance(self.sampler, AsyncSampler):
self.sampler.shutdown = True
def _validate_and_canonicalize(policy_graph, env):
if isinstance(policy_graph, dict):

View file

@ -2,6 +2,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
import collections
import numpy as np
@ -233,8 +234,9 @@ class SampleBatch(object):
self.data = dict(*args, **kwargs)
lengths = []
for k, v in self.data.copy().items():
assert type(k) == str, self
assert isinstance(k, six.string_types), self
lengths.append(len(v))
self.data[k] = np.array(v, copy=False)
if not lengths:
raise ValueError("Empty sample batch")
assert len(set(lengths)) == 1, "data columns must be same length"

View file

@ -112,7 +112,8 @@ class AsyncSampler(threading.Thread):
horizon=None,
pack=False,
tf_sess=None,
clip_actions=True):
clip_actions=True,
blackhole_outputs=False):
for _, f in obs_filters.items():
assert getattr(f, "is_concurrent", False), \
"Observation Filter must support concurrent updates."
@ -133,6 +134,8 @@ class AsyncSampler(threading.Thread):
self.tf_sess = tf_sess
self.callbacks = callbacks
self.clip_actions = clip_actions
self.blackhole_outputs = blackhole_outputs
self.shutdown = False
def run(self):
try:
@ -142,12 +145,19 @@ class AsyncSampler(threading.Thread):
raise e
def _run(self):
if self.blackhole_outputs:
queue_putter = (lambda x: None)
extra_batches_putter = (lambda x: None)
else:
queue_putter = self.queue.put
extra_batches_putter = (
lambda x: self.extra_batches.put(x, timeout=600.0))
rollout_provider = _env_runner(
self.async_vector_env, self.extra_batches.put, self.policies,
self.async_vector_env, extra_batches_putter, self.policies,
self.policy_mapping_fn, self.unroll_length, self.horizon,
self.preprocessors, self.obs_filters, self.clip_rewards,
self.clip_actions, self.pack, self.callbacks, self.tf_sess)
while True:
while not self.shutdown:
# The timeout variable exists because apparently, if one worker
# dies, the other workers won't die with it, unless the timeout is
# set to some large number. This is an empirical observation.
@ -155,7 +165,7 @@ class AsyncSampler(threading.Thread):
if isinstance(item, RolloutMetrics):
self.metrics_queue.put(item)
else:
self.queue.put(item, timeout=600.0)
queue_putter(item)
def get_data(self):
rollout = self.queue.get(timeout=600.0)
@ -246,7 +256,7 @@ def _env_runner(async_vector_env,
horizon = (
async_vector_env.get_unwrapped()[0].spec.max_episode_steps)
except Exception:
logger.warn("no episode horizon specified, assuming inf")
logger.debug("no episode horizon specified, assuming inf")
if not horizon:
horizon = float("inf")
@ -332,12 +342,12 @@ def _process_observations(async_vector_env, policies, batch_builder_pool,
"More than {} observations for {} env steps ".format(
episode.batch_builder.total(),
episode.batch_builder.count) + "are buffered in "
"the sampler. If this is not intentional, check that the "
"the `horizon` config is set correctly, or consider setting "
"`batch_mode` to 'truncate_episodes'. Note that in "
"multi-agent environments, `sample_batch_size` sets the "
"batch size based on environment steps, not the steps of "
"individual agents.")
"the sampler. If this is more than you expected, check that "
"that you set a horizon on your environment correctly. Note "
"that in multi-agent environments, `sample_batch_size` sets "
"the batch size based on environment steps, not the steps of "
"individual agents, which can result in unexpectedly large "
"batches.")
# Check episode termination conditions
if dones[env_id]["__all__"] or episode.length >= horizon:

View file

@ -0,0 +1,20 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.offline.io_context import IOContext
from ray.rllib.offline.json_reader import JsonReader
from ray.rllib.offline.json_writer import JsonWriter
from ray.rllib.offline.output_writer import OutputWriter, NoopOutput
from ray.rllib.offline.input_reader import InputReader
from ray.rllib.offline.mixed_input import MixedInput
__all__ = [
"IOContext",
"JsonReader",
"JsonWriter",
"NoopOutput",
"OutputWriter",
"InputReader",
"MixedInput",
]

View file

@ -0,0 +1,30 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.utils.annotations import override
class InputReader(object):
"""Input object for loading experiences in policy evaluation."""
def next(self):
"""Return the next batch of experiences read."""
raise NotImplementedError
class SamplerInput(InputReader):
"""Reads input experiences from an existing sampler."""
def __init__(self, sampler):
self.sampler = sampler
@override(InputReader)
def next(self):
batches = [self.sampler.get_data()]
batches.extend(self.sampler.get_extra_batches())
if len(batches) > 1:
return batches[0].concat_samples(batches)
else:
return batches[0]

View file

@ -0,0 +1,28 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.offline.input_reader import SamplerInput
class IOContext(object):
"""Attributes to pass to input / output class constructors.
RLlib auto-sets these attributes when constructing input / output classes.
Attributes:
log_dir (str): Default logging directory.
config (dict): Configuration of the agent.
worker_index (int): When there are multiple workers created, this
uniquely identifies the current worker.
evaluator (PolicyEvaluator): policy evaluator object reference.
"""
def __init__(self, log_dir, config, worker_index, evaluator):
self.log_dir = log_dir
self.config = config
self.worker_index = worker_index
self.evaluator = evaluator
def default_sampler_input(self):
return SamplerInput(self.evaluator.sampler)

View file

@ -0,0 +1,126 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import glob
import json
import logging
import os
import random
import six
from six.moves.urllib.parse import urlparse
try:
from smart_open import smart_open
except ImportError:
smart_open = None
from ray.rllib.offline.input_reader import InputReader
from ray.rllib.evaluation.sample_batch import SampleBatch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.compression import unpack_if_needed
logger = logging.getLogger(__name__)
class JsonReader(InputReader):
"""Reader object that loads experiences from JSON file chunks.
The input files will be read from in an random order."""
def __init__(self, ioctx, inputs):
"""Initialize a JsonReader.
Arguments:
ioctx (IOContext): current IO context object.
inputs (str|list): either a glob expression for files, e.g.,
"/tmp/**/*.json", or a list of single file paths or URIs, e.g.,
["s3://bucket/file.json", "s3://bucket/file2.json"].
"""
self.ioctx = ioctx
if isinstance(inputs, six.string_types):
if os.path.isdir(inputs):
inputs = os.path.join(inputs, "*.json")
logger.warn(
"Treating input directory as glob pattern: {}".format(
inputs))
if urlparse(inputs).scheme:
raise ValueError(
"Don't know how to glob over `{}`, ".format(inputs) +
"please specify a list of files to read instead.")
else:
self.files = glob.glob(inputs)
elif type(inputs) is list:
self.files = inputs
else:
raise ValueError(
"type of inputs must be list or str, not {}".format(inputs))
if self.files:
logger.info("Found {} input files.".format(len(self.files)))
else:
raise ValueError("No files found matching {}".format(inputs))
self.cur_file = None
@override(InputReader)
def next(self):
batch = self._try_parse(self._next_line())
tries = 0
while not batch and tries < 100:
tries += 1
logger.debug("Skipping empty line in {}".format(self.cur_file))
batch = self._try_parse(self._next_line())
if not batch:
raise ValueError(
"Failed to read valid experience batch from file: {}".format(
self.cur_file))
return batch
def _try_parse(self, line):
line = line.strip()
if not line:
return None
try:
return _from_json(line)
except Exception:
logger.exception("Ignoring corrupt json record in {}: {}".format(
self.cur_file, line))
return None
def _next_line(self):
if not self.cur_file:
self.cur_file = self._next_file()
line = self.cur_file.readline()
tries = 0
while not line and tries < 100:
tries += 1
if hasattr(self.cur_file, "close"): # legacy smart_open impls
self.cur_file.close()
self.cur_file = self._next_file()
line = self.cur_file.readline()
if not line:
logger.debug("Ignoring empty file {}".format(self.cur_file))
if not line:
raise ValueError("Failed to read next line from files: {}".format(
self.files))
return line
def _next_file(self):
path = random.choice(self.files)
if urlparse(path).scheme:
if smart_open is None:
raise ValueError(
"You must install the `smart_open` module to read "
"from URIs like {}".format(path))
return smart_open(path, "r")
else:
return open(path, "r")
def _from_json(batch):
if isinstance(batch, bytes): # smart_open S3 doesn't respect "r"
batch = batch.decode("utf-8")
data = json.loads(batch)
for k, v in data.items():
data[k] = [unpack_if_needed(x) for x in unpack_if_needed(v)]
return SampleBatch(data)

View file

@ -0,0 +1,108 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from datetime import datetime
import json
import logging
import numpy as np
import os
from six.moves.urllib.parse import urlparse
import time
try:
from smart_open import smart_open
except ImportError:
smart_open = None
from ray.rllib.offline.output_writer import OutputWriter
from ray.rllib.utils.annotations import override
from ray.rllib.utils.compression import pack
logger = logging.getLogger(__name__)
class JsonWriter(OutputWriter):
"""Writer object that saves experiences in JSON file chunks."""
def __init__(self,
ioctx,
path,
max_file_size=64 * 1024 * 1024,
compress_columns=frozenset(["obs", "new_obs"])):
"""Initialize a JsonWriter.
Arguments:
ioctx (IOContext): current IO context object.
path (str): a path/URI of the output directory to save files in.
max_file_size (int): max size of single files before rolling over.
compress_columns (list): list of sample batch columns to compress.
"""
self.ioctx = ioctx
self.path = path
self.max_file_size = max_file_size
self.compress_columns = compress_columns
if urlparse(path).scheme:
self.path_is_uri = True
else:
# Try to create local dirs if they don't exist
try:
os.makedirs(path)
except OSError:
pass # already exists
assert os.path.exists(path), "Failed to create {}".format(path)
self.path_is_uri = False
self.file_index = 0
self.bytes_written = 0
self.cur_file = None
@override(OutputWriter)
def write(self, sample_batch):
start = time.time()
data = _to_json(sample_batch, self.compress_columns)
f = self._get_file()
f.write(data)
f.write("\n")
if hasattr(f, "flush"): # legacy smart_open impls
f.flush()
self.bytes_written += len(data)
logger.debug("Wrote {} bytes to {} in {}s".format(
len(data), f,
time.time() - start))
def _get_file(self):
if not self.cur_file or self.bytes_written >= self.max_file_size:
if self.cur_file:
self.cur_file.close()
timestr = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
path = os.path.join(
self.path, "output-{}_worker-{}_{}.json".format(
timestr, self.ioctx.worker_index, self.file_index))
if self.path_is_uri:
if smart_open is None:
raise ValueError(
"You must install the `smart_open` module to write "
"to URIs like {}".format(path))
self.cur_file = smart_open(path, "w")
else:
self.cur_file = open(path, "w")
self.file_index += 1
self.bytes_written = 0
logger.info("Writing to new output file {}".format(self.cur_file))
return self.cur_file
def _to_jsonable(v, compress):
if compress:
return str(pack(v))
elif isinstance(v, np.ndarray):
return v.tolist()
return v
def _to_json(batch, compress_columns):
return json.dumps({
k: _to_jsonable(v, compress=k in compress_columns)
for k, v in batch.data.items()
})

View file

@ -0,0 +1,45 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from ray.rllib.offline.input_reader import InputReader
from ray.rllib.offline.json_reader import JsonReader
from ray.rllib.utils.annotations import override
class MixedInput(InputReader):
"""Mixes input from a number of other input sources.
Examples:
>>> MixedInput(ioctx, {
"sampler": 0.4,
"/tmp/experiences/*.json": 0.4,
"s3://bucket/expert.json": 0.2,
})
"""
def __init__(self, ioctx, dist):
"""Initialize a MixedInput.
Arguments:
ioctx (IOContext): current IO context object.
dist (dict): dict mapping JSONReader paths or "sampler" to
probabilities. The probabilities must sum to 1.0.
"""
if sum(dist.values()) != 1.0:
raise ValueError("Values must sum to 1.0: {}".format(dist))
self.choices = []
self.p = []
for k, v in dist.items():
if k == "sampler":
self.choices.append(ioctx.default_sampler_input())
else:
self.choices.append(JsonReader(ioctx, k))
self.p.append(v)
@override(InputReader)
def next(self):
source = np.random.choice(self.choices, p=self.p)
return source.next()

View file

@ -0,0 +1,25 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.utils.annotations import override
class OutputWriter(object):
"""Writer object for saving experiences from policy evaluation."""
def write(self, sample_batch):
"""Save a batch of experiences.
Arguments:
sample_batch: SampleBatch or MultiAgentBatch to save.
"""
raise NotImplementedError
class NoopOutput(OutputWriter):
"""Output writer that discards its outputs."""
@override(OutputWriter)
def write(self, sample_batch):
pass

View file

@ -0,0 +1,235 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import glob
import numpy as np
import os
import shutil
import tempfile
import time
import unittest
import ray
from ray.rllib.agents.pg import PGAgent
from ray.rllib.evaluation import SampleBatch
from ray.rllib.offline import IOContext, JsonWriter, JsonReader
from ray.rllib.offline.json_writer import _to_json
SAMPLES = SampleBatch({
"actions": np.array([1, 2, 3]),
"obs": np.array([4, 5, 6])
})
def make_sample_batch(i):
return SampleBatch({
"actions": np.array([i, i, i]),
"obs": np.array([i, i, i])
})
class AgentIOTest(unittest.TestCase):
def setUp(self):
self.test_dir = tempfile.mkdtemp()
def tearDown(self):
shutil.rmtree(self.test_dir)
def writeOutputs(self, output):
agent = PGAgent(
env="CartPole-v0",
config={
"output": output,
"sample_batch_size": 250,
})
agent.train()
return agent
def testAgentOutputOk(self):
self.writeOutputs(self.test_dir)
self.assertEqual(len(os.listdir(self.test_dir)), 1)
ioctx = IOContext(self.test_dir, {}, 0, None)
reader = JsonReader(ioctx, self.test_dir + "/*.json")
reader.next()
def testAgentOutputLogdir(self):
agent = self.writeOutputs("logdir")
self.assertEqual(len(glob.glob(agent.logdir + "/output-*.json")), 1)
def testAgentInputDir(self):
self.writeOutputs(self.test_dir)
agent = PGAgent(
env="CartPole-v0",
config={
"input": self.test_dir,
"input_evaluation": None,
})
result = agent.train()
self.assertEqual(result["timesteps_total"], 250) # read from input
self.assertTrue(np.isnan(result["episode_reward_mean"]))
def testAgentInputEvalSim(self):
self.writeOutputs(self.test_dir)
agent = PGAgent(
env="CartPole-v0",
config={
"input": self.test_dir,
"input_evaluation": "simulation",
})
for _ in range(50):
result = agent.train()
if not np.isnan(result["episode_reward_mean"]):
return # simulation ok
time.sleep(0.1)
assert False, "did not see any simulation results"
def testAgentInputList(self):
self.writeOutputs(self.test_dir)
agent = PGAgent(
env="CartPole-v0",
config={
"input": glob.glob(self.test_dir + "/*.json"),
"input_evaluation": None,
"sample_batch_size": 99,
})
result = agent.train()
self.assertEqual(result["timesteps_total"], 250) # read from input
self.assertTrue(np.isnan(result["episode_reward_mean"]))
def testAgentInputDict(self):
self.writeOutputs(self.test_dir)
agent = PGAgent(
env="CartPole-v0",
config={
"input": {
self.test_dir: 0.1,
"sampler": 0.9,
},
"train_batch_size": 2000,
"input_evaluation": None,
})
result = agent.train()
self.assertTrue(not np.isnan(result["episode_reward_mean"]))
class JsonIOTest(unittest.TestCase):
def setUp(self):
self.test_dir = tempfile.mkdtemp()
def tearDown(self):
shutil.rmtree(self.test_dir)
def testWriteSimple(self):
ioctx = IOContext(self.test_dir, {}, 0, None)
writer = JsonWriter(
ioctx, self.test_dir, max_file_size=1000, compress_columns=["obs"])
self.assertEqual(len(os.listdir(self.test_dir)), 0)
writer.write(SAMPLES)
writer.write(SAMPLES)
self.assertEqual(len(os.listdir(self.test_dir)), 1)
def testWriteFileURI(self):
ioctx = IOContext(self.test_dir, {}, 0, None)
writer = JsonWriter(
ioctx,
"file:" + self.test_dir,
max_file_size=1000,
compress_columns=["obs"])
self.assertEqual(len(os.listdir(self.test_dir)), 0)
writer.write(SAMPLES)
writer.write(SAMPLES)
self.assertEqual(len(os.listdir(self.test_dir)), 1)
def testWritePaginate(self):
ioctx = IOContext(self.test_dir, {}, 0, None)
writer = JsonWriter(
ioctx, self.test_dir, max_file_size=5000, compress_columns=["obs"])
self.assertEqual(len(os.listdir(self.test_dir)), 0)
for _ in range(100):
writer.write(SAMPLES)
self.assertEqual(len(os.listdir(self.test_dir)), 12)
def testReadWrite(self):
ioctx = IOContext(self.test_dir, {}, 0, None)
writer = JsonWriter(
ioctx, self.test_dir, max_file_size=5000, compress_columns=["obs"])
for i in range(100):
writer.write(make_sample_batch(i))
reader = JsonReader(ioctx, self.test_dir + "/*.json")
seen_a = set()
seen_o = set()
for i in range(1000):
batch = reader.next()
seen_a.add(batch["actions"][0])
seen_o.add(batch["obs"][0])
self.assertGreater(len(seen_a), 90)
self.assertLess(len(seen_a), 101)
self.assertGreater(len(seen_o), 90)
self.assertLess(len(seen_o), 101)
def testSkipsOverEmptyLinesAndFiles(self):
ioctx = IOContext(self.test_dir, {}, 0, None)
open(self.test_dir + "/empty", "w").close()
with open(self.test_dir + "/f1", "w") as f:
f.write("\n")
f.write("\n")
f.write(_to_json(make_sample_batch(0), []))
with open(self.test_dir + "/f2", "w") as f:
f.write(_to_json(make_sample_batch(1), []))
f.write("\n")
reader = JsonReader(ioctx, [
self.test_dir + "/empty",
self.test_dir + "/f1",
"file:" + self.test_dir + "/f2",
])
seen_a = set()
for i in range(100):
batch = reader.next()
seen_a.add(batch["actions"][0])
self.assertEqual(len(seen_a), 2)
def testSkipsOverCorruptedLines(self):
ioctx = IOContext(self.test_dir, {}, 0, None)
with open(self.test_dir + "/f1", "w") as f:
f.write(_to_json(make_sample_batch(0), []))
f.write("\n")
f.write(_to_json(make_sample_batch(1), []))
f.write("\n")
f.write(_to_json(make_sample_batch(2), []))
f.write("\n")
f.write(_to_json(make_sample_batch(3), []))
f.write("\n")
f.write("{..corrupted_json_record")
reader = JsonReader(ioctx, [
self.test_dir + "/f1",
])
seen_a = set()
for i in range(10):
batch = reader.next()
seen_a.add(batch["actions"][0])
self.assertEqual(len(seen_a), 4)
def testAbortOnAllEmptyInputs(self):
ioctx = IOContext(self.test_dir, {}, 0, None)
open(self.test_dir + "/empty", "w").close()
reader = JsonReader(ioctx, [
self.test_dir + "/empty",
])
self.assertRaises(ValueError, lambda: reader.next())
with open(self.test_dir + "/empty1", "w") as f:
for _ in range(100):
f.write("\n")
with open(self.test_dir + "/empty2", "w") as f:
for _ in range(100):
f.write("\n")
reader = JsonReader(ioctx, [
self.test_dir + "/empty1",
self.test_dir + "/empty2",
])
self.assertRaises(ValueError, lambda: reader.next())
if __name__ == "__main__":
ray.init(num_cpus=1)
unittest.main(verbosity=2)

View file

@ -215,6 +215,7 @@ class NestedSpacesTest(unittest.TestCase):
config={
"num_workers": 0,
"sample_batch_size": 5,
"train_batch_size": 5,
"model": {
"custom_model": "composite",
"use_lstm": test_lstm,
@ -243,6 +244,7 @@ class NestedSpacesTest(unittest.TestCase):
config={
"num_workers": 0,
"sample_batch_size": 5,
"train_batch_size": 5,
"model": {
"custom_model": "composite2",
},
@ -302,6 +304,7 @@ class NestedSpacesTest(unittest.TestCase):
config={
"num_workers": 0,
"sample_batch_size": 5,
"train_batch_size": 5,
"multiagent": {
"policy_graphs": {
"tuple_policy": (

View file

@ -186,6 +186,7 @@ class TestPolicyEvaluator(unittest.TestCase):
env="CartPole-v0", config={
"num_workers": 0,
"sample_batch_size": 50,
"train_batch_size": 50,
"callbacks": {
"on_episode_start": lambda x: counts.update({"start": 1}),
"on_episode_step": lambda x: counts.update({"step": 1}),

View file

@ -7,6 +7,7 @@ import time
import base64
import numpy as np
import pyarrow
from six import string_types
logger = logging.getLogger(__name__)
@ -26,7 +27,7 @@ def pack(data):
data = lz4.frame.compress(data)
# TODO(ekl) we shouldn't need to base64 encode this data, but this
# seems to not survive a transfer through the object store if we don't.
data = base64.b64encode(data)
data = base64.b64encode(data).decode("ascii")
return data
@ -45,7 +46,7 @@ def unpack(data):
def unpack_if_needed(data):
if isinstance(data, bytes):
if isinstance(data, bytes) or isinstance(data, string_types):
data = unpack(data)
return data

View file

@ -248,6 +248,9 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
python /ray/python/ray/rllib/test/test_local.py
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
python /ray/python/ray/rllib/test/test_io.py
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
python /ray/python/ray/rllib/test/test_checkpoint_restore.py