[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
from collections import defaultdict, namedtuple
|
2018-10-21 23:43:57 -07:00
|
|
|
import logging
|
2018-10-20 15:21:22 -07:00
|
|
|
import numpy as np
|
2018-06-09 00:21:35 -07:00
|
|
|
import six.moves.queue as queue
|
|
|
|
import threading
|
2019-03-27 13:24:23 -07:00
|
|
|
import time
|
2017-11-30 00:22:25 -08:00
|
|
|
|
2020-02-27 19:40:44 +01:00
|
|
|
from ray.util.debug import log_once
|
2018-10-20 15:21:22 -07:00
|
|
|
from ray.rllib.evaluation.episode import MultiAgentEpisode, _flatten_action
|
2019-07-18 21:01:16 -07:00
|
|
|
from ray.rllib.evaluation.rollout_metrics import RolloutMetrics
|
2019-01-23 21:27:26 -08:00
|
|
|
from ray.rllib.evaluation.sample_batch_builder import \
|
|
|
|
MultiAgentSampleBatchBuilder
|
2020-02-19 21:18:45 +01:00
|
|
|
from ray.rllib.policy.policy import clip_action
|
2019-05-20 16:46:05 -07:00
|
|
|
from ray.rllib.policy.tf_policy import TFPolicy
|
2019-03-29 21:19:42 +01:00
|
|
|
from ray.rllib.env.base_env import BaseEnv, ASYNC_RESET_RETURN
|
2018-08-23 17:49:10 -07:00
|
|
|
from ray.rllib.env.atari_wrappers import get_wrapper_by_cls, MonitorEnv
|
2019-02-13 16:25:05 -08:00
|
|
|
from ray.rllib.offline import InputReader
|
|
|
|
from ray.rllib.utils.annotations import override
|
2020-02-27 19:40:44 +01:00
|
|
|
from ray.rllib.utils.debug import summarize
|
2020-02-19 21:18:45 +01:00
|
|
|
from ray.rllib.utils.tuple_actions import TupleActions
|
2018-06-25 22:33:57 -07:00
|
|
|
from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
2018-05-16 22:59:46 -07:00
|
|
|
|
2018-10-21 23:43:57 -07:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
2018-12-18 10:40:01 -08:00
|
|
|
PolicyEvalData = namedtuple("PolicyEvalData", [
|
|
|
|
"env_id", "agent_id", "obs", "info", "rnn_state", "prev_action",
|
|
|
|
"prev_reward"
|
|
|
|
])
|
2017-11-30 00:22:25 -08:00
|
|
|
|
|
|
|
|
2020-01-02 17:42:13 -08:00
|
|
|
class PerfStats:
|
2019-03-27 13:24:23 -07:00
|
|
|
"""Sampler perf stats that will be included in rollout metrics."""
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
self.iters = 0
|
|
|
|
self.env_wait_time = 0.0
|
|
|
|
self.processing_time = 0.0
|
|
|
|
self.inference_time = 0.0
|
|
|
|
|
|
|
|
def get(self):
|
|
|
|
return {
|
|
|
|
"mean_env_wait_ms": self.env_wait_time * 1000 / self.iters,
|
|
|
|
"mean_processing_ms": self.processing_time * 1000 / self.iters,
|
|
|
|
"mean_inference_ms": self.inference_time * 1000 / self.iters
|
|
|
|
}
|
|
|
|
|
|
|
|
|
2019-02-13 16:25:05 -08:00
|
|
|
class SamplerInput(InputReader):
|
|
|
|
"""Reads input experiences from an existing sampler."""
|
|
|
|
|
|
|
|
@override(InputReader)
|
|
|
|
def next(self):
|
|
|
|
batches = [self.get_data()]
|
|
|
|
batches.extend(self.get_extra_batches())
|
|
|
|
if len(batches) > 1:
|
|
|
|
return batches[0].concat_samples(batches)
|
|
|
|
else:
|
|
|
|
return batches[0]
|
|
|
|
|
|
|
|
|
|
|
|
class SyncSampler(SamplerInput):
|
2018-07-19 15:30:36 -07:00
|
|
|
def __init__(self,
|
|
|
|
env,
|
|
|
|
policies,
|
|
|
|
policy_mapping_fn,
|
2018-12-05 23:31:45 -08:00
|
|
|
preprocessors,
|
2018-07-19 15:30:36 -07:00
|
|
|
obs_filters,
|
2018-08-20 15:28:03 -07:00
|
|
|
clip_rewards,
|
2018-09-30 18:36:22 -07:00
|
|
|
unroll_length,
|
2018-11-03 18:48:32 -07:00
|
|
|
callbacks,
|
2018-07-19 15:30:36 -07:00
|
|
|
horizon=None,
|
|
|
|
pack=False,
|
2018-12-03 19:55:25 -08:00
|
|
|
tf_sess=None,
|
2019-04-02 02:44:15 -07:00
|
|
|
clip_actions=True,
|
2019-08-01 23:37:36 -07:00
|
|
|
soft_horizon=False,
|
|
|
|
no_done_at_end=False):
|
2019-01-23 21:27:26 -08:00
|
|
|
self.base_env = BaseEnv.to_base_env(env)
|
2018-09-30 18:36:22 -07:00
|
|
|
self.unroll_length = unroll_length
|
2017-12-14 01:08:23 -08:00
|
|
|
self.horizon = horizon
|
2018-06-23 18:32:16 -07:00
|
|
|
self.policies = policies
|
|
|
|
self.policy_mapping_fn = policy_mapping_fn
|
2018-12-05 23:31:45 -08:00
|
|
|
self.preprocessors = preprocessors
|
|
|
|
self.obs_filters = obs_filters
|
2018-08-16 14:37:21 -07:00
|
|
|
self.extra_batches = queue.Queue()
|
2019-03-27 13:24:23 -07:00
|
|
|
self.perf_stats = PerfStats()
|
2018-06-23 18:32:16 -07:00
|
|
|
self.rollout_provider = _env_runner(
|
2019-01-23 21:27:26 -08:00
|
|
|
self.base_env, self.extra_batches.put, self.policies,
|
2018-09-30 18:36:22 -07:00
|
|
|
self.policy_mapping_fn, self.unroll_length, self.horizon,
|
2018-12-05 23:31:45 -08:00
|
|
|
self.preprocessors, self.obs_filters, clip_rewards, clip_actions,
|
2019-08-01 23:37:36 -07:00
|
|
|
pack, callbacks, tf_sess, self.perf_stats, soft_horizon,
|
|
|
|
no_done_at_end)
|
2017-11-30 00:22:25 -08:00
|
|
|
self.metrics_queue = queue.Queue()
|
|
|
|
|
|
|
|
def get_data(self):
|
|
|
|
while True:
|
|
|
|
item = next(self.rollout_provider)
|
2018-06-23 18:32:16 -07:00
|
|
|
if isinstance(item, RolloutMetrics):
|
2017-11-30 00:22:25 -08:00
|
|
|
self.metrics_queue.put(item)
|
|
|
|
else:
|
2017-12-14 01:08:23 -08:00
|
|
|
return item
|
2017-11-30 00:22:25 -08:00
|
|
|
|
|
|
|
def get_metrics(self):
|
|
|
|
completed = []
|
|
|
|
while True:
|
|
|
|
try:
|
2019-03-27 13:24:23 -07:00
|
|
|
completed.append(self.metrics_queue.get_nowait()._replace(
|
|
|
|
perf_stats=self.perf_stats.get()))
|
2017-11-30 00:22:25 -08:00
|
|
|
except queue.Empty:
|
|
|
|
break
|
|
|
|
return completed
|
|
|
|
|
2018-08-16 14:37:21 -07:00
|
|
|
def get_extra_batches(self):
|
|
|
|
extra = []
|
|
|
|
while True:
|
|
|
|
try:
|
|
|
|
extra.append(self.extra_batches.get_nowait())
|
|
|
|
except queue.Empty:
|
|
|
|
break
|
|
|
|
return extra
|
|
|
|
|
2017-11-30 00:22:25 -08:00
|
|
|
|
2019-02-13 16:25:05 -08:00
|
|
|
class AsyncSampler(threading.Thread, SamplerInput):
|
2018-07-19 15:30:36 -07:00
|
|
|
def __init__(self,
|
|
|
|
env,
|
|
|
|
policies,
|
|
|
|
policy_mapping_fn,
|
2018-12-05 23:31:45 -08:00
|
|
|
preprocessors,
|
2018-07-19 15:30:36 -07:00
|
|
|
obs_filters,
|
2018-08-20 15:28:03 -07:00
|
|
|
clip_rewards,
|
2018-09-30 18:36:22 -07:00
|
|
|
unroll_length,
|
2018-11-03 18:48:32 -07:00
|
|
|
callbacks,
|
2018-07-19 15:30:36 -07:00
|
|
|
horizon=None,
|
|
|
|
pack=False,
|
2018-12-03 19:55:25 -08:00
|
|
|
tf_sess=None,
|
2018-12-12 13:57:48 -08:00
|
|
|
clip_actions=True,
|
2019-04-02 02:44:15 -07:00
|
|
|
blackhole_outputs=False,
|
2019-08-01 23:37:36 -07:00
|
|
|
soft_horizon=False,
|
|
|
|
no_done_at_end=False):
|
2018-06-23 18:32:16 -07:00
|
|
|
for _, f in obs_filters.items():
|
|
|
|
assert getattr(f, "is_concurrent", False), \
|
|
|
|
"Observation Filter must support concurrent updates."
|
2019-01-23 21:27:26 -08:00
|
|
|
self.base_env = BaseEnv.to_base_env(env)
|
2017-11-30 00:22:25 -08:00
|
|
|
threading.Thread.__init__(self)
|
|
|
|
self.queue = queue.Queue(5)
|
2018-08-16 14:37:21 -07:00
|
|
|
self.extra_batches = queue.Queue()
|
2017-11-30 00:22:25 -08:00
|
|
|
self.metrics_queue = queue.Queue()
|
2018-09-30 18:36:22 -07:00
|
|
|
self.unroll_length = unroll_length
|
2017-12-14 01:08:23 -08:00
|
|
|
self.horizon = horizon
|
2018-06-23 18:32:16 -07:00
|
|
|
self.policies = policies
|
|
|
|
self.policy_mapping_fn = policy_mapping_fn
|
2018-12-05 23:31:45 -08:00
|
|
|
self.preprocessors = preprocessors
|
|
|
|
self.obs_filters = obs_filters
|
2018-08-20 15:28:03 -07:00
|
|
|
self.clip_rewards = clip_rewards
|
2018-01-23 10:31:19 -08:00
|
|
|
self.daemon = True
|
2018-06-09 00:21:35 -07:00
|
|
|
self.pack = pack
|
2018-06-25 22:33:57 -07:00
|
|
|
self.tf_sess = tf_sess
|
2018-11-03 18:48:32 -07:00
|
|
|
self.callbacks = callbacks
|
2018-12-03 19:55:25 -08:00
|
|
|
self.clip_actions = clip_actions
|
2018-12-12 13:57:48 -08:00
|
|
|
self.blackhole_outputs = blackhole_outputs
|
2019-04-02 02:44:15 -07:00
|
|
|
self.soft_horizon = soft_horizon
|
2019-08-01 23:37:36 -07:00
|
|
|
self.no_done_at_end = no_done_at_end
|
2019-03-27 13:24:23 -07:00
|
|
|
self.perf_stats = PerfStats()
|
2018-12-12 13:57:48 -08:00
|
|
|
self.shutdown = False
|
2017-11-30 00:22:25 -08:00
|
|
|
|
|
|
|
def run(self):
|
|
|
|
try:
|
|
|
|
self._run()
|
|
|
|
except BaseException as e:
|
|
|
|
self.queue.put(e)
|
|
|
|
raise e
|
|
|
|
|
|
|
|
def _run(self):
|
2018-12-12 13:57:48 -08:00
|
|
|
if self.blackhole_outputs:
|
|
|
|
queue_putter = (lambda x: None)
|
|
|
|
extra_batches_putter = (lambda x: None)
|
|
|
|
else:
|
|
|
|
queue_putter = self.queue.put
|
|
|
|
extra_batches_putter = (
|
|
|
|
lambda x: self.extra_batches.put(x, timeout=600.0))
|
2018-06-23 18:32:16 -07:00
|
|
|
rollout_provider = _env_runner(
|
2019-01-23 21:27:26 -08:00
|
|
|
self.base_env, extra_batches_putter, self.policies,
|
2018-09-30 18:36:22 -07:00
|
|
|
self.policy_mapping_fn, self.unroll_length, self.horizon,
|
2018-12-05 23:31:45 -08:00
|
|
|
self.preprocessors, self.obs_filters, self.clip_rewards,
|
2019-03-27 13:24:23 -07:00
|
|
|
self.clip_actions, self.pack, self.callbacks, self.tf_sess,
|
2019-08-01 23:37:36 -07:00
|
|
|
self.perf_stats, self.soft_horizon, self.no_done_at_end)
|
2018-12-12 13:57:48 -08:00
|
|
|
while not self.shutdown:
|
2017-11-30 00:22:25 -08:00
|
|
|
# The timeout variable exists because apparently, if one worker
|
|
|
|
# dies, the other workers won't die with it, unless the timeout is
|
|
|
|
# set to some large number. This is an empirical observation.
|
|
|
|
item = next(rollout_provider)
|
2018-06-23 18:32:16 -07:00
|
|
|
if isinstance(item, RolloutMetrics):
|
2017-11-30 00:22:25 -08:00
|
|
|
self.metrics_queue.put(item)
|
|
|
|
else:
|
2018-12-12 13:57:48 -08:00
|
|
|
queue_putter(item)
|
2017-11-30 00:22:25 -08:00
|
|
|
|
2017-12-14 01:08:23 -08:00
|
|
|
def get_data(self):
|
2019-03-16 13:34:09 -07:00
|
|
|
if not self.is_alive():
|
|
|
|
raise RuntimeError("Sampling thread has died")
|
2017-11-30 00:22:25 -08:00
|
|
|
rollout = self.queue.get(timeout=600.0)
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
|
|
|
|
# Propagate errors
|
2017-11-30 00:22:25 -08:00
|
|
|
if isinstance(rollout, BaseException):
|
|
|
|
raise rollout
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
|
2017-11-30 00:22:25 -08:00
|
|
|
return rollout
|
|
|
|
|
|
|
|
def get_metrics(self):
|
|
|
|
completed = []
|
|
|
|
while True:
|
|
|
|
try:
|
2019-03-27 13:24:23 -07:00
|
|
|
completed.append(self.metrics_queue.get_nowait()._replace(
|
|
|
|
perf_stats=self.perf_stats.get()))
|
2017-11-30 00:22:25 -08:00
|
|
|
except queue.Empty:
|
|
|
|
break
|
|
|
|
return completed
|
|
|
|
|
2018-08-16 14:37:21 -07:00
|
|
|
def get_extra_batches(self):
|
|
|
|
extra = []
|
|
|
|
while True:
|
|
|
|
try:
|
|
|
|
extra.append(self.extra_batches.get_nowait())
|
|
|
|
except queue.Empty:
|
|
|
|
break
|
|
|
|
return extra
|
|
|
|
|
2017-11-30 00:22:25 -08:00
|
|
|
|
2019-03-27 13:24:23 -07:00
|
|
|
def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
|
|
|
|
unroll_length, horizon, preprocessors, obs_filters,
|
|
|
|
clip_rewards, clip_actions, pack, callbacks, tf_sess,
|
2019-08-01 23:37:36 -07:00
|
|
|
perf_stats, soft_horizon, no_done_at_end):
|
2018-06-23 18:32:16 -07:00
|
|
|
"""This implements the common experience collection logic.
|
2017-11-30 00:22:25 -08:00
|
|
|
|
|
|
|
Args:
|
2019-01-23 21:27:26 -08:00
|
|
|
base_env (BaseEnv): env implementing BaseEnv.
|
2018-08-16 14:37:21 -07:00
|
|
|
extra_batch_callback (fn): function to send extra batch data to.
|
2019-05-20 16:46:05 -07:00
|
|
|
policies (dict): Map of policy ids to Policy instances.
|
2018-06-23 18:32:16 -07:00
|
|
|
policy_mapping_fn (func): Function that maps agent ids to policy ids.
|
|
|
|
This is called when an agent first enters the environment. The
|
|
|
|
agent is then "bound" to the returned policy for the episode.
|
2018-09-30 18:36:22 -07:00
|
|
|
unroll_length (int): Number of episode steps before `SampleBatch` is
|
2018-06-23 18:32:16 -07:00
|
|
|
yielded. Set to infinity to yield complete episodes.
|
|
|
|
horizon (int): Horizon of the episode.
|
2018-12-05 23:31:45 -08:00
|
|
|
preprocessors (dict): Map of policy id to preprocessor for the
|
|
|
|
observations prior to filtering.
|
2018-06-23 18:32:16 -07:00
|
|
|
obs_filters (dict): Map of policy id to filter used to process
|
|
|
|
observations for the policy.
|
2018-08-20 15:28:03 -07:00
|
|
|
clip_rewards (bool): Whether to clip rewards before postprocessing.
|
2018-06-23 18:32:16 -07:00
|
|
|
pack (bool): Whether to pack multiple episodes into each batch. This
|
2018-09-30 18:36:22 -07:00
|
|
|
guarantees batches will be exactly `unroll_length` in size.
|
2018-12-03 19:55:25 -08:00
|
|
|
clip_actions (bool): Whether to clip actions to the space range.
|
2018-11-03 18:48:32 -07:00
|
|
|
callbacks (dict): User callbacks to run on episode events.
|
2018-06-25 22:33:57 -07:00
|
|
|
tf_sess (Session|None): Optional tensorflow session to use for batching
|
|
|
|
TF policy evaluations.
|
2019-03-27 13:24:23 -07:00
|
|
|
perf_stats (PerfStats): Record perf stats into this object.
|
2019-04-02 02:44:15 -07:00
|
|
|
soft_horizon (bool): Calculate rewards but don't reset the
|
|
|
|
environment when the horizon is hit.
|
2019-08-01 23:37:36 -07:00
|
|
|
no_done_at_end (bool): Ignore the done=True at the end of the episode
|
|
|
|
and instead record done=False.
|
2017-11-30 00:22:25 -08:00
|
|
|
|
|
|
|
Yields:
|
2018-06-09 00:21:35 -07:00
|
|
|
rollout (SampleBatch): Object containing state, action, reward,
|
2017-11-30 00:22:25 -08:00
|
|
|
terminal condition, and other fields as dictated by `policy`.
|
|
|
|
"""
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
|
2020-03-12 19:03:37 +01:00
|
|
|
# Try to get Env's max_episode_steps prop. If it doesn't exist, catch
|
|
|
|
# error and continue.
|
|
|
|
max_episode_steps = None
|
2018-01-05 21:32:41 -08:00
|
|
|
try:
|
2020-03-12 19:03:37 +01:00
|
|
|
max_episode_steps = base_env.get_unwrapped()[0].spec.max_episode_steps
|
2018-01-05 21:32:41 -08:00
|
|
|
except Exception:
|
2020-03-12 19:03:37 +01:00
|
|
|
pass
|
|
|
|
|
|
|
|
# Trainer has a given `horizon` setting.
|
|
|
|
if horizon:
|
|
|
|
# `horizon` is larger than env's limit -> Error and explain how
|
|
|
|
# to increase Env's own episode limit.
|
|
|
|
if max_episode_steps and horizon > max_episode_steps:
|
|
|
|
raise ValueError(
|
|
|
|
"Your `horizon` setting ({}) is larger than the Env's own "
|
|
|
|
"timestep limit ({})! Try to increase the Env's limit via "
|
|
|
|
"setting its `spec.max_episode_steps` property.".format(
|
|
|
|
horizon, max_episode_steps))
|
|
|
|
# Otherwise, set Trainer's horizon to env's max-steps.
|
|
|
|
elif max_episode_steps:
|
|
|
|
horizon = max_episode_steps
|
|
|
|
logger.debug(
|
|
|
|
"No episode horizon specified, setting it to Env's limit ({}).".
|
|
|
|
format(max_episode_steps))
|
|
|
|
else:
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
horizon = float("inf")
|
2020-03-12 19:03:37 +01:00
|
|
|
logger.debug("No episode horizon specified, assuming inf.")
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
|
|
|
|
# Pool of batch builders, which can be shared across episodes to pack
|
|
|
|
# trajectory data.
|
|
|
|
batch_builder_pool = []
|
|
|
|
|
|
|
|
def get_batch_builder():
|
|
|
|
if batch_builder_pool:
|
|
|
|
return batch_builder_pool.pop()
|
|
|
|
else:
|
2019-04-07 00:36:18 -07:00
|
|
|
return MultiAgentSampleBatchBuilder(
|
|
|
|
policies, clip_rewards, callbacks.get("on_postprocess_traj"))
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
|
2018-07-19 15:30:36 -07:00
|
|
|
def new_episode():
|
2018-11-03 18:48:32 -07:00
|
|
|
episode = MultiAgentEpisode(policies, policy_mapping_fn,
|
|
|
|
get_batch_builder, extra_batch_callback)
|
|
|
|
if callbacks.get("on_episode_start"):
|
|
|
|
callbacks["on_episode_start"]({
|
2019-01-23 21:27:26 -08:00
|
|
|
"env": base_env,
|
2019-02-21 14:35:18 +08:00
|
|
|
"policy": policies,
|
|
|
|
"episode": episode,
|
2018-11-03 18:48:32 -07:00
|
|
|
})
|
|
|
|
return episode
|
2018-07-19 15:30:36 -07:00
|
|
|
|
|
|
|
active_episodes = defaultdict(new_episode)
|
2017-11-30 00:22:25 -08:00
|
|
|
|
|
|
|
while True:
|
2019-03-27 13:24:23 -07:00
|
|
|
perf_stats.iters += 1
|
|
|
|
t0 = time.time()
|
2018-06-23 18:32:16 -07:00
|
|
|
# Get observations from all ready agents
|
2018-06-20 13:22:39 -07:00
|
|
|
unfiltered_obs, rewards, dones, infos, off_policy_actions = \
|
2019-01-23 21:27:26 -08:00
|
|
|
base_env.poll()
|
2019-03-27 13:24:23 -07:00
|
|
|
perf_stats.env_wait_time += time.time() - t0
|
2018-06-23 18:32:16 -07:00
|
|
|
|
2019-03-26 00:27:59 -07:00
|
|
|
if log_once("env_returns"):
|
|
|
|
logger.info("Raw obs from env: {}".format(
|
|
|
|
summarize(unfiltered_obs)))
|
|
|
|
logger.info("Info return from env: {}".format(summarize(infos)))
|
|
|
|
|
2018-11-24 18:16:54 -08:00
|
|
|
# Process observations and prepare for policy evaluation
|
2019-03-27 13:24:23 -07:00
|
|
|
t1 = time.time()
|
2018-11-24 18:16:54 -08:00
|
|
|
active_envs, to_eval, outputs = _process_observations(
|
2019-01-23 21:27:26 -08:00
|
|
|
base_env, policies, batch_builder_pool, active_episodes,
|
2018-11-24 18:16:54 -08:00
|
|
|
unfiltered_obs, rewards, dones, infos, off_policy_actions, horizon,
|
2019-04-02 02:44:15 -07:00
|
|
|
preprocessors, obs_filters, unroll_length, pack, callbacks,
|
2019-08-01 23:37:36 -07:00
|
|
|
soft_horizon, no_done_at_end)
|
2019-03-27 13:24:23 -07:00
|
|
|
perf_stats.processing_time += time.time() - t1
|
2018-11-24 18:16:54 -08:00
|
|
|
for o in outputs:
|
|
|
|
yield o
|
|
|
|
|
|
|
|
# Do batched policy eval
|
2019-03-27 13:24:23 -07:00
|
|
|
t2 = time.time()
|
2018-11-24 18:16:54 -08:00
|
|
|
eval_results = _do_policy_eval(tf_sess, to_eval, policies,
|
2018-12-09 21:57:11 -08:00
|
|
|
active_episodes)
|
2019-03-27 13:24:23 -07:00
|
|
|
perf_stats.inference_time += time.time() - t2
|
2018-11-24 18:16:54 -08:00
|
|
|
|
|
|
|
# Process results and update episode state
|
2019-03-27 13:24:23 -07:00
|
|
|
t3 = time.time()
|
2018-11-24 18:16:54 -08:00
|
|
|
actions_to_send = _process_policy_eval_results(
|
|
|
|
to_eval, eval_results, active_episodes, active_envs,
|
2018-12-09 21:57:11 -08:00
|
|
|
off_policy_actions, policies, clip_actions)
|
2019-03-27 13:24:23 -07:00
|
|
|
perf_stats.processing_time += time.time() - t3
|
2018-06-23 18:32:16 -07:00
|
|
|
|
2018-11-24 18:16:54 -08:00
|
|
|
# Return computed actions to ready envs. We also send to envs that have
|
|
|
|
# taken off-policy actions; those envs are free to ignore the action.
|
2019-03-27 13:24:23 -07:00
|
|
|
t4 = time.time()
|
2019-01-23 21:27:26 -08:00
|
|
|
base_env.send_actions(actions_to_send)
|
2019-03-27 13:24:23 -07:00
|
|
|
perf_stats.env_wait_time += time.time() - t4
|
2018-06-25 22:33:57 -07:00
|
|
|
|
2018-11-24 18:16:54 -08:00
|
|
|
|
2019-01-23 21:27:26 -08:00
|
|
|
def _process_observations(base_env, policies, batch_builder_pool,
|
2018-11-24 18:16:54 -08:00
|
|
|
active_episodes, unfiltered_obs, rewards, dones,
|
2018-12-05 23:31:45 -08:00
|
|
|
infos, off_policy_actions, horizon, preprocessors,
|
2019-04-02 02:44:15 -07:00
|
|
|
obs_filters, unroll_length, pack, callbacks,
|
2019-08-01 23:37:36 -07:00
|
|
|
soft_horizon, no_done_at_end):
|
2018-11-24 18:16:54 -08:00
|
|
|
"""Record new data from the environment and prepare for policy evaluation.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
active_envs: set of non-terminated env ids
|
|
|
|
to_eval: map of policy_id to list of agent PolicyEvalData
|
|
|
|
outputs: list of metrics and samples to return from the sampler
|
|
|
|
"""
|
|
|
|
|
|
|
|
active_envs = set()
|
|
|
|
to_eval = defaultdict(list)
|
|
|
|
outputs = []
|
2020-03-04 12:58:34 -08:00
|
|
|
large_batch_threshold = max(1000, unroll_length * 10) if \
|
|
|
|
unroll_length != float("inf") else 5000
|
2018-11-24 18:16:54 -08:00
|
|
|
|
|
|
|
# For each environment
|
|
|
|
for env_id, agent_obs in unfiltered_obs.items():
|
|
|
|
new_episode = env_id not in active_episodes
|
|
|
|
episode = active_episodes[env_id]
|
|
|
|
if not new_episode:
|
|
|
|
episode.length += 1
|
|
|
|
episode.batch_builder.count += 1
|
|
|
|
episode._add_agent_rewards(rewards[env_id])
|
|
|
|
|
2020-03-04 12:58:34 -08:00
|
|
|
if (episode.batch_builder.total() > large_batch_threshold
|
2019-03-26 00:27:59 -07:00
|
|
|
and log_once("large_batch_warning")):
|
2018-12-18 17:04:51 -08:00
|
|
|
logger.warning(
|
2018-12-05 23:31:45 -08:00
|
|
|
"More than {} observations for {} env steps ".format(
|
|
|
|
episode.batch_builder.total(),
|
|
|
|
episode.batch_builder.count) + "are buffered in "
|
2018-12-12 13:57:48 -08:00
|
|
|
"the sampler. If this is more than you expected, check that "
|
2020-03-04 12:58:34 -08:00
|
|
|
"that you set a horizon on your environment correctly and that"
|
|
|
|
" it terminates at some point. "
|
|
|
|
"Note: In multi-agent environments, `sample_batch_size` sets "
|
2018-12-12 13:57:48 -08:00
|
|
|
"the batch size based on environment steps, not the steps of "
|
|
|
|
"individual agents, which can result in unexpectedly large "
|
2020-03-04 12:58:34 -08:00
|
|
|
"batches. Also, you may be in evaluation waiting for your Env "
|
|
|
|
"to terminate (batch_mode=`complete_episodes`). Make sure it "
|
|
|
|
"does at some point.")
|
2018-12-05 23:31:45 -08:00
|
|
|
|
2018-11-24 18:16:54 -08:00
|
|
|
# Check episode termination conditions
|
|
|
|
if dones[env_id]["__all__"] or episode.length >= horizon:
|
2019-04-02 02:44:15 -07:00
|
|
|
hit_horizon = (episode.length >= horizon
|
|
|
|
and not dones[env_id]["__all__"])
|
2018-11-24 18:16:54 -08:00
|
|
|
all_done = True
|
2019-01-23 21:27:26 -08:00
|
|
|
atari_metrics = _fetch_atari_metrics(base_env)
|
2018-11-24 18:16:54 -08:00
|
|
|
if atari_metrics is not None:
|
|
|
|
for m in atari_metrics:
|
|
|
|
outputs.append(
|
|
|
|
m._replace(custom_metrics=episode.custom_metrics))
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
else:
|
2018-11-24 18:16:54 -08:00
|
|
|
outputs.append(
|
|
|
|
RolloutMetrics(episode.length, episode.total_reward,
|
|
|
|
dict(episode.agent_rewards),
|
2020-01-31 08:02:53 +02:00
|
|
|
episode.custom_metrics, {},
|
|
|
|
episode.hist_data))
|
2018-11-24 18:16:54 -08:00
|
|
|
else:
|
2019-04-02 02:44:15 -07:00
|
|
|
hit_horizon = False
|
2018-11-24 18:16:54 -08:00
|
|
|
all_done = False
|
|
|
|
active_envs.add(env_id)
|
|
|
|
|
2020-03-04 12:58:34 -08:00
|
|
|
# For each agent in the environment.
|
2018-11-24 18:16:54 -08:00
|
|
|
for agent_id, raw_obs in agent_obs.items():
|
|
|
|
policy_id = episode.policy_for(agent_id)
|
2018-12-05 23:31:45 -08:00
|
|
|
prep_obs = _get_or_raise(preprocessors,
|
|
|
|
policy_id).transform(raw_obs)
|
2019-03-26 00:27:59 -07:00
|
|
|
if log_once("prep_obs"):
|
|
|
|
logger.info("Preprocessed obs: {}".format(summarize(prep_obs)))
|
|
|
|
|
2018-12-05 23:31:45 -08:00
|
|
|
filtered_obs = _get_or_raise(obs_filters, policy_id)(prep_obs)
|
2019-03-26 00:27:59 -07:00
|
|
|
if log_once("filtered_obs"):
|
|
|
|
logger.info("Filtered obs: {}".format(summarize(filtered_obs)))
|
|
|
|
|
2018-11-24 18:16:54 -08:00
|
|
|
agent_done = bool(all_done or dones[env_id].get(agent_id))
|
|
|
|
if not agent_done:
|
|
|
|
to_eval[policy_id].append(
|
|
|
|
PolicyEvalData(env_id, agent_id, filtered_obs,
|
2018-12-18 10:40:01 -08:00
|
|
|
infos[env_id].get(agent_id, {}),
|
2018-11-24 18:16:54 -08:00
|
|
|
episode.rnn_state_for(agent_id),
|
|
|
|
episode.last_action_for(agent_id),
|
|
|
|
rewards[env_id][agent_id] or 0.0))
|
|
|
|
|
|
|
|
last_observation = episode.last_observation_for(agent_id)
|
|
|
|
episode._set_last_observation(agent_id, filtered_obs)
|
2019-03-06 10:21:05 -08:00
|
|
|
episode._set_last_raw_obs(agent_id, raw_obs)
|
2018-12-18 10:40:01 -08:00
|
|
|
episode._set_last_info(agent_id, infos[env_id].get(agent_id, {}))
|
2018-11-24 18:16:54 -08:00
|
|
|
|
|
|
|
# Record transition info if applicable
|
2018-12-18 10:40:01 -08:00
|
|
|
if (last_observation is not None and infos[env_id].get(
|
|
|
|
agent_id, {}).get("training_enabled", True)):
|
2018-11-24 18:16:54 -08:00
|
|
|
episode.batch_builder.add_values(
|
|
|
|
agent_id,
|
|
|
|
policy_id,
|
|
|
|
t=episode.length - 1,
|
|
|
|
eps_id=episode.episode_id,
|
|
|
|
agent_index=episode._agent_index(agent_id),
|
|
|
|
obs=last_observation,
|
|
|
|
actions=episode.last_action_for(agent_id),
|
|
|
|
rewards=rewards[env_id][agent_id],
|
|
|
|
prev_actions=episode.prev_action_for(agent_id),
|
|
|
|
prev_rewards=episode.prev_reward_for(agent_id),
|
2019-08-01 23:37:36 -07:00
|
|
|
dones=(False if (no_done_at_end
|
|
|
|
or (hit_horizon and soft_horizon)) else
|
|
|
|
agent_done),
|
2018-12-18 10:40:01 -08:00
|
|
|
infos=infos[env_id].get(agent_id, {}),
|
2018-11-24 18:16:54 -08:00
|
|
|
new_obs=filtered_obs,
|
|
|
|
**episode.last_pi_info_for(agent_id))
|
|
|
|
|
|
|
|
# Invoke the step callback after the step is logged to the episode
|
|
|
|
if callbacks.get("on_episode_step"):
|
2019-01-23 21:27:26 -08:00
|
|
|
callbacks["on_episode_step"]({"env": base_env, "episode": episode})
|
2018-11-24 18:16:54 -08:00
|
|
|
|
|
|
|
# Cut the batch if we're not packing multiple episodes into one,
|
|
|
|
# or if we've exceeded the requested batch size.
|
2020-03-04 12:58:34 -08:00
|
|
|
if episode.batch_builder.has_pending_agent_data():
|
2019-08-01 23:37:36 -07:00
|
|
|
if dones[env_id]["__all__"] and not no_done_at_end:
|
2019-02-23 21:23:40 -08:00
|
|
|
episode.batch_builder.check_missing_dones()
|
2018-11-24 18:16:54 -08:00
|
|
|
if (all_done and not pack) or \
|
|
|
|
episode.batch_builder.count >= unroll_length:
|
|
|
|
outputs.append(episode.batch_builder.build_and_reset(episode))
|
|
|
|
elif all_done:
|
|
|
|
# Make sure postprocessor stays within one episode
|
|
|
|
episode.batch_builder.postprocess_batch_so_far(episode)
|
|
|
|
|
|
|
|
if all_done:
|
|
|
|
# Handle episode termination
|
|
|
|
batch_builder_pool.append(episode.batch_builder)
|
|
|
|
if callbacks.get("on_episode_end"):
|
|
|
|
callbacks["on_episode_end"]({
|
2019-01-23 21:27:26 -08:00
|
|
|
"env": base_env,
|
2019-02-21 14:35:18 +08:00
|
|
|
"policy": policies,
|
2018-11-03 18:48:32 -07:00
|
|
|
"episode": episode
|
|
|
|
})
|
2019-04-02 02:44:15 -07:00
|
|
|
if hit_horizon and soft_horizon:
|
|
|
|
episode.soft_reset()
|
|
|
|
resetted_obs = agent_obs
|
|
|
|
else:
|
|
|
|
del active_episodes[env_id]
|
|
|
|
resetted_obs = base_env.try_reset(env_id)
|
2018-11-24 18:16:54 -08:00
|
|
|
if resetted_obs is None:
|
|
|
|
# Reset not supported, drop this env from the ready list
|
|
|
|
if horizon != float("inf"):
|
|
|
|
raise ValueError(
|
|
|
|
"Setting episode horizon requires reset() support "
|
|
|
|
"from the environment.")
|
2019-03-29 21:19:42 +01:00
|
|
|
elif resetted_obs != ASYNC_RESET_RETURN:
|
|
|
|
# Creates a new episode if this is not async return
|
|
|
|
# If reset is async, we will get its result in some future poll
|
2018-11-24 18:16:54 -08:00
|
|
|
episode = active_episodes[env_id]
|
|
|
|
for agent_id, raw_obs in resetted_obs.items():
|
|
|
|
policy_id = episode.policy_for(agent_id)
|
|
|
|
policy = _get_or_raise(policies, policy_id)
|
2018-12-05 23:31:45 -08:00
|
|
|
prep_obs = _get_or_raise(preprocessors,
|
|
|
|
policy_id).transform(raw_obs)
|
2018-11-24 18:16:54 -08:00
|
|
|
filtered_obs = _get_or_raise(obs_filters,
|
2018-12-05 23:31:45 -08:00
|
|
|
policy_id)(prep_obs)
|
2018-11-24 18:16:54 -08:00
|
|
|
episode._set_last_observation(agent_id, filtered_obs)
|
|
|
|
to_eval[policy_id].append(
|
|
|
|
PolicyEvalData(
|
|
|
|
env_id, agent_id, filtered_obs,
|
2018-12-18 10:40:01 -08:00
|
|
|
episode.last_info_for(agent_id) or {},
|
2018-11-24 18:16:54 -08:00
|
|
|
episode.rnn_state_for(agent_id),
|
|
|
|
np.zeros_like(
|
|
|
|
_flatten_action(policy.action_space.sample())),
|
|
|
|
0.0))
|
2018-11-03 18:48:32 -07:00
|
|
|
|
2018-11-24 18:16:54 -08:00
|
|
|
return active_envs, to_eval, outputs
|
|
|
|
|
|
|
|
|
2018-12-09 21:57:11 -08:00
|
|
|
def _do_policy_eval(tf_sess, to_eval, policies, active_episodes):
|
2018-11-24 18:16:54 -08:00
|
|
|
"""Call compute actions on observation batches to get next actions.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
eval_results: dict of policy to compute_action() outputs.
|
|
|
|
"""
|
|
|
|
|
|
|
|
eval_results = {}
|
|
|
|
|
|
|
|
if tf_sess:
|
|
|
|
builder = TFRunBuilder(tf_sess, "policy_eval")
|
|
|
|
pending_fetches = {}
|
|
|
|
else:
|
|
|
|
builder = None
|
2019-03-26 00:27:59 -07:00
|
|
|
|
|
|
|
if log_once("compute_actions_input"):
|
2019-03-29 12:44:23 -07:00
|
|
|
logger.info("Inputs to compute_actions():\n\n{}\n".format(
|
2019-03-26 00:27:59 -07:00
|
|
|
summarize(to_eval)))
|
|
|
|
|
2018-11-24 18:16:54 -08:00
|
|
|
for policy_id, eval_data in to_eval.items():
|
2020-02-06 18:44:08 +01:00
|
|
|
rnn_in = [t.rnn_state for t in eval_data]
|
2018-11-24 18:16:54 -08:00
|
|
|
policy = _get_or_raise(policies, policy_id)
|
|
|
|
if builder and (policy.compute_actions.__code__ is
|
2019-05-20 16:46:05 -07:00
|
|
|
TFPolicy.compute_actions.__code__):
|
2020-02-06 18:44:08 +01:00
|
|
|
rnn_in_cols = _to_column_format(rnn_in)
|
2018-12-18 10:40:01 -08:00
|
|
|
# TODO(ekl): how can we make info batch available to TF code?
|
2020-02-11 00:22:07 +01:00
|
|
|
# TODO(sven): Return dict from _build_compute_actions.
|
|
|
|
# it's becoming more and more unclear otherwise, what's where in
|
|
|
|
# the return tuple.
|
2018-12-08 16:28:58 -08:00
|
|
|
pending_fetches[policy_id] = policy._build_compute_actions(
|
2020-02-11 00:22:07 +01:00
|
|
|
builder,
|
|
|
|
obs_batch=[t.obs for t in eval_data],
|
|
|
|
state_batches=rnn_in_cols,
|
2018-11-24 18:16:54 -08:00
|
|
|
prev_action_batch=[t.prev_action for t in eval_data],
|
2020-02-11 00:22:07 +01:00
|
|
|
prev_reward_batch=[t.prev_reward for t in eval_data],
|
|
|
|
timestep=policy.global_timestep)
|
2018-06-25 22:33:57 -07:00
|
|
|
else:
|
2020-02-06 18:44:08 +01:00
|
|
|
# TODO(sven): Does this work for LSTM torch?
|
|
|
|
rnn_in_cols = [
|
|
|
|
np.stack([row[i] for row in rnn_in])
|
|
|
|
for i in range(len(rnn_in[0]))
|
|
|
|
]
|
2018-11-24 18:16:54 -08:00
|
|
|
eval_results[policy_id] = policy.compute_actions(
|
|
|
|
[t.obs for t in eval_data],
|
2020-02-11 00:22:07 +01:00
|
|
|
state_batches=rnn_in_cols,
|
2018-11-24 18:16:54 -08:00
|
|
|
prev_action_batch=[t.prev_action for t in eval_data],
|
|
|
|
prev_reward_batch=[t.prev_reward for t in eval_data],
|
2018-12-18 10:40:01 -08:00
|
|
|
info_batch=[t.info for t in eval_data],
|
2020-02-11 00:22:07 +01:00
|
|
|
episodes=[active_episodes[t.env_id] for t in eval_data],
|
|
|
|
timestep=policy.global_timestep)
|
2018-11-24 18:16:54 -08:00
|
|
|
if builder:
|
|
|
|
for k, v in pending_fetches.items():
|
|
|
|
eval_results[k] = builder.get(v)
|
|
|
|
|
2019-03-26 00:27:59 -07:00
|
|
|
if log_once("compute_actions_result"):
|
2019-03-29 12:44:23 -07:00
|
|
|
logger.info("Outputs of compute_actions():\n\n{}\n".format(
|
2019-03-26 00:27:59 -07:00
|
|
|
summarize(eval_results)))
|
|
|
|
|
2018-11-24 18:16:54 -08:00
|
|
|
return eval_results
|
|
|
|
|
|
|
|
|
|
|
|
def _process_policy_eval_results(to_eval, eval_results, active_episodes,
|
2018-12-09 21:57:11 -08:00
|
|
|
active_envs, off_policy_actions, policies,
|
|
|
|
clip_actions):
|
2018-11-24 18:16:54 -08:00
|
|
|
"""Process the output of policy neural network evaluation.
|
|
|
|
|
|
|
|
Records policy evaluation results into the given episode objects and
|
|
|
|
returns replies to send back to agents in the env.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
actions_to_send: nested dict of env id -> agent id -> agent replies.
|
|
|
|
"""
|
|
|
|
|
|
|
|
actions_to_send = defaultdict(dict)
|
|
|
|
for env_id in active_envs:
|
|
|
|
actions_to_send[env_id] = {} # at minimum send empty dict
|
|
|
|
|
|
|
|
for policy_id, eval_data in to_eval.items():
|
|
|
|
rnn_in_cols = _to_column_format([t.rnn_state for t in eval_data])
|
2020-02-11 00:22:07 +01:00
|
|
|
actions, rnn_out_cols, pi_info_cols = eval_results[policy_id][:3]
|
2018-11-24 18:16:54 -08:00
|
|
|
if len(rnn_in_cols) != len(rnn_out_cols):
|
|
|
|
raise ValueError("Length of RNN in did not match RNN out, got: "
|
|
|
|
"{} vs {}".format(rnn_in_cols, rnn_out_cols))
|
|
|
|
# Add RNN state info
|
|
|
|
for f_i, column in enumerate(rnn_in_cols):
|
|
|
|
pi_info_cols["state_in_{}".format(f_i)] = column
|
|
|
|
for f_i, column in enumerate(rnn_out_cols):
|
|
|
|
pi_info_cols["state_out_{}".format(f_i)] = column
|
|
|
|
# Save output rows
|
|
|
|
actions = _unbatch_tuple_actions(actions)
|
2018-12-09 21:57:11 -08:00
|
|
|
policy = _get_or_raise(policies, policy_id)
|
2018-11-24 18:16:54 -08:00
|
|
|
for i, action in enumerate(actions):
|
|
|
|
env_id = eval_data[i].env_id
|
|
|
|
agent_id = eval_data[i].agent_id
|
2018-12-09 21:57:11 -08:00
|
|
|
if clip_actions:
|
2019-03-22 08:51:27 +01:00
|
|
|
actions_to_send[env_id][agent_id] = clip_action(
|
2018-12-09 21:57:11 -08:00
|
|
|
action, policy.action_space)
|
|
|
|
else:
|
|
|
|
actions_to_send[env_id][agent_id] = action
|
2018-11-24 18:16:54 -08:00
|
|
|
episode = active_episodes[env_id]
|
|
|
|
episode._set_rnn_state(agent_id, [c[i] for c in rnn_out_cols])
|
|
|
|
episode._set_last_pi_info(
|
|
|
|
agent_id, {k: v[i]
|
|
|
|
for k, v in pi_info_cols.items()})
|
|
|
|
if env_id in off_policy_actions and \
|
|
|
|
agent_id in off_policy_actions[env_id]:
|
|
|
|
episode._set_last_action(agent_id,
|
|
|
|
off_policy_actions[env_id][agent_id])
|
2018-06-25 22:33:57 -07:00
|
|
|
else:
|
2018-11-24 18:16:54 -08:00
|
|
|
episode._set_last_action(agent_id, action)
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
|
2018-11-24 18:16:54 -08:00
|
|
|
return actions_to_send
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
|
|
|
|
|
2019-01-23 21:27:26 -08:00
|
|
|
def _fetch_atari_metrics(base_env):
|
2018-08-23 17:49:10 -07:00
|
|
|
"""Atari games have multiple logical episodes, one per life.
|
|
|
|
|
|
|
|
However for metrics reporting we count full episodes all lives included.
|
|
|
|
"""
|
2019-01-23 21:27:26 -08:00
|
|
|
unwrapped = base_env.get_unwrapped()
|
2018-08-23 17:49:10 -07:00
|
|
|
if not unwrapped:
|
|
|
|
return None
|
|
|
|
atari_out = []
|
|
|
|
for u in unwrapped:
|
|
|
|
monitor = get_wrapper_by_cls(u, MonitorEnv)
|
|
|
|
if not monitor:
|
|
|
|
return None
|
|
|
|
for eps_rew, eps_len in monitor.next_episode_results():
|
2020-01-31 08:02:53 +02:00
|
|
|
atari_out.append(RolloutMetrics(eps_len, eps_rew))
|
2018-08-23 17:49:10 -07:00
|
|
|
return atari_out
|
|
|
|
|
|
|
|
|
2018-10-20 15:21:22 -07:00
|
|
|
def _unbatch_tuple_actions(action_batch):
|
|
|
|
# convert list of batches -> batch of lists
|
|
|
|
if isinstance(action_batch, TupleActions):
|
|
|
|
out = []
|
|
|
|
for j in range(len(action_batch.batches[0])):
|
|
|
|
out.append([
|
|
|
|
action_batch.batches[i][j]
|
|
|
|
for i in range(len(action_batch.batches))
|
|
|
|
])
|
|
|
|
return out
|
|
|
|
return action_batch
|
|
|
|
|
|
|
|
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
def _to_column_format(rnn_state_rows):
|
|
|
|
num_cols = len(rnn_state_rows[0])
|
2018-07-19 15:30:36 -07:00
|
|
|
return [[row[i] for row in rnn_state_rows] for i in range(num_cols)]
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
|
|
|
|
|
2018-06-25 22:33:57 -07:00
|
|
|
def _get_or_raise(mapping, policy_id):
|
2020-02-11 00:22:07 +01:00
|
|
|
"""Returns a Policy object under key `policy_id` in `mapping`.
|
|
|
|
|
|
|
|
Throws an error if `policy_id` cannot be found.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Policy: The found Policy object.
|
|
|
|
"""
|
2018-06-25 22:33:57 -07:00
|
|
|
if policy_id not in mapping:
|
|
|
|
raise ValueError(
|
|
|
|
"Could not find policy for agent: agent policy id `{}` not "
|
|
|
|
"in policy map keys {}.".format(policy_id, mapping.keys()))
|
|
|
|
return mapping[policy_id]
|