mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[wingman -> rllib] Remote and entangled environments (#3968)
* added all our environment changes * fixed merge request comments and remote env * fixed remote check * moved remote_worker_envs to correct config section * lint * auto wrap impl * fix * fixed the tests
This commit is contained in:
parent
b3f72e8a75
commit
0e37ac6d1d
6 changed files with 170 additions and 23 deletions
|
@ -129,6 +129,10 @@ COMMON_CONFIG = {
|
||||||
"compress_observations": False,
|
"compress_observations": False,
|
||||||
# Drop metric batches from unresponsive workers after this many seconds
|
# Drop metric batches from unresponsive workers after this many seconds
|
||||||
"collect_metrics_timeout": 180,
|
"collect_metrics_timeout": 180,
|
||||||
|
# If using num_envs_per_worker > 1, whether to create those new envs in
|
||||||
|
# remote processes instead of in the same worker. This adds overheads, but
|
||||||
|
# can make sense if your envs are very CPU intensive (e.g., for StarCraft).
|
||||||
|
"remote_worker_envs": False,
|
||||||
|
|
||||||
# === Offline Datasets ===
|
# === Offline Datasets ===
|
||||||
# __sphinx_doc_input_begin__
|
# __sphinx_doc_input_begin__
|
||||||
|
@ -463,7 +467,9 @@ class Agent(Trainable):
|
||||||
"tf_session_args": self.
|
"tf_session_args": self.
|
||||||
config["local_evaluator_tf_session_args"]
|
config["local_evaluator_tf_session_args"]
|
||||||
}),
|
}),
|
||||||
extra_config or {}))
|
extra_config or {}),
|
||||||
|
remote_worker_envs=False,
|
||||||
|
)
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def make_remote_evaluators(self, env_creator, policy_graph, count):
|
def make_remote_evaluators(self, env_creator, policy_graph, count):
|
||||||
|
@ -476,9 +482,16 @@ class Agent(Trainable):
|
||||||
}
|
}
|
||||||
|
|
||||||
cls = PolicyEvaluator.as_remote(**remote_args).remote
|
cls = PolicyEvaluator.as_remote(**remote_args).remote
|
||||||
|
|
||||||
return [
|
return [
|
||||||
self._make_evaluator(cls, env_creator, policy_graph, i + 1,
|
self._make_evaluator(
|
||||||
self.config) for i in range(count)
|
cls,
|
||||||
|
env_creator,
|
||||||
|
policy_graph,
|
||||||
|
i + 1,
|
||||||
|
self.config,
|
||||||
|
remote_worker_envs=self.config["remote_worker_envs"])
|
||||||
|
for i in range(count)
|
||||||
]
|
]
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
|
@ -544,8 +557,13 @@ class Agent(Trainable):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"`input_evaluation` should not be set when input=sampler")
|
"`input_evaluation` should not be set when input=sampler")
|
||||||
|
|
||||||
def _make_evaluator(self, cls, env_creator, policy_graph, worker_index,
|
def _make_evaluator(self,
|
||||||
config):
|
cls,
|
||||||
|
env_creator,
|
||||||
|
policy_graph,
|
||||||
|
worker_index,
|
||||||
|
config,
|
||||||
|
remote_worker_envs=False):
|
||||||
def session_creator():
|
def session_creator():
|
||||||
logger.debug("Creating TF session {}".format(
|
logger.debug("Creating TF session {}".format(
|
||||||
config["tf_session_args"]))
|
config["tf_session_args"]))
|
||||||
|
@ -573,10 +591,10 @@ class Agent(Trainable):
|
||||||
compress_columns=config["output_compress_columns"]))
|
compress_columns=config["output_compress_columns"]))
|
||||||
else:
|
else:
|
||||||
output_creator = (lambda ioctx: JsonWriter(
|
output_creator = (lambda ioctx: JsonWriter(
|
||||||
config["output"],
|
config["output"],
|
||||||
ioctx,
|
ioctx,
|
||||||
max_file_size=config["output_max_file_size"],
|
max_file_size=config["output_max_file_size"],
|
||||||
compress_columns=config["output_compress_columns"]))
|
compress_columns=config["output_compress_columns"]))
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
env_creator,
|
env_creator,
|
||||||
|
@ -605,7 +623,8 @@ class Agent(Trainable):
|
||||||
callbacks=config["callbacks"],
|
callbacks=config["callbacks"],
|
||||||
input_creator=input_creator,
|
input_creator=input_creator,
|
||||||
input_evaluation_method=config["input_evaluation"],
|
input_evaluation_method=config["input_evaluation"],
|
||||||
output_creator=output_creator)
|
output_creator=output_creator,
|
||||||
|
remote_worker_envs=remote_worker_envs)
|
||||||
|
|
||||||
@override(Trainable)
|
@override(Trainable)
|
||||||
def _export_model(self, export_formats, export_dir):
|
def _export_model(self, export_formats, export_dir):
|
||||||
|
|
17
python/ray/rllib/env/base_env.py
vendored
17
python/ray/rllib/env/base_env.py
vendored
|
@ -66,10 +66,18 @@ class BaseEnv(object):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def to_base_env(env, make_env=None, num_envs=1):
|
def to_base_env(env, make_env=None, num_envs=1, remote_envs=False):
|
||||||
"""Wraps any env type as needed to expose the async interface."""
|
"""Wraps any env type as needed to expose the async interface."""
|
||||||
|
if remote_envs and num_envs == 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Remote envs only make sense to use if num_envs > 1 "
|
||||||
|
"(i.e. vectorization is enabled).")
|
||||||
if not isinstance(env, BaseEnv):
|
if not isinstance(env, BaseEnv):
|
||||||
if isinstance(env, MultiAgentEnv):
|
if isinstance(env, MultiAgentEnv):
|
||||||
|
if remote_envs:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Remote multiagent environments are not implemented")
|
||||||
|
|
||||||
env = _MultiAgentEnvToBaseEnv(
|
env = _MultiAgentEnvToBaseEnv(
|
||||||
make_env=make_env, existing_envs=[env], num_envs=num_envs)
|
make_env=make_env, existing_envs=[env], num_envs=num_envs)
|
||||||
elif isinstance(env, ExternalEnv):
|
elif isinstance(env, ExternalEnv):
|
||||||
|
@ -81,7 +89,12 @@ class BaseEnv(object):
|
||||||
env = _VectorEnvToBaseEnv(env)
|
env = _VectorEnvToBaseEnv(env)
|
||||||
else:
|
else:
|
||||||
env = VectorEnv.wrap(
|
env = VectorEnv.wrap(
|
||||||
make_env=make_env, existing_envs=[env], num_envs=num_envs)
|
make_env=make_env,
|
||||||
|
existing_envs=[env],
|
||||||
|
num_envs=num_envs,
|
||||||
|
remote_envs=remote_envs,
|
||||||
|
action_space=env.action_space,
|
||||||
|
observation_space=env.observation_space)
|
||||||
env = _VectorEnvToBaseEnv(env)
|
env = _VectorEnvToBaseEnv(env)
|
||||||
assert isinstance(env, BaseEnv)
|
assert isinstance(env, BaseEnv)
|
||||||
return env
|
return env
|
||||||
|
|
16
python/ray/rllib/env/env_context.py
vendored
16
python/ray/rllib/env/env_context.py
vendored
|
@ -20,13 +20,23 @@ class EnvContext(dict):
|
||||||
uniquely identifies the worker the env is created in.
|
uniquely identifies the worker the env is created in.
|
||||||
vector_index (int): When there are multiple envs per worker, this
|
vector_index (int): When there are multiple envs per worker, this
|
||||||
uniquely identifies the env index within the worker.
|
uniquely identifies the env index within the worker.
|
||||||
|
remote (bool): Whether environment should be remote or not.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, env_config, worker_index, vector_index=0):
|
def __init__(self, env_config, worker_index, vector_index=0, remote=False):
|
||||||
dict.__init__(self, env_config)
|
dict.__init__(self, env_config)
|
||||||
self.worker_index = worker_index
|
self.worker_index = worker_index
|
||||||
self.vector_index = vector_index
|
self.vector_index = vector_index
|
||||||
|
self.remote = remote
|
||||||
|
|
||||||
def with_vector_index(self, vector_index):
|
def copy_with_overrides(self,
|
||||||
|
env_config=None,
|
||||||
|
worker_index=None,
|
||||||
|
vector_index=None,
|
||||||
|
remote=None):
|
||||||
return EnvContext(
|
return EnvContext(
|
||||||
self, worker_index=self.worker_index, vector_index=vector_index)
|
env_config if env_config is not None else self,
|
||||||
|
worker_index if worker_index is not None else self.worker_index,
|
||||||
|
vector_index if vector_index is not None else self.vector_index,
|
||||||
|
remote if remote is not None else self.remote,
|
||||||
|
)
|
||||||
|
|
98
python/ray/rllib/env/vector_env.py
vendored
98
python/ray/rllib/env/vector_env.py
vendored
|
@ -2,8 +2,13 @@ from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import ray
|
||||||
from ray.rllib.utils.annotations import override, PublicAPI
|
from ray.rllib.utils.annotations import override, PublicAPI
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
class VectorEnv(object):
|
class VectorEnv(object):
|
||||||
|
@ -18,8 +23,17 @@ class VectorEnv(object):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def wrap(make_env=None, existing_envs=None, num_envs=1):
|
def wrap(make_env=None,
|
||||||
return _VectorizedGymEnv(make_env, existing_envs or [], num_envs)
|
existing_envs=None,
|
||||||
|
num_envs=1,
|
||||||
|
remote_envs=False,
|
||||||
|
action_space=None,
|
||||||
|
observation_space=None):
|
||||||
|
if remote_envs:
|
||||||
|
return _RemoteVectorizedGymEnv(make_env, num_envs, action_space,
|
||||||
|
observation_space)
|
||||||
|
return _VectorizedGymEnv(make_env, existing_envs or [], num_envs,
|
||||||
|
action_space, observation_space)
|
||||||
|
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
def vector_reset(self):
|
def vector_reset(self):
|
||||||
|
@ -70,14 +84,20 @@ class _VectorizedGymEnv(VectorEnv):
|
||||||
num_envs (int): Desired num gym envs to keep total.
|
num_envs (int): Desired num gym envs to keep total.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, make_env, existing_envs, num_envs):
|
def __init__(self,
|
||||||
|
make_env,
|
||||||
|
existing_envs,
|
||||||
|
num_envs,
|
||||||
|
action_space=None,
|
||||||
|
observation_space=None):
|
||||||
self.make_env = make_env
|
self.make_env = make_env
|
||||||
self.envs = existing_envs
|
self.envs = existing_envs
|
||||||
self.num_envs = num_envs
|
self.num_envs = num_envs
|
||||||
while len(self.envs) < self.num_envs:
|
while len(self.envs) < self.num_envs:
|
||||||
self.envs.append(self.make_env(len(self.envs)))
|
self.envs.append(self.make_env(len(self.envs)))
|
||||||
self.action_space = self.envs[0].action_space
|
self.action_space = action_space or self.envs[0].action_space
|
||||||
self.observation_space = self.envs[0].observation_space
|
self.observation_space = observation_space or \
|
||||||
|
self.envs[0].observation_space
|
||||||
|
|
||||||
@override(VectorEnv)
|
@override(VectorEnv)
|
||||||
def vector_reset(self):
|
def vector_reset(self):
|
||||||
|
@ -101,3 +121,71 @@ class _VectorizedGymEnv(VectorEnv):
|
||||||
@override(VectorEnv)
|
@override(VectorEnv)
|
||||||
def get_unwrapped(self):
|
def get_unwrapped(self):
|
||||||
return self.envs
|
return self.envs
|
||||||
|
|
||||||
|
|
||||||
|
@ray.remote(num_cpus=0)
|
||||||
|
class _RemoteEnv(object):
|
||||||
|
"""Wrapper class for making a gym env a remote actor."""
|
||||||
|
|
||||||
|
def __init__(self, make_env, i):
|
||||||
|
self.env = make_env(i)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
return self.env.reset()
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
return self.env.step(action)
|
||||||
|
|
||||||
|
|
||||||
|
class _RemoteVectorizedGymEnv(_VectorizedGymEnv):
|
||||||
|
"""Internal wrapper for gym envs to implement VectorEnv as remote workers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
make_env,
|
||||||
|
num_envs,
|
||||||
|
action_space=None,
|
||||||
|
observation_space=None):
|
||||||
|
self.make_local_env = make_env
|
||||||
|
self.num_envs = num_envs
|
||||||
|
self.initialized = False
|
||||||
|
self.action_space = action_space
|
||||||
|
self.observation_space = observation_space
|
||||||
|
|
||||||
|
def _initialize_if_needed(self):
|
||||||
|
if self.initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.initialized = True
|
||||||
|
|
||||||
|
def make_remote_env(i):
|
||||||
|
logger.info("Launching env {} in remote actor".format(i))
|
||||||
|
return _RemoteEnv.remote(self.make_local_env, i)
|
||||||
|
|
||||||
|
_VectorizedGymEnv.__init__(self, make_remote_env, [], self.num_envs,
|
||||||
|
self.action_space, self.observation_space)
|
||||||
|
|
||||||
|
for env in self.envs:
|
||||||
|
assert isinstance(env, ray.actor.ActorHandle), env
|
||||||
|
|
||||||
|
@override(_VectorizedGymEnv)
|
||||||
|
def vector_reset(self):
|
||||||
|
self._initialize_if_needed()
|
||||||
|
return ray.get([env.reset.remote() for env in self.envs])
|
||||||
|
|
||||||
|
@override(_VectorizedGymEnv)
|
||||||
|
def reset_at(self, index):
|
||||||
|
return ray.get(self.envs[index].reset.remote())
|
||||||
|
|
||||||
|
@override(_VectorizedGymEnv)
|
||||||
|
def vector_step(self, actions):
|
||||||
|
step_outs = ray.get(
|
||||||
|
[env.step.remote(act) for env, act in zip(self.envs, actions)])
|
||||||
|
|
||||||
|
obs_batch, rew_batch, done_batch, info_batch = [], [], [], []
|
||||||
|
for obs, rew, done, info in step_outs:
|
||||||
|
obs_batch.append(obs)
|
||||||
|
rew_batch.append(rew)
|
||||||
|
done_batch.append(done)
|
||||||
|
info_batch.append(info)
|
||||||
|
return obs_batch, rew_batch, done_batch, info_batch
|
||||||
|
|
|
@ -117,7 +117,8 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||||
callbacks=None,
|
callbacks=None,
|
||||||
input_creator=lambda ioctx: ioctx.default_sampler_input(),
|
input_creator=lambda ioctx: ioctx.default_sampler_input(),
|
||||||
input_evaluation_method=None,
|
input_evaluation_method=None,
|
||||||
output_creator=lambda ioctx: NoopOutput()):
|
output_creator=lambda ioctx: NoopOutput(),
|
||||||
|
remote_worker_envs=False):
|
||||||
"""Initialize a policy evaluator.
|
"""Initialize a policy evaluator.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
|
@ -192,6 +193,10 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||||
use this data for evaluation only and never for learning.
|
use this data for evaluation only and never for learning.
|
||||||
output_creator (func): Function that returns an OutputWriter object
|
output_creator (func): Function that returns an OutputWriter object
|
||||||
for saving generated experiences.
|
for saving generated experiences.
|
||||||
|
remote_worker_envs (bool): If using num_envs > 1, whether to create
|
||||||
|
those new envs in remote processes instead of in the current
|
||||||
|
process. This adds overheads, but can make sense if your envs
|
||||||
|
are very CPU intensive (e.g., for StarCraft).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if log_level:
|
if log_level:
|
||||||
|
@ -250,7 +255,9 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||||
|
|
||||||
def make_env(vector_index):
|
def make_env(vector_index):
|
||||||
return wrap(
|
return wrap(
|
||||||
env_creator(env_context.with_vector_index(vector_index)))
|
env_creator(
|
||||||
|
env_context.copy_with_overrides(
|
||||||
|
vector_index=vector_index, remote=remote_worker_envs)))
|
||||||
|
|
||||||
self.tf_sess = None
|
self.tf_sess = None
|
||||||
policy_dict = _validate_and_canonicalize(policy_graph, self.env)
|
policy_dict = _validate_and_canonicalize(policy_graph, self.env)
|
||||||
|
@ -293,7 +300,10 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||||
|
|
||||||
# Always use vector env for consistency even if num_envs = 1
|
# Always use vector env for consistency even if num_envs = 1
|
||||||
self.async_env = BaseEnv.to_base_env(
|
self.async_env = BaseEnv.to_base_env(
|
||||||
self.env, make_env=make_env, num_envs=num_envs)
|
self.env,
|
||||||
|
make_env=make_env,
|
||||||
|
num_envs=num_envs,
|
||||||
|
remote_envs=remote_worker_envs)
|
||||||
self.num_envs = num_envs
|
self.num_envs = num_envs
|
||||||
|
|
||||||
if self.batch_mode == "truncate_episodes":
|
if self.batch_mode == "truncate_episodes":
|
||||||
|
|
|
@ -71,6 +71,13 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
||||||
--stop '{"training_iteration": 2}' \
|
--stop '{"training_iteration": 2}' \
|
||||||
--config '{"kl_coeff": 1.0, "num_sgd_iter": 10, "lr": 1e-4, "sgd_minibatch_size": 64, "train_batch_size": 2000, "num_workers": 1, "use_gae": false, "batch_mode": "complete_episodes"}'
|
--config '{"kl_coeff": 1.0, "num_sgd_iter": 10, "lr": 1e-4, "sgd_minibatch_size": 64, "train_batch_size": 2000, "num_workers": 1, "use_gae": false, "batch_mode": "complete_episodes"}'
|
||||||
|
|
||||||
|
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
||||||
|
python /ray/python/ray/rllib/train.py \
|
||||||
|
--env CartPole-v1 \
|
||||||
|
--run PPO \
|
||||||
|
--stop '{"training_iteration": 2}' \
|
||||||
|
--config '{"remote_worker_envs": true, "num_envs_per_worker": 2, "num_workers": 1, "train_batch_size": 100, "sgd_minibatch_size": 50}'
|
||||||
|
|
||||||
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
||||||
python /ray/python/ray/rllib/train.py \
|
python /ray/python/ray/rllib/train.py \
|
||||||
--env Pendulum-v0 \
|
--env Pendulum-v0 \
|
||||||
|
|
Loading…
Add table
Reference in a new issue