2018-06-09 00:21:35 -07:00
|
|
|
from __future__ import absolute_import
|
|
|
|
from __future__ import division
|
|
|
|
from __future__ import print_function
|
|
|
|
|
|
|
|
import gym
|
2018-10-20 15:21:22 -07:00
|
|
|
import numpy as np
|
2019-03-06 10:21:05 -08:00
|
|
|
import random
|
2018-06-09 00:21:35 -07:00
|
|
|
import time
|
|
|
|
import unittest
|
2018-11-03 18:48:32 -07:00
|
|
|
from collections import Counter
|
2018-06-09 00:21:35 -07:00
|
|
|
|
|
|
|
import ray
|
2019-04-07 00:36:18 -07:00
|
|
|
from ray.rllib.agents.pg import PGTrainer
|
|
|
|
from ray.rllib.agents.a3c import A2CTrainer
|
2019-06-03 06:49:24 +08:00
|
|
|
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
2018-07-01 00:05:08 -07:00
|
|
|
from ray.rllib.evaluation.metrics import collect_metrics
|
2019-05-20 16:46:05 -07:00
|
|
|
from ray.rllib.policy.policy import Policy
|
2018-07-01 00:05:08 -07:00
|
|
|
from ray.rllib.evaluation.postprocessing import compute_advantages
|
2019-05-20 16:46:05 -07:00
|
|
|
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
|
2018-07-01 00:05:08 -07:00
|
|
|
from ray.rllib.env.vector_env import VectorEnv
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
from ray.tune.registry import register_env
|
2018-06-09 00:21:35 -07:00
|
|
|
|
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
class MockPolicy(Policy):
|
2018-08-16 14:37:21 -07:00
|
|
|
def compute_actions(self,
|
|
|
|
obs_batch,
|
|
|
|
state_batches,
|
2018-10-20 15:21:22 -07:00
|
|
|
prev_action_batch=None,
|
|
|
|
prev_reward_batch=None,
|
2018-12-18 10:40:01 -08:00
|
|
|
episodes=None,
|
|
|
|
**kwargs):
|
2019-03-06 10:21:05 -08:00
|
|
|
return [random.choice([0, 1])] * len(obs_batch), [], {}
|
2018-06-09 00:21:35 -07:00
|
|
|
|
2018-10-29 19:37:27 -07:00
|
|
|
def postprocess_trajectory(self,
|
|
|
|
batch,
|
|
|
|
other_agent_batches=None,
|
|
|
|
episode=None):
|
|
|
|
assert episode is not None
|
2018-06-09 00:21:35 -07:00
|
|
|
return compute_advantages(batch, 100.0, 0.9, use_gae=False)
|
|
|
|
|
|
|
|
|
2019-05-20 16:46:05 -07:00
|
|
|
class BadPolicy(Policy):
|
2018-08-16 14:37:21 -07:00
|
|
|
def compute_actions(self,
|
|
|
|
obs_batch,
|
|
|
|
state_batches,
|
2018-10-20 15:21:22 -07:00
|
|
|
prev_action_batch=None,
|
|
|
|
prev_reward_batch=None,
|
2018-12-18 10:40:01 -08:00
|
|
|
episodes=None,
|
|
|
|
**kwargs):
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
raise Exception("intentional error")
|
|
|
|
|
2018-10-29 19:37:27 -07:00
|
|
|
def postprocess_trajectory(self,
|
|
|
|
batch,
|
|
|
|
other_agent_batches=None,
|
|
|
|
episode=None):
|
|
|
|
assert episode is not None
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
return compute_advantages(batch, 100.0, 0.9, use_gae=False)
|
|
|
|
|
|
|
|
|
2018-11-11 01:45:37 -08:00
|
|
|
class FailOnStepEnv(gym.Env):
|
|
|
|
def __init__(self):
|
|
|
|
self.observation_space = gym.spaces.Discrete(1)
|
|
|
|
self.action_space = gym.spaces.Discrete(2)
|
|
|
|
|
|
|
|
def reset(self):
|
|
|
|
raise ValueError("kaboom")
|
|
|
|
|
|
|
|
def step(self, action):
|
|
|
|
raise ValueError("kaboom")
|
|
|
|
|
|
|
|
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
class MockEnv(gym.Env):
|
2018-08-01 16:29:27 -07:00
|
|
|
def __init__(self, episode_length, config=None):
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
self.episode_length = episode_length
|
2018-08-01 16:29:27 -07:00
|
|
|
self.config = config
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
self.i = 0
|
|
|
|
self.observation_space = gym.spaces.Discrete(1)
|
|
|
|
self.action_space = gym.spaces.Discrete(2)
|
|
|
|
|
|
|
|
def reset(self):
|
|
|
|
self.i = 0
|
|
|
|
return self.i
|
|
|
|
|
|
|
|
def step(self, action):
|
|
|
|
self.i += 1
|
|
|
|
return 0, 1, self.i >= self.episode_length, {}
|
|
|
|
|
|
|
|
|
2018-06-25 22:33:57 -07:00
|
|
|
class MockEnv2(gym.Env):
|
|
|
|
def __init__(self, episode_length):
|
|
|
|
self.episode_length = episode_length
|
|
|
|
self.i = 0
|
|
|
|
self.observation_space = gym.spaces.Discrete(100)
|
|
|
|
self.action_space = gym.spaces.Discrete(2)
|
|
|
|
|
|
|
|
def reset(self):
|
|
|
|
self.i = 0
|
|
|
|
return self.i
|
|
|
|
|
|
|
|
def step(self, action):
|
|
|
|
self.i += 1
|
|
|
|
return self.i, 100, self.i >= self.episode_length, {}
|
|
|
|
|
|
|
|
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
class MockVectorEnv(VectorEnv):
|
|
|
|
def __init__(self, episode_length, num_envs):
|
2018-07-19 15:30:36 -07:00
|
|
|
self.envs = [MockEnv(episode_length) for _ in range(num_envs)]
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
self.observation_space = gym.spaces.Discrete(1)
|
|
|
|
self.action_space = gym.spaces.Discrete(2)
|
|
|
|
self.num_envs = num_envs
|
|
|
|
|
|
|
|
def vector_reset(self):
|
|
|
|
return [e.reset() for e in self.envs]
|
|
|
|
|
|
|
|
def reset_at(self, index):
|
|
|
|
return self.envs[index].reset()
|
|
|
|
|
|
|
|
def vector_step(self, actions):
|
|
|
|
obs_batch, rew_batch, done_batch, info_batch = [], [], [], []
|
|
|
|
for i in range(len(self.envs)):
|
|
|
|
obs, rew, done, info = self.envs[i].step(actions[i])
|
|
|
|
obs_batch.append(obs)
|
|
|
|
rew_batch.append(rew)
|
|
|
|
done_batch.append(done)
|
|
|
|
info_batch.append(info)
|
|
|
|
return obs_batch, rew_batch, done_batch, info_batch
|
|
|
|
|
2018-08-23 17:49:10 -07:00
|
|
|
def get_unwrapped(self):
|
|
|
|
return self.envs
|
|
|
|
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
|
2019-06-03 06:49:24 +08:00
|
|
|
class TestRolloutWorker(unittest.TestCase):
|
2018-06-09 00:21:35 -07:00
|
|
|
def testBasic(self):
|
2019-06-03 06:49:24 +08:00
|
|
|
ev = RolloutWorker(
|
2019-05-20 16:46:05 -07:00
|
|
|
env_creator=lambda _: gym.make("CartPole-v0"), policy=MockPolicy)
|
2018-06-09 00:21:35 -07:00
|
|
|
batch = ev.sample()
|
2018-10-20 15:21:22 -07:00
|
|
|
for key in [
|
|
|
|
"obs", "actions", "rewards", "dones", "advantages",
|
|
|
|
"prev_rewards", "prev_actions"
|
|
|
|
]:
|
2018-06-09 00:21:35 -07:00
|
|
|
self.assertIn(key, batch)
|
2019-03-06 10:21:05 -08:00
|
|
|
self.assertGreater(np.abs(np.mean(batch[key])), 0)
|
2018-10-20 15:21:22 -07:00
|
|
|
|
|
|
|
def to_prev(vec):
|
|
|
|
out = np.zeros_like(vec)
|
|
|
|
for i, v in enumerate(vec):
|
|
|
|
if i + 1 < len(out) and not batch["dones"][i]:
|
|
|
|
out[i + 1] = v
|
|
|
|
return out.tolist()
|
|
|
|
|
|
|
|
self.assertEqual(batch["prev_rewards"].tolist(),
|
|
|
|
to_prev(batch["rewards"]))
|
|
|
|
self.assertEqual(batch["prev_actions"].tolist(),
|
|
|
|
to_prev(batch["actions"]))
|
2018-06-09 00:21:35 -07:00
|
|
|
self.assertGreater(batch["advantages"][0], 1)
|
|
|
|
|
2019-04-07 12:11:30 -07:00
|
|
|
def testBatchIds(self):
|
2019-06-03 06:49:24 +08:00
|
|
|
ev = RolloutWorker(
|
2019-05-20 16:46:05 -07:00
|
|
|
env_creator=lambda _: gym.make("CartPole-v0"), policy=MockPolicy)
|
2019-04-07 12:11:30 -07:00
|
|
|
batch1 = ev.sample()
|
|
|
|
batch2 = ev.sample()
|
|
|
|
self.assertEqual(len(set(batch1["unroll_id"])), 1)
|
|
|
|
self.assertEqual(len(set(batch2["unroll_id"])), 1)
|
|
|
|
self.assertEqual(
|
|
|
|
len(set(SampleBatch.concat(batch1, batch2)["unroll_id"])), 2)
|
|
|
|
|
2018-08-23 17:49:10 -07:00
|
|
|
def testGlobalVarsUpdate(self):
|
2019-04-07 00:36:18 -07:00
|
|
|
agent = A2CTrainer(
|
2018-08-23 17:49:10 -07:00
|
|
|
env="CartPole-v0",
|
|
|
|
config={
|
|
|
|
"lr_schedule": [[0, 0.1], [400, 0.000001]],
|
|
|
|
})
|
|
|
|
result = agent.train()
|
|
|
|
self.assertGreater(result["info"]["learner"]["cur_lr"], 0.01)
|
|
|
|
result2 = agent.train()
|
|
|
|
self.assertLess(result2["info"]["learner"]["cur_lr"], 0.0001)
|
|
|
|
|
2018-11-11 01:45:37 -08:00
|
|
|
def testNoStepOnInit(self):
|
|
|
|
register_env("fail", lambda _: FailOnStepEnv())
|
2019-04-07 00:36:18 -07:00
|
|
|
pg = PGTrainer(env="fail", config={"num_workers": 1})
|
2018-11-11 01:45:37 -08:00
|
|
|
self.assertRaises(Exception, lambda: pg.train())
|
|
|
|
|
2018-11-03 18:48:32 -07:00
|
|
|
def testCallbacks(self):
|
|
|
|
counts = Counter()
|
2019-04-07 00:36:18 -07:00
|
|
|
pg = PGTrainer(
|
2018-11-03 18:48:32 -07:00
|
|
|
env="CartPole-v0", config={
|
|
|
|
"num_workers": 0,
|
|
|
|
"sample_batch_size": 50,
|
2018-12-12 13:57:48 -08:00
|
|
|
"train_batch_size": 50,
|
2018-11-03 18:48:32 -07:00
|
|
|
"callbacks": {
|
|
|
|
"on_episode_start": lambda x: counts.update({"start": 1}),
|
|
|
|
"on_episode_step": lambda x: counts.update({"step": 1}),
|
|
|
|
"on_episode_end": lambda x: counts.update({"end": 1}),
|
|
|
|
"on_sample_end": lambda x: counts.update({"sample": 1}),
|
|
|
|
},
|
|
|
|
})
|
|
|
|
pg.train()
|
2018-11-06 14:59:22 -08:00
|
|
|
pg.train()
|
|
|
|
pg.train()
|
|
|
|
pg.train()
|
|
|
|
self.assertEqual(counts["sample"], 4)
|
2018-11-03 18:48:32 -07:00
|
|
|
self.assertGreater(counts["start"], 0)
|
|
|
|
self.assertGreater(counts["end"], 0)
|
2018-11-06 14:59:22 -08:00
|
|
|
self.assertGreater(counts["step"], 200)
|
|
|
|
self.assertLess(counts["step"], 400)
|
2018-11-03 18:48:32 -07:00
|
|
|
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
def testQueryEvaluators(self):
|
|
|
|
register_env("test", lambda _: gym.make("CartPole-v0"))
|
2019-04-07 00:36:18 -07:00
|
|
|
pg = PGTrainer(
|
2019-02-11 10:40:47 -08:00
|
|
|
env="test",
|
|
|
|
config={
|
2018-07-19 15:30:36 -07:00
|
|
|
"num_workers": 2,
|
2019-02-11 10:40:47 -08:00
|
|
|
"sample_batch_size": 5,
|
|
|
|
"num_envs_per_worker": 2,
|
2018-07-19 15:30:36 -07:00
|
|
|
})
|
2019-06-03 06:49:24 +08:00
|
|
|
results = pg.workers.foreach_worker(lambda ev: ev.sample_batch_size)
|
|
|
|
results2 = pg.workers.foreach_worker_with_index(
|
2018-09-30 18:36:22 -07:00
|
|
|
lambda ev, i: (i, ev.sample_batch_size))
|
2019-06-03 06:49:24 +08:00
|
|
|
results3 = pg.workers.foreach_worker(
|
2019-02-11 10:40:47 -08:00
|
|
|
lambda ev: ev.foreach_env(lambda env: 1))
|
|
|
|
self.assertEqual(results, [10, 10, 10])
|
|
|
|
self.assertEqual(results2, [(0, 10), (1, 10), (2, 10)])
|
|
|
|
self.assertEqual(results3, [[1, 1], [1, 1], [1, 1]])
|
2018-06-09 00:21:35 -07:00
|
|
|
|
2018-08-20 15:28:03 -07:00
|
|
|
def testRewardClipping(self):
|
|
|
|
# clipping on
|
2019-06-03 06:49:24 +08:00
|
|
|
ev = RolloutWorker(
|
2018-08-20 15:28:03 -07:00
|
|
|
env_creator=lambda _: MockEnv2(episode_length=10),
|
2019-05-20 16:46:05 -07:00
|
|
|
policy=MockPolicy,
|
2018-08-20 15:28:03 -07:00
|
|
|
clip_rewards=True,
|
|
|
|
batch_mode="complete_episodes")
|
|
|
|
self.assertEqual(max(ev.sample()["rewards"]), 1)
|
|
|
|
result = collect_metrics(ev, [])
|
|
|
|
self.assertEqual(result["episode_reward_mean"], 1000)
|
|
|
|
|
|
|
|
# clipping off
|
2019-06-03 06:49:24 +08:00
|
|
|
ev2 = RolloutWorker(
|
2018-08-20 15:28:03 -07:00
|
|
|
env_creator=lambda _: MockEnv2(episode_length=10),
|
2019-05-20 16:46:05 -07:00
|
|
|
policy=MockPolicy,
|
2018-08-20 15:28:03 -07:00
|
|
|
clip_rewards=False,
|
|
|
|
batch_mode="complete_episodes")
|
|
|
|
self.assertEqual(max(ev2.sample()["rewards"]), 100)
|
|
|
|
result2 = collect_metrics(ev2, [])
|
|
|
|
self.assertEqual(result2["episode_reward_mean"], 1000)
|
|
|
|
|
2019-04-02 02:44:15 -07:00
|
|
|
def testHardHorizon(self):
|
2019-06-03 06:49:24 +08:00
|
|
|
ev = RolloutWorker(
|
2019-04-02 02:44:15 -07:00
|
|
|
env_creator=lambda _: MockEnv(episode_length=10),
|
2019-05-20 16:46:05 -07:00
|
|
|
policy=MockPolicy,
|
2019-04-02 02:44:15 -07:00
|
|
|
batch_mode="complete_episodes",
|
|
|
|
batch_steps=10,
|
|
|
|
episode_horizon=4,
|
|
|
|
soft_horizon=False)
|
|
|
|
samples = ev.sample()
|
|
|
|
# three logical episodes
|
|
|
|
self.assertEqual(len(set(samples["eps_id"])), 3)
|
|
|
|
# 3 done values
|
|
|
|
self.assertEqual(sum(samples["dones"]), 3)
|
|
|
|
|
|
|
|
def testSoftHorizon(self):
|
2019-06-03 06:49:24 +08:00
|
|
|
ev = RolloutWorker(
|
2019-04-02 02:44:15 -07:00
|
|
|
env_creator=lambda _: MockEnv(episode_length=10),
|
2019-05-20 16:46:05 -07:00
|
|
|
policy=MockPolicy,
|
2019-04-02 02:44:15 -07:00
|
|
|
batch_mode="complete_episodes",
|
|
|
|
batch_steps=10,
|
|
|
|
episode_horizon=4,
|
|
|
|
soft_horizon=True)
|
|
|
|
samples = ev.sample()
|
|
|
|
# three logical episodes
|
|
|
|
self.assertEqual(len(set(samples["eps_id"])), 3)
|
|
|
|
# only 1 hard done value
|
|
|
|
self.assertEqual(sum(samples["dones"]), 1)
|
|
|
|
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
def testMetrics(self):
|
2019-06-03 06:49:24 +08:00
|
|
|
ev = RolloutWorker(
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
env_creator=lambda _: MockEnv(episode_length=10),
|
2019-05-20 16:46:05 -07:00
|
|
|
policy=MockPolicy,
|
2018-07-19 15:30:36 -07:00
|
|
|
batch_mode="complete_episodes")
|
2019-06-03 06:49:24 +08:00
|
|
|
remote_ev = RolloutWorker.as_remote().remote(
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
env_creator=lambda _: MockEnv(episode_length=10),
|
2019-05-20 16:46:05 -07:00
|
|
|
policy=MockPolicy,
|
2018-07-19 15:30:36 -07:00
|
|
|
batch_mode="complete_episodes")
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
ev.sample()
|
|
|
|
ray.get(remote_ev.sample.remote())
|
|
|
|
result = collect_metrics(ev, [remote_ev])
|
2018-09-30 01:15:13 -07:00
|
|
|
self.assertEqual(result["episodes_this_iter"], 20)
|
2018-08-07 12:17:44 -07:00
|
|
|
self.assertEqual(result["episode_reward_mean"], 10)
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
|
|
|
|
def testAsync(self):
|
2019-06-03 06:49:24 +08:00
|
|
|
ev = RolloutWorker(
|
2018-06-09 00:21:35 -07:00
|
|
|
env_creator=lambda _: gym.make("CartPole-v0"),
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
sample_async=True,
|
2019-05-20 16:46:05 -07:00
|
|
|
policy=MockPolicy)
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
batch = ev.sample()
|
|
|
|
for key in ["obs", "actions", "rewards", "dones", "advantages"]:
|
|
|
|
self.assertIn(key, batch)
|
|
|
|
self.assertGreater(batch["advantages"][0], 1)
|
|
|
|
|
|
|
|
def testAutoVectorization(self):
|
2019-06-03 06:49:24 +08:00
|
|
|
ev = RolloutWorker(
|
2018-08-01 16:29:27 -07:00
|
|
|
env_creator=lambda cfg: MockEnv(episode_length=20, config=cfg),
|
2019-05-20 16:46:05 -07:00
|
|
|
policy=MockPolicy,
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
batch_mode="truncate_episodes",
|
2018-09-30 18:36:22 -07:00
|
|
|
batch_steps=2,
|
2018-07-19 15:30:36 -07:00
|
|
|
num_envs=8)
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
for _ in range(8):
|
|
|
|
batch = ev.sample()
|
|
|
|
self.assertEqual(batch.count, 16)
|
|
|
|
result = collect_metrics(ev, [])
|
2018-09-30 01:15:13 -07:00
|
|
|
self.assertEqual(result["episodes_this_iter"], 0)
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
for _ in range(8):
|
|
|
|
batch = ev.sample()
|
|
|
|
self.assertEqual(batch.count, 16)
|
|
|
|
result = collect_metrics(ev, [])
|
2018-09-30 01:15:13 -07:00
|
|
|
self.assertEqual(result["episodes_this_iter"], 8)
|
2018-08-01 16:29:27 -07:00
|
|
|
indices = []
|
|
|
|
for env in ev.async_env.vector_env.envs:
|
|
|
|
self.assertEqual(env.unwrapped.config.worker_index, 0)
|
|
|
|
indices.append(env.unwrapped.config.vector_index)
|
|
|
|
self.assertEqual(indices, [0, 1, 2, 3, 4, 5, 6, 7])
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
|
2018-09-30 18:36:22 -07:00
|
|
|
def testBatchesLargerWhenVectorized(self):
|
2019-06-03 06:49:24 +08:00
|
|
|
ev = RolloutWorker(
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
env_creator=lambda _: MockEnv(episode_length=8),
|
2019-05-20 16:46:05 -07:00
|
|
|
policy=MockPolicy,
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
batch_mode="truncate_episodes",
|
2018-09-30 18:36:22 -07:00
|
|
|
batch_steps=4,
|
2018-07-19 15:30:36 -07:00
|
|
|
num_envs=4)
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
batch = ev.sample()
|
|
|
|
self.assertEqual(batch.count, 16)
|
|
|
|
result = collect_metrics(ev, [])
|
2018-09-30 01:15:13 -07:00
|
|
|
self.assertEqual(result["episodes_this_iter"], 0)
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
batch = ev.sample()
|
|
|
|
result = collect_metrics(ev, [])
|
2018-09-30 01:15:13 -07:00
|
|
|
self.assertEqual(result["episodes_this_iter"], 4)
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
|
|
|
|
def testVectorEnvSupport(self):
|
2019-06-03 06:49:24 +08:00
|
|
|
ev = RolloutWorker(
|
2018-07-19 15:30:36 -07:00
|
|
|
env_creator=lambda _: MockVectorEnv(episode_length=20, num_envs=8),
|
2019-05-20 16:46:05 -07:00
|
|
|
policy=MockPolicy,
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
batch_mode="truncate_episodes",
|
|
|
|
batch_steps=10)
|
|
|
|
for _ in range(8):
|
|
|
|
batch = ev.sample()
|
|
|
|
self.assertEqual(batch.count, 10)
|
|
|
|
result = collect_metrics(ev, [])
|
2018-09-30 01:15:13 -07:00
|
|
|
self.assertEqual(result["episodes_this_iter"], 0)
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
for _ in range(8):
|
|
|
|
batch = ev.sample()
|
|
|
|
self.assertEqual(batch.count, 10)
|
|
|
|
result = collect_metrics(ev, [])
|
2018-09-30 01:15:13 -07:00
|
|
|
self.assertEqual(result["episodes_this_iter"], 8)
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
|
|
|
|
def testTruncateEpisodes(self):
|
2019-06-03 06:49:24 +08:00
|
|
|
ev = RolloutWorker(
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
env_creator=lambda _: MockEnv(10),
|
2019-05-20 16:46:05 -07:00
|
|
|
policy=MockPolicy,
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
batch_steps=15,
|
2018-06-09 00:21:35 -07:00
|
|
|
batch_mode="truncate_episodes")
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
batch = ev.sample()
|
|
|
|
self.assertEqual(batch.count, 15)
|
2018-06-09 00:21:35 -07:00
|
|
|
|
|
|
|
def testCompleteEpisodes(self):
|
2019-06-03 06:49:24 +08:00
|
|
|
ev = RolloutWorker(
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
env_creator=lambda _: MockEnv(10),
|
2019-05-20 16:46:05 -07:00
|
|
|
policy=MockPolicy,
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
batch_steps=5,
|
2018-06-09 00:21:35 -07:00
|
|
|
batch_mode="complete_episodes")
|
|
|
|
batch = ev.sample()
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
self.assertEqual(batch.count, 10)
|
|
|
|
|
|
|
|
def testCompleteEpisodesPacking(self):
|
2019-06-03 06:49:24 +08:00
|
|
|
ev = RolloutWorker(
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
env_creator=lambda _: MockEnv(10),
|
2019-05-20 16:46:05 -07:00
|
|
|
policy=MockPolicy,
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
batch_steps=15,
|
|
|
|
batch_mode="complete_episodes")
|
2018-06-09 00:21:35 -07:00
|
|
|
batch = ev.sample()
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
self.assertEqual(batch.count, 20)
|
2018-06-23 18:32:16 -07:00
|
|
|
self.assertEqual(
|
|
|
|
batch["t"].tolist(),
|
|
|
|
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
|
2018-06-09 00:21:35 -07:00
|
|
|
|
|
|
|
def testFilterSync(self):
|
2019-06-03 06:49:24 +08:00
|
|
|
ev = RolloutWorker(
|
2018-06-09 00:21:35 -07:00
|
|
|
env_creator=lambda _: gym.make("CartPole-v0"),
|
2019-05-20 16:46:05 -07:00
|
|
|
policy=MockPolicy,
|
2018-06-09 00:21:35 -07:00
|
|
|
sample_async=True,
|
|
|
|
observation_filter="ConcurrentMeanStdFilter")
|
|
|
|
time.sleep(2)
|
|
|
|
ev.sample()
|
|
|
|
filters = ev.get_filters(flush_after=True)
|
2019-03-26 00:27:59 -07:00
|
|
|
obs_f = filters[DEFAULT_POLICY_ID]
|
2018-06-09 00:21:35 -07:00
|
|
|
self.assertNotEqual(obs_f.rs.n, 0)
|
|
|
|
self.assertNotEqual(obs_f.buffer.n, 0)
|
|
|
|
|
|
|
|
def testGetFilters(self):
|
2019-06-03 06:49:24 +08:00
|
|
|
ev = RolloutWorker(
|
2018-06-09 00:21:35 -07:00
|
|
|
env_creator=lambda _: gym.make("CartPole-v0"),
|
2019-05-20 16:46:05 -07:00
|
|
|
policy=MockPolicy,
|
2018-06-09 00:21:35 -07:00
|
|
|
sample_async=True,
|
|
|
|
observation_filter="ConcurrentMeanStdFilter")
|
|
|
|
self.sample_and_flush(ev)
|
|
|
|
filters = ev.get_filters(flush_after=False)
|
|
|
|
time.sleep(2)
|
|
|
|
filters2 = ev.get_filters(flush_after=False)
|
2019-03-26 00:27:59 -07:00
|
|
|
obs_f = filters[DEFAULT_POLICY_ID]
|
|
|
|
obs_f2 = filters2[DEFAULT_POLICY_ID]
|
2018-06-09 00:21:35 -07:00
|
|
|
self.assertGreaterEqual(obs_f2.rs.n, obs_f.rs.n)
|
|
|
|
self.assertGreaterEqual(obs_f2.buffer.n, obs_f.buffer.n)
|
|
|
|
|
|
|
|
def testSyncFilter(self):
|
2019-06-03 06:49:24 +08:00
|
|
|
ev = RolloutWorker(
|
2018-06-09 00:21:35 -07:00
|
|
|
env_creator=lambda _: gym.make("CartPole-v0"),
|
2019-05-20 16:46:05 -07:00
|
|
|
policy=MockPolicy,
|
2018-06-09 00:21:35 -07:00
|
|
|
sample_async=True,
|
|
|
|
observation_filter="ConcurrentMeanStdFilter")
|
|
|
|
obs_f = self.sample_and_flush(ev)
|
|
|
|
|
|
|
|
# Current State
|
|
|
|
filters = ev.get_filters(flush_after=False)
|
2019-03-26 00:27:59 -07:00
|
|
|
obs_f = filters[DEFAULT_POLICY_ID]
|
2018-06-09 00:21:35 -07:00
|
|
|
|
|
|
|
self.assertLessEqual(obs_f.buffer.n, 20)
|
|
|
|
|
|
|
|
new_obsf = obs_f.copy()
|
|
|
|
new_obsf.rs._n = 100
|
2019-03-26 00:27:59 -07:00
|
|
|
ev.sync_filters({DEFAULT_POLICY_ID: new_obsf})
|
2018-06-09 00:21:35 -07:00
|
|
|
filters = ev.get_filters(flush_after=False)
|
2019-03-26 00:27:59 -07:00
|
|
|
obs_f = filters[DEFAULT_POLICY_ID]
|
2018-06-09 00:21:35 -07:00
|
|
|
self.assertGreaterEqual(obs_f.rs.n, 100)
|
|
|
|
self.assertLessEqual(obs_f.buffer.n, 20)
|
|
|
|
|
|
|
|
def sample_and_flush(self, ev):
|
|
|
|
time.sleep(2)
|
|
|
|
ev.sample()
|
|
|
|
filters = ev.get_filters(flush_after=True)
|
2019-03-26 00:27:59 -07:00
|
|
|
obs_f = filters[DEFAULT_POLICY_ID]
|
2018-06-09 00:21:35 -07:00
|
|
|
self.assertNotEqual(obs_f.rs.n, 0)
|
|
|
|
self.assertNotEqual(obs_f.buffer.n, 0)
|
|
|
|
return obs_f
|
|
|
|
|
|
|
|
|
2019-02-15 13:32:43 -08:00
|
|
|
if __name__ == "__main__":
|
[rllib] Envs for vectorized execution, async execution, and policy serving (#2170)
## What do these changes do?
**Vectorized envs**: Users can either implement `VectorEnv`, or alternatively set `num_envs=N` to auto-vectorize gym envs (this vectorizes just the action computation part).
```
# CartPole-v0 on single core with 64x64 MLP:
# vector_width=1:
Actions per second 2720.1284458322966
# vector_width=8:
Actions per second 13773.035334888269
# vector_width=64:
Actions per second 37903.20472563333
```
**Async envs**: The more general form of `VectorEnv` is `AsyncVectorEnv`, which allows agents to execute out of lockstep. We use this as an adapter to support `ServingEnv`. Since we can convert any other form of env to `AsyncVectorEnv`, utils.sampler has been rewritten to run against this interface.
**Policy serving**: This provides an env which is not stepped. Rather, the env executes in its own thread, querying the policy for actions via `self.get_action(obs)`, and reporting results via `self.log_returns(rewards)`. We also support logging of off-policy actions via `self.log_action(obs, action)`. This is a more convenient API for some use cases, and also provides parallelizable support for policy serving (for example, if you start a HTTP server in the env) and ingest of offline logs (if the env reads from serving logs).
Any of these types of envs can be passed to RLlib agents. RLlib handles conversions internally in CommonPolicyEvaluator, for example:
```
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.ServingEnv => rllib.AsyncVectorEnv
```
2018-06-18 11:55:32 -07:00
|
|
|
ray.init(num_cpus=5)
|
2018-06-09 00:21:35 -07:00
|
|
|
unittest.main(verbosity=2)
|