mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[rllib] Basic Offline Data IO API (#3473)
This commit is contained in:
parent
cc8f7db246
commit
32473cf22e
18 changed files with 801 additions and 48 deletions
|
@ -5,7 +5,7 @@ FROM ray-project/deploy
|
|||
# This updates numpy to 1.14 and mutes errors from other libraries
|
||||
RUN conda install -y numpy
|
||||
RUN apt-get install -y zlib1g-dev
|
||||
RUN pip install gym[atari] opencv-python==3.2.0.8 tensorflow lz4 keras pytest-timeout
|
||||
RUN pip install gym[atari] opencv-python==3.2.0.8 tensorflow lz4 keras pytest-timeout smart_open
|
||||
RUN pip install -U h5py # Mutes FutureWarnings
|
||||
RUN pip install --upgrade git+git://github.com/hyperopt/hyperopt.git
|
||||
RUN conda install pytorch-cpu torchvision-cpu -c pytorch
|
||||
|
|
|
@ -10,8 +10,10 @@ import pickle
|
|||
import six
|
||||
import tempfile
|
||||
import tensorflow as tf
|
||||
from types import FunctionType
|
||||
|
||||
import ray
|
||||
from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter
|
||||
from ray.rllib.models import MODEL_DEFAULTS
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
|
@ -122,11 +124,45 @@ COMMON_CONFIG = {
|
|||
"intra_op_parallelism_threads": 8,
|
||||
"inter_op_parallelism_threads": 8,
|
||||
},
|
||||
# Whether to LZ4 compress observations
|
||||
# Whether to LZ4 compress individual observations
|
||||
"compress_observations": False,
|
||||
# Drop metric batches from unresponsive workers after this many seconds
|
||||
"collect_metrics_timeout": 180,
|
||||
|
||||
# === Offline Data Input / Output ===
|
||||
# Specify how to generate experiences:
|
||||
# - "sampler": generate experiences via online simulation (default)
|
||||
# - a local directory or file glob expression (e.g., "/tmp/*.json")
|
||||
# - a list of individual file paths/URIs (e.g., ["/tmp/1.json",
|
||||
# "s3://bucket/2.json"])
|
||||
# - a dict with string keys and sampling probabilities as values (e.g.,
|
||||
# {"sampler": 0.4, "/tmp/*.json": 0.4, "s3://bucket/expert.json": 0.2}).
|
||||
# - a function that returns a rllib.offline.InputReader
|
||||
"input": "sampler",
|
||||
# Specify how to evaluate the current policy. This only makes sense to set
|
||||
# when the input is not already generating simulation data:
|
||||
# - None: don't evaluate the policy. The episode reward and other
|
||||
# metrics will be NaN if using offline data.
|
||||
# - "simulation": run the environment in the background, but use
|
||||
# this data for evaluation only and not for learning.
|
||||
# - "counterfactual": use counterfactual policy evaluation to estimate
|
||||
# performance (this option is not implemented yet).
|
||||
"input_evaluation": None,
|
||||
# Specify where experiences should be saved:
|
||||
# - None: don't save any experiences
|
||||
# - "logdir" to save to the agent log dir
|
||||
# - a path/URI to save to a custom output directory (e.g., "s3://bucket/")
|
||||
# - a function that returns a rllib.offline.OutputWriter
|
||||
"output": None,
|
||||
# What sample batch columns to LZ4 compress in the output data.
|
||||
"output_compress_columns": ["obs", "new_obs"],
|
||||
# Max output file size before rolling over to a new file.
|
||||
"output_max_file_size": 64 * 1024 * 1024,
|
||||
# Whether to run postprocess_trajectory() on the trajectory fragments from
|
||||
# offline inputs. Whether this makes sense is algorithm-specific.
|
||||
# TODO(ekl) implement this and multi-agent batch handling
|
||||
# "postprocess_inputs": False,
|
||||
|
||||
# === Multiagent ===
|
||||
"multiagent": {
|
||||
# Map from policy ids to tuples of (policy_graph_cls, obs_space,
|
||||
|
@ -179,7 +215,6 @@ class Agent(Trainable):
|
|||
"""
|
||||
|
||||
config = config or {}
|
||||
Agent._validate_config(config)
|
||||
|
||||
# Vars to synchronize to evaluators on each train call
|
||||
self.global_vars = {"timestep": 0}
|
||||
|
@ -267,6 +302,7 @@ class Agent(Trainable):
|
|||
self._allow_unknown_configs,
|
||||
self._allow_unknown_subkeys)
|
||||
self.config = merged_config
|
||||
Agent._validate_config(self.config)
|
||||
if self.config.get("log_level"):
|
||||
logging.getLogger("ray.rllib").setLevel(self.config["log_level"])
|
||||
|
||||
|
@ -391,6 +427,32 @@ class Agent(Trainable):
|
|||
self.config) for i in range(count)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def resource_help(cls, config):
|
||||
return ("\n\nYou can adjust the resource requests of RLlib agents by "
|
||||
"setting `num_workers` and other configs. See the "
|
||||
"DEFAULT_CONFIG defined by each agent for more info.\n\n"
|
||||
"The config of this agent is: {}".format(config))
|
||||
|
||||
@staticmethod
|
||||
def _validate_config(config):
|
||||
if "gpu" in config:
|
||||
raise ValueError(
|
||||
"The `gpu` config is deprecated, please use `num_gpus=0|1` "
|
||||
"instead.")
|
||||
if "gpu_fraction" in config:
|
||||
raise ValueError(
|
||||
"The `gpu_fraction` config is deprecated, please use "
|
||||
"`num_gpus=<fraction>` instead.")
|
||||
if "use_gpu_for_workers" in config:
|
||||
raise ValueError(
|
||||
"The `use_gpu_for_workers` config is deprecated, please use "
|
||||
"`num_gpus_per_worker=1` instead.")
|
||||
if (config["input"] == "sampler"
|
||||
and config["input_evaluation"] is not None):
|
||||
raise ValueError(
|
||||
"`input_evaluation` should not be set when input=sampler")
|
||||
|
||||
def _make_evaluator(self, cls, env_creator, policy_graph, worker_index,
|
||||
config):
|
||||
def session_creator():
|
||||
|
@ -399,6 +461,32 @@ class Agent(Trainable):
|
|||
return tf.Session(
|
||||
config=tf.ConfigProto(**config["tf_session_args"]))
|
||||
|
||||
if isinstance(config["input"], FunctionType):
|
||||
input_creator = config["input"]
|
||||
elif config["input"] == "sampler":
|
||||
input_creator = (lambda ioctx: ioctx.default_sampler_input())
|
||||
elif isinstance(config["input"], dict):
|
||||
input_creator = (lambda ioctx: MixedInput(ioctx, config["input"]))
|
||||
else:
|
||||
input_creator = (lambda ioctx: JsonReader(ioctx, config["input"]))
|
||||
|
||||
if isinstance(config["output"], FunctionType):
|
||||
output_creator = config["output"]
|
||||
elif config["output"] is None:
|
||||
output_creator = (lambda ioctx: NoopOutput())
|
||||
elif config["output"] == "logdir":
|
||||
output_creator = (lambda ioctx: JsonWriter(
|
||||
ioctx,
|
||||
ioctx.log_dir,
|
||||
max_file_size=config["output_max_file_size"],
|
||||
compress_columns=config["output_compress_columns"]))
|
||||
else:
|
||||
output_creator = (lambda ioctx: JsonWriter(
|
||||
ioctx,
|
||||
config["output"],
|
||||
max_file_size=config["output_max_file_size"],
|
||||
compress_columns=config["output_compress_columns"]))
|
||||
|
||||
return cls(
|
||||
env_creator,
|
||||
self.config["multiagent"]["policy_graphs"] or policy_graph,
|
||||
|
@ -421,30 +509,12 @@ class Agent(Trainable):
|
|||
policy_config=config,
|
||||
worker_index=worker_index,
|
||||
monitor_path=self.logdir if config["monitor"] else None,
|
||||
log_dir=self.logdir,
|
||||
log_level=config["log_level"],
|
||||
callbacks=config["callbacks"])
|
||||
|
||||
@classmethod
|
||||
def resource_help(cls, config):
|
||||
return ("\n\nYou can adjust the resource requests of RLlib agents by "
|
||||
"setting `num_workers` and other configs. See the "
|
||||
"DEFAULT_CONFIG defined by each agent for more info.\n\n"
|
||||
"The config of this agent is: {}".format(config))
|
||||
|
||||
@staticmethod
|
||||
def _validate_config(config):
|
||||
if "gpu" in config:
|
||||
raise ValueError(
|
||||
"The `gpu` config is deprecated, please use `num_gpus=0|1` "
|
||||
"instead.")
|
||||
if "gpu_fraction" in config:
|
||||
raise ValueError(
|
||||
"The `gpu_fraction` config is deprecated, please use "
|
||||
"`num_gpus=<fraction>` instead.")
|
||||
if "use_gpu_for_workers" in config:
|
||||
raise ValueError(
|
||||
"The `use_gpu_for_workers` config is deprecated, please use "
|
||||
"`num_gpus_per_worker=1` instead.")
|
||||
callbacks=config["callbacks"],
|
||||
input_creator=input_creator,
|
||||
input_evaluation_method=config["input_evaluation"],
|
||||
output_creator=output_creator)
|
||||
|
||||
def __getstate__(self):
|
||||
state = {}
|
||||
|
|
|
@ -36,9 +36,11 @@ class PGAgent(Agent):
|
|||
self.env_creator, self._policy_graph)
|
||||
self.remote_evaluators = self.make_remote_evaluators(
|
||||
self.env_creator, self._policy_graph, self.config["num_workers"])
|
||||
self.optimizer = SyncSamplesOptimizer(self.local_evaluator,
|
||||
self.remote_evaluators,
|
||||
self.config["optimizer"])
|
||||
optimizer_config = dict(
|
||||
self.config["optimizer"],
|
||||
**{"train_batch_size": self.config["train_batch_size"]})
|
||||
self.optimizer = SyncSamplesOptimizer(
|
||||
self.local_evaluator, self.remote_evaluators, optimizer_config)
|
||||
|
||||
@override(Agent)
|
||||
def _train(self):
|
||||
|
|
|
@ -18,6 +18,7 @@ from ray.rllib.evaluation.sample_batch import MultiAgentBatch, \
|
|||
from ray.rllib.evaluation.sampler import AsyncSampler, SyncSampler
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
|
||||
from ray.rllib.offline import NoopOutput, IOContext, OutputWriter, InputReader
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.models.preprocessors import NoPreprocessor
|
||||
from ray.rllib.utils import merge_dicts
|
||||
|
@ -108,8 +109,12 @@ class PolicyEvaluator(EvaluatorInterface):
|
|||
policy_config=None,
|
||||
worker_index=0,
|
||||
monitor_path=None,
|
||||
log_dir=None,
|
||||
log_level=None,
|
||||
callbacks=None):
|
||||
callbacks=None,
|
||||
input_creator=lambda ioctx: ioctx.default_sampler_input(),
|
||||
input_evaluation_method=None,
|
||||
output_creator=lambda ioctx: NoopOutput()):
|
||||
"""Initialize a policy evaluator.
|
||||
|
||||
Arguments:
|
||||
|
@ -170,8 +175,22 @@ class PolicyEvaluator(EvaluatorInterface):
|
|||
through EnvContext so that envs can be configured per worker.
|
||||
monitor_path (str): Write out episode stats and videos to this
|
||||
directory if specified.
|
||||
log_dir (str): Directory where logs can be placed.
|
||||
log_level (str): Set the root log level on creation.
|
||||
callbacks (dict): Dict of custom debug callbacks.
|
||||
input_creator (func): Function that returns an InputReader object
|
||||
for loading previous generated experiences.
|
||||
input_evaluation_method (str): How to evaluate the current policy.
|
||||
This only applies when the input is reading offline data.
|
||||
Options are:
|
||||
- None: don't evaluate the policy. The episode reward and
|
||||
other metrics will be NaN.
|
||||
- "simulation": run the environment in the background, but
|
||||
use this data for evaluation only and never for learning.
|
||||
- "counterfactual": use counterfactual policy evaluation to
|
||||
estimate performance.
|
||||
output_creator (func): Function that returns an OutputWriter object
|
||||
for saving generated experiences.
|
||||
"""
|
||||
|
||||
if log_level:
|
||||
|
@ -279,6 +298,20 @@ class PolicyEvaluator(EvaluatorInterface):
|
|||
else:
|
||||
raise ValueError("Unsupported batch mode: {}".format(
|
||||
self.batch_mode))
|
||||
|
||||
if input_evaluation_method == "simulation":
|
||||
logger.warn(
|
||||
"Requested 'simulation' input evaluation method: "
|
||||
"will discard all sampler outputs and keep only metrics.")
|
||||
sample_async = True
|
||||
elif input_evaluation_method == "counterfactual":
|
||||
raise NotImplementedError
|
||||
elif input_evaluation_method is None:
|
||||
pass
|
||||
else:
|
||||
raise ValueError("Unknown evaluation method: {}".format(
|
||||
input_evaluation_method))
|
||||
|
||||
if sample_async:
|
||||
self.sampler = AsyncSampler(
|
||||
self.async_env,
|
||||
|
@ -292,7 +325,8 @@ class PolicyEvaluator(EvaluatorInterface):
|
|||
horizon=episode_horizon,
|
||||
pack=pack_episodes,
|
||||
tf_sess=self.tf_sess,
|
||||
clip_actions=clip_actions)
|
||||
clip_actions=clip_actions,
|
||||
blackhole_outputs=input_evaluation_method == "simulation")
|
||||
self.sampler.start()
|
||||
else:
|
||||
self.sampler = SyncSampler(
|
||||
|
@ -309,6 +343,12 @@ class PolicyEvaluator(EvaluatorInterface):
|
|||
tf_sess=self.tf_sess,
|
||||
clip_actions=clip_actions)
|
||||
|
||||
self.io_context = IOContext(log_dir, policy_config, worker_index, self)
|
||||
self.input_reader = input_creator(self.io_context)
|
||||
assert isinstance(self.input_reader, InputReader), self.input_reader
|
||||
self.output_writer = output_creator(self.io_context)
|
||||
assert isinstance(self.output_writer, OutputWriter), self.output_writer
|
||||
|
||||
logger.debug("Created evaluator with env {} ({}), policies {}".format(
|
||||
self.async_env, self.env, self.policy_map))
|
||||
|
||||
|
@ -320,7 +360,7 @@ class PolicyEvaluator(EvaluatorInterface):
|
|||
SampleBatch|MultiAgentBatch from evaluating the current policies.
|
||||
"""
|
||||
|
||||
batches = [self.sampler.get_data()]
|
||||
batches = [self.input_reader.next()]
|
||||
steps_so_far = batches[0].count
|
||||
|
||||
# In truncate_episodes mode, never pull more than 1 batch per env.
|
||||
|
@ -332,10 +372,9 @@ class PolicyEvaluator(EvaluatorInterface):
|
|||
|
||||
while steps_so_far < self.sample_batch_size and len(
|
||||
batches) < max_batches:
|
||||
batch = self.sampler.get_data()
|
||||
batch = self.input_reader.next()
|
||||
steps_so_far += batch.count
|
||||
batches.append(batch)
|
||||
batches.extend(self.sampler.get_extra_batches())
|
||||
batch = batches[0].concat_samples(batches)
|
||||
|
||||
if self.callbacks.get("on_sample_end"):
|
||||
|
@ -353,6 +392,7 @@ class PolicyEvaluator(EvaluatorInterface):
|
|||
batch["obs"] = [pack(o) for o in batch["obs"]]
|
||||
batch["new_obs"] = [pack(o) for o in batch["new_obs"]]
|
||||
|
||||
self.output_writer.write(batch)
|
||||
return batch
|
||||
|
||||
@ray.method(num_return_vals=2)
|
||||
|
@ -531,6 +571,10 @@ class PolicyEvaluator(EvaluatorInterface):
|
|||
policy_map[name] = cls(obs_space, act_space, merged_conf)
|
||||
return policy_map, preprocessors
|
||||
|
||||
def __del__(self):
|
||||
if isinstance(self.sampler, AsyncSampler):
|
||||
self.sampler.shutdown = True
|
||||
|
||||
|
||||
def _validate_and_canonicalize(policy_graph, env):
|
||||
if isinstance(policy_graph, dict):
|
||||
|
|
|
@ -2,6 +2,7 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import six
|
||||
import collections
|
||||
import numpy as np
|
||||
|
||||
|
@ -233,8 +234,9 @@ class SampleBatch(object):
|
|||
self.data = dict(*args, **kwargs)
|
||||
lengths = []
|
||||
for k, v in self.data.copy().items():
|
||||
assert type(k) == str, self
|
||||
assert isinstance(k, six.string_types), self
|
||||
lengths.append(len(v))
|
||||
self.data[k] = np.array(v, copy=False)
|
||||
if not lengths:
|
||||
raise ValueError("Empty sample batch")
|
||||
assert len(set(lengths)) == 1, "data columns must be same length"
|
||||
|
|
|
@ -112,7 +112,8 @@ class AsyncSampler(threading.Thread):
|
|||
horizon=None,
|
||||
pack=False,
|
||||
tf_sess=None,
|
||||
clip_actions=True):
|
||||
clip_actions=True,
|
||||
blackhole_outputs=False):
|
||||
for _, f in obs_filters.items():
|
||||
assert getattr(f, "is_concurrent", False), \
|
||||
"Observation Filter must support concurrent updates."
|
||||
|
@ -133,6 +134,8 @@ class AsyncSampler(threading.Thread):
|
|||
self.tf_sess = tf_sess
|
||||
self.callbacks = callbacks
|
||||
self.clip_actions = clip_actions
|
||||
self.blackhole_outputs = blackhole_outputs
|
||||
self.shutdown = False
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
|
@ -142,12 +145,19 @@ class AsyncSampler(threading.Thread):
|
|||
raise e
|
||||
|
||||
def _run(self):
|
||||
if self.blackhole_outputs:
|
||||
queue_putter = (lambda x: None)
|
||||
extra_batches_putter = (lambda x: None)
|
||||
else:
|
||||
queue_putter = self.queue.put
|
||||
extra_batches_putter = (
|
||||
lambda x: self.extra_batches.put(x, timeout=600.0))
|
||||
rollout_provider = _env_runner(
|
||||
self.async_vector_env, self.extra_batches.put, self.policies,
|
||||
self.async_vector_env, extra_batches_putter, self.policies,
|
||||
self.policy_mapping_fn, self.unroll_length, self.horizon,
|
||||
self.preprocessors, self.obs_filters, self.clip_rewards,
|
||||
self.clip_actions, self.pack, self.callbacks, self.tf_sess)
|
||||
while True:
|
||||
while not self.shutdown:
|
||||
# The timeout variable exists because apparently, if one worker
|
||||
# dies, the other workers won't die with it, unless the timeout is
|
||||
# set to some large number. This is an empirical observation.
|
||||
|
@ -155,7 +165,7 @@ class AsyncSampler(threading.Thread):
|
|||
if isinstance(item, RolloutMetrics):
|
||||
self.metrics_queue.put(item)
|
||||
else:
|
||||
self.queue.put(item, timeout=600.0)
|
||||
queue_putter(item)
|
||||
|
||||
def get_data(self):
|
||||
rollout = self.queue.get(timeout=600.0)
|
||||
|
@ -246,7 +256,7 @@ def _env_runner(async_vector_env,
|
|||
horizon = (
|
||||
async_vector_env.get_unwrapped()[0].spec.max_episode_steps)
|
||||
except Exception:
|
||||
logger.warn("no episode horizon specified, assuming inf")
|
||||
logger.debug("no episode horizon specified, assuming inf")
|
||||
if not horizon:
|
||||
horizon = float("inf")
|
||||
|
||||
|
@ -332,12 +342,12 @@ def _process_observations(async_vector_env, policies, batch_builder_pool,
|
|||
"More than {} observations for {} env steps ".format(
|
||||
episode.batch_builder.total(),
|
||||
episode.batch_builder.count) + "are buffered in "
|
||||
"the sampler. If this is not intentional, check that the "
|
||||
"the `horizon` config is set correctly, or consider setting "
|
||||
"`batch_mode` to 'truncate_episodes'. Note that in "
|
||||
"multi-agent environments, `sample_batch_size` sets the "
|
||||
"batch size based on environment steps, not the steps of "
|
||||
"individual agents.")
|
||||
"the sampler. If this is more than you expected, check that "
|
||||
"that you set a horizon on your environment correctly. Note "
|
||||
"that in multi-agent environments, `sample_batch_size` sets "
|
||||
"the batch size based on environment steps, not the steps of "
|
||||
"individual agents, which can result in unexpectedly large "
|
||||
"batches.")
|
||||
|
||||
# Check episode termination conditions
|
||||
if dones[env_id]["__all__"] or episode.length >= horizon:
|
||||
|
|
20
python/ray/rllib/offline/__init__.py
Normal file
20
python/ray/rllib/offline/__init__.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.offline.io_context import IOContext
|
||||
from ray.rllib.offline.json_reader import JsonReader
|
||||
from ray.rllib.offline.json_writer import JsonWriter
|
||||
from ray.rllib.offline.output_writer import OutputWriter, NoopOutput
|
||||
from ray.rllib.offline.input_reader import InputReader
|
||||
from ray.rllib.offline.mixed_input import MixedInput
|
||||
|
||||
__all__ = [
|
||||
"IOContext",
|
||||
"JsonReader",
|
||||
"JsonWriter",
|
||||
"NoopOutput",
|
||||
"OutputWriter",
|
||||
"InputReader",
|
||||
"MixedInput",
|
||||
]
|
30
python/ray/rllib/offline/input_reader.py
Normal file
30
python/ray/rllib/offline/input_reader.py
Normal file
|
@ -0,0 +1,30 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
||||
|
||||
class InputReader(object):
|
||||
"""Input object for loading experiences in policy evaluation."""
|
||||
|
||||
def next(self):
|
||||
"""Return the next batch of experiences read."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SamplerInput(InputReader):
|
||||
"""Reads input experiences from an existing sampler."""
|
||||
|
||||
def __init__(self, sampler):
|
||||
self.sampler = sampler
|
||||
|
||||
@override(InputReader)
|
||||
def next(self):
|
||||
batches = [self.sampler.get_data()]
|
||||
batches.extend(self.sampler.get_extra_batches())
|
||||
if len(batches) > 1:
|
||||
return batches[0].concat_samples(batches)
|
||||
else:
|
||||
return batches[0]
|
28
python/ray/rllib/offline/io_context.py
Normal file
28
python/ray/rllib/offline/io_context.py
Normal file
|
@ -0,0 +1,28 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.offline.input_reader import SamplerInput
|
||||
|
||||
|
||||
class IOContext(object):
|
||||
"""Attributes to pass to input / output class constructors.
|
||||
|
||||
RLlib auto-sets these attributes when constructing input / output classes.
|
||||
|
||||
Attributes:
|
||||
log_dir (str): Default logging directory.
|
||||
config (dict): Configuration of the agent.
|
||||
worker_index (int): When there are multiple workers created, this
|
||||
uniquely identifies the current worker.
|
||||
evaluator (PolicyEvaluator): policy evaluator object reference.
|
||||
"""
|
||||
|
||||
def __init__(self, log_dir, config, worker_index, evaluator):
|
||||
self.log_dir = log_dir
|
||||
self.config = config
|
||||
self.worker_index = worker_index
|
||||
self.evaluator = evaluator
|
||||
|
||||
def default_sampler_input(self):
|
||||
return SamplerInput(self.evaluator.sampler)
|
126
python/ray/rllib/offline/json_reader.py
Normal file
126
python/ray/rllib/offline/json_reader.py
Normal file
|
@ -0,0 +1,126 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import six
|
||||
from six.moves.urllib.parse import urlparse
|
||||
|
||||
try:
|
||||
from smart_open import smart_open
|
||||
except ImportError:
|
||||
smart_open = None
|
||||
|
||||
from ray.rllib.offline.input_reader import InputReader
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.compression import unpack_if_needed
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class JsonReader(InputReader):
|
||||
"""Reader object that loads experiences from JSON file chunks.
|
||||
|
||||
The input files will be read from in an random order."""
|
||||
|
||||
def __init__(self, ioctx, inputs):
|
||||
"""Initialize a JsonReader.
|
||||
|
||||
Arguments:
|
||||
ioctx (IOContext): current IO context object.
|
||||
inputs (str|list): either a glob expression for files, e.g.,
|
||||
"/tmp/**/*.json", or a list of single file paths or URIs, e.g.,
|
||||
["s3://bucket/file.json", "s3://bucket/file2.json"].
|
||||
"""
|
||||
|
||||
self.ioctx = ioctx
|
||||
if isinstance(inputs, six.string_types):
|
||||
if os.path.isdir(inputs):
|
||||
inputs = os.path.join(inputs, "*.json")
|
||||
logger.warn(
|
||||
"Treating input directory as glob pattern: {}".format(
|
||||
inputs))
|
||||
if urlparse(inputs).scheme:
|
||||
raise ValueError(
|
||||
"Don't know how to glob over `{}`, ".format(inputs) +
|
||||
"please specify a list of files to read instead.")
|
||||
else:
|
||||
self.files = glob.glob(inputs)
|
||||
elif type(inputs) is list:
|
||||
self.files = inputs
|
||||
else:
|
||||
raise ValueError(
|
||||
"type of inputs must be list or str, not {}".format(inputs))
|
||||
if self.files:
|
||||
logger.info("Found {} input files.".format(len(self.files)))
|
||||
else:
|
||||
raise ValueError("No files found matching {}".format(inputs))
|
||||
self.cur_file = None
|
||||
|
||||
@override(InputReader)
|
||||
def next(self):
|
||||
batch = self._try_parse(self._next_line())
|
||||
tries = 0
|
||||
while not batch and tries < 100:
|
||||
tries += 1
|
||||
logger.debug("Skipping empty line in {}".format(self.cur_file))
|
||||
batch = self._try_parse(self._next_line())
|
||||
if not batch:
|
||||
raise ValueError(
|
||||
"Failed to read valid experience batch from file: {}".format(
|
||||
self.cur_file))
|
||||
return batch
|
||||
|
||||
def _try_parse(self, line):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
return None
|
||||
try:
|
||||
return _from_json(line)
|
||||
except Exception:
|
||||
logger.exception("Ignoring corrupt json record in {}: {}".format(
|
||||
self.cur_file, line))
|
||||
return None
|
||||
|
||||
def _next_line(self):
|
||||
if not self.cur_file:
|
||||
self.cur_file = self._next_file()
|
||||
line = self.cur_file.readline()
|
||||
tries = 0
|
||||
while not line and tries < 100:
|
||||
tries += 1
|
||||
if hasattr(self.cur_file, "close"): # legacy smart_open impls
|
||||
self.cur_file.close()
|
||||
self.cur_file = self._next_file()
|
||||
line = self.cur_file.readline()
|
||||
if not line:
|
||||
logger.debug("Ignoring empty file {}".format(self.cur_file))
|
||||
if not line:
|
||||
raise ValueError("Failed to read next line from files: {}".format(
|
||||
self.files))
|
||||
return line
|
||||
|
||||
def _next_file(self):
|
||||
path = random.choice(self.files)
|
||||
if urlparse(path).scheme:
|
||||
if smart_open is None:
|
||||
raise ValueError(
|
||||
"You must install the `smart_open` module to read "
|
||||
"from URIs like {}".format(path))
|
||||
return smart_open(path, "r")
|
||||
else:
|
||||
return open(path, "r")
|
||||
|
||||
|
||||
def _from_json(batch):
|
||||
if isinstance(batch, bytes): # smart_open S3 doesn't respect "r"
|
||||
batch = batch.decode("utf-8")
|
||||
data = json.loads(batch)
|
||||
for k, v in data.items():
|
||||
data[k] = [unpack_if_needed(x) for x in unpack_if_needed(v)]
|
||||
return SampleBatch(data)
|
108
python/ray/rllib/offline/json_writer.py
Normal file
108
python/ray/rllib/offline/json_writer.py
Normal file
|
@ -0,0 +1,108 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from datetime import datetime
|
||||
import json
|
||||
import logging
|
||||
import numpy as np
|
||||
import os
|
||||
from six.moves.urllib.parse import urlparse
|
||||
import time
|
||||
|
||||
try:
|
||||
from smart_open import smart_open
|
||||
except ImportError:
|
||||
smart_open = None
|
||||
|
||||
from ray.rllib.offline.output_writer import OutputWriter
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.compression import pack
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class JsonWriter(OutputWriter):
|
||||
"""Writer object that saves experiences in JSON file chunks."""
|
||||
|
||||
def __init__(self,
|
||||
ioctx,
|
||||
path,
|
||||
max_file_size=64 * 1024 * 1024,
|
||||
compress_columns=frozenset(["obs", "new_obs"])):
|
||||
"""Initialize a JsonWriter.
|
||||
|
||||
Arguments:
|
||||
ioctx (IOContext): current IO context object.
|
||||
path (str): a path/URI of the output directory to save files in.
|
||||
max_file_size (int): max size of single files before rolling over.
|
||||
compress_columns (list): list of sample batch columns to compress.
|
||||
"""
|
||||
|
||||
self.ioctx = ioctx
|
||||
self.path = path
|
||||
self.max_file_size = max_file_size
|
||||
self.compress_columns = compress_columns
|
||||
if urlparse(path).scheme:
|
||||
self.path_is_uri = True
|
||||
else:
|
||||
# Try to create local dirs if they don't exist
|
||||
try:
|
||||
os.makedirs(path)
|
||||
except OSError:
|
||||
pass # already exists
|
||||
assert os.path.exists(path), "Failed to create {}".format(path)
|
||||
self.path_is_uri = False
|
||||
self.file_index = 0
|
||||
self.bytes_written = 0
|
||||
self.cur_file = None
|
||||
|
||||
@override(OutputWriter)
|
||||
def write(self, sample_batch):
|
||||
start = time.time()
|
||||
data = _to_json(sample_batch, self.compress_columns)
|
||||
f = self._get_file()
|
||||
f.write(data)
|
||||
f.write("\n")
|
||||
if hasattr(f, "flush"): # legacy smart_open impls
|
||||
f.flush()
|
||||
self.bytes_written += len(data)
|
||||
logger.debug("Wrote {} bytes to {} in {}s".format(
|
||||
len(data), f,
|
||||
time.time() - start))
|
||||
|
||||
def _get_file(self):
|
||||
if not self.cur_file or self.bytes_written >= self.max_file_size:
|
||||
if self.cur_file:
|
||||
self.cur_file.close()
|
||||
timestr = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
path = os.path.join(
|
||||
self.path, "output-{}_worker-{}_{}.json".format(
|
||||
timestr, self.ioctx.worker_index, self.file_index))
|
||||
if self.path_is_uri:
|
||||
if smart_open is None:
|
||||
raise ValueError(
|
||||
"You must install the `smart_open` module to write "
|
||||
"to URIs like {}".format(path))
|
||||
self.cur_file = smart_open(path, "w")
|
||||
else:
|
||||
self.cur_file = open(path, "w")
|
||||
self.file_index += 1
|
||||
self.bytes_written = 0
|
||||
logger.info("Writing to new output file {}".format(self.cur_file))
|
||||
return self.cur_file
|
||||
|
||||
|
||||
def _to_jsonable(v, compress):
|
||||
if compress:
|
||||
return str(pack(v))
|
||||
elif isinstance(v, np.ndarray):
|
||||
return v.tolist()
|
||||
return v
|
||||
|
||||
|
||||
def _to_json(batch, compress_columns):
|
||||
return json.dumps({
|
||||
k: _to_jsonable(v, compress=k in compress_columns)
|
||||
for k, v in batch.data.items()
|
||||
})
|
45
python/ray/rllib/offline/mixed_input.py
Normal file
45
python/ray/rllib/offline/mixed_input.py
Normal file
|
@ -0,0 +1,45 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.offline.input_reader import InputReader
|
||||
from ray.rllib.offline.json_reader import JsonReader
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
||||
|
||||
class MixedInput(InputReader):
|
||||
"""Mixes input from a number of other input sources.
|
||||
|
||||
Examples:
|
||||
>>> MixedInput(ioctx, {
|
||||
"sampler": 0.4,
|
||||
"/tmp/experiences/*.json": 0.4,
|
||||
"s3://bucket/expert.json": 0.2,
|
||||
})
|
||||
"""
|
||||
|
||||
def __init__(self, ioctx, dist):
|
||||
"""Initialize a MixedInput.
|
||||
|
||||
Arguments:
|
||||
ioctx (IOContext): current IO context object.
|
||||
dist (dict): dict mapping JSONReader paths or "sampler" to
|
||||
probabilities. The probabilities must sum to 1.0.
|
||||
"""
|
||||
if sum(dist.values()) != 1.0:
|
||||
raise ValueError("Values must sum to 1.0: {}".format(dist))
|
||||
self.choices = []
|
||||
self.p = []
|
||||
for k, v in dist.items():
|
||||
if k == "sampler":
|
||||
self.choices.append(ioctx.default_sampler_input())
|
||||
else:
|
||||
self.choices.append(JsonReader(ioctx, k))
|
||||
self.p.append(v)
|
||||
|
||||
@override(InputReader)
|
||||
def next(self):
|
||||
source = np.random.choice(self.choices, p=self.p)
|
||||
return source.next()
|
25
python/ray/rllib/offline/output_writer.py
Normal file
25
python/ray/rllib/offline/output_writer.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
||||
|
||||
class OutputWriter(object):
|
||||
"""Writer object for saving experiences from policy evaluation."""
|
||||
|
||||
def write(self, sample_batch):
|
||||
"""Save a batch of experiences.
|
||||
|
||||
Arguments:
|
||||
sample_batch: SampleBatch or MultiAgentBatch to save.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class NoopOutput(OutputWriter):
|
||||
"""Output writer that discards its outputs."""
|
||||
|
||||
@override(OutputWriter)
|
||||
def write(self, sample_batch):
|
||||
pass
|
235
python/ray/rllib/test/test_io.py
Normal file
235
python/ray/rllib/test/test_io.py
Normal file
|
@ -0,0 +1,235 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import glob
|
||||
import numpy as np
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.pg import PGAgent
|
||||
from ray.rllib.evaluation import SampleBatch
|
||||
from ray.rllib.offline import IOContext, JsonWriter, JsonReader
|
||||
from ray.rllib.offline.json_writer import _to_json
|
||||
|
||||
SAMPLES = SampleBatch({
|
||||
"actions": np.array([1, 2, 3]),
|
||||
"obs": np.array([4, 5, 6])
|
||||
})
|
||||
|
||||
|
||||
def make_sample_batch(i):
|
||||
return SampleBatch({
|
||||
"actions": np.array([i, i, i]),
|
||||
"obs": np.array([i, i, i])
|
||||
})
|
||||
|
||||
|
||||
class AgentIOTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.test_dir = tempfile.mkdtemp()
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.test_dir)
|
||||
|
||||
def writeOutputs(self, output):
|
||||
agent = PGAgent(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"output": output,
|
||||
"sample_batch_size": 250,
|
||||
})
|
||||
agent.train()
|
||||
return agent
|
||||
|
||||
def testAgentOutputOk(self):
|
||||
self.writeOutputs(self.test_dir)
|
||||
self.assertEqual(len(os.listdir(self.test_dir)), 1)
|
||||
ioctx = IOContext(self.test_dir, {}, 0, None)
|
||||
reader = JsonReader(ioctx, self.test_dir + "/*.json")
|
||||
reader.next()
|
||||
|
||||
def testAgentOutputLogdir(self):
|
||||
agent = self.writeOutputs("logdir")
|
||||
self.assertEqual(len(glob.glob(agent.logdir + "/output-*.json")), 1)
|
||||
|
||||
def testAgentInputDir(self):
|
||||
self.writeOutputs(self.test_dir)
|
||||
agent = PGAgent(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"input": self.test_dir,
|
||||
"input_evaluation": None,
|
||||
})
|
||||
result = agent.train()
|
||||
self.assertEqual(result["timesteps_total"], 250) # read from input
|
||||
self.assertTrue(np.isnan(result["episode_reward_mean"]))
|
||||
|
||||
def testAgentInputEvalSim(self):
|
||||
self.writeOutputs(self.test_dir)
|
||||
agent = PGAgent(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"input": self.test_dir,
|
||||
"input_evaluation": "simulation",
|
||||
})
|
||||
for _ in range(50):
|
||||
result = agent.train()
|
||||
if not np.isnan(result["episode_reward_mean"]):
|
||||
return # simulation ok
|
||||
time.sleep(0.1)
|
||||
assert False, "did not see any simulation results"
|
||||
|
||||
def testAgentInputList(self):
|
||||
self.writeOutputs(self.test_dir)
|
||||
agent = PGAgent(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"input": glob.glob(self.test_dir + "/*.json"),
|
||||
"input_evaluation": None,
|
||||
"sample_batch_size": 99,
|
||||
})
|
||||
result = agent.train()
|
||||
self.assertEqual(result["timesteps_total"], 250) # read from input
|
||||
self.assertTrue(np.isnan(result["episode_reward_mean"]))
|
||||
|
||||
def testAgentInputDict(self):
|
||||
self.writeOutputs(self.test_dir)
|
||||
agent = PGAgent(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"input": {
|
||||
self.test_dir: 0.1,
|
||||
"sampler": 0.9,
|
||||
},
|
||||
"train_batch_size": 2000,
|
||||
"input_evaluation": None,
|
||||
})
|
||||
result = agent.train()
|
||||
self.assertTrue(not np.isnan(result["episode_reward_mean"]))
|
||||
|
||||
|
||||
class JsonIOTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.test_dir = tempfile.mkdtemp()
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.test_dir)
|
||||
|
||||
def testWriteSimple(self):
|
||||
ioctx = IOContext(self.test_dir, {}, 0, None)
|
||||
writer = JsonWriter(
|
||||
ioctx, self.test_dir, max_file_size=1000, compress_columns=["obs"])
|
||||
self.assertEqual(len(os.listdir(self.test_dir)), 0)
|
||||
writer.write(SAMPLES)
|
||||
writer.write(SAMPLES)
|
||||
self.assertEqual(len(os.listdir(self.test_dir)), 1)
|
||||
|
||||
def testWriteFileURI(self):
|
||||
ioctx = IOContext(self.test_dir, {}, 0, None)
|
||||
writer = JsonWriter(
|
||||
ioctx,
|
||||
"file:" + self.test_dir,
|
||||
max_file_size=1000,
|
||||
compress_columns=["obs"])
|
||||
self.assertEqual(len(os.listdir(self.test_dir)), 0)
|
||||
writer.write(SAMPLES)
|
||||
writer.write(SAMPLES)
|
||||
self.assertEqual(len(os.listdir(self.test_dir)), 1)
|
||||
|
||||
def testWritePaginate(self):
|
||||
ioctx = IOContext(self.test_dir, {}, 0, None)
|
||||
writer = JsonWriter(
|
||||
ioctx, self.test_dir, max_file_size=5000, compress_columns=["obs"])
|
||||
self.assertEqual(len(os.listdir(self.test_dir)), 0)
|
||||
for _ in range(100):
|
||||
writer.write(SAMPLES)
|
||||
self.assertEqual(len(os.listdir(self.test_dir)), 12)
|
||||
|
||||
def testReadWrite(self):
|
||||
ioctx = IOContext(self.test_dir, {}, 0, None)
|
||||
writer = JsonWriter(
|
||||
ioctx, self.test_dir, max_file_size=5000, compress_columns=["obs"])
|
||||
for i in range(100):
|
||||
writer.write(make_sample_batch(i))
|
||||
reader = JsonReader(ioctx, self.test_dir + "/*.json")
|
||||
seen_a = set()
|
||||
seen_o = set()
|
||||
for i in range(1000):
|
||||
batch = reader.next()
|
||||
seen_a.add(batch["actions"][0])
|
||||
seen_o.add(batch["obs"][0])
|
||||
self.assertGreater(len(seen_a), 90)
|
||||
self.assertLess(len(seen_a), 101)
|
||||
self.assertGreater(len(seen_o), 90)
|
||||
self.assertLess(len(seen_o), 101)
|
||||
|
||||
def testSkipsOverEmptyLinesAndFiles(self):
|
||||
ioctx = IOContext(self.test_dir, {}, 0, None)
|
||||
open(self.test_dir + "/empty", "w").close()
|
||||
with open(self.test_dir + "/f1", "w") as f:
|
||||
f.write("\n")
|
||||
f.write("\n")
|
||||
f.write(_to_json(make_sample_batch(0), []))
|
||||
with open(self.test_dir + "/f2", "w") as f:
|
||||
f.write(_to_json(make_sample_batch(1), []))
|
||||
f.write("\n")
|
||||
reader = JsonReader(ioctx, [
|
||||
self.test_dir + "/empty",
|
||||
self.test_dir + "/f1",
|
||||
"file:" + self.test_dir + "/f2",
|
||||
])
|
||||
seen_a = set()
|
||||
for i in range(100):
|
||||
batch = reader.next()
|
||||
seen_a.add(batch["actions"][0])
|
||||
self.assertEqual(len(seen_a), 2)
|
||||
|
||||
def testSkipsOverCorruptedLines(self):
|
||||
ioctx = IOContext(self.test_dir, {}, 0, None)
|
||||
with open(self.test_dir + "/f1", "w") as f:
|
||||
f.write(_to_json(make_sample_batch(0), []))
|
||||
f.write("\n")
|
||||
f.write(_to_json(make_sample_batch(1), []))
|
||||
f.write("\n")
|
||||
f.write(_to_json(make_sample_batch(2), []))
|
||||
f.write("\n")
|
||||
f.write(_to_json(make_sample_batch(3), []))
|
||||
f.write("\n")
|
||||
f.write("{..corrupted_json_record")
|
||||
reader = JsonReader(ioctx, [
|
||||
self.test_dir + "/f1",
|
||||
])
|
||||
seen_a = set()
|
||||
for i in range(10):
|
||||
batch = reader.next()
|
||||
seen_a.add(batch["actions"][0])
|
||||
self.assertEqual(len(seen_a), 4)
|
||||
|
||||
def testAbortOnAllEmptyInputs(self):
|
||||
ioctx = IOContext(self.test_dir, {}, 0, None)
|
||||
open(self.test_dir + "/empty", "w").close()
|
||||
reader = JsonReader(ioctx, [
|
||||
self.test_dir + "/empty",
|
||||
])
|
||||
self.assertRaises(ValueError, lambda: reader.next())
|
||||
with open(self.test_dir + "/empty1", "w") as f:
|
||||
for _ in range(100):
|
||||
f.write("\n")
|
||||
with open(self.test_dir + "/empty2", "w") as f:
|
||||
for _ in range(100):
|
||||
f.write("\n")
|
||||
reader = JsonReader(ioctx, [
|
||||
self.test_dir + "/empty1",
|
||||
self.test_dir + "/empty2",
|
||||
])
|
||||
self.assertRaises(ValueError, lambda: reader.next())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ray.init(num_cpus=1)
|
||||
unittest.main(verbosity=2)
|
|
@ -215,6 +215,7 @@ class NestedSpacesTest(unittest.TestCase):
|
|||
config={
|
||||
"num_workers": 0,
|
||||
"sample_batch_size": 5,
|
||||
"train_batch_size": 5,
|
||||
"model": {
|
||||
"custom_model": "composite",
|
||||
"use_lstm": test_lstm,
|
||||
|
@ -243,6 +244,7 @@ class NestedSpacesTest(unittest.TestCase):
|
|||
config={
|
||||
"num_workers": 0,
|
||||
"sample_batch_size": 5,
|
||||
"train_batch_size": 5,
|
||||
"model": {
|
||||
"custom_model": "composite2",
|
||||
},
|
||||
|
@ -302,6 +304,7 @@ class NestedSpacesTest(unittest.TestCase):
|
|||
config={
|
||||
"num_workers": 0,
|
||||
"sample_batch_size": 5,
|
||||
"train_batch_size": 5,
|
||||
"multiagent": {
|
||||
"policy_graphs": {
|
||||
"tuple_policy": (
|
||||
|
|
|
@ -186,6 +186,7 @@ class TestPolicyEvaluator(unittest.TestCase):
|
|||
env="CartPole-v0", config={
|
||||
"num_workers": 0,
|
||||
"sample_batch_size": 50,
|
||||
"train_batch_size": 50,
|
||||
"callbacks": {
|
||||
"on_episode_start": lambda x: counts.update({"start": 1}),
|
||||
"on_episode_step": lambda x: counts.update({"step": 1}),
|
||||
|
|
|
@ -7,6 +7,7 @@ import time
|
|||
import base64
|
||||
import numpy as np
|
||||
import pyarrow
|
||||
from six import string_types
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -26,7 +27,7 @@ def pack(data):
|
|||
data = lz4.frame.compress(data)
|
||||
# TODO(ekl) we shouldn't need to base64 encode this data, but this
|
||||
# seems to not survive a transfer through the object store if we don't.
|
||||
data = base64.b64encode(data)
|
||||
data = base64.b64encode(data).decode("ascii")
|
||||
return data
|
||||
|
||||
|
||||
|
@ -45,7 +46,7 @@ def unpack(data):
|
|||
|
||||
|
||||
def unpack_if_needed(data):
|
||||
if isinstance(data, bytes):
|
||||
if isinstance(data, bytes) or isinstance(data, string_types):
|
||||
data = unpack(data)
|
||||
return data
|
||||
|
||||
|
|
|
@ -248,6 +248,9 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
|||
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
||||
python /ray/python/ray/rllib/test/test_local.py
|
||||
|
||||
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
||||
python /ray/python/ray/rllib/test/test_io.py
|
||||
|
||||
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
||||
python /ray/python/ray/rllib/test/test_checkpoint_restore.py
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue