[wingman -> rllib] Remote and entangled environments (#3968)

* added all our environment changes

* fixed merge request comments and remote env

* fixed remote check

* moved remote_worker_envs to correct config section

* lint

* auto wrap impl

* fix

* fixed the tests
This commit is contained in:
bjg2 2019-02-13 19:08:26 +01:00 committed by Eric Liang
parent b3f72e8a75
commit 0e37ac6d1d
6 changed files with 170 additions and 23 deletions

View file

@ -129,6 +129,10 @@ COMMON_CONFIG = {
"compress_observations": False, "compress_observations": False,
# Drop metric batches from unresponsive workers after this many seconds # Drop metric batches from unresponsive workers after this many seconds
"collect_metrics_timeout": 180, "collect_metrics_timeout": 180,
# If using num_envs_per_worker > 1, whether to create those new envs in
# remote processes instead of in the same worker. This adds overheads, but
# can make sense if your envs are very CPU intensive (e.g., for StarCraft).
"remote_worker_envs": False,
# === Offline Datasets === # === Offline Datasets ===
# __sphinx_doc_input_begin__ # __sphinx_doc_input_begin__
@ -463,7 +467,9 @@ class Agent(Trainable):
"tf_session_args": self. "tf_session_args": self.
config["local_evaluator_tf_session_args"] config["local_evaluator_tf_session_args"]
}), }),
extra_config or {})) extra_config or {}),
remote_worker_envs=False,
)
@DeveloperAPI @DeveloperAPI
def make_remote_evaluators(self, env_creator, policy_graph, count): def make_remote_evaluators(self, env_creator, policy_graph, count):
@ -476,9 +482,16 @@ class Agent(Trainable):
} }
cls = PolicyEvaluator.as_remote(**remote_args).remote cls = PolicyEvaluator.as_remote(**remote_args).remote
return [ return [
self._make_evaluator(cls, env_creator, policy_graph, i + 1, self._make_evaluator(
self.config) for i in range(count) cls,
env_creator,
policy_graph,
i + 1,
self.config,
remote_worker_envs=self.config["remote_worker_envs"])
for i in range(count)
] ]
@DeveloperAPI @DeveloperAPI
@ -544,8 +557,13 @@ class Agent(Trainable):
raise ValueError( raise ValueError(
"`input_evaluation` should not be set when input=sampler") "`input_evaluation` should not be set when input=sampler")
def _make_evaluator(self, cls, env_creator, policy_graph, worker_index, def _make_evaluator(self,
config): cls,
env_creator,
policy_graph,
worker_index,
config,
remote_worker_envs=False):
def session_creator(): def session_creator():
logger.debug("Creating TF session {}".format( logger.debug("Creating TF session {}".format(
config["tf_session_args"])) config["tf_session_args"]))
@ -573,10 +591,10 @@ class Agent(Trainable):
compress_columns=config["output_compress_columns"])) compress_columns=config["output_compress_columns"]))
else: else:
output_creator = (lambda ioctx: JsonWriter( output_creator = (lambda ioctx: JsonWriter(
config["output"], config["output"],
ioctx, ioctx,
max_file_size=config["output_max_file_size"], max_file_size=config["output_max_file_size"],
compress_columns=config["output_compress_columns"])) compress_columns=config["output_compress_columns"]))
return cls( return cls(
env_creator, env_creator,
@ -605,7 +623,8 @@ class Agent(Trainable):
callbacks=config["callbacks"], callbacks=config["callbacks"],
input_creator=input_creator, input_creator=input_creator,
input_evaluation_method=config["input_evaluation"], input_evaluation_method=config["input_evaluation"],
output_creator=output_creator) output_creator=output_creator,
remote_worker_envs=remote_worker_envs)
@override(Trainable) @override(Trainable)
def _export_model(self, export_formats, export_dir): def _export_model(self, export_formats, export_dir):

View file

@ -66,10 +66,18 @@ class BaseEnv(object):
""" """
@staticmethod @staticmethod
def to_base_env(env, make_env=None, num_envs=1): def to_base_env(env, make_env=None, num_envs=1, remote_envs=False):
"""Wraps any env type as needed to expose the async interface.""" """Wraps any env type as needed to expose the async interface."""
if remote_envs and num_envs == 1:
raise ValueError(
"Remote envs only make sense to use if num_envs > 1 "
"(i.e. vectorization is enabled).")
if not isinstance(env, BaseEnv): if not isinstance(env, BaseEnv):
if isinstance(env, MultiAgentEnv): if isinstance(env, MultiAgentEnv):
if remote_envs:
raise NotImplementedError(
"Remote multiagent environments are not implemented")
env = _MultiAgentEnvToBaseEnv( env = _MultiAgentEnvToBaseEnv(
make_env=make_env, existing_envs=[env], num_envs=num_envs) make_env=make_env, existing_envs=[env], num_envs=num_envs)
elif isinstance(env, ExternalEnv): elif isinstance(env, ExternalEnv):
@ -81,7 +89,12 @@ class BaseEnv(object):
env = _VectorEnvToBaseEnv(env) env = _VectorEnvToBaseEnv(env)
else: else:
env = VectorEnv.wrap( env = VectorEnv.wrap(
make_env=make_env, existing_envs=[env], num_envs=num_envs) make_env=make_env,
existing_envs=[env],
num_envs=num_envs,
remote_envs=remote_envs,
action_space=env.action_space,
observation_space=env.observation_space)
env = _VectorEnvToBaseEnv(env) env = _VectorEnvToBaseEnv(env)
assert isinstance(env, BaseEnv) assert isinstance(env, BaseEnv)
return env return env

View file

@ -20,13 +20,23 @@ class EnvContext(dict):
uniquely identifies the worker the env is created in. uniquely identifies the worker the env is created in.
vector_index (int): When there are multiple envs per worker, this vector_index (int): When there are multiple envs per worker, this
uniquely identifies the env index within the worker. uniquely identifies the env index within the worker.
remote (bool): Whether environment should be remote or not.
""" """
def __init__(self, env_config, worker_index, vector_index=0): def __init__(self, env_config, worker_index, vector_index=0, remote=False):
dict.__init__(self, env_config) dict.__init__(self, env_config)
self.worker_index = worker_index self.worker_index = worker_index
self.vector_index = vector_index self.vector_index = vector_index
self.remote = remote
def with_vector_index(self, vector_index): def copy_with_overrides(self,
env_config=None,
worker_index=None,
vector_index=None,
remote=None):
return EnvContext( return EnvContext(
self, worker_index=self.worker_index, vector_index=vector_index) env_config if env_config is not None else self,
worker_index if worker_index is not None else self.worker_index,
vector_index if vector_index is not None else self.vector_index,
remote if remote is not None else self.remote,
)

View file

@ -2,8 +2,13 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import logging
import ray
from ray.rllib.utils.annotations import override, PublicAPI from ray.rllib.utils.annotations import override, PublicAPI
logger = logging.getLogger(__name__)
@PublicAPI @PublicAPI
class VectorEnv(object): class VectorEnv(object):
@ -18,8 +23,17 @@ class VectorEnv(object):
""" """
@staticmethod @staticmethod
def wrap(make_env=None, existing_envs=None, num_envs=1): def wrap(make_env=None,
return _VectorizedGymEnv(make_env, existing_envs or [], num_envs) existing_envs=None,
num_envs=1,
remote_envs=False,
action_space=None,
observation_space=None):
if remote_envs:
return _RemoteVectorizedGymEnv(make_env, num_envs, action_space,
observation_space)
return _VectorizedGymEnv(make_env, existing_envs or [], num_envs,
action_space, observation_space)
@PublicAPI @PublicAPI
def vector_reset(self): def vector_reset(self):
@ -70,14 +84,20 @@ class _VectorizedGymEnv(VectorEnv):
num_envs (int): Desired num gym envs to keep total. num_envs (int): Desired num gym envs to keep total.
""" """
def __init__(self, make_env, existing_envs, num_envs): def __init__(self,
make_env,
existing_envs,
num_envs,
action_space=None,
observation_space=None):
self.make_env = make_env self.make_env = make_env
self.envs = existing_envs self.envs = existing_envs
self.num_envs = num_envs self.num_envs = num_envs
while len(self.envs) < self.num_envs: while len(self.envs) < self.num_envs:
self.envs.append(self.make_env(len(self.envs))) self.envs.append(self.make_env(len(self.envs)))
self.action_space = self.envs[0].action_space self.action_space = action_space or self.envs[0].action_space
self.observation_space = self.envs[0].observation_space self.observation_space = observation_space or \
self.envs[0].observation_space
@override(VectorEnv) @override(VectorEnv)
def vector_reset(self): def vector_reset(self):
@ -101,3 +121,71 @@ class _VectorizedGymEnv(VectorEnv):
@override(VectorEnv) @override(VectorEnv)
def get_unwrapped(self): def get_unwrapped(self):
return self.envs return self.envs
@ray.remote(num_cpus=0)
class _RemoteEnv(object):
"""Wrapper class for making a gym env a remote actor."""
def __init__(self, make_env, i):
self.env = make_env(i)
def reset(self):
return self.env.reset()
def step(self, action):
return self.env.step(action)
class _RemoteVectorizedGymEnv(_VectorizedGymEnv):
"""Internal wrapper for gym envs to implement VectorEnv as remote workers.
"""
def __init__(self,
make_env,
num_envs,
action_space=None,
observation_space=None):
self.make_local_env = make_env
self.num_envs = num_envs
self.initialized = False
self.action_space = action_space
self.observation_space = observation_space
def _initialize_if_needed(self):
if self.initialized:
return
self.initialized = True
def make_remote_env(i):
logger.info("Launching env {} in remote actor".format(i))
return _RemoteEnv.remote(self.make_local_env, i)
_VectorizedGymEnv.__init__(self, make_remote_env, [], self.num_envs,
self.action_space, self.observation_space)
for env in self.envs:
assert isinstance(env, ray.actor.ActorHandle), env
@override(_VectorizedGymEnv)
def vector_reset(self):
self._initialize_if_needed()
return ray.get([env.reset.remote() for env in self.envs])
@override(_VectorizedGymEnv)
def reset_at(self, index):
return ray.get(self.envs[index].reset.remote())
@override(_VectorizedGymEnv)
def vector_step(self, actions):
step_outs = ray.get(
[env.step.remote(act) for env, act in zip(self.envs, actions)])
obs_batch, rew_batch, done_batch, info_batch = [], [], [], []
for obs, rew, done, info in step_outs:
obs_batch.append(obs)
rew_batch.append(rew)
done_batch.append(done)
info_batch.append(info)
return obs_batch, rew_batch, done_batch, info_batch

