ray/rllib/optimizers/policy_optimizer.py

133 lines
4.4 KiB
Python
Raw Normal View History

import logging
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.evaluation.metrics import collect_episodes, summarize_episodes
logger = logging.getLogger(__name__)
@DeveloperAPI
class PolicyOptimizer:
"""Policy optimizers encapsulate distributed RL optimization strategies.
Policy optimizers serve as the "control plane" of algorithms.
For example, AsyncOptimizer is used for A3C, and LocalMultiGPUOptimizer is
used for PPO. These optimizers are all pluggable, and it is possible
to mix and match as needed.
Attributes:
config (dict): The JSON configuration passed to this optimizer.
workers (WorkerSet): The set of rollout workers to use.
num_steps_trained (int): Number of timesteps trained on so far.
num_steps_sampled (int): Number of timesteps sampled so far.
"""
@DeveloperAPI
def __init__(self, workers):
"""Create an optimizer instance.
Args:
workers (WorkerSet): The set of rollout workers to use.
"""
self.workers = workers
self.episode_history = []
self.to_be_collected = []
# Counters that should be updated by sub-classes
self.num_steps_trained = 0
self.num_steps_sampled = 0
@DeveloperAPI
def step(self):
"""Takes a logical optimization step.
This should run for long enough to minimize call overheads (i.e., at
least a couple seconds), but short enough to return control
periodically to callers (i.e., at most a few tens of seconds).
Returns:
fetches (dict|None): Optional fetches from compute grads calls.
"""
raise NotImplementedError
@DeveloperAPI
def stats(self):
"""Returns a dictionary of internal performance statistics."""
return {
"num_steps_trained": self.num_steps_trained,
"num_steps_sampled": self.num_steps_sampled,
}
@DeveloperAPI
def save(self):
"""Returns a serializable object representing the optimizer state."""
return [self.num_steps_trained, self.num_steps_sampled]
@DeveloperAPI
def restore(self, data):
"""Restores optimizer state from the given data object."""
self.num_steps_trained = data[0]
self.num_steps_sampled = data[1]
@DeveloperAPI
def stop(self):
"""Release any resources used by this optimizer."""
pass
@DeveloperAPI
def collect_metrics(self,
timeout_seconds,
min_history=100,
selected_workers=None):
"""Returns worker and optimizer stats.
Arguments:
timeout_seconds (int): Max wait time for a worker before
dropping its results. This usually indicates a hung worker.
min_history (int): Min history length to smooth results over.
selected_workers (list): Override the list of remote workers
to collect metrics from.
Returns:
res (dict): A training result dict from worker metrics with
`info` replaced with stats from self.
"""
episodes, self.to_be_collected = collect_episodes(
self.workers.local_worker(),
selected_workers or self.workers.remote_workers(),
self.to_be_collected,
timeout_seconds=timeout_seconds)
orig_episodes = list(episodes)
missing = min_history - len(episodes)
if missing > 0:
episodes.extend(self.episode_history[-missing:])
assert len(episodes) <= min_history
self.episode_history.extend(orig_episodes)
self.episode_history = self.episode_history[-min_history:]
res = summarize_episodes(episodes, orig_episodes)
res.update(info=self.stats())
return res
@DeveloperAPI
def reset(self, remote_workers):
"""Called to change the set of remote workers being used."""
self.workers.reset(remote_workers)
@DeveloperAPI
def foreach_worker(self, func):
"""Apply the given function to each worker instance."""
return self.workers.foreach_worker(func)
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170) ## What do these changes do? **Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part). ``` # CartPole-v0 on single core with 64x64 MLP: # vector_width=1: Actions per second 2720.1284458322966 # vector_width=8: Actions per second 13773.035334888269 # vector_width=64: Actions per second 37903.20472563333 ``` **Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface. **Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs). Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example: ``` gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv rllib.ServingEnv => rllib.AsyncVectorEnv ```
2018-06-18 11:55:32 -07:00
@DeveloperAPI
def foreach_worker_with_index(self, func):
"""Apply the given function to each worker instance.
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170) ## What do these changes do? **Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part). ``` # CartPole-v0 on single core with 64x64 MLP: # vector_width=1: Actions per second 2720.1284458322966 # vector_width=8: Actions per second 13773.035334888269 # vector_width=64: Actions per second 37903.20472563333 ``` **Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface. **Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs). Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example: ``` gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv rllib.ServingEnv => rllib.AsyncVectorEnv ```
2018-06-18 11:55:32 -07:00
The index will be passed as the second arg to the given function.
"""
return self.workers.foreach_worker_with_index(func)