2020-06-04 22:47:32 +02:00
|
|
|
from abc import abstractmethod, ABCMeta
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
from collections import defaultdict, namedtuple
|
2018-10-21 23:43:57 -07:00
|
|
|
import logging
|
2018-10-20 15:21:22 -07:00
|
|
|
import numpy as np
|
2020-03-23 11:42:05 -07:00
|
|
|
import queue
|
2018-06-09 00:21:35 -07:00
|
|
|
import threading
|
2019-03-27 13:24:23 -07:00
|
|
|
import time
|
2020-08-21 12:35:16 +02:00
|
|
|
from typing import Any, Callable, Dict, List, Iterable, Optional, Set, Tuple,\
|
2020-12-30 20:32:21 -05:00
|
|
|
Type, TYPE_CHECKING, Union
|
2017-11-30 00:22:25 -08:00
|
|
|
|
2020-02-27 19:40:44 +01:00
|
|
|
from ray.util.debug import log_once
|
2020-10-01 16:57:10 +02:00
|
|
|
from ray.rllib.evaluation.collectors.sample_collector import \
|
2020-12-30 20:32:21 -05:00
|
|
|
SampleCollector
|
2020-10-01 16:57:10 +02:00
|
|
|
from ray.rllib.evaluation.collectors.simple_list_collector import \
|
2020-12-30 20:32:21 -05:00
|
|
|
SimpleListCollector
|
2020-04-23 09:09:22 +02:00
|
|
|
from ray.rllib.evaluation.episode import MultiAgentEpisode
|
2019-07-18 21:01:16 -07:00
|
|
|
from ray.rllib.evaluation.rollout_metrics import RolloutMetrics
|
2019-01-23 21:27:26 -08:00
|
|
|
from ray.rllib.evaluation.sample_batch_builder import \
|
|
|
|
MultiAgentSampleBatchBuilder
|
2019-03-29 21:19:42 +01:00
|
|
|
from ray.rllib.env.base_env import BaseEnv, ASYNC_RESET_RETURN
|
2021-01-19 10:09:39 +01:00
|
|
|
from ray.rllib.env.wrappers.atari_wrappers import get_wrapper_by_cls, \
|
|
|
|
MonitorEnv
|
2021-01-13 08:53:34 +01:00
|
|
|
from ray.rllib.models.preprocessors import Preprocessor
|
2019-02-13 16:25:05 -08:00
|
|
|
from ray.rllib.offline import InputReader
|
2021-06-30 12:32:11 +02:00
|
|
|
from ray.rllib.policy.policy import Policy
|
2020-06-04 22:47:32 +02:00
|
|
|
from ray.rllib.utils.annotations import override, DeveloperAPI
|
2020-02-27 19:40:44 +01:00
|
|
|
from ray.rllib.utils.debug import summarize
|
2021-06-21 13:46:01 +02:00
|
|
|
from ray.rllib.utils.deprecation import deprecation_warning
|
2021-01-13 08:53:34 +01:00
|
|
|
from ray.rllib.utils.filter import Filter
|
2020-08-21 12:35:16 +02:00
|
|
|
from ray.rllib.utils.numpy import convert_to_numpy
|
2021-06-30 12:32:11 +02:00
|
|
|
from ray.rllib.utils.spaces.space_utils import clip_action, \
|
|
|
|
unsquash_action, unbatch
|
2020-08-15 13:24:22 +02:00
|
|
|
from ray.rllib.utils.typing import SampleBatchType, AgentID, PolicyID, \
|
2020-06-19 13:09:05 -07:00
|
|
|
EnvObsType, EnvInfoDict, EnvID, MultiEnvDict, EnvActionType, \
|
|
|
|
TensorStructType
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
from ray.rllib.agents.callbacks import DefaultCallbacks
|
|
|
|
from ray.rllib.evaluation.observation_function import ObservationFunction
|
|
|
|
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
2021-05-03 14:23:28 -07:00
|
|
|
from ray.rllib.utils import try_import_tf
|
|
|
|
_, tf, _ = try_import_tf()
|
|
|
|
from gym.envs.classic_control.rendering import SimpleImageViewer
|
2020-04-28 14:59:16 +02:00
|
|
|
|
2018-10-21 23:43:57 -07:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
2018-12-18 10:40:01 -08:00
|
|
|
PolicyEvalData = namedtuple("PolicyEvalData", [
|
|
|
|
"env_id", "agent_id", "obs", "info", "rnn_state", "prev_action",
|
|
|
|
"prev_reward"
|
|
|
|
])
|
2017-11-30 00:22:25 -08:00
|
|
|
|
2020-06-19 13:09:05 -07:00
|
|
|
# A batch of RNN states with dimensions [state_index, batch, state_object].
|
|
|
|
StateBatch = List[List[Any]]
|
|
|
|
|
2017-11-30 00:22:25 -08:00
|
|
|
|
2020-09-03 17:27:05 +02:00
|
|
|
class NewEpisodeDefaultDict(defaultdict):
|
2020-11-19 19:01:14 +01:00
|
|
|
def __missing__(self, env_id):
|
2020-09-03 17:27:05 +02:00
|
|
|
if self.default_factory is None:
|
2020-11-19 19:01:14 +01:00
|
|
|
raise KeyError(env_id)
|
2020-09-03 17:27:05 +02:00
|
|
|
else:
|
2020-11-19 19:01:14 +01:00
|
|
|
ret = self[env_id] = self.default_factory(env_id)
|
2020-09-03 17:27:05 +02:00
|
|
|
return ret
|
|
|
|
|
|
|
|
|
2020-06-19 13:09:05 -07:00
|
|
|
class _PerfStats:
|
2019-03-27 13:24:23 -07:00
|
|
|
"""Sampler perf stats that will be included in rollout metrics."""
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
self.iters = 0
|
2020-08-21 12:35:16 +02:00
|
|
|
self.raw_obs_processing_time = 0.0
|
2019-03-27 13:24:23 -07:00
|
|
|
self.inference_time = 0.0
|
2020-08-21 12:35:16 +02:00
|
|
|
self.action_processing_time = 0.0
|
2021-02-08 12:05:16 +01:00
|
|
|
self.env_wait_time = 0.0
|
|
|
|
self.env_render_time = 0.0
|
2019-03-27 13:24:23 -07:00
|
|
|
|
|
|
|
def get(self):
|
2020-08-21 12:35:16 +02:00
|
|
|
# Mean multiplicator (1000 = ms -> sec).
|
|
|
|
factor = 1000 / self.iters
|
2019-03-27 13:24:23 -07:00
|
|
|
return {
|
2020-08-21 12:35:16 +02:00
|
|
|
# Raw observation preprocessing.
|
|
|
|
"mean_raw_obs_processing_ms": self.raw_obs_processing_time *
|
|
|
|
factor,
|
|
|
|
# Computing actions through policy.
|
|
|
|
"mean_inference_ms": self.inference_time * factor,
|
|
|
|
# Processing actions (to be sent to env, e.g. clipping).
|
|
|
|
"mean_action_processing_ms": self.action_processing_time * factor,
|
2021-02-08 12:05:16 +01:00
|
|
|
# Waiting for environment (during poll).
|
|
|
|
"mean_env_wait_ms": self.env_wait_time * factor,
|
|
|
|
# Environment rendering (False by default).
|
|
|
|
"mean_env_render_ms": self.env_render_time * factor,
|
2019-03-27 13:24:23 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
|
2020-06-04 22:47:32 +02:00
|
|
|
@DeveloperAPI
|
|
|
|
class SamplerInput(InputReader, metaclass=ABCMeta):
|
2019-02-13 16:25:05 -08:00
|
|
|
"""Reads input experiences from an existing sampler."""
|
|
|
|
|
|
|
|
@override(InputReader)
|
2020-06-19 13:09:05 -07:00
|
|
|
def next(self) -> SampleBatchType:
|
2019-02-13 16:25:05 -08:00
|
|
|
batches = [self.get_data()]
|
|
|
|
batches.extend(self.get_extra_batches())
|
|
|
|
if len(batches) > 1:
|
|
|
|
return batches[0].concat_samples(batches)
|
|
|
|
else:
|
|
|
|
return batches[0]
|
|
|
|
|
2020-06-04 22:47:32 +02:00
|
|
|
@abstractmethod
|
|
|
|
@DeveloperAPI
|
2020-06-19 13:09:05 -07:00
|
|
|
def get_data(self) -> SampleBatchType:
|
2020-06-04 22:47:32 +02:00
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
@DeveloperAPI
|
2020-06-19 13:09:05 -07:00
|
|
|
def get_metrics(self) -> List[RolloutMetrics]:
|
2020-06-04 22:47:32 +02:00
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
@DeveloperAPI
|
2020-06-19 13:09:05 -07:00
|
|
|
def get_extra_batches(self) -> List[SampleBatchType]:
|
2020-06-04 22:47:32 +02:00
|
|
|
raise NotImplementedError
|
|
|
|
|
2019-02-13 16:25:05 -08:00
|
|
|
|
2020-06-04 22:47:32 +02:00
|
|
|
@DeveloperAPI
|
2019-02-13 16:25:05 -08:00
|
|
|
class SyncSampler(SamplerInput):
|
2020-06-04 22:47:32 +02:00
|
|
|
"""Sync SamplerInput that collects experiences when `get_data()` is called.
|
|
|
|
"""
|
|
|
|
|
2020-12-30 20:32:21 -05:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
*,
|
|
|
|
worker: "RolloutWorker",
|
|
|
|
env: BaseEnv,
|
|
|
|
clip_rewards: bool,
|
|
|
|
rollout_fragment_length: int,
|
|
|
|
count_steps_by: str = "env_steps",
|
|
|
|
callbacks: "DefaultCallbacks",
|
|
|
|
horizon: int = None,
|
|
|
|
multiple_episodes_in_batch: bool = False,
|
2021-06-30 12:32:11 +02:00
|
|
|
normalize_actions: bool = True,
|
|
|
|
clip_actions: bool = False,
|
2020-12-30 20:32:21 -05:00
|
|
|
soft_horizon: bool = False,
|
|
|
|
no_done_at_end: bool = False,
|
|
|
|
observation_fn: "ObservationFunction" = None,
|
2021-02-08 12:05:16 +01:00
|
|
|
sample_collector_class: Optional[Type[SampleCollector]] = None,
|
|
|
|
render: bool = False,
|
2021-06-21 13:46:01 +02:00
|
|
|
# Obsolete.
|
|
|
|
policies=None,
|
|
|
|
policy_mapping_fn=None,
|
|
|
|
preprocessors=None,
|
|
|
|
obs_filters=None,
|
2021-07-19 13:16:03 -04:00
|
|
|
tf_sess=None,
|
2021-02-08 12:05:16 +01:00
|
|
|
):
|
2020-06-04 22:47:32 +02:00
|
|
|
"""Initializes a SyncSampler object.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
worker (RolloutWorker): The RolloutWorker that will use this
|
|
|
|
Sampler for sampling.
|
|
|
|
env (Env): Any Env object. Will be converted into an RLlib BaseEnv.
|
|
|
|
clip_rewards (Union[bool,float]): True for +/-1.0 clipping, actual
|
|
|
|
float value for +/- value clipping. False for no clipping.
|
|
|
|
rollout_fragment_length (int): The length of a fragment to collect
|
|
|
|
before building a SampleBatch from the data and resetting
|
|
|
|
the SampleBatchBuilder object.
|
|
|
|
callbacks (Callbacks): The Callbacks object to use when episode
|
|
|
|
events happen during rollout.
|
|
|
|
horizon (Optional[int]): Hard-reset the Env
|
2020-08-21 12:35:16 +02:00
|
|
|
multiple_episodes_in_batch (bool): Whether to pack multiple
|
2020-06-04 22:47:32 +02:00
|
|
|
episodes into each batch. This guarantees batches will be
|
|
|
|
exactly `rollout_fragment_length` in size.
|
2021-06-30 12:32:11 +02:00
|
|
|
normalize_actions (bool): Whether to normalize actions to the
|
|
|
|
action space's bounds.
|
2020-06-04 22:47:32 +02:00
|
|
|
clip_actions (bool): Whether to clip actions according to the
|
|
|
|
given action_space's bounds.
|
|
|
|
soft_horizon (bool): If True, calculate bootstrapped values as if
|
|
|
|
episode had ended, but don't physically reset the environment
|
|
|
|
when the horizon is hit.
|
|
|
|
no_done_at_end (bool): Ignore the done=True at the end of the
|
|
|
|
episode and instead record done=False.
|
|
|
|
observation_fn (Optional[ObservationFunction]): Optional
|
|
|
|
multi-agent observation func to use for preprocessing
|
|
|
|
observations.
|
2020-12-30 20:32:21 -05:00
|
|
|
sample_collector_class (Optional[Type[SampleCollector]]): An
|
|
|
|
optional Samplecollector sub-class to use to collect, store,
|
|
|
|
and retrieve environment-, model-, and sampler data.
|
2021-02-08 12:05:16 +01:00
|
|
|
render (bool): Whether to try to render the environment after each
|
|
|
|
step.
|
2020-06-04 22:47:32 +02:00
|
|
|
"""
|
2021-06-21 13:46:01 +02:00
|
|
|
# All of the following arguments are deprecated. They will instead be
|
|
|
|
# provided via the passed in `worker` arg, e.g. `worker.policy_map`.
|
|
|
|
if log_once("deprecated_sync_sampler_args"):
|
|
|
|
if policies is not None:
|
|
|
|
deprecation_warning(old="policies")
|
|
|
|
if policy_mapping_fn is not None:
|
|
|
|
deprecation_warning(old="policy_mapping_fn")
|
|
|
|
if preprocessors is not None:
|
|
|
|
deprecation_warning(old="preprocessors")
|
|
|
|
if obs_filters is not None:
|
|
|
|
deprecation_warning(old="obs_filters")
|
2021-07-19 13:16:03 -04:00
|
|
|
if tf_sess is not None:
|
|
|
|
deprecation_warning(old="tf_sess")
|
2020-06-04 22:47:32 +02:00
|
|
|
|
2019-01-23 21:27:26 -08:00
|
|
|
self.base_env = BaseEnv.to_base_env(env)
|
2020-03-14 12:05:04 -07:00
|
|
|
self.rollout_fragment_length = rollout_fragment_length
|
2017-12-14 01:08:23 -08:00
|
|
|
self.horizon = horizon
|
2018-08-16 14:37:21 -07:00
|
|
|
self.extra_batches = queue.Queue()
|
2020-06-19 13:09:05 -07:00
|
|
|
self.perf_stats = _PerfStats()
|
2021-03-23 17:50:18 +01:00
|
|
|
if not sample_collector_class:
|
|
|
|
sample_collector_class = SimpleListCollector
|
|
|
|
self.sample_collector = sample_collector_class(
|
2021-06-21 13:46:01 +02:00
|
|
|
worker.policy_map,
|
2021-03-23 17:50:18 +01:00
|
|
|
clip_rewards,
|
|
|
|
callbacks,
|
|
|
|
multiple_episodes_in_batch,
|
|
|
|
rollout_fragment_length,
|
|
|
|
count_steps_by=count_steps_by)
|
2021-02-08 12:05:16 +01:00
|
|
|
self.render = render
|
2020-08-21 12:35:16 +02:00
|
|
|
|
2020-06-04 22:47:32 +02:00
|
|
|
# Create the rollout generator to use for calls to `get_data()`.
|
2018-06-23 18:32:16 -07:00
|
|
|
self.rollout_provider = _env_runner(
|
2021-06-21 13:46:01 +02:00
|
|
|
worker, self.base_env, self.extra_batches.put,
|
|
|
|
self.rollout_fragment_length, self.horizon, clip_rewards,
|
2021-06-30 12:32:11 +02:00
|
|
|
normalize_actions, clip_actions, multiple_episodes_in_batch,
|
2021-07-19 13:16:03 -04:00
|
|
|
callbacks, self.perf_stats, soft_horizon, no_done_at_end,
|
2021-06-30 12:32:11 +02:00
|
|
|
observation_fn, self.sample_collector, self.render)
|
2017-11-30 00:22:25 -08:00
|
|
|
self.metrics_queue = queue.Queue()
|
|
|
|
|
2020-06-04 22:47:32 +02:00
|
|
|
@override(SamplerInput)
|
2020-06-19 13:09:05 -07:00
|
|
|
def get_data(self) -> SampleBatchType:
|
2017-11-30 00:22:25 -08:00
|
|
|
while True:
|
|
|
|
item = next(self.rollout_provider)
|
2018-06-23 18:32:16 -07:00
|
|
|
if isinstance(item, RolloutMetrics):
|
2017-11-30 00:22:25 -08:00
|
|
|
self.metrics_queue.put(item)
|
|
|
|
else:
|
2017-12-14 01:08:23 -08:00
|
|
|
return item
|
2017-11-30 00:22:25 -08:00
|
|
|
|
2020-06-04 22:47:32 +02:00
|
|
|
@override(SamplerInput)
|
2020-06-19 13:09:05 -07:00
|
|
|
def get_metrics(self) -> List[RolloutMetrics]:
|
2017-11-30 00:22:25 -08:00
|
|
|
completed = []
|
|
|
|
while True:
|
|
|
|
try:
|
2019-03-27 13:24:23 -07:00
|
|
|
completed.append(self.metrics_queue.get_nowait()._replace(
|
|
|
|
perf_stats=self.perf_stats.get()))
|
2017-11-30 00:22:25 -08:00
|
|
|
except queue.Empty:
|
|
|
|
break
|
|
|
|
return completed
|
|
|
|
|
2020-06-04 22:47:32 +02:00
|
|
|
@override(SamplerInput)
|
2020-06-19 13:09:05 -07:00
|
|
|
def get_extra_batches(self) -> List[SampleBatchType]:
|
2018-08-16 14:37:21 -07:00
|
|
|
extra = []
|
|
|
|
while True:
|
|
|
|
try:
|
|
|
|
extra.append(self.extra_batches.get_nowait())
|
|
|
|
except queue.Empty:
|
|
|
|
break
|
|
|
|
return extra
|
|
|
|
|
2017-11-30 00:22:25 -08:00
|
|
|
|
2020-06-04 22:47:32 +02:00
|
|
|
@DeveloperAPI
|
2019-02-13 16:25:05 -08:00
|
|
|
class AsyncSampler(threading.Thread, SamplerInput):
|
2020-06-04 22:47:32 +02:00
|
|
|
"""Async SamplerInput that collects experiences in thread and queues them.
|
|
|
|
|
|
|
|
Once started, experiences are continuously collected and put into a Queue,
|
|
|
|
from where they can be unqueued by the caller of `get_data()`.
|
|
|
|
"""
|
|
|
|
|
2020-12-30 20:32:21 -05:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
*,
|
|
|
|
worker: "RolloutWorker",
|
|
|
|
env: BaseEnv,
|
|
|
|
clip_rewards: bool,
|
|
|
|
rollout_fragment_length: int,
|
|
|
|
count_steps_by: str = "env_steps",
|
|
|
|
callbacks: "DefaultCallbacks",
|
|
|
|
horizon: int = None,
|
|
|
|
multiple_episodes_in_batch: bool = False,
|
2021-06-30 12:32:11 +02:00
|
|
|
normalize_actions: bool = True,
|
|
|
|
clip_actions: bool = False,
|
2020-12-30 20:32:21 -05:00
|
|
|
blackhole_outputs: bool = False,
|
|
|
|
soft_horizon: bool = False,
|
|
|
|
no_done_at_end: bool = False,
|
|
|
|
observation_fn: "ObservationFunction" = None,
|
|
|
|
sample_collector_class: Optional[Type[SampleCollector]] = None,
|
2021-02-08 12:05:16 +01:00
|
|
|
render: bool = False,
|
2021-06-21 13:46:01 +02:00
|
|
|
# Obsolete.
|
|
|
|
policies=None,
|
|
|
|
policy_mapping_fn=None,
|
|
|
|
preprocessors=None,
|
|
|
|
obs_filters=None,
|
2021-07-19 13:16:03 -04:00
|
|
|
tf_sess=None,
|
2020-12-30 20:32:21 -05:00
|
|
|
):
|
2020-06-04 22:47:32 +02:00
|
|
|
"""Initializes a AsyncSampler object.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
worker (RolloutWorker): The RolloutWorker that will use this
|
|
|
|
Sampler for sampling.
|
|
|
|
env (Env): Any Env object. Will be converted into an RLlib BaseEnv.
|
2020-06-19 13:09:05 -07:00
|
|
|
clip_rewards (Union[bool, float]): True for +/-1.0 clipping, actual
|
2020-06-04 22:47:32 +02:00
|
|
|
float value for +/- value clipping. False for no clipping.
|
|
|
|
rollout_fragment_length (int): The length of a fragment to collect
|
|
|
|
before building a SampleBatch from the data and resetting
|
|
|
|
the SampleBatchBuilder object.
|
2020-12-09 01:41:45 +01:00
|
|
|
count_steps_by (str): Either "env_steps" or "agent_steps".
|
|
|
|
Refers to the unit of `rollout_fragment_length`.
|
2020-06-04 22:47:32 +02:00
|
|
|
callbacks (Callbacks): The Callbacks object to use when episode
|
|
|
|
events happen during rollout.
|
|
|
|
horizon (Optional[int]): Hard-reset the Env
|
2020-08-21 12:35:16 +02:00
|
|
|
multiple_episodes_in_batch (bool): Whether to pack multiple
|
2020-06-04 22:47:32 +02:00
|
|
|
episodes into each batch. This guarantees batches will be
|
|
|
|
exactly `rollout_fragment_length` in size.
|
2021-06-30 12:32:11 +02:00
|
|
|
normalize_actions (bool): Whether to normalize actions to the
|
|
|
|
action space's bounds.
|
2020-06-04 22:47:32 +02:00
|
|
|
clip_actions (bool): Whether to clip actions according to the
|
|
|
|
given action_space's bounds.
|
|
|
|
blackhole_outputs (bool): Whether to collect samples, but then
|
|
|
|
not further process or store them (throw away all samples).
|
|
|
|
soft_horizon (bool): If True, calculate bootstrapped values as if
|
|
|
|
episode had ended, but don't physically reset the environment
|
|
|
|
when the horizon is hit.
|
|
|
|
no_done_at_end (bool): Ignore the done=True at the end of the
|
|
|
|
episode and instead record done=False.
|
|
|
|
observation_fn (Optional[ObservationFunction]): Optional
|
|
|
|
multi-agent observation func to use for preprocessing
|
|
|
|
observations.
|
2020-12-30 20:32:21 -05:00
|
|
|
sample_collector_class (Optional[Type[SampleCollector]]): An
|
|
|
|
optional Samplecollector sub-class to use to collect, store,
|
|
|
|
and retrieve environment-, model-, and sampler data.
|
2021-02-08 12:05:16 +01:00
|
|
|
render (bool): Whether to try to render the environment after each
|
|
|
|
step.
|
2020-06-04 22:47:32 +02:00
|
|
|
"""
|
2021-06-21 13:46:01 +02:00
|
|
|
# All of the following arguments are deprecated. They will instead be
|
|
|
|
# provided via the passed in `worker` arg, e.g. `worker.policy_map`.
|
|
|
|
if log_once("deprecated_async_sampler_args"):
|
|
|
|
if policies is not None:
|
|
|
|
deprecation_warning(old="policies")
|
|
|
|
if policy_mapping_fn is not None:
|
|
|
|
deprecation_warning(old="policy_mapping_fn")
|
|
|
|
if preprocessors is not None:
|
|
|
|
deprecation_warning(old="preprocessors")
|
|
|
|
if obs_filters is not None:
|
|
|
|
deprecation_warning(old="obs_filters")
|
2021-07-19 13:16:03 -04:00
|
|
|
if tf_sess is not None:
|
|
|
|
deprecation_warning(old="tf_sess")
|
2021-06-21 13:46:01 +02:00
|
|
|
|
|
|
|
self.worker = worker
|
|
|
|
|
|
|
|
for _, f in worker.filters.items():
|
2018-06-23 18:32:16 -07:00
|
|
|
assert getattr(f, "is_concurrent", False), \
|
|
|
|
"Observation Filter must support concurrent updates."
|
2021-06-21 13:46:01 +02:00
|
|
|
|
2019-01-23 21:27:26 -08:00
|
|
|
self.base_env = BaseEnv.to_base_env(env)
|
2017-11-30 00:22:25 -08:00
|
|
|
threading.Thread.__init__(self)
|
|
|
|
self.queue = queue.Queue(5)
|
2018-08-16 14:37:21 -07:00
|
|
|
self.extra_batches = queue.Queue()
|
2017-11-30 00:22:25 -08:00
|
|
|
self.metrics_queue = queue.Queue()
|
2020-03-14 12:05:04 -07:00
|
|
|
self.rollout_fragment_length = rollout_fragment_length
|
2017-12-14 01:08:23 -08:00
|
|
|
self.horizon = horizon
|
2018-08-20 15:28:03 -07:00
|
|
|
self.clip_rewards = clip_rewards
|
2018-01-23 10:31:19 -08:00
|
|
|
self.daemon = True
|
2020-08-21 12:35:16 +02:00
|
|
|
self.multiple_episodes_in_batch = multiple_episodes_in_batch
|
2018-11-03 18:48:32 -07:00
|
|
|
self.callbacks = callbacks
|
2021-06-30 12:32:11 +02:00
|
|
|
self.normalize_actions = normalize_actions
|
2018-12-03 19:55:25 -08:00
|
|
|
self.clip_actions = clip_actions
|
2018-12-12 13:57:48 -08:00
|
|
|
self.blackhole_outputs = blackhole_outputs
|
2019-04-02 02:44:15 -07:00
|
|
|
self.soft_horizon = soft_horizon
|
2019-08-01 23:37:36 -07:00
|
|
|
self.no_done_at_end = no_done_at_end
|
2020-06-19 13:09:05 -07:00
|
|
|
self.perf_stats = _PerfStats()
|
2018-12-12 13:57:48 -08:00
|
|
|
self.shutdown = False
|
2020-05-04 22:13:49 -07:00
|
|
|
self.observation_fn = observation_fn
|
2021-02-08 12:05:16 +01:00
|
|
|
self.render = render
|
2021-03-23 17:50:18 +01:00
|
|
|
if not sample_collector_class:
|
|
|
|
sample_collector_class = SimpleListCollector
|
|
|
|
self.sample_collector = sample_collector_class(
|
2021-06-21 13:46:01 +02:00
|
|
|
worker.policy_map,
|
2021-03-23 17:50:18 +01:00
|
|
|
clip_rewards,
|
|
|
|
callbacks,
|
|
|
|
multiple_episodes_in_batch,
|
|
|
|
rollout_fragment_length,
|
|
|
|
count_steps_by=count_steps_by)
|
2017-11-30 00:22:25 -08:00
|
|
|
|
2020-06-04 22:47:32 +02:00
|
|
|
@override(threading.Thread)
|
2017-11-30 00:22:25 -08:00
|
|
|
def run(self):
|
|
|
|
try:
|
|
|
|
self._run()
|
|
|
|
except BaseException as e:
|
|
|
|
self.queue.put(e)
|
|
|
|
raise e
|
|
|
|
|
|
|
|
def _run(self):
|
2018-12-12 13:57:48 -08:00
|
|
|
if self.blackhole_outputs:
|
|
|
|
queue_putter = (lambda x: None)
|
|
|
|
extra_batches_putter = (lambda x: None)
|
|
|
|
else:
|
|
|
|
queue_putter = self.queue.put
|
|
|
|
extra_batches_putter = (
|
|
|
|
lambda x: self.extra_batches.put(x, timeout=600.0))
|
2018-06-23 18:32:16 -07:00
|
|
|
rollout_provider = _env_runner(
|
2021-06-21 13:46:01 +02:00
|
|
|
self.worker, self.base_env, extra_batches_putter,
|
|
|
|
self.rollout_fragment_length, self.horizon, self.clip_rewards,
|
2021-06-30 12:32:11 +02:00
|
|
|
self.normalize_actions, self.clip_actions,
|
2021-07-19 13:16:03 -04:00
|
|
|
self.multiple_episodes_in_batch, self.callbacks, self.perf_stats,
|
|
|
|
self.soft_horizon, self.no_done_at_end, self.observation_fn,
|
|
|
|
self.sample_collector, self.render)
|
2018-12-12 13:57:48 -08:00
|
|
|
while not self.shutdown:
|
2017-11-30 00:22:25 -08:00
|
|
|
# The timeout variable exists because apparently, if one worker
|
|
|
|
# dies, the other workers won't die with it, unless the timeout is
|
|
|
|
# set to some large number. This is an empirical observation.
|
|
|
|
item = next(rollout_provider)
|
2018-06-23 18:32:16 -07:00
|
|
|
if isinstance(item, RolloutMetrics):
|
2017-11-30 00:22:25 -08:00
|
|
|
self.metrics_queue.put(item)
|
|
|
|
else:
|
2018-12-12 13:57:48 -08:00
|
|
|
queue_putter(item)
|
2017-11-30 00:22:25 -08:00
|
|
|
|
2020-06-04 22:47:32 +02:00
|
|
|
@override(SamplerInput)
|
2020-06-19 13:09:05 -07:00
|
|
|
def get_data(self) -> SampleBatchType:
|
2019-03-16 13:34:09 -07:00
|
|
|
if not self.is_alive():
|
|
|
|
raise RuntimeError("Sampling thread has died")
|
2017-11-30 00:22:25 -08:00
|
|
|
rollout = self.queue.get(timeout=600.0)
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
|
2020-07-29 21:15:09 +02:00
|
|
|
# Propagate errors.
|
2017-11-30 00:22:25 -08:00
|
|
|
if isinstance(rollout, BaseException):
|
|
|
|
raise rollout
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
|
2017-11-30 00:22:25 -08:00
|
|
|
return rollout
|
|
|
|
|
2020-06-04 22:47:32 +02:00
|
|
|
@override(SamplerInput)
|
2020-06-19 13:09:05 -07:00
|
|
|
def get_metrics(self) -> List[RolloutMetrics]:
|
2017-11-30 00:22:25 -08:00
|
|
|
completed = []
|
|
|
|
while True:
|
|
|
|
try:
|
2019-03-27 13:24:23 -07:00
|
|
|
completed.append(self.metrics_queue.get_nowait()._replace(
|
|
|
|
perf_stats=self.perf_stats.get()))
|
2017-11-30 00:22:25 -08:00
|
|
|
except queue.Empty:
|
|
|
|
break
|
|
|
|
return completed
|
|
|
|
|
2020-06-04 22:47:32 +02:00
|
|
|
@override(SamplerInput)
|
2020-06-19 13:09:05 -07:00
|
|
|
def get_extra_batches(self) -> List[SampleBatchType]:
|
2018-08-16 14:37:21 -07:00
|
|
|
extra = []
|
|
|
|
while True:
|
|
|
|
try:
|
|
|
|
extra.append(self.extra_batches.get_nowait())
|
|
|
|
except queue.Empty:
|
|
|
|
break
|
|
|
|
return extra
|
|
|
|
|
2017-11-30 00:22:25 -08:00
|
|
|
|
2020-08-07 16:49:49 -07:00
|
|
|
def _env_runner(
|
|
|
|
worker: "RolloutWorker",
|
|
|
|
base_env: BaseEnv,
|
|
|
|
extra_batch_callback: Callable[[SampleBatchType], None],
|
|
|
|
rollout_fragment_length: int,
|
|
|
|
horizon: int,
|
|
|
|
clip_rewards: bool,
|
2021-06-30 12:32:11 +02:00
|
|
|
normalize_actions: bool,
|
2020-08-07 16:49:49 -07:00
|
|
|
clip_actions: bool,
|
2020-08-21 12:35:16 +02:00
|
|
|
multiple_episodes_in_batch: bool,
|
2020-08-07 16:49:49 -07:00
|
|
|
callbacks: "DefaultCallbacks",
|
|
|
|
perf_stats: _PerfStats,
|
|
|
|
soft_horizon: bool,
|
|
|
|
no_done_at_end: bool,
|
|
|
|
observation_fn: "ObservationFunction",
|
2020-12-30 20:32:21 -05:00
|
|
|
sample_collector: Optional[SampleCollector] = None,
|
2021-02-08 12:05:16 +01:00
|
|
|
render: bool = None,
|
2020-08-21 12:35:16 +02:00
|
|
|
) -> Iterable[SampleBatchType]:
|
2018-06-23 18:32:16 -07:00
|
|
|
"""This implements the common experience collection logic.
|
2017-11-30 00:22:25 -08:00
|
|
|
|
|
|
|
Args:
|
2020-06-04 22:47:32 +02:00
|
|
|
worker (RolloutWorker): Reference to the current rollout worker.
|
|
|
|
base_env (BaseEnv): Env implementing BaseEnv.
|
2018-08-16 14:37:21 -07:00
|
|
|
extra_batch_callback (fn): function to send extra batch data to.
|
2020-03-14 12:05:04 -07:00
|
|
|
rollout_fragment_length (int): Number of episode steps before
|
|
|
|
`SampleBatch` is yielded. Set to infinity to yield complete
|
|
|
|
episodes.
|
2018-06-23 18:32:16 -07:00
|
|
|
horizon (int): Horizon of the episode.
|
2018-08-20 15:28:03 -07:00
|
|
|
clip_rewards (bool): Whether to clip rewards before postprocessing.
|
2020-08-21 12:35:16 +02:00
|
|
|
multiple_episodes_in_batch (bool): Whether to pack multiple
|
2020-06-04 22:47:32 +02:00
|
|
|
episodes into each batch. This guarantees batches will be exactly
|
|
|
|
`rollout_fragment_length` in size.
|
2021-06-30 12:32:11 +02:00
|
|
|
normalize_actions (bool): Whether to normalize actions to the action
|
|
|
|
space's bounds.
|
2018-12-03 19:55:25 -08:00
|
|
|
clip_actions (bool): Whether to clip actions to the space range.
|
2020-04-17 02:06:42 +03:00
|
|
|
callbacks (DefaultCallbacks): User callbacks to run on episode events.
|
2020-06-19 13:09:05 -07:00
|
|
|
perf_stats (_PerfStats): Record perf stats into this object.
|
2019-04-02 02:44:15 -07:00
|
|
|
soft_horizon (bool): Calculate rewards but don't reset the
|
|
|
|
environment when the horizon is hit.
|
2019-08-01 23:37:36 -07:00
|
|
|
no_done_at_end (bool): Ignore the done=True at the end of the episode
|
|
|
|
and instead record done=False.
|
2020-05-04 22:13:49 -07:00
|
|
|
observation_fn (ObservationFunction): Optional multi-agent
|
|
|
|
observation func to use for preprocessing observations.
|
2020-12-30 20:32:21 -05:00
|
|
|
sample_collector (Optional[SampleCollector]): An optional
|
2021-02-08 12:05:16 +01:00
|
|
|
SampleCollector object to use.
|
|
|
|
render (bool): Whether to try to render the environment after each
|
|
|
|
step.
|
2017-11-30 00:22:25 -08:00
|
|
|
|
|
|
|
Yields:
|
2018-06-09 00:21:35 -07:00
|
|
|
rollout (SampleBatch): Object containing state, action, reward,
|
2017-11-30 00:22:25 -08:00
|
|
|
terminal condition, and other fields as dictated by `policy`.
|
|
|
|
"""
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
|
2021-03-23 10:06:06 +01:00
|
|
|
# May be populated with used for image rendering
|
|
|
|
simple_image_viewer: Optional["SimpleImageViewer"] = None
|
|
|
|
|
2020-07-29 21:15:09 +02:00
|
|
|
# Try to get Env's `max_episode_steps` prop. If it doesn't exist, ignore
|
|
|
|
# error and continue with max_episode_steps=None.
|
2020-03-12 19:03:37 +01:00
|
|
|
max_episode_steps = None
|
2018-01-05 21:32:41 -08:00
|
|
|
try:
|
2020-03-12 19:03:37 +01:00
|
|
|
max_episode_steps = base_env.get_unwrapped()[0].spec.max_episode_steps
|
2018-01-05 21:32:41 -08:00
|
|
|
except Exception:
|
2020-03-12 19:03:37 +01:00
|
|
|
pass
|
|
|
|
|
|
|
|
# Trainer has a given `horizon` setting.
|
|
|
|
if horizon:
|
2020-12-30 20:32:21 -05:00
|
|
|
# `horizon` is larger than env's limit.
|
2020-03-12 19:03:37 +01:00
|
|
|
if max_episode_steps and horizon > max_episode_steps:
|
2020-12-30 20:32:21 -05:00
|
|
|
# Try to override the env's own max-step setting with our horizon.
|
|
|
|
# If this won't work, throw an error.
|
|
|
|
try:
|
|
|
|
base_env.get_unwrapped()[0].spec.max_episode_steps = horizon
|
|
|
|
base_env.get_unwrapped()[0]._max_episode_steps = horizon
|
|
|
|
except Exception:
|
|
|
|
raise ValueError(
|
|
|
|
"Your `horizon` setting ({}) is larger than the Env's own "
|
|
|
|
"timestep limit ({}), which seems to be unsettable! Try "
|
|
|
|
"to increase the Env's built-in limit to be at least as "
|
|
|
|
"large as your wanted `horizon`.".format(
|
|
|
|
horizon, max_episode_steps))
|
2020-03-12 19:03:37 +01:00
|
|
|
# Otherwise, set Trainer's horizon to env's max-steps.
|
|
|
|
elif max_episode_steps:
|
|
|
|
horizon = max_episode_steps
|
|
|
|
logger.debug(
|
|
|
|
"No episode horizon specified, setting it to Env's limit ({}).".
|
|
|
|
format(max_episode_steps))
|
2020-12-30 20:32:21 -05:00
|
|
|
# No horizon/max_episode_steps -> Episodes may be infinitely long.
|
2020-03-12 19:03:37 +01:00
|
|
|
else:
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
horizon = float("inf")
|
2020-03-12 19:03:37 +01:00
|
|
|
logger.debug("No episode horizon specified, assuming inf.")
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
|
|
|
|
# Pool of batch builders, which can be shared across episodes to pack
|
|
|
|
# trajectory data.
|
2020-06-19 13:09:05 -07:00
|
|
|
batch_builder_pool: List[MultiAgentSampleBatchBuilder] = []
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
|
|
|
|
def get_batch_builder():
|
|
|
|
if batch_builder_pool:
|
|
|
|
return batch_builder_pool.pop()
|
|
|
|
else:
|
2021-03-23 17:50:18 +01:00
|
|
|
return None
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
|
2020-11-19 19:01:14 +01:00
|
|
|
def new_episode(env_id):
|
|
|
|
episode = MultiAgentEpisode(
|
2021-06-21 13:46:01 +02:00
|
|
|
worker.policy_map,
|
|
|
|
worker.policy_mapping_fn,
|
2020-11-19 19:01:14 +01:00
|
|
|
get_batch_builder,
|
|
|
|
extra_batch_callback,
|
|
|
|
env_id=env_id)
|
2020-03-29 00:16:30 +01:00
|
|
|
# Call each policy's Exploration.on_episode_start method.
|
2021-06-24 22:06:33 -07:00
|
|
|
# types: Policy
|
2021-06-21 13:46:01 +02:00
|
|
|
for p in worker.policy_map.values():
|
2020-04-27 23:19:26 +02:00
|
|
|
if getattr(p, "exploration", None) is not None:
|
|
|
|
p.exploration.on_episode_start(
|
|
|
|
policy=p,
|
|
|
|
environment=base_env,
|
|
|
|
episode=episode,
|
2021-07-19 13:16:03 -04:00
|
|
|
tf_sess=p.get_session())
|
2020-04-17 02:06:42 +03:00
|
|
|
callbacks.on_episode_start(
|
|
|
|
worker=worker,
|
|
|
|
base_env=base_env,
|
2021-06-21 13:46:01 +02:00
|
|
|
policies=worker.policy_map,
|
2020-09-03 17:27:05 +02:00
|
|
|
episode=episode,
|
2020-11-19 19:01:14 +01:00
|
|
|
env_index=env_id,
|
2020-09-03 17:27:05 +02:00
|
|
|
)
|
2018-11-03 18:48:32 -07:00
|
|
|
return episode
|
2018-07-19 15:30:36 -07:00
|
|
|
|
2020-09-03 17:27:05 +02:00
|
|
|
active_episodes: Dict[str, MultiAgentEpisode] = \
|
|
|
|
NewEpisodeDefaultDict(new_episode)
|
2017-11-30 00:22:25 -08:00
|
|
|
|
|
|
|
while True:
|
2019-03-27 13:24:23 -07:00
|
|
|
perf_stats.iters += 1
|
|
|
|
t0 = time.time()
|
2020-05-30 22:48:34 +02:00
|
|
|
# Get observations from all ready agents.
|
2021-06-24 22:06:33 -07:00
|
|
|
# types: MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict, ...
|
2018-06-20 13:22:39 -07:00
|
|
|
unfiltered_obs, rewards, dones, infos, off_policy_actions = \
|
2019-01-23 21:27:26 -08:00
|
|
|
base_env.poll()
|
2019-03-27 13:24:23 -07:00
|
|
|
perf_stats.env_wait_time += time.time() - t0
|
2018-06-23 18:32:16 -07:00
|
|
|
|
2019-03-26 00:27:59 -07:00
|
|
|
if log_once("env_returns"):
|
|
|
|
logger.info("Raw obs from env: {}".format(
|
|
|
|
summarize(unfiltered_obs)))
|
|
|
|
logger.info("Info return from env: {}".format(summarize(infos)))
|
|
|
|
|
2020-05-30 22:48:34 +02:00
|
|
|
# Process observations and prepare for policy evaluation.
|
2019-03-27 13:24:23 -07:00
|
|
|
t1 = time.time()
|
2021-06-24 22:06:33 -07:00
|
|
|
# types: Set[EnvID], Dict[PolicyID, List[PolicyEvalData]],
|
2020-06-19 13:09:05 -07:00
|
|
|
# List[Union[RolloutMetrics, SampleBatchType]]
|
2021-03-23 17:50:18 +01:00
|
|
|
active_envs, to_eval, outputs = \
|
|
|
|
_process_observations(
|
2020-08-21 12:35:16 +02:00
|
|
|
worker=worker,
|
|
|
|
base_env=base_env,
|
|
|
|
active_episodes=active_episodes,
|
|
|
|
unfiltered_obs=unfiltered_obs,
|
|
|
|
rewards=rewards,
|
|
|
|
dones=dones,
|
|
|
|
infos=infos,
|
|
|
|
horizon=horizon,
|
|
|
|
multiple_episodes_in_batch=multiple_episodes_in_batch,
|
|
|
|
callbacks=callbacks,
|
|
|
|
soft_horizon=soft_horizon,
|
|
|
|
no_done_at_end=no_done_at_end,
|
|
|
|
observation_fn=observation_fn,
|
2021-03-23 17:50:18 +01:00
|
|
|
sample_collector=sample_collector,
|
2020-08-21 12:35:16 +02:00
|
|
|
)
|
|
|
|
perf_stats.raw_obs_processing_time += time.time() - t1
|
2018-11-24 18:16:54 -08:00
|
|
|
for o in outputs:
|
|
|
|
yield o
|
|
|
|
|
2020-05-30 22:48:34 +02:00
|
|
|
# Do batched policy eval (accross vectorized envs).
|
2019-03-27 13:24:23 -07:00
|
|
|
t2 = time.time()
|
2021-06-24 22:06:33 -07:00
|
|
|
# types: Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]
|
2021-03-23 17:50:18 +01:00
|
|
|
eval_results = _do_policy_eval(
|
|
|
|
to_eval=to_eval,
|
2021-06-21 13:46:01 +02:00
|
|
|
policies=worker.policy_map,
|
|
|
|
policy_mapping_fn=worker.policy_mapping_fn,
|
2021-03-23 17:50:18 +01:00
|
|
|
sample_collector=sample_collector,
|
|
|
|
active_episodes=active_episodes,
|
|
|
|
)
|
2019-03-27 13:24:23 -07:00
|
|
|
perf_stats.inference_time += time.time() - t2
|
2018-11-24 18:16:54 -08:00
|
|
|
|
2020-05-30 22:48:34 +02:00
|
|
|
# Process results and update episode state.
|
2019-03-27 13:24:23 -07:00
|
|
|
t3 = time.time()
|
2020-06-19 13:09:05 -07:00
|
|
|
actions_to_send: Dict[EnvID, Dict[AgentID, EnvActionType]] = \
|
|
|
|
_process_policy_eval_results(
|
|
|
|
to_eval=to_eval,
|
|
|
|
eval_results=eval_results,
|
|
|
|
active_episodes=active_episodes,
|
|
|
|
active_envs=active_envs,
|
|
|
|
off_policy_actions=off_policy_actions,
|
2021-06-21 13:46:01 +02:00
|
|
|
policies=worker.policy_map,
|
2021-06-30 12:32:11 +02:00
|
|
|
normalize_actions=normalize_actions,
|
2020-07-14 04:27:49 +02:00
|
|
|
clip_actions=clip_actions,
|
2020-08-21 12:35:16 +02:00
|
|
|
)
|
|
|
|
perf_stats.action_processing_time += time.time() - t3
|
2018-06-23 18:32:16 -07:00
|
|
|
|
2018-11-24 18:16:54 -08:00
|
|
|
# Return computed actions to ready envs. We also send to envs that have
|
|
|
|
# taken off-policy actions; those envs are free to ignore the action.
|
2019-03-27 13:24:23 -07:00
|
|
|
t4 = time.time()
|
2019-01-23 21:27:26 -08:00
|
|
|
base_env.send_actions(actions_to_send)
|
2019-03-27 13:24:23 -07:00
|
|
|
perf_stats.env_wait_time += time.time() - t4
|
2018-06-25 22:33:57 -07:00
|
|
|
|
2021-02-08 12:05:16 +01:00
|
|
|
# Try to render the env, if required.
|
|
|
|
if render:
|
|
|
|
t5 = time.time()
|
2021-03-23 10:06:06 +01:00
|
|
|
# Render can either return an RGB image (uint8 [w x h x 3] numpy
|
|
|
|
# array) or take care of rendering itself (returning True).
|
|
|
|
rendered = base_env.try_render()
|
|
|
|
# Rendering returned an image -> Display it in a SimpleImageViewer.
|
|
|
|
if isinstance(rendered, np.ndarray) and len(rendered.shape) == 3:
|
|
|
|
# ImageViewer not defined yet, try to create one.
|
|
|
|
if simple_image_viewer is None:
|
|
|
|
try:
|
|
|
|
from gym.envs.classic_control.rendering import \
|
|
|
|
SimpleImageViewer
|
|
|
|
simple_image_viewer = SimpleImageViewer()
|
|
|
|
except (ImportError, ModuleNotFoundError):
|
|
|
|
render = False # disable rendering
|
|
|
|
logger.warning(
|
|
|
|
"Could not import gym.envs.classic_control."
|
|
|
|
"rendering! Try `pip install gym[all]`.")
|
|
|
|
if simple_image_viewer:
|
|
|
|
simple_image_viewer.imshow(rendered)
|
2021-06-19 08:57:53 +02:00
|
|
|
elif rendered not in [True, False, None]:
|
|
|
|
raise ValueError(
|
|
|
|
"The env's ({base_env}) `try_render()` method returned an"
|
|
|
|
" unsupported value! Make sure you either return a "
|
|
|
|
"uint8/w x h x 3 (RGB) image or handle rendering in a "
|
|
|
|
"window and then return `True`.")
|
2021-02-08 12:05:16 +01:00
|
|
|
perf_stats.env_render_time += time.time() - t5
|
|
|
|
|
2018-11-24 18:16:54 -08:00
|
|
|
|
2020-05-04 22:13:49 -07:00
|
|
|
def _process_observations(
|
2020-08-21 12:35:16 +02:00
|
|
|
*,
|
2020-08-07 16:49:49 -07:00
|
|
|
worker: "RolloutWorker",
|
|
|
|
base_env: BaseEnv,
|
2020-06-19 13:09:05 -07:00
|
|
|
active_episodes: Dict[str, MultiAgentEpisode],
|
|
|
|
unfiltered_obs: Dict[EnvID, Dict[AgentID, EnvObsType]],
|
|
|
|
rewards: Dict[EnvID, Dict[AgentID, float]],
|
|
|
|
dones: Dict[EnvID, Dict[AgentID, bool]],
|
2020-08-07 16:49:49 -07:00
|
|
|
infos: Dict[EnvID, Dict[AgentID, EnvInfoDict]],
|
|
|
|
horizon: int,
|
2020-08-21 12:35:16 +02:00
|
|
|
multiple_episodes_in_batch: bool,
|
2020-08-07 16:49:49 -07:00
|
|
|
callbacks: "DefaultCallbacks",
|
|
|
|
soft_horizon: bool,
|
|
|
|
no_done_at_end: bool,
|
2020-07-14 04:27:49 +02:00
|
|
|
observation_fn: "ObservationFunction",
|
2021-03-23 17:50:18 +01:00
|
|
|
sample_collector: SampleCollector,
|
2020-06-19 13:09:05 -07:00
|
|
|
) -> Tuple[Set[EnvID], Dict[PolicyID, List[PolicyEvalData]], List[Union[
|
|
|
|
RolloutMetrics, SampleBatchType]]]:
|
2018-11-24 18:16:54 -08:00
|
|
|
"""Record new data from the environment and prepare for policy evaluation.
|
|
|
|
|
2020-06-04 22:47:32 +02:00
|
|
|
Args:
|
|
|
|
worker (RolloutWorker): Reference to the current rollout worker.
|
|
|
|
base_env (BaseEnv): Env implementing BaseEnv.
|
|
|
|
batch_builder_pool (List[SampleBatchBuilder]): List of pooled
|
|
|
|
SampleBatchBuilder object for recycling.
|
2020-06-19 13:09:05 -07:00
|
|
|
active_episodes (Dict[str, MultiAgentEpisode]): Mapping from
|
2020-06-04 22:47:32 +02:00
|
|
|
episode ID to currently ongoing MultiAgentEpisode object.
|
2020-08-21 12:35:16 +02:00
|
|
|
unfiltered_obs (dict): Doubly keyed dict of env-ids -> agent ids
|
|
|
|
-> unfiltered observation tensor, returned by a `BaseEnv.poll()`
|
|
|
|
call.
|
2020-06-04 22:47:32 +02:00
|
|
|
rewards (dict): Doubly keyed dict of env-ids -> agent ids ->
|
|
|
|
rewards tensor, returned by a `BaseEnv.poll()` call.
|
|
|
|
dones (dict): Doubly keyed dict of env-ids -> agent ids ->
|
|
|
|
boolean done flags, returned by a `BaseEnv.poll()` call.
|
|
|
|
infos (dict): Doubly keyed dict of env-ids -> agent ids ->
|
|
|
|
info dicts, returned by a `BaseEnv.poll()` call.
|
|
|
|
horizon (int): Horizon of the episode.
|
|
|
|
rollout_fragment_length (int): Number of episode steps before
|
|
|
|
`SampleBatch` is yielded. Set to infinity to yield complete
|
|
|
|
episodes.
|
2020-08-21 12:35:16 +02:00
|
|
|
multiple_episodes_in_batch (bool): Whether to pack multiple
|
2020-06-04 22:47:32 +02:00
|
|
|
episodes into each batch. This guarantees batches will be exactly
|
|
|
|
`rollout_fragment_length` in size.
|
|
|
|
callbacks (DefaultCallbacks): User callbacks to run on episode events.
|
|
|
|
soft_horizon (bool): Calculate rewards but don't reset the
|
|
|
|
environment when the horizon is hit.
|
|
|
|
no_done_at_end (bool): Ignore the done=True at the end of the episode
|
|
|
|
and instead record done=False.
|
|
|
|
observation_fn (ObservationFunction): Optional multi-agent
|
|
|
|
observation func to use for preprocessing observations.
|
2021-03-23 17:50:18 +01:00
|
|
|
sample_collector (SampleCollector): The SampleCollector object
|
|
|
|
used to store and retrieve environment samples.
|
2020-06-04 22:47:32 +02:00
|
|
|
|
2018-11-24 18:16:54 -08:00
|
|
|
Returns:
|
2020-06-04 22:47:32 +02:00
|
|
|
Tuple:
|
|
|
|
- active_envs: Set of non-terminated env ids.
|
|
|
|
- to_eval: Map of policy_id to list of agent PolicyEvalData.
|
|
|
|
- outputs: List of metrics and samples to return from the sampler.
|
2018-11-24 18:16:54 -08:00
|
|
|
"""
|
|
|
|
|
2020-06-19 13:09:05 -07:00
|
|
|
# Output objects.
|
|
|
|
active_envs: Set[EnvID] = set()
|
|
|
|
to_eval: Dict[PolicyID, List[PolicyEvalData]] = defaultdict(list)
|
|
|
|
outputs: List[Union[RolloutMetrics, SampleBatchType]] = []
|
|
|
|
|
2020-10-01 16:57:10 +02:00
|
|
|
# For each (vectorized) sub-environment.
|
2021-06-24 22:06:33 -07:00
|
|
|
# types: EnvID, Dict[AgentID, EnvObsType]
|
2020-10-01 16:57:10 +02:00
|
|
|
for env_id, all_agents_obs in unfiltered_obs.items():
|
2020-08-21 12:35:16 +02:00
|
|
|
is_new_episode: bool = env_id not in active_episodes
|
|
|
|
episode: MultiAgentEpisode = active_episodes[env_id]
|
|
|
|
|
|
|
|
if not is_new_episode:
|
2021-03-23 10:30:53 +01:00
|
|
|
sample_collector.episode_step(episode)
|
2020-08-21 12:35:16 +02:00
|
|
|
episode._add_agent_rewards(rewards[env_id])
|
|
|
|
|
|
|
|
# Check episode termination conditions.
|
|
|
|
if dones[env_id]["__all__"] or episode.length >= horizon:
|
|
|
|
hit_horizon = (episode.length >= horizon
|
|
|
|
and not dones[env_id]["__all__"])
|
|
|
|
all_agents_done = True
|
|
|
|
atari_metrics: List[RolloutMetrics] = _fetch_atari_metrics(
|
|
|
|
base_env)
|
|
|
|
if atari_metrics is not None:
|
|
|
|
for m in atari_metrics:
|
|
|
|
outputs.append(
|
|
|
|
m._replace(custom_metrics=episode.custom_metrics))
|
|
|
|
else:
|
|
|
|
outputs.append(
|
|
|
|
RolloutMetrics(episode.length, episode.total_reward,
|
|
|
|
dict(episode.agent_rewards),
|
|
|
|
episode.custom_metrics, {},
|
2021-03-19 08:17:09 +00:00
|
|
|
episode.hist_data, episode.media))
|
2021-06-21 13:46:01 +02:00
|
|
|
# Check whether we have to create a fake-last observation
|
|
|
|
# for some agents (the environment is not required to do so if
|
|
|
|
# dones[__all__]=True).
|
|
|
|
for ag_id in episode.get_agents():
|
|
|
|
if not episode.last_done_for(
|
|
|
|
ag_id) and ag_id not in all_agents_obs:
|
|
|
|
# Create a fake (all-0s) observation.
|
|
|
|
obs_sp = worker.policy_map[episode.policy_for(
|
|
|
|
ag_id)].observation_space
|
|
|
|
obs_sp = getattr(obs_sp, "original_space", obs_sp)
|
|
|
|
all_agents_obs[ag_id] = np.zeros_like(obs_sp.sample())
|
2020-08-21 12:35:16 +02:00
|
|
|
else:
|
|
|
|
hit_horizon = False
|
|
|
|
all_agents_done = False
|
|
|
|
active_envs.add(env_id)
|
|
|
|
|
|
|
|
# Custom observation function is applied before preprocessing.
|
|
|
|
if observation_fn:
|
2020-10-01 16:57:10 +02:00
|
|
|
all_agents_obs: Dict[AgentID, EnvObsType] = observation_fn(
|
|
|
|
agent_obs=all_agents_obs,
|
2020-08-21 12:35:16 +02:00
|
|
|
worker=worker,
|
|
|
|
base_env=base_env,
|
2021-06-21 13:46:01 +02:00
|
|
|
policies=worker.policy_map,
|
2020-08-21 12:35:16 +02:00
|
|
|
episode=episode)
|
2020-10-01 16:57:10 +02:00
|
|
|
if not isinstance(all_agents_obs, dict):
|
2020-08-21 12:35:16 +02:00
|
|
|
raise ValueError(
|
|
|
|
"observe() must return a dict of agent observations")
|
|
|
|
|
|
|
|
# For each agent in the environment.
|
2021-06-24 22:06:33 -07:00
|
|
|
# types: AgentID, EnvObsType
|
2020-10-01 16:57:10 +02:00
|
|
|
for agent_id, raw_obs in all_agents_obs.items():
|
2020-08-21 12:35:16 +02:00
|
|
|
assert agent_id != "__all__"
|
2021-02-18 14:07:49 +01:00
|
|
|
|
|
|
|
last_observation: EnvObsType = episode.last_observation_for(
|
|
|
|
agent_id)
|
|
|
|
agent_done = bool(all_agents_done or dones[env_id].get(agent_id))
|
|
|
|
|
|
|
|
# A new agent (initial obs) is already done -> Skip entirely.
|
|
|
|
if last_observation is None and agent_done:
|
|
|
|
continue
|
|
|
|
|
2020-08-21 12:35:16 +02:00
|
|
|
policy_id: PolicyID = episode.policy_for(agent_id)
|
2021-02-18 14:07:49 +01:00
|
|
|
|
2021-06-21 13:46:01 +02:00
|
|
|
prep_obs: EnvObsType = _get_or_raise(worker.preprocessors,
|
2020-08-21 12:35:16 +02:00
|
|
|
policy_id).transform(raw_obs)
|
|
|
|
if log_once("prep_obs"):
|
|
|
|
logger.info("Preprocessed obs: {}".format(summarize(prep_obs)))
|
2021-06-21 13:46:01 +02:00
|
|
|
filtered_obs: EnvObsType = _get_or_raise(worker.filters,
|
2020-08-21 12:35:16 +02:00
|
|
|
policy_id)(prep_obs)
|
|
|
|
if log_once("filtered_obs"):
|
|
|
|
logger.info("Filtered obs: {}".format(summarize(filtered_obs)))
|
|
|
|
|
|
|
|
episode._set_last_observation(agent_id, filtered_obs)
|
|
|
|
episode._set_last_raw_obs(agent_id, raw_obs)
|
2021-06-21 13:46:01 +02:00
|
|
|
episode._set_last_done(agent_id, agent_done)
|
2020-11-28 01:25:47 +01:00
|
|
|
# Infos from the environment.
|
|
|
|
agent_infos = infos[env_id].get(agent_id, {})
|
|
|
|
episode._set_last_info(agent_id, agent_infos)
|
2020-08-21 12:35:16 +02:00
|
|
|
|
|
|
|
# Record transition info if applicable.
|
|
|
|
if last_observation is None:
|
2020-12-30 20:32:21 -05:00
|
|
|
sample_collector.add_init_obs(episode, agent_id, env_id,
|
|
|
|
policy_id, episode.length - 1,
|
|
|
|
filtered_obs)
|
2020-08-21 12:35:16 +02:00
|
|
|
else:
|
2020-10-01 16:57:10 +02:00
|
|
|
# Add actions, rewards, next-obs to collectors.
|
2020-08-21 12:35:16 +02:00
|
|
|
values_dict = {
|
|
|
|
"t": episode.length - 1,
|
2020-10-01 16:57:10 +02:00
|
|
|
"env_id": env_id,
|
2020-08-21 12:35:16 +02:00
|
|
|
"agent_index": episode._agent_index(agent_id),
|
|
|
|
# Action (slot 0) taken at timestep t.
|
2020-10-01 16:57:10 +02:00
|
|
|
"actions": episode.last_action_for(agent_id),
|
2020-08-21 12:35:16 +02:00
|
|
|
# Reward received after taking a at timestep t.
|
2021-06-21 13:46:01 +02:00
|
|
|
"rewards": rewards[env_id].get(agent_id, 0.0),
|
2020-10-01 16:57:10 +02:00
|
|
|
# After taking action=a, did we reach terminal?
|
2020-08-21 12:35:16 +02:00
|
|
|
"dones": (False if (no_done_at_end
|
|
|
|
or (hit_horizon and soft_horizon)) else
|
|
|
|
agent_done),
|
|
|
|
# Next observation.
|
|
|
|
"new_obs": filtered_obs,
|
|
|
|
}
|
2020-10-01 16:57:10 +02:00
|
|
|
# Add extra-action-fetches to collectors.
|
2021-06-21 13:46:01 +02:00
|
|
|
pol = worker.policy_map[policy_id]
|
2020-11-28 01:25:47 +01:00
|
|
|
for key, value in episode.last_pi_info_for(agent_id).items():
|
2020-12-02 02:41:10 +01:00
|
|
|
if key in pol.view_requirements:
|
|
|
|
values_dict[key] = value
|
2020-11-28 01:25:47 +01:00
|
|
|
# Env infos for this agent.
|
|
|
|
if "infos" in pol.view_requirements:
|
|
|
|
values_dict["infos"] = agent_infos
|
2020-12-30 20:32:21 -05:00
|
|
|
sample_collector.add_action_reward_next_obs(
|
2020-08-21 12:35:16 +02:00
|
|
|
episode.episode_id, agent_id, env_id, policy_id,
|
|
|
|
agent_done, values_dict)
|
|
|
|
|
|
|
|
if not agent_done:
|
2020-10-01 16:57:10 +02:00
|
|
|
item = PolicyEvalData(
|
2020-11-28 01:25:47 +01:00
|
|
|
env_id, agent_id, filtered_obs, agent_infos, None
|
|
|
|
if last_observation is None else
|
2020-10-01 16:57:10 +02:00
|
|
|
episode.rnn_state_for(agent_id), None
|
|
|
|
if last_observation is None else
|
2021-06-21 13:46:01 +02:00
|
|
|
episode.last_action_for(agent_id), rewards[env_id].get(
|
|
|
|
agent_id, 0.0))
|
2020-10-01 16:57:10 +02:00
|
|
|
to_eval[policy_id].append(item)
|
2020-08-21 12:35:16 +02:00
|
|
|
|
2021-04-11 13:16:17 +02:00
|
|
|
# Invoke the `on_episode_step` callback after the step is logged
|
|
|
|
# to the episode.
|
|
|
|
# Exception: The very first env.poll() call causes the env to get reset
|
|
|
|
# (no step taken yet, just a single starting observation logged).
|
|
|
|
# We need to skip this callback in this case.
|
|
|
|
if episode.length > 0:
|
|
|
|
callbacks.on_episode_step(
|
|
|
|
worker=worker,
|
|
|
|
base_env=base_env,
|
|
|
|
episode=episode,
|
|
|
|
env_index=env_id)
|
2020-08-21 12:35:16 +02:00
|
|
|
|
2020-11-19 19:01:14 +01:00
|
|
|
# Episode is done for all agents (dones[__all__] == True)
|
|
|
|
# or we hit the horizon.
|
2020-08-21 12:35:16 +02:00
|
|
|
if all_agents_done:
|
2020-10-01 16:57:10 +02:00
|
|
|
is_done = dones[env_id]["__all__"]
|
|
|
|
check_dones = is_done and not no_done_at_end
|
2020-11-19 19:01:14 +01:00
|
|
|
|
|
|
|
# If, we are not allowed to pack the next episode into the same
|
2020-10-01 16:57:10 +02:00
|
|
|
# SampleBatch (batch_mode=complete_episodes) -> Build the
|
|
|
|
# MultiAgentBatch from a single episode and add it to "outputs".
|
2020-11-19 19:01:14 +01:00
|
|
|
# Otherwise, just postprocess and continue collecting across
|
|
|
|
# episodes.
|
2020-12-30 20:32:21 -05:00
|
|
|
ma_sample_batch = sample_collector.postprocess_episode(
|
2020-11-19 19:01:14 +01:00
|
|
|
episode,
|
|
|
|
is_done=is_done or (hit_horizon and not soft_horizon),
|
|
|
|
check_dones=check_dones,
|
|
|
|
build=not multiple_episodes_in_batch)
|
|
|
|
if ma_sample_batch:
|
2020-10-01 16:57:10 +02:00
|
|
|
outputs.append(ma_sample_batch)
|
|
|
|
|
2020-08-21 12:35:16 +02:00
|
|
|
# Call each policy's Exploration.on_episode_end method.
|
2021-06-21 13:46:01 +02:00
|
|
|
for p in worker.policy_map.values():
|
2020-08-21 12:35:16 +02:00
|
|
|
if getattr(p, "exploration", None) is not None:
|
|
|
|
p.exploration.on_episode_end(
|
|
|
|
policy=p,
|
|
|
|
environment=base_env,
|
|
|
|
episode=episode,
|
2021-07-19 13:16:03 -04:00
|
|
|
tf_sess=p.get_session())
|
2020-08-21 12:35:16 +02:00
|
|
|
# Call custom on_episode_end callback.
|
|
|
|
callbacks.on_episode_end(
|
|
|
|
worker=worker,
|
|
|
|
base_env=base_env,
|
2021-06-21 13:46:01 +02:00
|
|
|
policies=worker.policy_map,
|
2020-09-03 17:27:05 +02:00
|
|
|
episode=episode,
|
|
|
|
env_index=env_id,
|
|
|
|
)
|
2020-10-01 16:57:10 +02:00
|
|
|
# Horizon hit and we have a soft horizon (no hard env reset).
|
2020-08-21 12:35:16 +02:00
|
|
|
if hit_horizon and soft_horizon:
|
|
|
|
episode.soft_reset()
|
2020-10-01 16:57:10 +02:00
|
|
|
resetted_obs: Dict[AgentID, EnvObsType] = all_agents_obs
|
2020-08-21 12:35:16 +02:00
|
|
|
else:
|
|
|
|
del active_episodes[env_id]
|
|
|
|
resetted_obs: Dict[AgentID, EnvObsType] = base_env.try_reset(
|
|
|
|
env_id)
|
2020-10-01 16:57:10 +02:00
|
|
|
# Reset not supported, drop this env from the ready list.
|
2020-08-21 12:35:16 +02:00
|
|
|
if resetted_obs is None:
|
|
|
|
if horizon != float("inf"):
|
|
|
|
raise ValueError(
|
|
|
|
"Setting episode horizon requires reset() support "
|
|
|
|
"from the environment.")
|
2020-10-01 16:57:10 +02:00
|
|
|
# Creates a new episode if this is not async return.
|
|
|
|
# If reset is async, we will get its result in some future poll.
|
2020-08-21 12:35:16 +02:00
|
|
|
elif resetted_obs != ASYNC_RESET_RETURN:
|
2020-10-01 16:57:10 +02:00
|
|
|
new_episode: MultiAgentEpisode = active_episodes[env_id]
|
2020-08-21 12:35:16 +02:00
|
|
|
if observation_fn:
|
|
|
|
resetted_obs: Dict[AgentID, EnvObsType] = observation_fn(
|
|
|
|
agent_obs=resetted_obs,
|
|
|
|
worker=worker,
|
|
|
|
base_env=base_env,
|
2021-06-21 13:46:01 +02:00
|
|
|
policies=worker.policy_map,
|
2020-10-01 16:57:10 +02:00
|
|
|
episode=new_episode)
|
2021-06-24 22:06:33 -07:00
|
|
|
# types: AgentID, EnvObsType
|
2020-08-21 12:35:16 +02:00
|
|
|
for agent_id, raw_obs in resetted_obs.items():
|
2020-10-01 16:57:10 +02:00
|
|
|
policy_id: PolicyID = new_episode.policy_for(agent_id)
|
2020-08-21 12:35:16 +02:00
|
|
|
prep_obs: EnvObsType = _get_or_raise(
|
2021-06-21 13:46:01 +02:00
|
|
|
worker.preprocessors, policy_id).transform(raw_obs)
|
2020-08-21 12:35:16 +02:00
|
|
|
filtered_obs: EnvObsType = _get_or_raise(
|
2021-06-21 13:46:01 +02:00
|
|
|
worker.filters, policy_id)(prep_obs)
|
2020-10-01 16:57:10 +02:00
|
|
|
new_episode._set_last_observation(agent_id, filtered_obs)
|
2020-08-21 12:35:16 +02:00
|
|
|
|
|
|
|
# Add initial obs to buffer.
|
2020-12-30 20:32:21 -05:00
|
|
|
sample_collector.add_init_obs(
|
2020-12-01 08:21:45 +01:00
|
|
|
new_episode, agent_id, env_id, policy_id,
|
|
|
|
new_episode.length - 1, filtered_obs)
|
2020-10-01 16:57:10 +02:00
|
|
|
|
|
|
|
item = PolicyEvalData(
|
|
|
|
env_id, agent_id, filtered_obs,
|
|
|
|
episode.last_info_for(agent_id) or {},
|
|
|
|
episode.rnn_state_for(agent_id), None, 0.0)
|
|
|
|
to_eval[policy_id].append(item)
|
|
|
|
|
|
|
|
# Try to build something.
|
|
|
|
if multiple_episodes_in_batch:
|
2020-11-19 19:01:14 +01:00
|
|
|
sample_batches = \
|
2020-12-30 20:32:21 -05:00
|
|
|
sample_collector.try_build_truncated_episode_multi_agent_batch()
|
2020-11-19 19:01:14 +01:00
|
|
|
if sample_batches:
|
|
|
|
outputs.extend(sample_batches)
|
2018-11-03 18:48:32 -07:00
|
|
|
|
2018-11-24 18:16:54 -08:00
|
|
|
return active_envs, to_eval, outputs
|
|
|
|
|
|
|
|
|
2020-06-19 13:09:05 -07:00
|
|
|
def _do_policy_eval(
|
2020-08-21 12:35:16 +02:00
|
|
|
*,
|
|
|
|
to_eval: Dict[PolicyID, List[PolicyEvalData]],
|
|
|
|
policies: Dict[PolicyID, Policy],
|
2021-06-21 13:46:01 +02:00
|
|
|
policy_mapping_fn: Callable[[AgentID, "MultiAgentEpisode"], PolicyID],
|
2020-12-30 20:32:21 -05:00
|
|
|
sample_collector,
|
2020-11-03 21:53:34 +01:00
|
|
|
active_episodes: Dict[str, MultiAgentEpisode],
|
2020-08-21 12:35:16 +02:00
|
|
|
) -> Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]:
|
|
|
|
"""Call compute_actions on collected episode/model data to get next action.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
to_eval (Dict[PolicyID, List[PolicyEvalData]]): Mapping of policy
|
|
|
|
IDs to lists of PolicyEvalData objects (items in these lists will
|
|
|
|
be the batch's items for the model forward pass).
|
|
|
|
policies (Dict[PolicyID, Policy]): Mapping from policy ID to Policy
|
|
|
|
obj.
|
2020-12-30 20:32:21 -05:00
|
|
|
sample_collector (SampleCollector): The SampleCollector object to use.
|
2020-08-21 12:35:16 +02:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
eval_results: dict of policy to compute_action() outputs.
|
|
|
|
"""
|
|
|
|
|
|
|
|
eval_results: Dict[PolicyID, TensorStructType] = {}
|
|
|
|
|
|
|
|
if log_once("compute_actions_input"):
|
|
|
|
logger.info("Inputs to compute_actions():\n\n{}\n".format(
|
|
|
|
summarize(to_eval)))
|
|
|
|
|
2020-11-03 21:53:34 +01:00
|
|
|
for policy_id, eval_data in to_eval.items():
|
2021-06-21 13:46:01 +02:00
|
|
|
# In case the policyID has been removed from this worker, we need to
|
|
|
|
# re-assign policy_id and re-lookup the Policy object to use.
|
|
|
|
try:
|
|
|
|
policy: Policy = _get_or_raise(policies, policy_id)
|
|
|
|
except ValueError:
|
|
|
|
policy_id = policy_mapping_fn(eval_data[0].agent_id,
|
|
|
|
active_episodes[eval_data[0].env_id])
|
|
|
|
policy: Policy = _get_or_raise(policies, policy_id)
|
|
|
|
|
2020-12-30 20:32:21 -05:00
|
|
|
input_dict = sample_collector.get_inference_input_dict(policy_id)
|
2020-08-21 12:35:16 +02:00
|
|
|
eval_results[policy_id] = \
|
|
|
|
policy.compute_actions_from_input_dict(
|
2020-11-03 21:53:34 +01:00
|
|
|
input_dict,
|
|
|
|
timestep=policy.global_timestep,
|
|
|
|
episodes=[active_episodes[t.env_id] for t in eval_data])
|
2020-08-21 12:35:16 +02:00
|
|
|
|
2019-03-26 00:27:59 -07:00
|
|
|
if log_once("compute_actions_result"):
|
2019-03-29 12:44:23 -07:00
|
|
|
logger.info("Outputs of compute_actions():\n\n{}\n".format(
|
2019-03-26 00:27:59 -07:00
|
|
|
summarize(eval_results)))
|
|
|
|
|
2018-11-24 18:16:54 -08:00
|
|
|
return eval_results
|
|
|
|
|
|
|
|
|
2020-06-19 13:09:05 -07:00
|
|
|
def _process_policy_eval_results(
|
2020-07-14 04:27:49 +02:00
|
|
|
*,
|
|
|
|
to_eval: Dict[PolicyID, List[PolicyEvalData]],
|
2020-08-07 16:49:49 -07:00
|
|
|
eval_results: Dict[PolicyID, Tuple[TensorStructType, StateBatch,
|
|
|
|
dict]],
|
2020-07-14 04:27:49 +02:00
|
|
|
active_episodes: Dict[str, MultiAgentEpisode],
|
|
|
|
active_envs: Set[int],
|
|
|
|
off_policy_actions: MultiEnvDict,
|
|
|
|
policies: Dict[PolicyID, Policy],
|
2021-06-30 12:32:11 +02:00
|
|
|
normalize_actions: bool,
|
2020-07-14 04:27:49 +02:00
|
|
|
clip_actions: bool,
|
|
|
|
) -> Dict[EnvID, Dict[AgentID, EnvActionType]]:
|
2018-11-24 18:16:54 -08:00
|
|
|
"""Process the output of policy neural network evaluation.
|
|
|
|
|
|
|
|
Records policy evaluation results into the given episode objects and
|
|
|
|
returns replies to send back to agents in the env.
|
|
|
|
|
2020-06-04 22:47:32 +02:00
|
|
|
Args:
|
2020-06-19 13:09:05 -07:00
|
|
|
to_eval (Dict[PolicyID, List[PolicyEvalData]]): Mapping of policy IDs
|
|
|
|
to lists of PolicyEvalData objects.
|
|
|
|
eval_results (Dict[PolicyID, List]): Mapping of policy IDs to list of
|
2020-06-04 22:47:32 +02:00
|
|
|
actions, rnn-out states, extra-action-fetches dicts.
|
2020-06-19 13:09:05 -07:00
|
|
|
active_episodes (Dict[str, MultiAgentEpisode]): Mapping from
|
2020-06-04 22:47:32 +02:00
|
|
|
episode ID to currently ongoing MultiAgentEpisode object.
|
|
|
|
active_envs (Set[int]): Set of non-terminated env ids.
|
|
|
|
off_policy_actions (dict): Doubly keyed dict of env-ids -> agent ids ->
|
|
|
|
off-policy-action, returned by a `BaseEnv.poll()` call.
|
2020-06-19 13:09:05 -07:00
|
|
|
policies (Dict[PolicyID, Policy]): Mapping from policy ID to Policy.
|
2021-06-30 12:32:11 +02:00
|
|
|
normalize_actions (bool): Whether to normalize actions to the action
|
|
|
|
space's bounds.
|
2020-06-04 22:47:32 +02:00
|
|
|
clip_actions (bool): Whether to clip actions to the action space's
|
|
|
|
bounds.
|
|
|
|
|
2018-11-24 18:16:54 -08:00
|
|
|
Returns:
|
2020-07-29 21:15:09 +02:00
|
|
|
actions_to_send: Nested dict of env id -> agent id -> actions to be
|
|
|
|
sent to Env (np.ndarrays).
|
2018-11-24 18:16:54 -08:00
|
|
|
"""
|
|
|
|
|
2020-06-19 13:09:05 -07:00
|
|
|
actions_to_send: Dict[EnvID, Dict[AgentID, EnvActionType]] = \
|
|
|
|
defaultdict(dict)
|
|
|
|
|
2021-06-24 22:06:33 -07:00
|
|
|
# types: int
|
2018-11-24 18:16:54 -08:00
|
|
|
for env_id in active_envs:
|
|
|
|
actions_to_send[env_id] = {} # at minimum send empty dict
|
|
|
|
|
2021-06-24 22:06:33 -07:00
|
|
|
# types: PolicyID, List[PolicyEvalData]
|
2020-10-01 16:57:10 +02:00
|
|
|
for policy_id, eval_data in to_eval.items():
|
2020-06-19 13:09:05 -07:00
|
|
|
actions: TensorStructType = eval_results[policy_id][0]
|
2020-08-21 12:35:16 +02:00
|
|
|
actions = convert_to_numpy(actions)
|
|
|
|
|
2020-06-19 13:09:05 -07:00
|
|
|
rnn_out_cols: StateBatch = eval_results[policy_id][1]
|
|
|
|
pi_info_cols: dict = eval_results[policy_id][2]
|
2020-04-01 09:43:21 +02:00
|
|
|
|
2020-04-28 14:59:16 +02:00
|
|
|
# In case actions is a list (representing the 0th dim of a batch of
|
2021-06-30 12:32:11 +02:00
|
|
|
# primitive actions), try converting it first.
|
2020-04-28 14:59:16 +02:00
|
|
|
if isinstance(actions, list):
|
|
|
|
actions = np.array(actions)
|
|
|
|
|
2020-10-01 16:57:10 +02:00
|
|
|
# Store RNN state ins/outs and extra-action fetches to episode.
|
2021-03-23 17:50:18 +01:00
|
|
|
for f_i, column in enumerate(rnn_out_cols):
|
|
|
|
pi_info_cols["state_out_{}".format(f_i)] = column
|
2020-04-28 14:59:16 +02:00
|
|
|
|
2020-06-19 13:09:05 -07:00
|
|
|
policy: Policy = _get_or_raise(policies, policy_id)
|
2020-04-28 14:59:16 +02:00
|
|
|
# Split action-component batches into single action rows.
|
2020-06-19 13:09:05 -07:00
|
|
|
actions: List[EnvActionType] = unbatch(actions)
|
2021-06-24 22:06:33 -07:00
|
|
|
# types: int, EnvActionType
|
2018-11-24 18:16:54 -08:00
|
|
|
for i, action in enumerate(actions):
|
2021-06-30 12:32:11 +02:00
|
|
|
# Normalize, if necessary.
|
|
|
|
if normalize_actions:
|
|
|
|
action_to_send = unsquash_action(action,
|
|
|
|
policy.action_space_struct)
|
|
|
|
# Clip, if necessary.
|
|
|
|
elif clip_actions:
|
|
|
|
action_to_send = clip_action(action,
|
2020-06-11 19:17:43 +02:00
|
|
|
policy.action_space_struct)
|
|
|
|
else:
|
2021-06-30 12:32:11 +02:00
|
|
|
action_to_send = action
|
2020-08-21 12:35:16 +02:00
|
|
|
|
2020-10-01 16:57:10 +02:00
|
|
|
env_id: int = eval_data[i].env_id
|
|
|
|
agent_id: AgentID = eval_data[i].agent_id
|
|
|
|
episode: MultiAgentEpisode = active_episodes[env_id]
|
|
|
|
episode._set_rnn_state(agent_id, [c[i] for c in rnn_out_cols])
|
|
|
|
episode._set_last_pi_info(
|
|
|
|
agent_id, {k: v[i]
|
|
|
|
for k, v in pi_info_cols.items()})
|
|
|
|
if env_id in off_policy_actions and \
|
|
|
|
agent_id in off_policy_actions[env_id]:
|
|
|
|
episode._set_last_action(agent_id,
|
|
|
|
off_policy_actions[env_id][agent_id])
|
2018-06-25 22:33:57 -07:00
|
|
|
else:
|
2020-10-01 16:57:10 +02:00
|
|
|
episode._set_last_action(agent_id, action)
|
2020-08-21 12:35:16 +02:00
|
|
|
|
|
|
|
assert agent_id not in actions_to_send[env_id]
|
2021-06-30 12:32:11 +02:00
|
|
|
actions_to_send[env_id][agent_id] = action_to_send
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
|
2018-11-24 18:16:54 -08:00
|
|
|
return actions_to_send
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
|
|
|
|
|
2020-06-19 13:09:05 -07:00
|
|
|
def _fetch_atari_metrics(base_env: BaseEnv) -> List[RolloutMetrics]:
|
2018-08-23 17:49:10 -07:00
|
|
|
"""Atari games have multiple logical episodes, one per life.
|
|
|
|
|
2020-06-04 22:47:32 +02:00
|
|
|
However, for metrics reporting we count full episodes, all lives included.
|
2018-08-23 17:49:10 -07:00
|
|
|
"""
|
2019-01-23 21:27:26 -08:00
|
|
|
unwrapped = base_env.get_unwrapped()
|
2018-08-23 17:49:10 -07:00
|
|
|
if not unwrapped:
|
|
|
|
return None
|
|
|
|
atari_out = []
|
|
|
|
for u in unwrapped:
|
|
|
|
monitor = get_wrapper_by_cls(u, MonitorEnv)
|
|
|
|
if not monitor:
|
|
|
|
return None
|
|
|
|
for eps_rew, eps_len in monitor.next_episode_results():
|
2020-01-31 08:02:53 +02:00
|
|
|
atari_out.append(RolloutMetrics(eps_len, eps_rew))
|
2018-08-23 17:49:10 -07:00
|
|
|
return atari_out
|
|
|
|
|
|
|
|
|
2020-06-19 13:09:05 -07:00
|
|
|
def _to_column_format(rnn_state_rows: List[List[Any]]) -> StateBatch:
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
num_cols = len(rnn_state_rows[0])
|
2018-07-19 15:30:36 -07:00
|
|
|
return [[row[i] for row in rnn_state_rows] for i in range(num_cols)]
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
|
|
|
|
|
2020-08-21 12:35:16 +02:00
|
|
|
def _get_or_raise(mapping: Dict[PolicyID, Union[Policy, Preprocessor, Filter]],
|
|
|
|
policy_id: PolicyID) -> Union[Policy, Preprocessor, Filter]:
|
|
|
|
"""Returns an object under key `policy_id` in `mapping`.
|
2020-02-11 00:22:07 +01:00
|
|
|
|
2020-06-04 22:47:32 +02:00
|
|
|
Args:
|
2020-08-21 12:35:16 +02:00
|
|
|
mapping (Dict[PolicyID, Union[Policy, Preprocessor, Filter]]): The
|
|
|
|
mapping dict from policy id (str) to actual object (Policy,
|
|
|
|
Preprocessor, etc.).
|
2020-06-04 22:47:32 +02:00
|
|
|
policy_id (str): The policy ID to lookup.
|
2020-02-11 00:22:07 +01:00
|
|
|
|
|
|
|
Returns:
|
2020-08-21 12:35:16 +02:00
|
|
|
Union[Policy, Preprocessor, Filter]: The found object.
|
2020-06-04 22:47:32 +02:00
|
|
|
|
2020-09-20 11:27:02 +02:00
|
|
|
Raises:
|
2020-08-21 12:35:16 +02:00
|
|
|
ValueError: If `policy_id` cannot be found in `mapping`.
|
2020-02-11 00:22:07 +01:00
|
|
|
"""
|
2018-06-25 22:33:57 -07:00
|
|
|
if policy_id not in mapping:
|
|
|
|
raise ValueError(
|
2021-06-21 13:46:01 +02:00
|
|
|
"Could not find policy for agent: PolicyID `{}` not found "
|
|
|
|
"in policy map, whose keys are `{}`.".format(
|
|
|
|
policy_id, mapping.keys()))
|
2018-06-25 22:33:57 -07:00
|
|
|
return mapping[policy_id]
|