View file

@ -117,7 +117,8 @@ class PolicyEvaluator(EvaluatorInterface):
callbacks=None, callbacks=None,
input_creator=lambda ioctx: ioctx.default_sampler_input(), input_creator=lambda ioctx: ioctx.default_sampler_input(),
input_evaluation_method=None, input_evaluation_method=None,
output_creator=lambda ioctx: NoopOutput()): output_creator=lambda ioctx: NoopOutput(),
remote_worker_envs=False):
"""Initialize a policy evaluator. """Initialize a policy evaluator.
Arguments: Arguments:
@ -192,6 +193,10 @@ class PolicyEvaluator(EvaluatorInterface):
use this data for evaluation only and never for learning. use this data for evaluation only and never for learning.
output_creator (func): Function that returns an OutputWriter object output_creator (func): Function that returns an OutputWriter object
for saving generated experiences. for saving generated experiences.
remote_worker_envs (bool): If using num_envs > 1, whether to create
those new envs in remote processes instead of in the current
process. This adds overheads, but can make sense if your envs
are very CPU intensive (e.g., for StarCraft).
""" """
if log_level: if log_level:
@ -250,7 +255,9 @@ class PolicyEvaluator(EvaluatorInterface):
def make_env(vector_index): def make_env(vector_index):
return wrap( return wrap(
env_creator(env_context.with_vector_index(vector_index))) env_creator(
env_context.copy_with_overrides(
vector_index=vector_index, remote=remote_worker_envs)))
self.tf_sess = None self.tf_sess = None
policy_dict = _validate_and_canonicalize(policy_graph, self.env) policy_dict = _validate_and_canonicalize(policy_graph, self.env)
@ -293,7 +300,10 @@ class PolicyEvaluator(EvaluatorInterface):
# Always use vector env for consistency even if num_envs = 1 # Always use vector env for consistency even if num_envs = 1
self.async_env = BaseEnv.to_base_env( self.async_env = BaseEnv.to_base_env(
self.env, make_env=make_env, num_envs=num_envs) self.env,
make_env=make_env,
num_envs=num_envs,
remote_envs=remote_worker_envs)
self.num_envs = num_envs self.num_envs = num_envs
if self.batch_mode == "truncate_episodes": if self.batch_mode == "truncate_episodes":

View file

@ -71,6 +71,13 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
--stop '{"training_iteration": 2}' \ --stop '{"training_iteration": 2}' \
--config '{"kl_coeff": 1.0, "num_sgd_iter": 10, "lr": 1e-4, "sgd_minibatch_size": 64, "train_batch_size": 2000, "num_workers": 1, "use_gae": false, "batch_mode": "complete_episodes"}' --config '{"kl_coeff": 1.0, "num_sgd_iter": 10, "lr": 1e-4, "sgd_minibatch_size": 64, "train_batch_size": 2000, "num_workers": 1, "use_gae": false, "batch_mode": "complete_episodes"}'
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
python /ray/python/ray/rllib/train.py \
--env CartPole-v1 \
--run PPO \
--stop '{"training_iteration": 2}' \
--config '{"remote_worker_envs": true, "num_envs_per_worker": 2, "num_workers": 1, "train_batch_size": 100, "sgd_minibatch_size": 50}'
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
python /ray/python/ray/rllib/train.py \ python /ray/python/ray/rllib/train.py \
--env Pendulum-v0 \ --env Pendulum-v0 \