ray/rllib/algorithms/algorithm_config.py

1272 lines
62 KiB
Python

import copy
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, Union
import gym
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.env.env_context import EnvContext
from ray.rllib.evaluation.collectors.sample_collector import SampleCollector
from ray.rllib.evaluation.collectors.simple_list_collector import SimpleListCollector
from ray.rllib.models import MODEL_DEFAULTS
from ray.rllib.utils import deep_update, merge_dicts
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
from ray.rllib.utils.typing import (
AlgorithmConfigDict,
EnvConfigDict,
EnvType,
PartialAlgorithmConfigDict,
ResultDict,
)
from ray.tune.logger import Logger
if TYPE_CHECKING:
from ray.rllib.algorithms.algorithm import Algorithm
class AlgorithmConfig:
"""A RLlib AlgorithmConfig builds an RLlib Algorithm from a given configuration.
Example:
>>> from ray.rllib.algorithms.callbacks import MemoryTrackingCallbacks
>>> # Construct a generic config object, specifying values within different
>>> # sub-categories, e.g. "training".
>>> config = AlgorithmConfig().training(gamma=0.9, lr=0.01)
... .environment(env="CartPole-v1")
... .resources(num_gpus=0)
... .rollouts(num_rollout_workers=4)
... .callbacks(MemoryTrackingCallbacks)
>>> # A config object can be used to construct the respective Trainer.
>>> rllib_trainer = config.build()
Example:
>>> from ray import tune
>>> # In combination with a tune.grid_search:
>>> config = AlgorithmConfig()
>>> config.training(lr=tune.grid_search([0.01, 0.001]))
>>> # Use `to_dict()` method to get the legacy plain python config dict
>>> # for usage with `tune.run()`.
>>> tune.run("[registered trainer class]", config=config.to_dict())
"""
def __init__(self, algo_class=None):
# Define all settings and their default values.
# Define the default RLlib Trainer class that this AlgorithmConfig will be
# applied to.
self.algo_class = algo_class
# `self.python_environment()`
self.extra_python_environs_for_driver = {}
self.extra_python_environs_for_worker = {}
# `self.resources()`
self.num_gpus = 0
self.num_cpus_per_worker = 1
self.num_gpus_per_worker = 0
self._fake_gpus = False
self.num_cpus_for_local_worker = 1
self.custom_resources_per_worker = {}
self.placement_strategy = "PACK"
# `self.framework()`
self.framework_str = "tf"
self.eager_tracing = False
self.eager_max_retraces = 20
self.tf_session_args = {
# note: overridden by `local_tf_session_args`
"intra_op_parallelism_threads": 2,
"inter_op_parallelism_threads": 2,
"gpu_options": {
"allow_growth": True,
},
"log_device_placement": False,
"device_count": {"CPU": 1},
# Required by multi-GPU (num_gpus > 1).
"allow_soft_placement": True,
}
self.local_tf_session_args = {
# Allow a higher level of parallelism by default, but not unlimited
# since that can cause crashes with many concurrent drivers.
"intra_op_parallelism_threads": 8,
"inter_op_parallelism_threads": 8,
}
# `self.environment()`
self.env = None
self.env_config = {}
self.observation_space = None
self.action_space = None
self.env_task_fn = None
self.render_env = False
self.clip_rewards = None
self.normalize_actions = True
self.clip_actions = False
self.disable_env_checking = False
# `self.rollouts()`
self.num_workers = 2
self.num_envs_per_worker = 1
self.sample_collector = SimpleListCollector
self.create_env_on_local_worker = False
self.sample_async = False
self.enable_connectors = False
self.rollout_fragment_length = 200
self.batch_mode = "truncate_episodes"
self.remote_worker_envs = False
self.remote_env_batch_wait_ms = 0
self.validate_workers_after_construction = True
self.ignore_worker_failures = False
self.recreate_failed_workers = False
self.restart_failed_sub_environments = False
self.num_consecutive_worker_failures_tolerance = 100
self.horizon = None
self.soft_horizon = False
self.no_done_at_end = False
self.preprocessor_pref = "deepmind"
self.observation_filter = "NoFilter"
self.synchronize_filters = True
self.compress_observations = False
self.enable_tf1_exec_eagerly = False
self.sampler_perf_stats_ema_coef = None
# `self.training()`
self.gamma = 0.99
self.lr = 0.001
self.train_batch_size = 32
self.model = copy.deepcopy(MODEL_DEFAULTS)
self.optimizer = {}
# `self.callbacks()`
self.callbacks_class = DefaultCallbacks
# `self.explore()`
self.explore = True
self.exploration_config = {
# The Exploration class to use. In the simplest case, this is the name
# (str) of any class present in the `rllib.utils.exploration` package.
# You can also provide the python class directly or the full location
# of your class (e.g. "ray.rllib.utils.exploration.epsilon_greedy.
# EpsilonGreedy").
"type": "StochasticSampling",
# Add constructor kwargs here (if any).
}
# `self.multi_agent()`
self.policies = {}
self.policy_map_capacity = 100
self.policy_map_cache = None
self.policy_mapping_fn = None
self.policies_to_train = None
self.observation_fn = None
self.replay_mode = "independent"
self.count_steps_by = "env_steps"
# `self.offline_data()`
self.input_ = "sampler"
self.input_config = {}
self.actions_in_input_normalized = False
self.postprocess_inputs = False
self.shuffle_buffer_size = 0
self.output = None
self.output_config = {}
self.output_compress_columns = ["obs", "new_obs"]
self.output_max_file_size = 64 * 1024 * 1024
# `self.evaluation()`
self.evaluation_interval = None
self.evaluation_duration = 10
self.evaluation_duration_unit = "episodes"
self.evaluation_sample_timeout_s = 180.0
self.evaluation_parallel_to_training = False
self.evaluation_config = {}
self.off_policy_estimation_methods = {}
self.evaluation_num_workers = 0
self.custom_evaluation_function = None
self.always_attach_evaluation_results = False
self.enable_async_evaluation = False
# TODO: Set this flag still in the config or - much better - in the
# RolloutWorker as a property.
self.in_evaluation = False
self.sync_filters_on_rollout_workers_timeout_s = 60.0
# `self.reporting()`
self.keep_per_episode_custom_metrics = False
self.metrics_episode_collection_timeout_s = 60.0
self.metrics_num_episodes_for_smoothing = 100
self.min_time_s_per_iteration = None
self.min_train_timesteps_per_iteration = 0
self.min_sample_timesteps_per_iteration = 0
# `self.debugging()`
self.logger_creator = None
self.logger_config = None
self.log_level = "WARN"
self.log_sys_usage = True
self.fake_sampler = False
self.seed = None
# `self.experimental()`
self._tf_policy_handles_more_than_one_loss = False
self._disable_preprocessor_api = False
self._disable_action_flattening = False
self._disable_execution_plan_api = True
# TODO: Remove, once all deprecation_warning calls upon using these keys
# have been removed.
# === Deprecated keys ===
self.simple_optimizer = DEPRECATED_VALUE
self.monitor = DEPRECATED_VALUE
self.evaluation_num_episodes = DEPRECATED_VALUE
self.metrics_smoothing_episodes = DEPRECATED_VALUE
self.timesteps_per_iteration = DEPRECATED_VALUE
self.min_iter_time_s = DEPRECATED_VALUE
self.collect_metrics_timeout = DEPRECATED_VALUE
# The following values have moved because of the new ReplayBuffer API
self.buffer_size = DEPRECATED_VALUE
self.prioritized_replay = DEPRECATED_VALUE
self.learning_starts = DEPRECATED_VALUE
self.replay_batch_size = DEPRECATED_VALUE
# -1 = DEPRECATED_VALUE is a valid value for replay_sequence_length
self.replay_sequence_length = None
self.prioritized_replay_alpha = DEPRECATED_VALUE
self.prioritized_replay_beta = DEPRECATED_VALUE
self.prioritized_replay_eps = DEPRECATED_VALUE
self.min_time_s_per_reporting = DEPRECATED_VALUE
self.min_train_timesteps_per_reporting = DEPRECATED_VALUE
self.min_sample_timesteps_per_reporting = DEPRECATED_VALUE
self.input_evaluation = DEPRECATED_VALUE
def to_dict(self) -> AlgorithmConfigDict:
"""Converts all settings into a legacy config dict for backward compatibility.
Returns:
A complete AlgorithmConfigDict, usable in backward-compatible Tune/RLlib
use cases, e.g. w/ `tune.run()`.
"""
config = copy.deepcopy(vars(self))
config.pop("algo_class")
# Worst naming convention ever: NEVER EVER use reserved key-words...
if "lambda_" in config:
assert hasattr(self, "lambda_")
config["lambda"] = getattr(self, "lambda_")
config.pop("lambda_")
if "input_" in config:
assert hasattr(self, "input_")
config["input"] = getattr(self, "input_")
config.pop("input_")
# Setup legacy multiagent sub-dict:
config["multiagent"] = {}
for k in [
"policies",
"policy_map_capacity",
"policy_map_cache",
"policy_mapping_fn",
"policies_to_train",
"observation_fn",
"replay_mode",
"count_steps_by",
]:
config["multiagent"][k] = config.pop(k)
# Switch out deprecated vs new config keys.
config["callbacks"] = config.pop("callbacks_class", DefaultCallbacks)
config["create_env_on_driver"] = config.pop("create_env_on_local_worker", 1)
config["custom_eval_function"] = config.pop("custom_evaluation_function", None)
config["framework"] = config.pop("framework_str", None)
config["num_cpus_for_driver"] = config.pop("num_cpus_for_local_worker", 1)
return config
def build(
self,
env: Optional[Union[str, EnvType]] = None,
logger_creator: Optional[Callable[[], Logger]] = None,
) -> "Algorithm":
"""Builds an Algorithm from the AlgorithmConfig.
Args:
env: Name of the environment to use (e.g. a gym-registered str),
a full class path (e.g.
"ray.rllib.examples.env.random_env.RandomEnv"), or an Env
class directly. Note that this arg can also be specified via
the "env" key in `config`.
logger_creator: Callable that creates a ray.tune.Logger
object. If unspecified, a default logger is created.
Returns:
A ray.rllib.algorithms.algorithm.Algorithm object.
"""
if env is not None:
self.env = env
if self.evaluation_config is not None:
self.evaluation_config["env"] = env
if logger_creator is not None:
self.logger_creator = logger_creator
return self.algo_class(
config=self.to_dict(),
env=self.env,
logger_creator=self.logger_creator,
)
def python_environment(
self,
*,
extra_python_environs_for_driver: Optional[dict] = None,
extra_python_environs_for_worker: Optional[dict] = None,
) -> "AlgorithmConfig":
"""Sets the config's python environment settings.
Args:
extra_python_environs_for_driver: Any extra python env vars to set in the
algorithm's process, e.g., {"OMP_NUM_THREADS": "16"}.
extra_python_environs_for_worker: The extra python environments need to set
for worker processes.
Returns:
This updated AlgorithmConfig object.
"""
if extra_python_environs_for_driver is not None:
self.extra_python_environs_for_driver = extra_python_environs_for_driver
if extra_python_environs_for_worker is not None:
self.extra_python_environs_for_worker = extra_python_environs_for_worker
return self
def resources(
self,
*,
num_gpus: Optional[Union[float, int]] = None,
_fake_gpus: Optional[bool] = None,
num_cpus_per_worker: Optional[int] = None,
num_gpus_per_worker: Optional[Union[float, int]] = None,
num_cpus_for_local_worker: Optional[int] = None,
custom_resources_per_worker: Optional[dict] = None,
placement_strategy: Optional[str] = None,
) -> "AlgorithmConfig":
"""Specifies resources allocated for an Algorithm and its ray actors/workers.
Args:
num_gpus: Number of GPUs to allocate to the algorithm process.
Note that not all algorithms can take advantage of GPUs.
Support for multi-GPU is currently only available for
tf-[PPO/IMPALA/DQN/PG]. This can be fractional (e.g., 0.3 GPUs).
_fake_gpus: Set to True for debugging (multi-)?GPU funcitonality on a
CPU machine. GPU towers will be simulated by graphs located on
CPUs in this case. Use `num_gpus` to test for different numbers of
fake GPUs.
num_cpus_per_worker: Number of CPUs to allocate per worker.
num_gpus_per_worker: Number of GPUs to allocate per worker. This can be
fractional. This is usually needed only if your env itself requires a
GPU (i.e., it is a GPU-intensive video game), or model inference is
unusually expensive.
custom_resources_per_worker: Any custom Ray resources to allocate per
worker.
num_cpus_for_local_worker: Number of CPUs to allocate for the algorithm.
Note: this only takes effect when running in Tune. Otherwise,
the algorithm runs in the main program (driver).
custom_resources_per_worker: Any custom Ray resources to allocate per
worker.
placement_strategy: The strategy for the placement group factory returned by
`Algorithm.default_resource_request()`. A PlacementGroup defines, which
devices (resources) should always be co-located on the same node.
For example, an Algorithm with 2 rollout workers, running with
num_gpus=1 will request a placement group with the bundles:
[{"gpu": 1, "cpu": 1}, {"cpu": 1}, {"cpu": 1}], where the first bundle
is for the driver and the other 2 bundles are for the two workers.
These bundles can now be "placed" on the same or different
nodes depending on the value of `placement_strategy`:
"PACK": Packs bundles into as few nodes as possible.
"SPREAD": Places bundles across distinct nodes as even as possible.
"STRICT_PACK": Packs bundles into one node. The group is not allowed
to span multiple nodes.
"STRICT_SPREAD": Packs bundles across distinct nodes.
Returns:
This updated AlgorithmConfig object.
"""
if num_gpus is not None:
self.num_gpus = num_gpus
if _fake_gpus is not None:
self._fake_gpus = _fake_gpus
if num_cpus_per_worker is not None:
self.num_cpus_per_worker = num_cpus_per_worker
if num_gpus_per_worker is not None:
self.num_gpus_per_worker = num_gpus_per_worker
if num_cpus_for_local_worker is not None:
self.num_cpus_for_local_worker = num_cpus_for_local_worker
if custom_resources_per_worker is not None:
self.custom_resources_per_worker = custom_resources_per_worker
if placement_strategy is not None:
self.placement_strategy = placement_strategy
return self
def framework(
self,
framework: Optional[str] = None,
*,
eager_tracing: Optional[bool] = None,
eager_max_retraces: Optional[int] = None,
tf_session_args: Optional[Dict[str, Any]] = None,
local_tf_session_args: Optional[Dict[str, Any]] = None,
) -> "AlgorithmConfig":
"""Sets the config's DL framework settings.
Args:
framework: tf: TensorFlow (static-graph); tf2: TensorFlow 2.x
(eager or traced, if eager_tracing=True); torch: PyTorch
eager_tracing: Enable tracing in eager mode. This greatly improves
performance (speedup ~2x), but makes it slightly harder to debug
since Python code won't be evaluated after the initial eager pass.
Only possible if framework=tf2.
eager_max_retraces: Maximum number of tf.function re-traces before a
runtime error is raised. This is to prevent unnoticed retraces of
methods inside the `..._eager_traced` Policy, which could slow down
execution by a factor of 4, without the user noticing what the root
cause for this slowdown could be.
Only necessary for framework=[tf2|tfe].
Set to None to ignore the re-trace count and never throw an error.
tf_session_args: Configures TF for single-process operation by default.
local_tf_session_args: Override the following tf session args on the local
worker
Returns:
This updated AlgorithmConfig object.
"""
if framework is not None:
self.framework_str = framework
if eager_tracing is not None:
self.eager_tracing = eager_tracing
if eager_max_retraces is not None:
self.eager_max_retraces = eager_max_retraces
if tf_session_args is not None:
self.tf_session_args = tf_session_args
if local_tf_session_args is not None:
self.local_tf_session_args = local_tf_session_args
return self
def environment(
self,
*,
env: Optional[Union[str, EnvType]] = None,
env_config: Optional[EnvConfigDict] = None,
observation_space: Optional[gym.spaces.Space] = None,
action_space: Optional[gym.spaces.Space] = None,
env_task_fn: Optional[Callable[[ResultDict, EnvType, EnvContext], Any]] = None,
render_env: Optional[bool] = None,
clip_rewards: Optional[Union[bool, float]] = None,
normalize_actions: Optional[bool] = None,
clip_actions: Optional[bool] = None,
disable_env_checking: Optional[bool] = None,
) -> "AlgorithmConfig":
"""Sets the config's RL-environment settings.
Args:
env: The environment specifier. This can either be a tune-registered env,
via `tune.register_env([name], lambda env_ctx: [env object])`,
or a string specifier of an RLlib supported type. In the latter case,
RLlib will try to interpret the specifier as either an openAI gym env,
a PyBullet env, a ViZDoomGym env, or a fully qualified classpath to an
Env class, e.g. "ray.rllib.examples.env.random_env.RandomEnv".
env_config: Arguments dict passed to the env creator as an EnvContext
object (which is a dict plus the properties: num_workers, worker_index,
vector_index, and remote).
observation_space: The observation space for the Policies of this Algorithm.
action_space: The action space for the Policies of this Algorithm.
env_task_fn: A callable taking the last train results, the base env and the
env context as args and returning a new task to set the env to.
The env must be a `TaskSettableEnv` sub-class for this to work.
See `examples/curriculum_learning.py` for an example.
render_env: If True, try to render the environment on the local worker or on
worker 1 (if num_workers > 0). For vectorized envs, this usually means
that only the first sub-environment will be rendered.
In order for this to work, your env will have to implement the
`render()` method which either:
a) handles window generation and rendering itself (returning True) or
b) returns a numpy uint8 image of shape [height x width x 3 (RGB)].
clip_rewards: Whether to clip rewards during Policy's postprocessing.
None (default): Clip for Atari only (r=sign(r)).
True: r=sign(r): Fixed rewards -1.0, 1.0, or 0.0.
False: Never clip.
[float value]: Clip at -value and + value.
Tuple[value1, value2]: Clip at value1 and value2.
normalize_actions: If True, RLlib will learn entirely inside a normalized
action space (0.0 centered with small stddev; only affecting Box
components). We will unsquash actions (and clip, just in case) to the
bounds of the env's action space before sending actions back to the env.
clip_actions: If True, RLlib will clip actions according to the env's bounds
before sending them back to the env.
TODO: (sven) This option should be deprecated and always be False.
disable_env_checking: If True, disable the environment pre-checking module.
Returns:
This updated AlgorithmConfig object.
"""
if env is not None:
self.env = env
if env_config is not None:
self.env_config = env_config
if observation_space is not None:
self.observation_space = observation_space
if action_space is not None:
self.action_space = action_space
if env_task_fn is not None:
self.env_task_fn = env_task_fn
if render_env is not None:
self.render_env = render_env
if clip_rewards is not None:
self.clip_rewards = clip_rewards
if normalize_actions is not None:
self.normalize_actions = normalize_actions
if clip_actions is not None:
self.clip_actions = clip_actions
if disable_env_checking is not None:
self.disable_env_checking = disable_env_checking
return self
def rollouts(
self,
*,
num_rollout_workers: Optional[int] = None,
num_envs_per_worker: Optional[int] = None,
create_env_on_local_worker: Optional[bool] = None,
sample_collector: Optional[Type[SampleCollector]] = None,
sample_async: Optional[bool] = None,
enable_connectors: Optional[bool] = None,
rollout_fragment_length: Optional[int] = None,
batch_mode: Optional[str] = None,
remote_worker_envs: Optional[bool] = None,
remote_env_batch_wait_ms: Optional[float] = None,
validate_workers_after_construction: Optional[bool] = None,
ignore_worker_failures: Optional[bool] = None,
recreate_failed_workers: Optional[bool] = None,
restart_failed_sub_environments: Optional[bool] = None,
num_consecutive_worker_failures_tolerance: Optional[int] = None,
horizon: Optional[int] = None,
soft_horizon: Optional[bool] = None,
no_done_at_end: Optional[bool] = None,
preprocessor_pref: Optional[str] = None,
observation_filter: Optional[str] = None,
synchronize_filter: Optional[bool] = None,
compress_observations: Optional[bool] = None,
enable_tf1_exec_eagerly: Optional[bool] = None,
sampler_perf_stats_ema_coef: Optional[float] = None,
) -> "AlgorithmConfig":
"""Sets the rollout worker configuration.
Args:
num_rollout_workers: Number of rollout worker actors to create for
parallel sampling. Setting this to 0 will force rollouts to be done in
the local worker (driver process or the Algorithm's actor when using
Tune).
num_envs_per_worker: Number of environments to evaluate vector-wise per
worker. This enables model inference batching, which can improve
performance for inference bottlenecked workloads.
sample_collector: The SampleCollector class to be used to collect and
retrieve environment-, model-, and sampler data. Override the
SampleCollector base class to implement your own
collection/buffering/retrieval logic.
create_env_on_local_worker: When `num_workers` > 0, the driver
(local_worker; worker-idx=0) does not need an environment. This is
because it doesn't have to sample (done by remote_workers;
worker_indices > 0) nor evaluate (done by evaluation workers;
see below).
sample_async: Use a background thread for sampling (slightly off-policy,
usually not advisable to turn on unless your env specifically requires
it).
enable_connectors: Use connector based environment runner, so that all
preprocessing of obs and postprocessing of actions are done in agent
and action connectors.
rollout_fragment_length: Divide episodes into fragments of this many steps
each during rollouts. Sample batches of this size are collected from
rollout workers and combined into a larger batch of `train_batch_size`
for learning.
For example, given rollout_fragment_length=100 and
train_batch_size=1000:
1. RLlib collects 10 fragments of 100 steps each from rollout workers.
2. These fragments are concatenated and we perform an epoch of SGD.
When using multiple envs per worker, the fragment size is multiplied by
`num_envs_per_worker`. This is since we are collecting steps from
multiple envs in parallel. For example, if num_envs_per_worker=5, then
rollout workers will return experiences in chunks of 5*100 = 500 steps.
The dataflow here can vary per algorithm. For example, PPO further
divides the train batch into minibatches for multi-epoch SGD.
batch_mode: How to build per-Sampler (RolloutWorker) batches, which are then
usually concat'd to form the train batch. Note that "steps" below can
mean different things (either env- or agent-steps) and depends on the
`count_steps_by` (multiagent) setting below.
"truncate_episodes": Each produced batch (when calling
RolloutWorker.sample()) will contain exactly `rollout_fragment_length`
steps. This mode guarantees evenly sized batches, but increases
variance as the future return must now be estimated at truncation
boundaries.
"complete_episodes": Each unroll happens exactly over one episode, from
beginning to end. Data collection will not stop unless the episode
terminates or a configured horizon (hard or soft) is hit.
remote_worker_envs: If using num_envs_per_worker > 1, whether to create
those new envs in remote processes instead of in the same worker.
This adds overheads, but can make sense if your envs can take much
time to step / reset (e.g., for StarCraft). Use this cautiously;
overheads are significant.
remote_env_batch_wait_ms: Timeout that remote workers are waiting when
polling environments. 0 (continue when at least one env is ready) is
a reasonable default, but optimal value could be obtained by measuring
your environment step / reset and model inference perf.
validate_workers_after_construction: Whether to validate that each created
remote worker is healthy after its construction process.
ignore_worker_failures: Whether to attempt to continue training if a worker
crashes. The number of currently healthy workers is reported as the
"num_healthy_workers" metric.
recreate_failed_workers: Whether - upon a worker failure - RLlib will try to
recreate the lost worker as an identical copy of the failed one. The new
worker will only differ from the failed one in its
`self.recreated_worker=True` property value. It will have the same
`worker_index` as the original one. If True, the
`ignore_worker_failures` setting will be ignored.
restart_failed_sub_environments: If True and any sub-environment (within
a vectorized env) throws any error during env stepping, the
Sampler will try to restart the faulty sub-environment. This is done
without disturbing the other (still intact) sub-environment and without
the RolloutWorker crashing.
num_consecutive_worker_failures_tolerance: The number of consecutive times
a rollout worker (or evaluation worker) failure is tolerated before
finally crashing the Algorithm. Only useful if either
`ignore_worker_failures` or `recreate_failed_workers` is True.
Note that for `restart_failed_sub_environments` and sub-environment
failures, the worker itself is NOT affected and won't throw any errors
as the flawed sub-environment is silently restarted under the hood.
horizon: Number of steps after which the episode is forced to terminate.
Defaults to `env.spec.max_episode_steps` (if present) for Gym envs.
soft_horizon: Calculate rewards but don't reset the environment when the
horizon is hit. This allows value estimation and RNN state to span
across logical episodes denoted by horizon. This only has an effect
if horizon != inf.
no_done_at_end: Don't set 'done' at the end of the episode.
In combination with `soft_horizon`, this works as follows:
- no_done_at_end=False soft_horizon=False:
Reset env and add `done=True` at end of each episode.
- no_done_at_end=True soft_horizon=False:
Reset env, but do NOT add `done=True` at end of the episode.
- no_done_at_end=False soft_horizon=True:
Do NOT reset env at horizon, but add `done=True` at the horizon
(pretending the episode has terminated).
- no_done_at_end=True soft_horizon=True:
Do NOT reset env at horizon and do NOT add `done=True` at the horizon.
preprocessor_pref: Whether to use "rllib" or "deepmind" preprocessors by
default. Set to None for using no preprocessor. In this case, the
model will have to handle possibly complex observations from the
environment.
observation_filter: Element-wise observation filter, either "NoFilter"
or "MeanStdFilter".
synchronize_filter: Whether to synchronize the statistics of remote filters.
compress_observations: Whether to LZ4 compress individual observations
in the SampleBatches collected during rollouts.
enable_tf1_exec_eagerly: Explicitly tells the rollout worker to enable
TF eager execution. This is useful for example when framework is
"torch", but a TF2 policy needs to be restored for evaluation or
league-based purposes.
sampler_perf_stats_ema_coef: If specified, perf stats are in EMAs. This
is the coeff of how much new data points contribute to the averages.
Default is None, which uses simple global average instead.
The EMA update rule is: updated = (1 - ema_coef) * old + ema_coef * new
Returns:
This updated AlgorithmConfig object.
"""
if num_rollout_workers is not None:
self.num_workers = num_rollout_workers
if num_envs_per_worker is not None:
self.num_envs_per_worker = num_envs_per_worker
if sample_collector is not None:
self.sample_collector = sample_collector
if create_env_on_local_worker is not None:
self.create_env_on_local_worker = create_env_on_local_worker
if sample_async is not None:
self.sample_async = sample_async
if enable_connectors is not None:
self.enable_connectors = enable_connectors
if rollout_fragment_length is not None:
self.rollout_fragment_length = rollout_fragment_length
if batch_mode is not None:
self.batch_mode = batch_mode
if remote_worker_envs is not None:
self.remote_worker_envs = remote_worker_envs
if remote_env_batch_wait_ms is not None:
self.remote_env_batch_wait_ms = remote_env_batch_wait_ms
if validate_workers_after_construction is not None:
self.validate_workers_after_construction = (
validate_workers_after_construction
)
if ignore_worker_failures is not None:
self.ignore_worker_failures = ignore_worker_failures
if recreate_failed_workers is not None:
self.recreate_failed_workers = recreate_failed_workers
if restart_failed_sub_environments is not None:
self.restart_failed_sub_environments = restart_failed_sub_environments
if num_consecutive_worker_failures_tolerance is not None:
self.num_consecutive_worker_failures_tolerance = (
num_consecutive_worker_failures_tolerance
)
if horizon is not None:
self.horizon = horizon
if soft_horizon is not None:
self.soft_horizon = soft_horizon
if no_done_at_end is not None:
self.no_done_at_end = no_done_at_end
if preprocessor_pref is not None:
self.preprocessor_pref = preprocessor_pref
if observation_filter is not None:
self.observation_filter = observation_filter
if synchronize_filter is not None:
self.synchronize_filters = synchronize_filter
if compress_observations is not None:
self.compress_observations = compress_observations
if enable_tf1_exec_eagerly is not None:
self.enable_tf1_exec_eagerly = enable_tf1_exec_eagerly
if sampler_perf_stats_ema_coef is not None:
self.sampler_perf_stats_ema_coef = sampler_perf_stats_ema_coef
return self
def training(
self,
gamma: Optional[float] = None,
lr: Optional[float] = None,
train_batch_size: Optional[int] = None,
model: Optional[dict] = None,
optimizer: Optional[dict] = None,
) -> "AlgorithmConfig":
"""Sets the training related configuration.
Args:
gamma: Float specifying the discount factor of the Markov Decision process.
lr: The default learning rate.
train_batch_size: Training batch size, if applicable.
model: Arguments passed into the policy model. See models/catalog.py for a
full list of the available model options.
optimizer: Arguments to pass to the policy optimizer.
Returns:
This updated AlgorithmConfig object.
"""
if gamma is not None:
self.gamma = gamma
if lr is not None:
self.lr = lr
if train_batch_size is not None:
self.train_batch_size = train_batch_size
if model is not None:
self.model = model
if optimizer is not None:
self.optimizer = merge_dicts(self.optimizer, optimizer)
return self
def callbacks(self, callbacks_class) -> "AlgorithmConfig":
"""Sets the callbacks configuration.
Args:
callbacks_class: Callbacks class, whose methods will be run during
various phases of training and environment sample collection.
See the `DefaultCallbacks` class and
`examples/custom_metrics_and_callbacks.py` for more usage information.
Returns:
This updated AlgorithmConfig object.
"""
self.callbacks_class = callbacks_class
return self
def exploration(
self,
*,
explore: Optional[bool] = None,
exploration_config: Optional[dict] = None,
) -> "AlgorithmConfig":
"""Sets the config's exploration settings.
Args:
explore: Default exploration behavior, iff `explore`=None is passed into
compute_action(s). Set to False for no exploration behavior (e.g.,
for evaluation).
exploration_config: A dict specifying the Exploration object's config.
Returns:
This updated AlgorithmConfig object.
"""
if explore is not None:
self.explore = explore
if exploration_config is not None:
# Override entire `exploration_config` if `type` key changes.
# Update, if `type` key remains the same or is not specified.
new_exploration_config = deep_update(
{"exploration_config": self.exploration_config},
{"exploration_config": exploration_config},
False,
["exploration_config"],
["exploration_config"],
)
self.exploration_config = new_exploration_config["exploration_config"]
return self
def evaluation(
self,
*,
evaluation_interval: Optional[int] = None,
evaluation_duration: Optional[Union[int, str]] = None,
evaluation_duration_unit: Optional[str] = None,
evaluation_sample_timeout_s: Optional[float] = None,
evaluation_parallel_to_training: Optional[bool] = None,
evaluation_config: Optional[
Union["AlgorithmConfig", PartialAlgorithmConfigDict]
] = None,
off_policy_estimation_methods: Optional[Dict] = None,
evaluation_num_workers: Optional[int] = None,
custom_evaluation_function: Optional[Callable] = None,
always_attach_evaluation_results: Optional[bool] = None,
enable_async_evaluation: Optional[bool] = None,
) -> "AlgorithmConfig":
"""Sets the config's evaluation settings.
Args:
evaluation_interval: Evaluate with every `evaluation_interval` training
iterations. The evaluation stats will be reported under the "evaluation"
metric key. Note that for Ape-X metrics are already only reported for
the lowest epsilon workers (least random workers).
Set to None (or 0) for no evaluation.
evaluation_duration: Duration for which to run evaluation each
`evaluation_interval`. The unit for the duration can be set via
`evaluation_duration_unit` to either "episodes" (default) or
"timesteps". If using multiple evaluation workers
(evaluation_num_workers > 1), the load to run will be split amongst
these.
If the value is "auto":
- For `evaluation_parallel_to_training=True`: Will run as many
episodes/timesteps that fit into the (parallel) training step.
- For `evaluation_parallel_to_training=False`: Error.
evaluation_duration_unit: The unit, with which to count the evaluation
duration. Either "episodes" (default) or "timesteps".
evaluation_sample_timeout_s: The timeout (in seconds) for the ray.get call
to the remote evaluation worker(s) `sample()` method. After this time,
the user will receive a warning and instructions on how to fix the
issue. This could be either to make sure the episode ends, increasing
the timeout, or switching to `evaluation_duration_unit=timesteps`.
evaluation_parallel_to_training: Whether to run evaluation in parallel to
a Algorithm.train() call using threading. Default=False.
E.g. evaluation_interval=2 -> For every other training iteration,
the Algorithm.train() and Algorithm.evaluate() calls run in parallel.
Note: This is experimental. Possible pitfalls could be race conditions
for weight synching at the beginning of the evaluation loop.
evaluation_config: Typical usage is to pass extra args to evaluation env
creator and to disable exploration by computing deterministic actions.
IMPORTANT NOTE: Policy gradient algorithms are able to find the optimal
policy, even if this is a stochastic one. Setting "explore=False" here
will result in the evaluation workers not using this optimal policy!
off_policy_estimation_methods: Specify how to evaluate the current policy,
along with any optional config parameters. This only has an effect when
reading offline experiences ("input" is not "sampler").
Available keys:
{ope_method_name: {"type": ope_type, ...}} where `ope_method_name`
is a user-defined string to save the OPE results under, and
`ope_type` can be any subclass of OffPolicyEstimator, e.g.
ray.rllib.offline.estimators.is::ImportanceSampling
or your own custom subclass, or the full class path to the subclass.
You can also add additional config arguments to be passed to the
OffPolicyEstimator in the dict, e.g.
{"qreg_dr": {"type": DoublyRobust, "q_model_type": "qreg", "k": 5}}
evaluation_num_workers: Number of parallel workers to use for evaluation.
Note that this is set to zero by default, which means evaluation will
be run in the algorithm process (only if evaluation_interval is not
None). If you increase this, it will increase the Ray resource usage of
the algorithm since evaluation workers are created separately from
rollout workers (used to sample data for training).
custom_evaluation_function: Customize the evaluation method. This must be a
function of signature (algo: Algorithm, eval_workers: WorkerSet) ->
metrics: dict. See the Algorithm.evaluate() method to see the default
implementation. The Algorithm guarantees all eval workers have the
latest policy state before this function is called.
always_attach_evaluation_results: Make sure the latest available evaluation
results are always attached to a step result dict. This may be useful
if Tune or some other meta controller needs access to evaluation metrics
all the time.
enable_async_evaluation: If True, use an AsyncRequestsManager for
the evaluation workers and use this manager to send `sample()` requests
to the evaluation workers. This way, the Algorithm becomes more robust
against long running episodes and/or failing (and restarting) workers.
Returns:
This updated AlgorithmConfig object.
"""
if evaluation_interval is not None:
self.evaluation_interval = evaluation_interval
if evaluation_duration is not None:
self.evaluation_duration = evaluation_duration
if evaluation_duration_unit is not None:
self.evaluation_duration_unit = evaluation_duration_unit
if evaluation_sample_timeout_s is not None:
self.evaluation_sample_timeout_s = evaluation_sample_timeout_s
if evaluation_parallel_to_training is not None:
self.evaluation_parallel_to_training = evaluation_parallel_to_training
if evaluation_config is not None:
# Convert another AlgorithmConfig into dict.
if isinstance(evaluation_config, AlgorithmConfig):
self.evaluation_config = evaluation_config.to_dict()
else:
self.evaluation_config = evaluation_config
if off_policy_estimation_methods is not None:
self.off_policy_estimation_methods = off_policy_estimation_methods
if evaluation_num_workers is not None:
self.evaluation_num_workers = evaluation_num_workers
if custom_evaluation_function is not None:
self.custom_evaluation_function = custom_evaluation_function
if always_attach_evaluation_results:
self.always_attach_evaluation_results = always_attach_evaluation_results
if enable_async_evaluation:
self.enable_async_evaluation = enable_async_evaluation
return self
def offline_data(
self,
*,
input_=None,
input_config=None,
actions_in_input_normalized=None,
input_evaluation=None,
postprocess_inputs=None,
shuffle_buffer_size=None,
output=None,
output_config=None,
output_compress_columns=None,
output_max_file_size=None,
) -> "AlgorithmConfig":
"""Sets the config's offline data settings.
TODO(jungong, sven): we can potentially unify all input types
under input and input_config keys. E.g.
input: sample
input_config {
env: Cartpole-v0
}
or:
input: json_reader
input_config {
path: /tmp/
}
or:
input: dataset
input_config {
format: parquet
path: /tmp/
}
Args:
input_: Specify how to generate experiences:
- "sampler": Generate experiences via online (env) simulation (default).
- A local directory or file glob expression (e.g., "/tmp/*.json").
- A list of individual file paths/URIs (e.g., ["/tmp/1.json",
"s3://bucket/2.json"]).
- A dict with string keys and sampling probabilities as values (e.g.,
{"sampler": 0.4, "/tmp/*.json": 0.4, "s3://bucket/expert.json": 0.2}).
- A callable that takes an `IOContext` object as only arg and returns a
ray.rllib.offline.InputReader.
- A string key that indexes a callable with tune.registry.register_input
input_config: Arguments accessible from the IOContext for configuring custom
input.
actions_in_input_normalized: True, if the actions in a given offline "input"
are already normalized (between -1.0 and 1.0). This is usually the case
when the offline file has been generated by another RLlib algorithm
(e.g. PPO or SAC), while "normalize_actions" was set to True.
postprocess_inputs: Whether to run postprocess_trajectory() on the
trajectory fragments from offline inputs. Note that postprocessing will
be done using the *current* policy, not the *behavior* policy, which
is typically undesirable for on-policy algorithms.
shuffle_buffer_size: If positive, input batches will be shuffled via a
sliding window buffer of this number of batches. Use this if the input
data is not in random enough order. Input is delayed until the shuffle
buffer is filled.
output: Specify where experiences should be saved:
- None: don't save any experiences
- "logdir" to save to the agent log dir
- a path/URI to save to a custom output directory (e.g., "s3://bckt/")
- a function that returns a rllib.offline.OutputWriter
output_config: Arguments accessible from the IOContext for configuring
custom output.
output_compress_columns: What sample batch columns to LZ4 compress in the
output data.
output_max_file_size: Max output file size before rolling over to a
new file.
Returns:
This updated AlgorithmConfig object.
"""
if input_ is not None:
self.input_ = input_
if input_config is not None:
self.input_config = input_config
if actions_in_input_normalized is not None:
self.actions_in_input_normalized = actions_in_input_normalized
if input_evaluation is not None:
deprecation_warning(
old="offline_data(input_evaluation={})".format(input_evaluation),
new="evaluation(off_policy_estimation_methods={})".format(
input_evaluation
),
error=True,
help="Running OPE during training is not recommended.",
)
if postprocess_inputs is not None:
self.postprocess_inputs = postprocess_inputs
if shuffle_buffer_size is not None:
self.shuffle_buffer_size = shuffle_buffer_size
if output is not None:
self.output = output
if output_config is not None:
self.output_config = output_config
if output_compress_columns is not None:
self.output_compress_columns = output_compress_columns
if output_max_file_size is not None:
self.output_max_file_size = output_max_file_size
return self
def multi_agent(
self,
*,
policies=None,
policy_map_capacity=None,
policy_map_cache=None,
policy_mapping_fn=None,
policies_to_train=None,
observation_fn=None,
replay_mode=None,
count_steps_by=None,
) -> "AlgorithmConfig":
"""Sets the config's multi-agent settings.
Args:
policies: Map of type MultiAgentPolicyConfigDict from policy ids to tuples
of (policy_cls, obs_space, act_space, config). This defines the
observation and action spaces of the policies and any extra config.
policy_map_capacity: Keep this many policies in the "policy_map" (before
writing least-recently used ones to disk/S3).
policy_map_cache: Where to store overflowing (least-recently used) policies?
Could be a directory (str) or an S3 location. None for using the
default output dir.
policy_mapping_fn: Function mapping agent ids to policy ids.
policies_to_train: Determines those policies that should be updated.
Options are:
- None, for all policies.
- An iterable of PolicyIDs that should be updated.
- A callable, taking a PolicyID and a SampleBatch or MultiAgentBatch
and returning a bool (indicating whether the given policy is trainable
or not, given the particular batch). This allows you to have a policy
trained only on certain data (e.g. when playing against a certain
opponent).
observation_fn: Optional function that can be used to enhance the local
agent observations to include more state. See
rllib/evaluation/observation_function.py for more info.
replay_mode: When replay_mode=lockstep, RLlib will replay all the agent
transitions at a particular timestep together in a batch. This allows
the policy to implement differentiable shared computations between
agents it controls at that timestep. When replay_mode=independent,
transitions are replayed independently per policy.
count_steps_by: Which metric to use as the "batch size" when building a
MultiAgentBatch. The two supported values are:
"env_steps": Count each time the env is "stepped" (no matter how many
multi-agent actions are passed/how many multi-agent observations
have been returned in the previous step).
"agent_steps": Count each individual agent step as one step.
Returns:
This updated AlgorithmConfig object.
"""
if policies is not None:
self.policies = policies
if policy_map_capacity is not None:
self.policy_map_capacity = policy_map_capacity
if policy_map_cache is not None:
self.policy_map_cache = policy_map_cache
if policy_mapping_fn is not None:
self.policy_mapping_fn = policy_mapping_fn
if policies_to_train is not None:
self.policies_to_train = policies_to_train
if observation_fn is not None:
self.observation_fn = observation_fn
if replay_mode is not None:
self.replay_mode = replay_mode
if count_steps_by is not None:
self.count_steps_by = count_steps_by
return self
def reporting(
self,
*,
keep_per_episode_custom_metrics: Optional[bool] = None,
metrics_episode_collection_timeout_s: Optional[float] = None,
metrics_num_episodes_for_smoothing: Optional[int] = None,
min_time_s_per_iteration: Optional[int] = None,
min_train_timesteps_per_iteration: Optional[int] = None,
min_sample_timesteps_per_iteration: Optional[int] = None,
) -> "AlgorithmConfig":
"""Sets the config's reporting settings.
Args:
keep_per_episode_custom_metrics: Store raw custom metrics without
calculating max, min, mean
metrics_episode_collection_timeout_s: Wait for metric batches for at most
this many seconds. Those that have not returned in time will be
collected in the next train iteration.
metrics_num_episodes_for_smoothing: Smooth rollout metrics over this many
episodes, if possible.
In case rollouts (sample collection) just started, there may be fewer
than this many episodes in the buffer and we'll compute metrics
over this smaller number of available episodes.
In case there are more than this many episodes collected in a single
training iteration, use all of these episodes for metrics computation,
meaning don't ever cut any "excess" episodes.
min_time_s_per_iteration: Minimum time to accumulate within a single
`train()` call. This value does not affect learning,
only the number of times `Algorithm.training_step()` is called by
`Algorithm.train()`. If - after one such step attempt, the time taken
has not reached `min_time_s_per_iteration`, will perform n more
`training_step()` calls until the minimum time has been
consumed. Set to 0 or None for no minimum time.
min_train_timesteps_per_iteration: Minimum training timesteps to accumulate
within a single `train()` call. This value does not affect learning,
only the number of times `Algorithm.training_step()` is called by
`Algorithm.train()`. If - after one such step attempt, the training
timestep count has not been reached, will perform n more
`training_step()` calls until the minimum timesteps have been
executed. Set to 0 or None for no minimum timesteps.
min_sample_timesteps_per_iteration: Minimum env sampling timesteps to
accumulate within a single `train()` call. This value does not affect
learning, only the number of times `Algorithm.training_step()` is
called by `Algorithm.train()`. If - after one such step attempt, the env
sampling timestep count has not been reached, will perform n more
`training_step()` calls until the minimum timesteps have been
executed. Set to 0 or None for no minimum timesteps.
Returns:
This updated AlgorithmConfig object.
"""
if keep_per_episode_custom_metrics is not None:
self.keep_per_episode_custom_metrics = keep_per_episode_custom_metrics
if metrics_episode_collection_timeout_s is not None:
self.metrics_episode_collection_timeout_s = (
metrics_episode_collection_timeout_s
)
if metrics_num_episodes_for_smoothing is not None:
self.metrics_num_episodes_for_smoothing = metrics_num_episodes_for_smoothing
if min_time_s_per_iteration is not None:
self.min_time_s_per_iteration = min_time_s_per_iteration
if min_train_timesteps_per_iteration is not None:
self.min_train_timesteps_per_iteration = min_train_timesteps_per_iteration
if min_sample_timesteps_per_iteration is not None:
self.min_sample_timesteps_per_iteration = min_sample_timesteps_per_iteration
return self
def debugging(
self,
*,
logger_creator: Optional[Callable[[], Logger]] = None,
logger_config: Optional[dict] = None,
log_level: Optional[str] = None,
log_sys_usage: Optional[bool] = None,
fake_sampler: Optional[bool] = None,
seed: Optional[int] = None,
) -> "AlgorithmConfig":
"""Sets the config's debugging settings.
Args:
logger_creator: Callable that creates a ray.tune.Logger
object. If unspecified, a default logger is created.
logger_config: Define logger-specific configuration to be used inside Logger
Default value None allows overwriting with nested dicts.
log_level: Set the ray.rllib.* log level for the agent process and its
workers. Should be one of DEBUG, INFO, WARN, or ERROR. The DEBUG level
will also periodically print out summaries of relevant internal dataflow
(this is also printed out once at startup at the INFO level). When using
the `rllib train` command, you can also use the `-v` and `-vv` flags as
shorthand for INFO and DEBUG.
log_sys_usage: Log system resource metrics to results. This requires
`psutil` to be installed for sys stats, and `gputil` for GPU metrics.
fake_sampler: Use fake (infinite speed) sampler. For testing only.
seed: This argument, in conjunction with worker_index, sets the random
seed of each worker, so that identically configured trials will have
identical results. This makes experiments reproducible.
Returns:
This updated AlgorithmConfig object.
"""
if logger_creator is not None:
self.logger_creator = logger_creator
if logger_config is not None:
self.logger_config = logger_config
if log_level is not None:
self.log_level = log_level
if log_sys_usage is not None:
self.log_sys_usage = log_sys_usage
if fake_sampler is not None:
self.fake_sampler = fake_sampler
if seed is not None:
self.seed = seed
return self
def experimental(
self,
*,
_tf_policy_handles_more_than_one_loss=None,
_disable_preprocessor_api=None,
_disable_action_flattening=None,
_disable_execution_plan_api=None,
) -> "AlgorithmConfig":
"""Sets the config's experimental settings.
Args:
_tf_policy_handles_more_than_one_loss: Experimental flag.
If True, TFPolicy will handle more than one loss/optimizer.
Set this to True, if you would like to return more than
one loss term from your `loss_fn` and an equal number of optimizers
from your `optimizer_fn`. In the future, the default for this will be
True.
_disable_preprocessor_api: Experimental flag.
If True, no (observation) preprocessor will be created and
observations will arrive in model as they are returned by the env.
In the future, the default for this will be True.
_disable_action_flattening: Experimental flag.
If True, RLlib will no longer flatten the policy-computed actions into
a single tensor (for storage in SampleCollectors/output files/etc..),
but leave (possibly nested) actions as-is. Disabling flattening affects:
- SampleCollectors: Have to store possibly nested action structs.
- Models that have the previous action(s) as part of their input.
- Algorithms reading from offline files (incl. action information).
_disable_execution_plan_api: Experimental flag.
If True, the execution plan API will not be used. Instead,
a Algorithm's `training_iteration` method will be called as-is each
training iteration.
Returns:
This updated AlgorithmConfig object.
"""
if _tf_policy_handles_more_than_one_loss is not None:
self._tf_policy_handles_more_than_one_loss = (
_tf_policy_handles_more_than_one_loss
)
if _disable_preprocessor_api is not None:
self._disable_preprocessor_api = _disable_preprocessor_api
if _disable_action_flattening is not None:
self._disable_action_flattening = _disable_action_flattening
if _disable_execution_plan_api is not None:
self._disable_execution_plan_api = _disable_execution_plan_api
return self