ray/rllib/agents/trainer.py
2019-09-21 11:06:34 -07:00

841 lines
34 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from datetime import datetime
import copy
import logging
import os
import pickle
import six
import time
import tempfile
import ray
from ray.exceptions import RayError
from ray.rllib.models import MODEL_DEFAULTS
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.evaluation.metrics import collect_metrics
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
from ray.rllib.utils import FilterManager, deep_update, merge_dicts
from ray.rllib.utils.memory import ray_get_and_free
from ray.rllib.utils import try_import_tf
from ray.tune.registry import ENV_CREATOR, register_env, _global_registry
from ray.tune.trainable import Trainable
from ray.tune.trial import ExportFormat
from ray.tune.resources import Resources
from ray.tune.logger import UnifiedLogger
from ray.tune.result import DEFAULT_RESULTS_DIR
tf = try_import_tf()
logger = logging.getLogger(__name__)
# Max number of times to retry a worker failure. We shouldn't try too many
# times in a row since that would indicate a persistent cluster issue.
MAX_WORKER_FAILURE_RETRIES = 3
# yapf: disable
# __sphinx_doc_begin__
COMMON_CONFIG = {
# === Debugging ===
# Whether to write episode stats and videos to the agent log dir
"monitor": False,
# Set the ray.rllib.* log level for the agent process and its workers.
# Should be one of DEBUG, INFO, WARN, or ERROR. The DEBUG level will also
# periodically print out summaries of relevant internal dataflow (this is
# also printed out once at startup at the INFO level).
"log_level": "INFO",
# Callbacks that will be run during various phases of training. These all
# take a single "info" dict as an argument. For episode callbacks, custom
# metrics can be attached to the episode by updating the episode object's
# custom metrics dict (see examples/custom_metrics_and_callbacks.py). You
# may also mutate the passed in batch data in your callback.
"callbacks": {
"on_episode_start": None, # arg: {"env": .., "episode": ...}
"on_episode_step": None, # arg: {"env": .., "episode": ...}
"on_episode_end": None, # arg: {"env": .., "episode": ...}
"on_sample_end": None, # arg: {"samples": .., "worker": ...}
"on_train_result": None, # arg: {"trainer": ..., "result": ...}
"on_postprocess_traj": None, # arg: {
# "agent_id": ..., "episode": ...,
# "pre_batch": (before processing),
# "post_batch": (after processing),
# "all_pre_batches": (other agent ids),
# }
},
# Whether to attempt to continue training if a worker crashes.
"ignore_worker_failures": False,
# Log system resource metrics to results.
"log_sys_usage": True,
# Enable TF eager execution (TF policies only).
"eager": False,
# Enable tracing in eager mode. This greatly improves performance, but
# makes it slightly harder to debug since Python code won't be evaluated
# after the initial eager pass.
"eager_tracing": False,
# Disable eager execution on workers (but allow it on the driver). This
# only has an effect is eager is enabled.
"no_eager_on_workers": False,
# === Policy ===
# Arguments to pass to model. See models/catalog.py for a full list of the
# available model options.
"model": MODEL_DEFAULTS,
# Arguments to pass to the policy optimizer. These vary by optimizer.
"optimizer": {},
# === Environment ===
# Discount factor of the MDP
"gamma": 0.99,
# Number of steps after which the episode is forced to terminate. Defaults
# to `env.spec.max_episode_steps` (if present) for Gym envs.
"horizon": None,
# Calculate rewards but don't reset the environment when the horizon is
# hit. This allows value estimation and RNN state to span across logical
# episodes denoted by horizon. This only has an effect if horizon != inf.
"soft_horizon": False,
# Don't set 'done' at the end of the episode. Note that you still need to
# set this if soft_horizon=True, unless your env is actually running
# forever without returning done=True.
"no_done_at_end": False,
# Arguments to pass to the env creator
"env_config": {},
# Environment name can also be passed via config
"env": None,
# Whether to clip rewards prior to experience postprocessing. Setting to
# None means clip for Atari only.
"clip_rewards": None,
# Whether to np.clip() actions to the action space low/high range spec.
"clip_actions": True,
# Whether to use rllib or deepmind preprocessors by default
"preprocessor_pref": "deepmind",
# The default learning rate
"lr": 0.0001,
# === Evaluation ===
# Evaluate with every `evaluation_interval` training iterations.
# The evaluation stats will be reported under the "evaluation" metric key.
# Note that evaluation is currently not parallelized, and that for Ape-X
# metrics are already only reported for the lowest epsilon workers.
"evaluation_interval": None,
# Number of episodes to run per evaluation period.
"evaluation_num_episodes": 10,
# Extra arguments to pass to evaluation workers.
# Typical usage is to pass extra args to evaluation env creator
# and to disable exploration by computing deterministic actions
# TODO(kismuz): implement determ. actions and include relevant keys hints
"evaluation_config": {},
# === Resources ===
# Number of actors used for parallelism
"num_workers": 2,
# Number of GPUs to allocate to the trainer process. Note that not all
# algorithms can take advantage of trainer GPUs. This can be fractional
# (e.g., 0.3 GPUs).
"num_gpus": 0,
# Number of CPUs to allocate per worker.
"num_cpus_per_worker": 1,
# Number of GPUs to allocate per worker. This can be fractional.
"num_gpus_per_worker": 0,
# Any custom resources to allocate per worker.
"custom_resources_per_worker": {},
# Number of CPUs to allocate for the trainer. Note: this only takes effect
# when running in Tune.
"num_cpus_for_driver": 1,
# === Memory quota ===
# You can set these memory quotas to tell Ray to reserve memory for your
# training run. This guarantees predictable execution, but the tradeoff is
# if your workload exceeeds the memory quota it will fail.
# Heap memory to reserve for the trainer process (0 for unlimited). This
# can be large if your are using large train batches, replay buffers, etc.
"memory": 0,
# Object store memory to reserve for the trainer process. Being large
# enough to fit a few copies of the model weights should be sufficient.
# This is enabled by default since models are typically quite small.
"object_store_memory": 0,
# Heap memory to reserve for each worker. Should generally be small unless
# your environment is very heavyweight.
"memory_per_worker": 0,
# Object store memory to reserve for each worker. This only needs to be
# large enough to fit a few sample batches at a time. This is enabled
# by default since it almost never needs to be larger than ~200MB.
"object_store_memory_per_worker": 0,
# === Execution ===
# Number of environments to evaluate vectorwise per worker.
"num_envs_per_worker": 1,
# Default sample batch size (unroll length). Batches of this size are
# collected from workers until train_batch_size is met. When using
# multiple envs per worker, this is multiplied by num_envs_per_worker.
"sample_batch_size": 200,
# Training batch size, if applicable. Should be >= sample_batch_size.
# Samples batches will be concatenated together to this size for training.
"train_batch_size": 200,
# Whether to rollout "complete_episodes" or "truncate_episodes"
"batch_mode": "truncate_episodes",
# Use a background thread for sampling (slightly off-policy, usually not
# advisable to turn on unless your env specifically requires it)
"sample_async": False,
# Element-wise observation filter, either "NoFilter" or "MeanStdFilter"
"observation_filter": "NoFilter",
# Whether to synchronize the statistics of remote filters.
"synchronize_filters": True,
# Configure TF for single-process operation by default
"tf_session_args": {
# note: overriden by `local_tf_session_args`
"intra_op_parallelism_threads": 2,
"inter_op_parallelism_threads": 2,
"gpu_options": {
"allow_growth": True,
},
"log_device_placement": False,
"device_count": {
"CPU": 1
},
"allow_soft_placement": True, # required by PPO multi-gpu
},
# Override the following tf session args on the local worker
"local_tf_session_args": {
# Allow a higher level of parallelism by default, but not unlimited
# since that can cause crashes with many concurrent drivers.
"intra_op_parallelism_threads": 8,
"inter_op_parallelism_threads": 8,
},
# Whether to LZ4 compress individual observations
"compress_observations": False,
# Wait for metric batches for at most this many seconds. Those that
# have not returned in time will be collected in the next iteration.
"collect_metrics_timeout": 180,
# Smooth metrics over this many episodes.
"metrics_smoothing_episodes": 100,
# If using num_envs_per_worker > 1, whether to create those new envs in
# remote processes instead of in the same worker. This adds overheads, but
# can make sense if your envs can take much time to step / reset
# (e.g., for StarCraft). Use this cautiously; overheads are significant.
"remote_worker_envs": False,
# Timeout that remote workers are waiting when polling environments.
# 0 (continue when at least one env is ready) is a reasonable default,
# but optimal value could be obtained by measuring your environment
# step / reset and model inference perf.
"remote_env_batch_wait_ms": 0,
# Minimum time per iteration
"min_iter_time_s": 0,
# Minimum env steps to optimize for per train call. This value does
# not affect learning, only the length of iterations.
"timesteps_per_iteration": 0,
# This argument, in conjunction with worker_index, sets the random seed of
# each worker, so that identically configured trials will have identical
# results. This makes experiments reproducible.
"seed": None,
# === Offline Datasets ===
# Specify how to generate experiences:
# - "sampler": generate experiences via online simulation (default)
# - a local directory or file glob expression (e.g., "/tmp/*.json")
# - a list of individual file paths/URIs (e.g., ["/tmp/1.json",
# "s3://bucket/2.json"])
# - a dict with string keys and sampling probabilities as values (e.g.,
# {"sampler": 0.4, "/tmp/*.json": 0.4, "s3://bucket/expert.json": 0.2}).
# - a function that returns a rllib.offline.InputReader
"input": "sampler",
# Specify how to evaluate the current policy. This only has an effect when
# reading offline experiences. Available options:
# - "wis": the weighted step-wise importance sampling estimator.
# - "is": the step-wise importance sampling estimator.
# - "simulation": run the environment in the background, but use
# this data for evaluation only and not for learning.
"input_evaluation": ["is", "wis"],
# Whether to run postprocess_trajectory() on the trajectory fragments from
# offline inputs. Note that postprocessing will be done using the *current*
# policy, not the *behaviour* policy, which is typically undesirable for
# on-policy algorithms.
"postprocess_inputs": False,
# If positive, input batches will be shuffled via a sliding window buffer
# of this number of batches. Use this if the input data is not in random
# enough order. Input is delayed until the shuffle buffer is filled.
"shuffle_buffer_size": 0,
# Specify where experiences should be saved:
# - None: don't save any experiences
# - "logdir" to save to the agent log dir
# - a path/URI to save to a custom output directory (e.g., "s3://bucket/")
# - a function that returns a rllib.offline.OutputWriter
"output": None,
# What sample batch columns to LZ4 compress in the output data.
"output_compress_columns": ["obs", "new_obs"],
# Max output file size before rolling over to a new file.
"output_max_file_size": 64 * 1024 * 1024,
# === Multiagent ===
"multiagent": {
# Map from policy ids to tuples of (policy_cls, obs_space,
# act_space, config). See rollout_worker.py for more info.
"policies": {},
# Function mapping agent ids to policy ids.
"policy_mapping_fn": None,
# Optional whitelist of policies to train, or None for all policies.
"policies_to_train": None,
},
}
# __sphinx_doc_end__
# yapf: enable
@DeveloperAPI
def with_common_config(extra_config):
"""Returns the given config dict merged with common agent confs."""
return with_base_config(COMMON_CONFIG, extra_config)
def with_base_config(base_config, extra_config):
"""Returns the given config dict merged with a base agent conf."""
config = copy.deepcopy(base_config)
config.update(extra_config)
return config
@PublicAPI
class Trainer(Trainable):
"""A trainer coordinates the optimization of one or more RL policies.
All RLlib trainers extend this base class, e.g., the A3CTrainer implements
the A3C algorithm for single and multi-agent training.
Trainer objects retain internal model state between calls to train(), so
you should create a new trainer instance for each training session.
Attributes:
env_creator (func): Function that creates a new training env.
config (obj): Algorithm-specific configuration data.
logdir (str): Directory in which training outputs should be placed.
"""
_allow_unknown_configs = False
_allow_unknown_subkeys = [
"tf_session_args", "local_tf_session_args", "env_config", "model",
"optimizer", "multiagent", "custom_resources_per_worker",
"evaluation_config"
]
@PublicAPI
def __init__(self, config=None, env=None, logger_creator=None):
"""Initialize an RLLib trainer.
Args:
config (dict): Algorithm-specific configuration data.
env (str): Name of the environment to use. Note that this can also
be specified as the `env` key in config.
logger_creator (func): Function that creates a ray.tune.Logger
object. If unspecified, a default logger is created.
"""
config = config or {}
if tf and config.get("eager"):
tf.enable_eager_execution()
logger.info("Executing eagerly, with eager_tracing={}".format(
"True" if config.get("eager_tracing") else "False"))
if tf and not tf.executing_eagerly():
logger.info("Tip: set 'eager': true or the --eager flag to enable "
"TensorFlow eager execution")
# Vars to synchronize to workers on each train call
self.global_vars = {"timestep": 0}
# Trainers allow env ids to be passed directly to the constructor.
self._env_id = self._register_if_needed(env or config.get("env"))
# Create a default logger creator if no logger_creator is specified
if logger_creator is None:
timestr = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
logdir_prefix = "{}_{}_{}".format(self._name, self._env_id,
timestr)
def default_logger_creator(config):
"""Creates a Unified logger with a default logdir prefix
containing the agent name and the env id
"""
if not os.path.exists(DEFAULT_RESULTS_DIR):
os.makedirs(DEFAULT_RESULTS_DIR)
logdir = tempfile.mkdtemp(
prefix=logdir_prefix, dir=DEFAULT_RESULTS_DIR)
return UnifiedLogger(config, logdir, loggers=None)
logger_creator = default_logger_creator
Trainable.__init__(self, config, logger_creator)
@classmethod
@override(Trainable)
def default_resource_request(cls, config):
cf = dict(cls._default_config, **config)
Trainer._validate_config(cf)
# TODO(ekl): add custom resources here once tune supports them
return Resources(
cpu=cf["num_cpus_for_driver"],
gpu=cf["num_gpus"],
memory=cf["memory"],
object_store_memory=cf["object_store_memory"],
extra_cpu=cf["num_cpus_per_worker"] * cf["num_workers"],
extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"],
extra_memory=cf["memory_per_worker"] * cf["num_workers"],
extra_object_store_memory=cf["object_store_memory_per_worker"] *
cf["num_workers"])
@override(Trainable)
@PublicAPI
def train(self):
"""Overrides super.train to synchronize global vars."""
if self._has_policy_optimizer():
self.global_vars["timestep"] = self.optimizer.num_steps_sampled
self.optimizer.workers.local_worker().set_global_vars(
self.global_vars)
for w in self.optimizer.workers.remote_workers():
w.set_global_vars.remote(self.global_vars)
logger.debug("updated global vars: {}".format(self.global_vars))
result = None
for _ in range(1 + MAX_WORKER_FAILURE_RETRIES):
try:
result = Trainable.train(self)
except RayError as e:
if self.config["ignore_worker_failures"]:
logger.exception(
"Error in train call, attempting to recover")
self._try_recover()
else:
logger.info(
"Worker crashed during call to train(). To attempt to "
"continue training without the failed worker, set "
"`'ignore_worker_failures': True`.")
raise e
except Exception as e:
time.sleep(0.5) # allow logs messages to propagate
raise e
else:
break
if result is None:
raise RuntimeError("Failed to recover from worker crash")
if (self.config.get("observation_filter", "NoFilter") != "NoFilter"
and hasattr(self, "workers")
and isinstance(self.workers, WorkerSet)):
FilterManager.synchronize(
self.workers.local_worker().filters,
self.workers.remote_workers(),
update_remote=self.config["synchronize_filters"])
logger.debug("synchronized filters: {}".format(
self.workers.local_worker().filters))
if self._has_policy_optimizer():
result["num_healthy_workers"] = len(
self.optimizer.workers.remote_workers())
if self.config["evaluation_interval"]:
if self._iteration % self.config["evaluation_interval"] == 0:
evaluation_metrics = self._evaluate()
assert isinstance(evaluation_metrics, dict), \
"_evaluate() needs to return a dict."
result.update(evaluation_metrics)
return result
@override(Trainable)
def _log_result(self, result):
if self.config["callbacks"].get("on_train_result"):
self.config["callbacks"]["on_train_result"]({
"trainer": self,
"result": result,
})
# log after the callback is invoked, so that the user has a chance
# to mutate the result
Trainable._log_result(self, result)
@override(Trainable)
def _setup(self, config):
env = self._env_id
if env:
config["env"] = env
if _global_registry.contains(ENV_CREATOR, env):
self.env_creator = _global_registry.get(ENV_CREATOR, env)
else:
import gym # soft dependency
self.env_creator = lambda env_config: gym.make(env)
else:
self.env_creator = lambda env_config: None
# Merge the supplied config with the class default
merged_config = copy.deepcopy(self._default_config)
merged_config = deep_update(merged_config, config,
self._allow_unknown_configs,
self._allow_unknown_subkeys)
self.raw_user_config = config
self.config = merged_config
Trainer._validate_config(self.config)
if self.config.get("log_level"):
logging.getLogger("ray.rllib").setLevel(self.config["log_level"])
def get_scope():
if tf and not tf.executing_eagerly():
return tf.Graph().as_default()
else:
return open("/dev/null") # fake a no-op scope
with get_scope():
self._init(self.config, self.env_creator)
# Evaluation related
if self.config.get("evaluation_interval"):
# Update env_config with evaluation settings:
extra_config = copy.deepcopy(self.config["evaluation_config"])
extra_config.update({
"batch_mode": "complete_episodes",
"batch_steps": 1,
})
logger.debug(
"using evaluation_config: {}".format(extra_config))
self.evaluation_workers = self._make_workers(
self.env_creator,
self._policy,
merge_dicts(self.config, extra_config),
num_workers=0)
self.evaluation_metrics = self._evaluate()
@override(Trainable)
def _stop(self):
if hasattr(self, "workers"):
self.workers.stop()
if hasattr(self, "optimizer"):
self.optimizer.stop()
@override(Trainable)
def _save(self, checkpoint_dir):
checkpoint_path = os.path.join(checkpoint_dir,
"checkpoint-{}".format(self.iteration))
pickle.dump(self.__getstate__(), open(checkpoint_path, "wb"))
return checkpoint_path
@override(Trainable)
def _restore(self, checkpoint_path):
extra_data = pickle.load(open(checkpoint_path, "rb"))
self.__setstate__(extra_data)
@DeveloperAPI
def _make_workers(self, env_creator, policy, config, num_workers):
return WorkerSet(
env_creator,
policy,
config,
num_workers=num_workers,
logdir=self.logdir)
@DeveloperAPI
def _init(self, config, env_creator):
"""Subclasses should override this for custom initialization."""
raise NotImplementedError
@DeveloperAPI
def _evaluate(self):
"""Evaluates current policy under `evaluation_config` settings.
Note that this default implementation does not do anything beyond
merging evaluation_config with the normal trainer config.
"""
if not self.config["evaluation_config"]:
raise ValueError(
"No evaluation_config specified. It doesn't make sense "
"to enable evaluation without specifying any config "
"overrides, since the results will be the "
"same as reported during normal policy evaluation.")
logger.info("Evaluating current policy for {} episodes".format(
self.config["evaluation_num_episodes"]))
self._before_evaluate()
self.evaluation_workers.local_worker().restore(
self.workers.local_worker().save())
for _ in range(self.config["evaluation_num_episodes"]):
self.evaluation_workers.local_worker().sample()
metrics = collect_metrics(self.evaluation_workers.local_worker())
return {"evaluation": metrics}
@DeveloperAPI
def _before_evaluate(self):
"""Pre-evaluation callback."""
pass
@PublicAPI
def compute_action(self,
observation,
state=None,
prev_action=None,
prev_reward=None,
info=None,
policy_id=DEFAULT_POLICY_ID,
full_fetch=False):
"""Computes an action for the specified policy.
Note that you can also access the policy object through
self.get_policy(policy_id) and call compute_actions() on it directly.
Arguments:
observation (obj): observation from the environment.
state (list): RNN hidden state, if any. If state is not None,
then all of compute_single_action(...) is returned
(computed action, rnn state, logits dictionary).
Otherwise compute_single_action(...)[0] is
returned (computed action).
prev_action (obj): previous action value, if any
prev_reward (int): previous reward, if any
info (dict): info object, if any
policy_id (str): policy to query (only applies to multi-agent).
full_fetch (bool): whether to return extra action fetch results.
This is always set to true if RNN state is specified.
Returns:
Just the computed action if full_fetch=False, or the full output
of policy.compute_actions() otherwise.
"""
if state is None:
state = []
preprocessed = self.workers.local_worker().preprocessors[
policy_id].transform(observation)
filtered_obs = self.workers.local_worker().filters[policy_id](
preprocessed, update=False)
if state:
return self.get_policy(policy_id).compute_single_action(
filtered_obs,
state,
prev_action,
prev_reward,
info,
clip_actions=self.config["clip_actions"])
res = self.get_policy(policy_id).compute_single_action(
filtered_obs,
state,
prev_action,
prev_reward,
info,
clip_actions=self.config["clip_actions"])
if full_fetch:
return res
else:
return res[0] # backwards compatibility
@property
def _name(self):
"""Subclasses should override this to declare their name."""
raise NotImplementedError
@property
def _default_config(self):
"""Subclasses should override this to declare their default config."""
raise NotImplementedError
@PublicAPI
def get_policy(self, policy_id=DEFAULT_POLICY_ID):
"""Return policy for the specified id, or None.
Arguments:
policy_id (str): id of policy to return.
"""
return self.workers.local_worker().get_policy(policy_id)
@PublicAPI
def get_weights(self, policies=None):
"""Return a dictionary of policy ids to weights.
Arguments:
policies (list): Optional list of policies to return weights for,
or None for all policies.
"""
return self.workers.local_worker().get_weights(policies)
@PublicAPI
def set_weights(self, weights):
"""Set policy weights by policy id.
Arguments:
weights (dict): Map of policy ids to weights to set.
"""
self.workers.local_worker().set_weights(weights)
@DeveloperAPI
def export_policy_model(self, export_dir, policy_id=DEFAULT_POLICY_ID):
"""Export policy model with given policy_id to local directory.
Arguments:
export_dir (string): Writable local directory.
policy_id (string): Optional policy id to export.
Example:
>>> trainer = MyTrainer()
>>> for _ in range(10):
>>> trainer.train()
>>> trainer.export_policy_model("/tmp/export_dir")
"""
self.workers.local_worker().export_policy_model(export_dir, policy_id)
@DeveloperAPI
def export_policy_checkpoint(self,
export_dir,
filename_prefix="model",
policy_id=DEFAULT_POLICY_ID):
"""Export tensorflow policy model checkpoint to local directory.
Arguments:
export_dir (string): Writable local directory.
filename_prefix (string): file name prefix of checkpoint files.
policy_id (string): Optional policy id to export.
Example:
>>> trainer = MyTrainer()
>>> for _ in range(10):
>>> trainer.train()
>>> trainer.export_policy_checkpoint("/tmp/export_dir")
"""
self.workers.local_worker().export_policy_checkpoint(
export_dir, filename_prefix, policy_id)
@DeveloperAPI
def collect_metrics(self, selected_workers=None):
"""Collects metrics from the remote workers of this agent.
This is the same data as returned by a call to train().
"""
return self.optimizer.collect_metrics(
self.config["collect_metrics_timeout"],
min_history=self.config["metrics_smoothing_episodes"],
selected_workers=selected_workers)
@classmethod
def resource_help(cls, config):
return ("\n\nYou can adjust the resource requests of RLlib agents by "
"setting `num_workers`, `num_gpus`, and other configs. See "
"the DEFAULT_CONFIG defined by each agent for more info.\n\n"
"The config of this agent is: {}".format(config))
@staticmethod
def _validate_config(config):
if "policy_graphs" in config["multiagent"]:
logger.warning(
"The `policy_graphs` config has been renamed to `policies`.")
# Backwards compatibility
config["multiagent"]["policies"] = config["multiagent"][
"policy_graphs"]
del config["multiagent"]["policy_graphs"]
if "gpu" in config:
raise ValueError(
"The `gpu` config is deprecated, please use `num_gpus=0|1` "
"instead.")
if "gpu_fraction" in config:
raise ValueError(
"The `gpu_fraction` config is deprecated, please use "
"`num_gpus=<fraction>` instead.")
if "use_gpu_for_workers" in config:
raise ValueError(
"The `use_gpu_for_workers` config is deprecated, please use "
"`num_gpus_per_worker=1` instead.")
if type(config["input_evaluation"]) != list:
raise ValueError(
"`input_evaluation` must be a list of strings, got {}".format(
config["input_evaluation"]))
def _try_recover(self):
"""Try to identify and blacklist any unhealthy workers.
This method is called after an unexpected remote error is encountered
from a worker. It issues check requests to all current workers and
blacklists any that respond with error. If no healthy workers remain,
an error is raised.
"""
if not self._has_policy_optimizer():
raise NotImplementedError(
"Recovery is not supported for this algorithm")
logger.info("Health checking all workers...")
checks = []
for ev in self.optimizer.workers.remote_workers():
_, obj_id = ev.sample_with_count.remote()
checks.append(obj_id)
healthy_workers = []
for i, obj_id in enumerate(checks):
w = self.optimizer.workers.remote_workers()[i]
try:
ray_get_and_free(obj_id)
healthy_workers.append(w)
logger.info("Worker {} looks healthy".format(i + 1))
except RayError:
logger.exception("Blacklisting worker {}".format(i + 1))
try:
w.__ray_terminate__.remote()
except Exception:
logger.exception("Error terminating unhealthy worker")
if len(healthy_workers) < 1:
raise RuntimeError(
"Not enough healthy workers remain to continue.")
self.optimizer.reset(healthy_workers)
def _has_policy_optimizer(self):
return hasattr(self, "optimizer") and isinstance(
self.optimizer, PolicyOptimizer)
@override(Trainable)
def _export_model(self, export_formats, export_dir):
ExportFormat.validate(export_formats)
exported = {}
if ExportFormat.CHECKPOINT in export_formats:
path = os.path.join(export_dir, ExportFormat.CHECKPOINT)
self.export_policy_checkpoint(path)
exported[ExportFormat.CHECKPOINT] = path
if ExportFormat.MODEL in export_formats:
path = os.path.join(export_dir, ExportFormat.MODEL)
self.export_policy_model(path)
exported[ExportFormat.MODEL] = path
return exported
def __getstate__(self):
state = {}
if hasattr(self, "workers"):
state["worker"] = self.workers.local_worker().save()
if hasattr(self, "optimizer") and hasattr(self.optimizer, "save"):
state["optimizer"] = self.optimizer.save()
return state
def __setstate__(self, state):
if "worker" in state:
self.workers.local_worker().restore(state["worker"])
remote_state = ray.put(state["worker"])
for r in self.workers.remote_workers():
r.restore.remote(remote_state)
if "optimizer" in state:
self.optimizer.restore(state["optimizer"])
def _register_if_needed(self, env_object):
if isinstance(env_object, six.string_types):
return env_object
elif isinstance(env_object, type):
name = env_object.__name__
register_env(name, lambda config: env_object(config))
return name
raise ValueError(
"{} is an invalid env specification. ".format(env_object) +
"You can specify a custom env as either a class "
"(e.g., YourEnvCls) or a registered env id (e.g., \"your_env\").")