mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[wingman -> rllib] Remote and entangled environments (#3968)
* added all our environment changes * fixed merge request comments and remote env * fixed remote check * moved remote_worker_envs to correct config section * lint * auto wrap impl * fix * fixed the tests
This commit is contained in:
parent
b3f72e8a75
commit
0e37ac6d1d
6 changed files with 170 additions and 23 deletions
|
@ -129,6 +129,10 @@ COMMON_CONFIG = {
|
|||
"compress_observations": False,
|
||||
# Drop metric batches from unresponsive workers after this many seconds
|
||||
"collect_metrics_timeout": 180,
|
||||
# If using num_envs_per_worker > 1, whether to create those new envs in
|
||||
# remote processes instead of in the same worker. This adds overheads, but
|
||||
# can make sense if your envs are very CPU intensive (e.g., for StarCraft).
|
||||
"remote_worker_envs": False,
|
||||
|
||||
# === Offline Datasets ===
|
||||
# __sphinx_doc_input_begin__
|
||||
|
@ -463,7 +467,9 @@ class Agent(Trainable):
|
|||
"tf_session_args": self.
|
||||
config["local_evaluator_tf_session_args"]
|
||||
}),
|
||||
extra_config or {}))
|
||||
extra_config or {}),
|
||||
remote_worker_envs=False,
|
||||
)
|
||||
|
||||
@DeveloperAPI
|
||||
def make_remote_evaluators(self, env_creator, policy_graph, count):
|
||||
|
@ -476,9 +482,16 @@ class Agent(Trainable):
|
|||
}
|
||||
|
||||
cls = PolicyEvaluator.as_remote(**remote_args).remote
|
||||
|
||||
return [
|
||||
self._make_evaluator(cls, env_creator, policy_graph, i + 1,
|
||||
self.config) for i in range(count)
|
||||
self._make_evaluator(
|
||||
cls,
|
||||
env_creator,
|
||||
policy_graph,
|
||||
i + 1,
|
||||
self.config,
|
||||
remote_worker_envs=self.config["remote_worker_envs"])
|
||||
for i in range(count)
|
||||
]
|
||||
|
||||
@DeveloperAPI
|
||||
|
@ -544,8 +557,13 @@ class Agent(Trainable):
|
|||
raise ValueError(
|
||||
"`input_evaluation` should not be set when input=sampler")
|
||||
|
||||
def _make_evaluator(self, cls, env_creator, policy_graph, worker_index,
|
||||
config):
|
||||
def _make_evaluator(self,
|
||||
cls,
|
||||
env_creator,
|
||||
policy_graph,
|
||||
worker_index,
|
||||
config,
|
||||
remote_worker_envs=False):
|
||||
def session_creator():
|
||||
logger.debug("Creating TF session {}".format(
|
||||
config["tf_session_args"]))
|
||||
|
@ -573,10 +591,10 @@ class Agent(Trainable):
|
|||
compress_columns=config["output_compress_columns"]))
|
||||
else:
|
||||
output_creator = (lambda ioctx: JsonWriter(
|
||||
config["output"],
|
||||
ioctx,
|
||||
max_file_size=config["output_max_file_size"],
|
||||
compress_columns=config["output_compress_columns"]))
|
||||
config["output"],
|
||||
ioctx,
|
||||
max_file_size=config["output_max_file_size"],
|
||||
compress_columns=config["output_compress_columns"]))
|
||||
|
||||
return cls(
|
||||
env_creator,
|
||||
|
@ -605,7 +623,8 @@ class Agent(Trainable):
|
|||
callbacks=config["callbacks"],
|
||||
input_creator=input_creator,
|
||||
input_evaluation_method=config["input_evaluation"],
|
||||
output_creator=output_creator)
|
||||
output_creator=output_creator,
|
||||
remote_worker_envs=remote_worker_envs)
|
||||
|
||||
@override(Trainable)
|
||||
def _export_model(self, export_formats, export_dir):
|
||||
|
|
17
python/ray/rllib/env/base_env.py
vendored
17
python/ray/rllib/env/base_env.py
vendored
|
@ -66,10 +66,18 @@ class BaseEnv(object):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def to_base_env(env, make_env=None, num_envs=1):
|
||||
def to_base_env(env, make_env=None, num_envs=1, remote_envs=False):
|
||||
"""Wraps any env type as needed to expose the async interface."""
|
||||
if remote_envs and num_envs == 1:
|
||||
raise ValueError(
|
||||
"Remote envs only make sense to use if num_envs > 1 "
|
||||
"(i.e. vectorization is enabled).")
|
||||
if not isinstance(env, BaseEnv):
|
||||
if isinstance(env, MultiAgentEnv):
|
||||
if remote_envs:
|
||||
raise NotImplementedError(
|
||||
"Remote multiagent environments are not implemented")
|
||||
|
||||
env = _MultiAgentEnvToBaseEnv(
|
||||
make_env=make_env, existing_envs=[env], num_envs=num_envs)
|
||||
elif isinstance(env, ExternalEnv):
|
||||
|
@ -81,7 +89,12 @@ class BaseEnv(object):
|
|||
env = _VectorEnvToBaseEnv(env)
|
||||
else:
|
||||
env = VectorEnv.wrap(
|
||||
make_env=make_env, existing_envs=[env], num_envs=num_envs)
|
||||
make_env=make_env,
|
||||
existing_envs=[env],
|
||||
num_envs=num_envs,
|
||||
remote_envs=remote_envs,
|
||||
action_space=env.action_space,
|
||||
observation_space=env.observation_space)
|
||||
env = _VectorEnvToBaseEnv(env)
|
||||
assert isinstance(env, BaseEnv)
|
||||
return env
|
||||
|
|
16
python/ray/rllib/env/env_context.py
vendored
16
python/ray/rllib/env/env_context.py
vendored
|
@ -20,13 +20,23 @@ class EnvContext(dict):
|
|||
uniquely identifies the worker the env is created in.
|
||||
vector_index (int): When there are multiple envs per worker, this
|
||||
uniquely identifies the env index within the worker.
|
||||
remote (bool): Whether environment should be remote or not.
|
||||
"""
|
||||
|
||||
def __init__(self, env_config, worker_index, vector_index=0):
|
||||
def __init__(self, env_config, worker_index, vector_index=0, remote=False):
|
||||
dict.__init__(self, env_config)
|
||||
self.worker_index = worker_index
|
||||
self.vector_index = vector_index
|
||||
self.remote = remote
|
||||
|
||||
def with_vector_index(self, vector_index):
|
||||
def copy_with_overrides(self,
|
||||
env_config=None,
|
||||
worker_index=None,
|
||||
vector_index=None,
|
||||
remote=None):
|
||||
return EnvContext(
|
||||
self, worker_index=self.worker_index, vector_index=vector_index)
|
||||
env_config if env_config is not None else self,
|
||||
worker_index if worker_index is not None else self.worker_index,
|
||||
vector_index if vector_index is not None else self.vector_index,
|
||||
remote if remote is not None else self.remote,
|
||||
)
|
||||
|
|
98
python/ray/rllib/env/vector_env.py
vendored
98
python/ray/rllib/env/vector_env.py
vendored
|
@ -2,8 +2,13 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import logging
|
||||
|
||||
import ray
|
||||
from ray.rllib.utils.annotations import override, PublicAPI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@PublicAPI
|
||||
class VectorEnv(object):
|
||||
|
@ -18,8 +23,17 @@ class VectorEnv(object):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def wrap(make_env=None, existing_envs=None, num_envs=1):
|
||||
return _VectorizedGymEnv(make_env, existing_envs or [], num_envs)
|
||||
def wrap(make_env=None,
|
||||
existing_envs=None,
|
||||
num_envs=1,
|
||||
remote_envs=False,
|
||||
action_space=None,
|
||||
observation_space=None):
|
||||
if remote_envs:
|
||||
return _RemoteVectorizedGymEnv(make_env, num_envs, action_space,
|
||||
observation_space)
|
||||
return _VectorizedGymEnv(make_env, existing_envs or [], num_envs,
|
||||
action_space, observation_space)
|
||||
|
||||
@PublicAPI
|
||||
def vector_reset(self):
|
||||
|
@ -70,14 +84,20 @@ class _VectorizedGymEnv(VectorEnv):
|
|||
num_envs (int): Desired num gym envs to keep total.
|
||||
"""
|
||||
|
||||
def __init__(self, make_env, existing_envs, num_envs):
|
||||
def __init__(self,
|
||||
make_env,
|
||||
existing_envs,
|
||||
num_envs,
|
||||
action_space=None,
|
||||
observation_space=None):
|
||||
self.make_env = make_env
|
||||
self.envs = existing_envs
|
||||
self.num_envs = num_envs
|
||||
while len(self.envs) < self.num_envs:
|
||||
self.envs.append(self.make_env(len(self.envs)))
|
||||
self.action_space = self.envs[0].action_space
|
||||
self.observation_space = self.envs[0].observation_space
|
||||
self.action_space = action_space or self.envs[0].action_space
|
||||
self.observation_space = observation_space or \
|
||||
self.envs[0].observation_space
|
||||
|
||||
@override(VectorEnv)
|
||||
def vector_reset(self):
|
||||
|
@ -101,3 +121,71 @@ class _VectorizedGymEnv(VectorEnv):
|
|||
@override(VectorEnv)
|
||||
def get_unwrapped(self):
|
||||
return self.envs
|
||||
|
||||
|
||||
@ray.remote(num_cpus=0)
|
||||
class _RemoteEnv(object):
|
||||
"""Wrapper class for making a gym env a remote actor."""
|
||||
|
||||
def __init__(self, make_env, i):
|
||||
self.env = make_env(i)
|
||||
|
||||
def reset(self):
|
||||
return self.env.reset()
|
||||
|
||||
def step(self, action):
|
||||
return self.env.step(action)
|
||||
|
||||
|
||||
class _RemoteVectorizedGymEnv(_VectorizedGymEnv):
|
||||
"""Internal wrapper for gym envs to implement VectorEnv as remote workers.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
make_env,
|
||||
num_envs,
|
||||
action_space=None,
|
||||
observation_space=None):
|
||||
self.make_local_env = make_env
|
||||
self.num_envs = num_envs
|
||||
self.initialized = False
|
||||
self.action_space = action_space
|
||||
self.observation_space = observation_space
|
||||
|
||||
def _initialize_if_needed(self):
|
||||
if self.initialized:
|
||||
return
|
||||
|
||||
self.initialized = True
|
||||
|
||||
def make_remote_env(i):
|
||||
logger.info("Launching env {} in remote actor".format(i))
|
||||
return _RemoteEnv.remote(self.make_local_env, i)
|
||||
|
||||
_VectorizedGymEnv.__init__(self, make_remote_env, [], self.num_envs,
|
||||
self.action_space, self.observation_space)
|
||||
|
||||
for env in self.envs:
|
||||
assert isinstance(env, ray.actor.ActorHandle), env
|
||||
|
||||
@override(_VectorizedGymEnv)
|
||||
def vector_reset(self):
|
||||
self._initialize_if_needed()
|
||||
return ray.get([env.reset.remote() for env in self.envs])
|
||||
|
||||
@override(_VectorizedGymEnv)
|
||||
def reset_at(self, index):
|
||||
return ray.get(self.envs[index].reset.remote())
|
||||
|
||||
@override(_VectorizedGymEnv)
|
||||
def vector_step(self, actions):
|
||||
step_outs = ray.get(
|
||||
[env.step.remote(act) for env, act in zip(self.envs, actions)])
|
||||
|
||||
obs_batch, rew_batch, done_batch, info_batch = [], [], [], []
|
||||
for obs, rew, done, info in step_outs:
|
||||
obs_batch.append(obs)
|
||||
rew_batch.append(rew)
|
||||
done_batch.append(done)
|
||||
info_batch.append(info)
|
||||
return obs_batch, rew_batch, done_batch, info_batch
|
||||
|
|
|
@ -117,7 +117,8 @@ class PolicyEvaluator(EvaluatorInterface):
|
|||
callbacks=None,
|
||||
input_creator=lambda ioctx: ioctx.default_sampler_input(),
|
||||
input_evaluation_method=None,
|
||||
output_creator=lambda ioctx: NoopOutput()):
|
||||
output_creator=lambda ioctx: NoopOutput(),
|
||||
remote_worker_envs=False):
|
||||
"""Initialize a policy evaluator.
|
||||
|
||||
Arguments:
|
||||
|
@ -192,6 +193,10 @@ class PolicyEvaluator(EvaluatorInterface):
|
|||
use this data for evaluation only and never for learning.
|
||||
output_creator (func): Function that returns an OutputWriter object
|
||||
for saving generated experiences.
|
||||
remote_worker_envs (bool): If using num_envs > 1, whether to create
|
||||
those new envs in remote processes instead of in the current
|
||||
process. This adds overheads, but can make sense if your envs
|
||||
are very CPU intensive (e.g., for StarCraft).
|
||||
"""
|
||||
|
||||
if log_level:
|
||||
|
@ -250,7 +255,9 @@ class PolicyEvaluator(EvaluatorInterface):
|
|||
|
||||
def make_env(vector_index):
|
||||
return wrap(
|
||||
env_creator(env_context.with_vector_index(vector_index)))
|
||||
env_creator(
|
||||
env_context.copy_with_overrides(
|
||||
vector_index=vector_index, remote=remote_worker_envs)))
|
||||
|
||||
self.tf_sess = None
|
||||
policy_dict = _validate_and_canonicalize(policy_graph, self.env)
|
||||
|
@ -293,7 +300,10 @@ class PolicyEvaluator(EvaluatorInterface):
|
|||
|
||||
# Always use vector env for consistency even if num_envs = 1
|
||||
self.async_env = BaseEnv.to_base_env(
|
||||
self.env, make_env=make_env, num_envs=num_envs)
|
||||
self.env,
|
||||
make_env=make_env,
|
||||
num_envs=num_envs,
|
||||
remote_envs=remote_worker_envs)
|
||||
self.num_envs = num_envs
|
||||
|
||||
if self.batch_mode == "truncate_episodes":
|
||||
|
|
|
@ -71,6 +71,13 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
|||
--stop '{"training_iteration": 2}' \
|
||||
--config '{"kl_coeff": 1.0, "num_sgd_iter": 10, "lr": 1e-4, "sgd_minibatch_size": 64, "train_batch_size": 2000, "num_workers": 1, "use_gae": false, "batch_mode": "complete_episodes"}'
|
||||
|
||||
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
||||
python /ray/python/ray/rllib/train.py \
|
||||
--env CartPole-v1 \
|
||||
--run PPO \
|
||||
--stop '{"training_iteration": 2}' \
|
||||
--config '{"remote_worker_envs": true, "num_envs_per_worker": 2, "num_workers": 1, "train_batch_size": 100, "sgd_minibatch_size": 50}'
|
||||
|
||||
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
||||
python /ray/python/ray/rllib/train.py \
|
||||
--env Pendulum-v0 \
|
||||
|
|
Loading…
Add table
Reference in a new issue