[RLlib] Do not create env on driver iff num_workers > 0. (#11307)

This commit is contained in:
Sven Mika 2020-10-15 18:21:30 +02:00 committed by GitHub
parent 60a4be4a59
commit 414041c6dd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 308 additions and 189 deletions

View file

@ -28,6 +28,8 @@ DEFAULT_CONFIG = with_common_config({
"kl_coeff": 0.0005,
# Size of batches collected from each worker
"rollout_fragment_length": 200,
# Do create an actual env on the local worker (worker-idx=0).
"create_env_on_driver": True,
# Stepsize of SGD
"lr": 1e-3,
# Share layers for value function
@ -209,15 +211,18 @@ def get_policy_class(config):
def validate_config(config):
if config["inner_adaptation_steps"] <= 0:
raise ValueError("Inner Adaptation Steps must be >=1.")
raise ValueError("Inner Adaptation Steps must be >=1!")
if config["maml_optimizer_steps"] <= 0:
raise ValueError("PPO steps for meta-update needs to be >=0")
raise ValueError("PPO steps for meta-update needs to be >=0!")
if config["entropy_coeff"] < 0:
raise ValueError("entropy_coeff must be >=0")
raise ValueError("`entropy_coeff` must be >=0.0!")
if config["batch_mode"] != "complete_episodes":
raise ValueError("truncate_episodes not supported")
raise ValueError("`batch_mode`=truncate_episodes not supported!")
if config["num_workers"] <= 0:
raise ValueError("Must have at least 1 worker/task.")
raise ValueError("Must have at least 1 worker/task!")
if config["create_env_on_driver"] is False:
raise ValueError("Must have an actual Env created on the driver "
"(local) worker! Set `create_env_on_driver` to True.")
MAMLTrainer = build_trainer(

View file

@ -51,6 +51,8 @@ DEFAULT_CONFIG = with_common_config({
"kl_coeff": 0.0005,
# Size of batches collected from each worker.
"rollout_fragment_length": 200,
# Do create an actual env on the local worker (worker-idx=0).
"create_env_on_driver": True,
# Step size of SGD.
"lr": 1e-3,
# Share layers for value function.
@ -414,15 +416,18 @@ def validate_config(config):
"`framework=torch`.")
config["framework"] = "torch"
if config["inner_adaptation_steps"] <= 0:
raise ValueError("Inner Adaptation Steps must be >=1.")
raise ValueError("Inner adaptation steps must be >=1!")
if config["maml_optimizer_steps"] <= 0:
raise ValueError("PPO steps for meta-update needs to be >=0")
raise ValueError("PPO steps for meta-update needs to be >=0!")
if config["entropy_coeff"] < 0:
raise ValueError("`entropy_coeff` must be >=0.")
raise ValueError("`entropy_coeff` must be >=0.0!")
if config["batch_mode"] != "complete_episodes":
raise ValueError("`batch_mode=truncate_episodes` not supported.")
raise ValueError("`batch_mode=truncate_episodes` not supported!")
if config["num_workers"] <= 0:
raise ValueError("Must have at least 1 worker/task.")
if config["create_env_on_driver"] is False:
raise ValueError("Must have an actual Env created on the driver "
"(local) worker! Set `create_env_on_driver` to True.")
def validate_env(env: EnvType, env_context: EnvContext):

View file

@ -53,6 +53,11 @@ COMMON_CONFIG: TrainerConfigDict = {
# model inference batching, which can improve performance for inference
# bottlenecked workloads.
"num_envs_per_worker": 1,
# When `num_workers` > 0, the driver (local_worker; worker-idx=0) does not
# need an environment. This is because it doesn't have to sample (done by
# remote_workers; worker_indices > 0) nor evaluate (done by evaluation
# workers; see below).
"create_env_on_driver": False,
# Divide episodes into fragments of this many steps each during rollouts.
# Sample batches of this size are collected from rollout workers and
# combined into a larger batch of `train_batch_size` for learning.
@ -308,16 +313,17 @@ COMMON_CONFIG: TrainerConfigDict = {
# === Offline Datasets ===
# Specify how to generate experiences:
# - "sampler": generate experiences via online simulation (default)
# - a local directory or file glob expression (e.g., "/tmp/*.json")
# - a list of individual file paths/URIs (e.g., ["/tmp/1.json",
# "s3://bucket/2.json"])
# - a dict with string keys and sampling probabilities as values (e.g.,
# - "sampler": Generate experiences via online (env) simulation (default).
# - A local directory or file glob expression (e.g., "/tmp/*.json").
# - A list of individual file paths/URIs (e.g., ["/tmp/1.json",
# "s3://bucket/2.json"]).
# - A dict with string keys and sampling probabilities as values (e.g.,
# {"sampler": 0.4, "/tmp/*.json": 0.4, "s3://bucket/expert.json": 0.2}).
# - a function that returns a rllib.offline.InputReader
# - A callable that returns a ray.rllib.offline.InputReader.
"input": "sampler",
# Specify how to evaluate the current policy. This only has an effect when
# reading offline experiences. Available options:
# reading offline experiences ("input" is not "sampler").
# Available options:
# - "wis": the weighted step-wise importance sampling estimator.
# - "is": the step-wise importance sampling estimator.
# - "simulation": run the environment in the background, but use
@ -557,12 +563,12 @@ class Trainer(Trainable):
# A class specifier.
elif "." in env:
self.env_creator = \
lambda env_config: from_config(env, env_config)
lambda env_context: from_config(env, env_context)
# Try gym.
else:
import gym # soft dependency
self.env_creator = \
lambda env_config: gym.make(env, **env_config)
lambda env_context: gym.make(env, **env_context)
else:
self.env_creator = lambda env_config: None

View file

@ -32,6 +32,7 @@ from ray.rllib.policy.torch_policy import TorchPolicy
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.debug import summarize
from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.filter import get_filter, Filter
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.sgd import do_minibatch_sgd
@ -84,7 +85,7 @@ class RolloutWorker(ParallelIteratorWorker):
>>> # Create a rollout worker and using it to collect experiences.
>>> worker = RolloutWorker(
... env_creator=lambda _: gym.make("CartPole-v0"),
... policy=PGTFPolicy)
... policy_spec=PGTFPolicy)
>>> print(worker.sample())
SampleBatch({
"obs": [[...]], "actions": [[...]], "rewards": [[...]],
@ -93,7 +94,7 @@ class RolloutWorker(ParallelIteratorWorker):
>>> # Creating a multi-agent rollout worker
>>> worker = RolloutWorker(
... env_creator=lambda _: MultiAgentTrafficGrid(num_cars=25),
... policies={
... policy_spec={
... # Use an ensemble of two policies for car agents
... "car_policy1":
... (PGTFPolicy, Box(...), Discrete(...), {"gamma": 0.99}),
@ -135,9 +136,10 @@ class RolloutWorker(ParallelIteratorWorker):
env_creator: Callable[[EnvContext], EnvType],
validate_env: Optional[Callable[[EnvType, EnvContext],
None]] = None,
policy: Union[type, Dict[str, Tuple[Optional[
type], gym.Space, gym.Space, PartialTrainerConfigDict]]],
policy_mapping_fn: Callable[[AgentID], PolicyID] = None,
policy_spec: Union[type, Dict[
str, Tuple[Optional[type], gym.Space, gym.Space,
PartialTrainerConfigDict]]] = None,
policy_mapping_fn: Optional[Callable[[AgentID], PolicyID]] = None,
policies_to_train: Optional[List[PolicyID]] = None,
tf_session_creator: Optional[Callable[[], "tf1.Session"]] = None,
rollout_fragment_length: int = 100,
@ -172,7 +174,13 @@ class RolloutWorker(ParallelIteratorWorker):
no_done_at_end: bool = False,
seed: int = None,
extra_python_environs: dict = None,
fake_sampler: bool = False):
fake_sampler: bool = False,
spaces: Optional[Dict[PolicyID, Tuple[gym.spaces.Space,
gym.spaces.Space]]] = None,
policy: Union[type, Dict[
str, Tuple[Optional[type], gym.Space, gym.Space,
PartialTrainerConfigDict]]] = None,
):
"""Initialize a rollout worker.
Args:
@ -181,16 +189,19 @@ class RolloutWorker(ParallelIteratorWorker):
validate_env (Optional[Callable[[EnvType, EnvContext], None]]):
Optional callable to validate the generated environment (only
on worker=0).
policy (Union[type, Dict[str, Tuple[Optional[type], gym.Space,
policy_spec (Union[type, Dict[str, Tuple[Type[Policy], gym.Space,
gym.Space, PartialTrainerConfigDict]]]): Either a Policy class
or a dict of policy id strings to
(Policy (None for default), obs_space, action_space,
config)-tuples. If a dict is specified, then we are in
multi-agent mode and a policy_mapping_fn should also be set.
policy_mapping_fn (Callable[[AgentID], PolicyID]): A function that
maps agent ids to policy ids in multi-agent mode. This function
will be called each time a new agent appears in an episode, to
bind that agent to a policy for the duration of the episode.
(Policy class, obs_space, action_space, config)-tuples. If a
dict is specified, then we are in multi-agent mode and a
policy_mapping_fn can also be set (if not, will map all agents
to DEFAULT_POLICY_ID).
policy_mapping_fn (Optional[Callable[[AgentID], PolicyID]]): A
callable that maps agent ids to policy ids in multi-agent mode.
This function will be called each time a new agent appears in
an episode, to bind that agent to a policy for the duration of
the episode. If not provided, will map all agents to
DEFAULT_POLICY_ID.
policies_to_train (Optional[List[PolicyID]]): Optional list of
policies to train, or None for all policies.
tf_session_creator (Optional[Callable[[], tf1.Session]]): A
@ -236,7 +247,7 @@ class RolloutWorker(ParallelIteratorWorker):
policy model.
policy_config (TrainerConfigDict): Config to pass to the policy.
In the multi-agent case, this config will be merged with the
per-policy configs specified by `policy`.
per-policy configs specified by `policy_spec`.
worker_index (int): For remote workers, this should be set to a
non-zero and unique value. This index is passed to created envs
through EnvContext so that envs can be configured per worker.
@ -277,7 +288,19 @@ class RolloutWorker(ParallelIteratorWorker):
extra_python_environs (dict): Extra python environments need to
be set.
fake_sampler (bool): Use a fake (inf speed) sampler for testing.
spaces (Optional[Dict[PolicyID, Tuple[gym.spaces.Space,
gym.spaces.Space]]]): An optional space dict mapping policy IDs
to (obs_space, action_space)-tuples. This is used in case no
Env is created on this RolloutWorker.
policy: Obsoleted arg. Use `policy_spec` instead.
"""
# Deprecated arg.
if policy is not None:
deprecation_warning("policy", "policy_spec", error=False)
policy_spec = policy
assert policy_spec is not None, "Must provide `policy_spec` when " \
"creating RolloutWorker!"
self._original_kwargs: dict = locals().copy()
del self._original_kwargs["self"]
@ -334,6 +357,12 @@ class RolloutWorker(ParallelIteratorWorker):
self.global_vars: dict = None
self.fake_sampler: bool = fake_sampler
# No Env will be used in this particular worker (not needed).
if worker_index == 0 and num_workers > 0 and \
policy_config["create_env_on_driver"] is False:
self.env = None
# Create an env for this worker.
else:
self.env = _validate_env(env_creator(env_context))
if validate_env is not None:
validate_env(self.env, self.env_context)
@ -385,7 +414,8 @@ class RolloutWorker(ParallelIteratorWorker):
self.make_env_fn = make_env
self.tf_sess = None
policy_dict = _validate_and_canonicalize(policy, self.env)
policy_dict = _validate_and_canonicalize(
policy_spec, self.env, spaces=spaces)
self.policies_to_train: List[PolicyID] = policies_to_train or list(
policy_dict.keys())
self.policy_map: Dict[PolicyID, Policy] = None
@ -446,7 +476,7 @@ class RolloutWorker(ParallelIteratorWorker):
self.multiagent: bool = set(
self.policy_map.keys()) != {DEFAULT_POLICY_ID}
if self.multiagent:
if self.multiagent and self.env is not None:
if not ((isinstance(self.env, MultiAgentEnv)
or isinstance(self.env, ExternalMultiAgentEnv))
or isinstance(self.env, BaseEnv)):
@ -466,7 +496,9 @@ class RolloutWorker(ParallelIteratorWorker):
self.num_envs: int = num_envs
if "custom_vector_env" in policy_config:
if self.env is None:
self.async_env = None
elif "custom_vector_env" in policy_config:
custom_vec_wrapper = policy_config["custom_vector_env"]
self.async_env = custom_vec_wrapper(self.env)
else:
@ -494,7 +526,7 @@ class RolloutWorker(ParallelIteratorWorker):
self.io_context: IOContext = IOContext(log_dir, policy_config,
worker_index, self)
self.reward_estimators: OffPolicyEstimator = []
self.reward_estimators: List[OffPolicyEstimator] = []
for method in input_evaluation:
if method == "simulation":
logger.warning(
@ -512,7 +544,9 @@ class RolloutWorker(ParallelIteratorWorker):
raise ValueError(
"Unknown evaluation method: {}".format(method))
if sample_async:
if self.env is None:
self.sampler = None
elif sample_async:
self.sampler = AsyncSampler(
worker=self,
env=self.async_env,
@ -816,7 +850,12 @@ class RolloutWorker(ParallelIteratorWorker):
def get_metrics(self) -> List[Union[RolloutMetrics, OffPolicyEstimate]]:
"""Returns a list of new RolloutMetric objects from evaluation."""
# Get metrics from sampler (if any).
if self.sampler is not None:
out = self.sampler.get_metrics()
else:
out = []
# Get metrics from our reward-estimators (if any).
for m in self.reward_estimators:
out.extend(m.get_metrics())
return out
@ -825,6 +864,9 @@ class RolloutWorker(ParallelIteratorWorker):
def foreach_env(self, func: Callable[[BaseEnv], T]) -> List[T]:
"""Apply the given function to each underlying env instance."""
if self.async_env is None:
return []
envs = self.async_env.get_unwrapped()
if not envs:
return [func(self.async_env)]
@ -957,6 +999,7 @@ class RolloutWorker(ParallelIteratorWorker):
@DeveloperAPI
def stop(self) -> None:
if self.env:
self.async_env.stop()
@DeveloperAPI
@ -1054,8 +1097,12 @@ class RolloutWorker(ParallelIteratorWorker):
self.sampler.shutdown = True
def _validate_and_canonicalize(policy: Policy,
env: EnvType) -> MultiAgentPolicyConfigDict:
def _validate_and_canonicalize(
policy: Union[Type[Policy], MultiAgentPolicyConfigDict],
env: Optional[EnvType] = None,
spaces: Optional[Dict[PolicyID, Tuple[
gym.spaces.Space, gym.spaces.Space]]] = None) -> \
MultiAgentPolicyConfigDict:
if isinstance(policy, dict):
_validate_multiagent_config(policy)
return policy
@ -1067,14 +1114,20 @@ def _validate_and_canonicalize(policy: Policy,
raise ValueError(
"MultiAgentEnv must have observation_space defined if run "
"in a single-agent configuration.")
if env is not None:
return {
DEFAULT_POLICY_ID: (policy, env.observation_space,
env.action_space, {})
}
else:
return {
DEFAULT_POLICY_ID: (policy, spaces[DEFAULT_POLICY_ID][0],
spaces[DEFAULT_POLICY_ID][1], {})
}
def _validate_multiagent_config(policy: MultiAgentPolicyConfigDict,
allow_none_graph: bool = False):
allow_none_graph: bool = False) -> None:
for k, v in policy.items():
if not isinstance(k, str):
raise ValueError("policy keys must be strs, got {}".format(

View file

@ -33,7 +33,7 @@ class TestTrajectoryViewAPI(unittest.TestCase):
view_req_model = policy.model.inference_view_requirements
view_req_policy = policy.view_requirements
assert len(view_req_model) == 1, view_req_model
assert len(view_req_policy) == 11, view_req_policy
assert len(view_req_policy) == 12, view_req_policy
for key in [
SampleBatch.OBS, SampleBatch.ACTIONS, SampleBatch.REWARDS,
SampleBatch.DONES, SampleBatch.NEXT_OBS,
@ -64,8 +64,8 @@ class TestTrajectoryViewAPI(unittest.TestCase):
policy = trainer.get_policy()
view_req_model = policy.model.inference_view_requirements
view_req_policy = policy.view_requirements
assert len(view_req_model) == 7, view_req_model
assert len(view_req_policy) == 17, view_req_policy
assert len(view_req_model) == 5, view_req_model
assert len(view_req_policy) == 18, view_req_policy
for key in [
SampleBatch.OBS, SampleBatch.ACTIONS, SampleBatch.REWARDS,
SampleBatch.DONES, SampleBatch.NEXT_OBS,
@ -220,7 +220,7 @@ class TestTrajectoryViewAPI(unittest.TestCase):
env_creator=lambda _: MultiAgentDebugCounterEnv({"num_agents": 4}),
policy_config=dict(config, **{"_use_trajectory_view_api": True}),
rollout_fragment_length=rollout_fragment_length,
policy=policies,
policy_spec=policies,
policy_mapping_fn=policy_fn,
num_envs=1,
)
@ -228,7 +228,7 @@ class TestTrajectoryViewAPI(unittest.TestCase):
env_creator=lambda _: MultiAgentDebugCounterEnv({"num_agents": 4}),
policy_config=dict(config, **{"_use_trajectory_view_api": False}),
rollout_fragment_length=rollout_fragment_length,
policy=policies,
policy_spec=policies,
policy_mapping_fn=policy_fn,
num_envs=1,
)

View file

@ -1,6 +1,7 @@
import gym
import logging
from types import FunctionType
from typing import Callable, List, Optional, Type, TypeVar, Union
from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
import ray
from ray.rllib.utils.annotations import DeveloperAPI
@ -72,14 +73,27 @@ class WorkerSet:
self._remote_workers = []
self.add_workers(num_workers)
# If num_workers > 0, get the action_spaces and observation_spaces
# to not be forced to create an Env on the driver.
if self._remote_workers:
remote_spaces = ray.get(self.remote_workers(
)[0].foreach_policy.remote(
lambda p, pid: (pid, p.observation_space, p.action_space)))
spaces = {e[0]: (e[1], e[2]) for e in remote_spaces}
else:
spaces = None
# Always create a local worker.
self._local_worker = self._make_worker(
cls=RolloutWorker,
env_creator=env_creator,
validate_env=validate_env,
policy=self._policy_class,
policy_cls=self._policy_class,
worker_index=0,
config=self._local_config)
num_workers=num_workers,
config=self._local_config,
spaces=spaces,
)
def local_worker(self) -> RolloutWorker:
"""Return the local rollout worker."""
@ -118,8 +132,9 @@ class WorkerSet:
cls=cls,
env_creator=self._env_creator,
validate_env=None,
policy=self._policy_class,
policy_cls=self._policy_class,
worker_index=i + 1,
num_workers=num_workers,
config=self._remote_config) for i in range(num_workers)
])
@ -217,11 +232,18 @@ class WorkerSet:
return workers
def _make_worker(
self, *, cls: Callable,
self,
*,
cls: Callable,
env_creator: Callable[[EnvContext], EnvType],
validate_env: Optional[Callable[[EnvType], None]],
policy: Type[Policy], worker_index: int,
config: TrainerConfigDict) -> Union[RolloutWorker, "ActorHandle"]:
policy_cls: Type[Policy],
worker_index: int,
num_workers: int,
config: TrainerConfigDict,
spaces: Optional[Dict[PolicyID, Tuple[gym.spaces.Space,
gym.spaces.Space]]] = None,
) -> Union[RolloutWorker, "ActorHandle"]:
def session_creator():
logger.debug("Creating TF session {}".format(
config["tf_session_args"]))
@ -263,14 +285,20 @@ class WorkerSet:
else:
input_evaluation = config["input_evaluation"]
# Fill in the default policy if 'None' is specified in multiagent.
# Fill in the default policy_cls if 'None' is specified in multiagent.
if config["multiagent"]["policies"]:
tmp = config["multiagent"]["policies"]
_validate_multiagent_config(tmp, allow_none_graph=True)
# TODO: (sven) Allow for setting observation and action spaces to
# None as well, in which case, spaces are taken from env.
# It's tedious to have to provide these in a multi-agent config.
for k, v in tmp.items():
if v[0] is None:
tmp[k] = (policy, v[1], v[2], v[3])
policy = tmp
tmp[k] = (policy_cls, v[1], v[2], v[3])
policy_spec = tmp
# Otherwise, policy spec is simply the policy class itself.
else:
policy_spec = policy_cls
if worker_index == 0:
extra_python_environs = config.get(
@ -282,7 +310,7 @@ class WorkerSet:
worker = cls(
env_creator=env_creator,
validate_env=validate_env,
policy=policy,
policy_spec=policy_spec,
policy_mapping_fn=config["multiagent"]["policy_mapping_fn"],
policies_to_train=config["multiagent"]["policies_to_train"],
tf_session_creator=(session_creator
@ -302,7 +330,7 @@ class WorkerSet:
model_config=config["model"],
policy_config=config,
worker_index=worker_index,
num_workers=config["num_workers"],
num_workers=num_workers,
monitor_path=self._logdir if config["monitor"] else None,
log_dir=self._logdir,
log_level=config["log_level"],
@ -317,6 +345,8 @@ class WorkerSet:
seed=(config["seed"] + worker_index)
if config["seed"] is not None else None,
fake_sampler=config["fake_sampler"],
extra_python_environs=extra_python_environs)
extra_python_environs=extra_python_environs,
spaces=spaces,
)
return worker

View file

@ -267,6 +267,9 @@ def run(args, parser):
if "num_workers" in config:
config["num_workers"] = min(2, config["num_workers"])
# Make sure worker 0 has an Env.
config["create_env_on_driver"] = True
# Merge with `evaluation_config` (first try from command line, then from
# pkl file).
evaluation_config = copy.deepcopy(

View file

@ -28,12 +28,12 @@ def iter_list(values):
def make_workers(n):
local = RolloutWorker(
env_creator=lambda _: gym.make("CartPole-v0"),
policy=PPOTFPolicy,
policy_spec=PPOTFPolicy,
rollout_fragment_length=100)
remotes = [
RolloutWorker.as_remote().remote(
env_creator=lambda _: gym.make("CartPole-v0"),
policy=PPOTFPolicy,
policy_spec=PPOTFPolicy,
rollout_fragment_length=100) for _ in range(n)
]
workers = WorkerSet._from_existing(local, remotes)

View file

@ -126,7 +126,7 @@ class TestExternalEnv(unittest.TestCase):
def test_external_env_complete_episodes(self):
ev = RolloutWorker(
env_creator=lambda _: SimpleServing(MockEnv(25)),
policy=MockPolicy,
policy_spec=MockPolicy,
rollout_fragment_length=40,
batch_mode="complete_episodes")
for _ in range(3):
@ -136,7 +136,7 @@ class TestExternalEnv(unittest.TestCase):
def test_external_env_truncate_episodes(self):
ev = RolloutWorker(
env_creator=lambda _: SimpleServing(MockEnv(25)),
policy=MockPolicy,
policy_spec=MockPolicy,
rollout_fragment_length=40,
batch_mode="truncate_episodes")
for _ in range(3):
@ -146,7 +146,7 @@ class TestExternalEnv(unittest.TestCase):
def test_external_env_off_policy(self):
ev = RolloutWorker(
env_creator=lambda _: SimpleOffPolicyServing(MockEnv(25), 42),
policy=MockPolicy,
policy_spec=MockPolicy,
rollout_fragment_length=40,
batch_mode="complete_episodes")
for _ in range(3):
@ -158,7 +158,7 @@ class TestExternalEnv(unittest.TestCase):
def test_external_env_bad_actions(self):
ev = RolloutWorker(
env_creator=lambda _: SimpleServing(MockEnv(25)),
policy=BadPolicy,
policy_spec=BadPolicy,
sample_async=True,
rollout_fragment_length=40,
batch_mode="truncate_episodes")
@ -226,7 +226,7 @@ class TestExternalEnv(unittest.TestCase):
def test_external_env_horizon_not_supported(self):
ev = RolloutWorker(
env_creator=lambda _: SimpleServing(MockEnv(25)),
policy=MockPolicy,
policy_spec=MockPolicy,
episode_horizon=20,
rollout_fragment_length=10,
batch_mode="complete_episodes")

View file

@ -25,7 +25,7 @@ class TestExternalMultiAgentEnv(unittest.TestCase):
agents = 4
ev = RolloutWorker(
env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)),
policy=MockPolicy,
policy_spec=MockPolicy,
rollout_fragment_length=40,
batch_mode="complete_episodes")
for _ in range(3):
@ -37,7 +37,7 @@ class TestExternalMultiAgentEnv(unittest.TestCase):
agents = 4
ev = RolloutWorker(
env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)),
policy=MockPolicy,
policy_spec=MockPolicy,
rollout_fragment_length=40,
batch_mode="truncate_episodes")
for _ in range(3):
@ -51,7 +51,7 @@ class TestExternalMultiAgentEnv(unittest.TestCase):
obs_space = gym.spaces.Discrete(2)
ev = RolloutWorker(
env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)),
policy={
policy_spec={
"p0": (MockPolicy, obs_space, act_space, {}),
"p1": (MockPolicy, obs_space, act_space, {}),
},

View file

@ -172,7 +172,7 @@ class TestMultiAgentEnv(unittest.TestCase):
obs_space = gym.spaces.Discrete(2)
ev = RolloutWorker(
env_creator=lambda _: BasicMultiAgent(5),
policy={
policy_spec={
"p0": (MockPolicy, obs_space, act_space, {}),
"p1": (MockPolicy, obs_space, act_space, {}),
},
@ -192,7 +192,7 @@ class TestMultiAgentEnv(unittest.TestCase):
obs_space = gym.spaces.Discrete(2)
ev = RolloutWorker(
env_creator=lambda _: BasicMultiAgent(5),
policy={
policy_spec={
"p0": (MockPolicy, obs_space, act_space, {}),
"p1": (MockPolicy, obs_space, act_space, {}),
},
@ -211,7 +211,7 @@ class TestMultiAgentEnv(unittest.TestCase):
obs_space = gym.spaces.Discrete(2)
ev = RolloutWorker(
env_creator=lambda _: BasicMultiAgent(5),
policy={
policy_spec={
"p0": (MockPolicy, obs_space, act_space, {}),
"p1": (MockPolicy, obs_space, act_space, {}),
},
@ -227,7 +227,7 @@ class TestMultiAgentEnv(unittest.TestCase):
obs_space = gym.spaces.Discrete(2)
ev = RolloutWorker(
env_creator=lambda _: BasicMultiAgent(5),
policy={
policy_spec={
"p0": (MockPolicy, obs_space, act_space, {}),
"p1": (MockPolicy, obs_space, act_space, {}),
},
@ -242,7 +242,7 @@ class TestMultiAgentEnv(unittest.TestCase):
obs_space = gym.spaces.Discrete(2)
ev = RolloutWorker(
env_creator=lambda _: EarlyDoneMultiAgent(),
policy={
policy_spec={
"p0": (MockPolicy, obs_space, act_space, {}),
"p1": (MockPolicy, obs_space, act_space, {}),
},
@ -258,7 +258,7 @@ class TestMultiAgentEnv(unittest.TestCase):
obs_space = gym.spaces.Discrete(10)
ev = RolloutWorker(
env_creator=lambda _: RoundRobinMultiAgent(5, increment_obs=True),
policy={
policy_spec={
"p0": (MockPolicy, obs_space, act_space, {}),
},
policy_mapping_fn=lambda agent_id: "p0",
@ -309,7 +309,7 @@ class TestMultiAgentEnv(unittest.TestCase):
ev = RolloutWorker(
env_creator=lambda _: gym.make("CartPole-v0"),
policy=StatefulPolicy,
policy_spec=StatefulPolicy,
rollout_fragment_length=5)
batch = ev.sample()
self.assertEqual(batch.count, 5)
@ -354,7 +354,7 @@ class TestMultiAgentEnv(unittest.TestCase):
act_space = single_env.action_space
ev = RolloutWorker(
env_creator=lambda _: MultiAgentCartPole({"num_agents": 2}),
policy={
policy_spec={
"p0": (ModelBasedPolicy, obs_space, act_space, {}),
"p1": (ModelBasedPolicy, obs_space, act_space, {}),
},

View file

@ -23,7 +23,7 @@ class TestPerf(unittest.TestCase):
for _ in range(20):
ev = RolloutWorker(
env_creator=lambda _: gym.make("CartPole-v0"),
policy=MockPolicy,
policy_spec=MockPolicy,
rollout_fragment_length=100)
start = time.time()
count = 0

View file

@ -6,7 +6,7 @@ from ray.tune.registry import register_env
from ray.rllib.env import PettingZooEnv
from ray.rllib.agents.registry import get_agent_class
from pettingzoo.mpe import simple_spread_v0
from pettingzoo.mpe import simple_spread_v1
class TestPettingZooEnv(unittest.TestCase):
@ -17,13 +17,13 @@ class TestPettingZooEnv(unittest.TestCase):
ray.shutdown()
def test_pettingzoo_env(self):
register_env("prison", lambda _: PettingZooEnv(simple_spread_v0.env()))
register_env("prison", lambda _: PettingZooEnv(simple_spread_v1.env()))
agent_class = get_agent_class("PPO")
config = deepcopy(agent_class._default_config)
test_env = PettingZooEnv(simple_spread_v0.env())
test_env = PettingZooEnv(simple_spread_v1.env())
obs_space = test_env.observation_space
act_space = test_env.action_space
test_env.close()

View file

@ -147,7 +147,8 @@ class TestRolloutWorker(unittest.TestCase):
def test_basic(self):
ev = RolloutWorker(
env_creator=lambda _: gym.make("CartPole-v0"), policy=MockPolicy)
env_creator=lambda _: gym.make("CartPole-v0"),
policy_spec=MockPolicy)
batch = ev.sample()
for key in [
"obs", "actions", "rewards", "dones", "advantages",
@ -173,7 +174,7 @@ class TestRolloutWorker(unittest.TestCase):
def test_batch_ids(self):
ev = RolloutWorker(
env_creator=lambda _: gym.make("CartPole-v0"),
policy=MockPolicy,
policy_spec=MockPolicy,
rollout_fragment_length=1)
batch1 = ev.sample()
batch2 = ev.sample()
@ -264,6 +265,7 @@ class TestRolloutWorker(unittest.TestCase):
"rollout_fragment_length": 5,
"num_envs_per_worker": 2,
"framework": fw,
"create_env_on_driver": True,
})
results = pg.workers.foreach_worker(
lambda ev: ev.rollout_fragment_length)
@ -288,7 +290,7 @@ class TestRolloutWorker(unittest.TestCase):
p_done=0.0,
check_action_bounds=True,
)),
policy=RandomPolicy,
policy_spec=RandomPolicy,
policy_config=dict(
action_space=action_space,
ignore_action_bounds=True,
@ -312,7 +314,7 @@ class TestRolloutWorker(unittest.TestCase):
p_done=0.0,
check_action_bounds=True,
)),
policy=RandomPolicy,
policy_spec=RandomPolicy,
policy_config=dict(
action_space=action_space,
ignore_action_bounds=True,
@ -331,7 +333,7 @@ class TestRolloutWorker(unittest.TestCase):
p_done=0.0,
check_action_bounds=True,
)),
policy=RandomPolicy,
policy_spec=RandomPolicy,
policy_config=dict(action_space=action_space),
# Should not be a problem as RandomPolicy abides to bounds.
clip_actions=False,
@ -345,7 +347,7 @@ class TestRolloutWorker(unittest.TestCase):
# Clipping: True (clip between -1.0 and 1.0).
ev = RolloutWorker(
env_creator=lambda _: MockEnv2(episode_length=10),
policy=MockPolicy,
policy_spec=MockPolicy,
clip_rewards=True,
batch_mode="complete_episodes")
self.assertEqual(max(ev.sample()["rewards"]), 1)
@ -363,7 +365,7 @@ class TestRolloutWorker(unittest.TestCase):
p_done=0.0,
max_episode_len=10,
)),
policy=MockPolicy,
policy_spec=MockPolicy,
clip_rewards=2.0,
batch_mode="complete_episodes")
sample = ev2.sample()
@ -376,7 +378,7 @@ class TestRolloutWorker(unittest.TestCase):
# Clipping: Off.
ev2 = RolloutWorker(
env_creator=lambda _: MockEnv2(episode_length=10),
policy=MockPolicy,
policy_spec=MockPolicy,
clip_rewards=False,
batch_mode="complete_episodes")
self.assertEqual(max(ev2.sample()["rewards"]), 100)
@ -387,7 +389,7 @@ class TestRolloutWorker(unittest.TestCase):
def test_hard_horizon(self):
ev = RolloutWorker(
env_creator=lambda _: MockEnv2(episode_length=10),
policy=MockPolicy,
policy_spec=MockPolicy,
batch_mode="complete_episodes",
rollout_fragment_length=10,
episode_horizon=4,
@ -406,7 +408,7 @@ class TestRolloutWorker(unittest.TestCase):
# A gym env's max_episode_steps is smaller than Trainer's horizon.
ev = RolloutWorker(
env_creator=lambda _: gym.make("CartPole-v0"),
policy=MockPolicy,
policy_spec=MockPolicy,
batch_mode="complete_episodes",
rollout_fragment_length=10,
episode_horizon=6,
@ -427,7 +429,7 @@ class TestRolloutWorker(unittest.TestCase):
def test_soft_horizon(self):
ev = RolloutWorker(
env_creator=lambda _: MockEnv(episode_length=10),
policy=MockPolicy,
policy_spec=MockPolicy,
batch_mode="complete_episodes",
rollout_fragment_length=10,
episode_horizon=4,
@ -442,11 +444,11 @@ class TestRolloutWorker(unittest.TestCase):
def test_metrics(self):
ev = RolloutWorker(
env_creator=lambda _: MockEnv(episode_length=10),
policy=MockPolicy,
policy_spec=MockPolicy,
batch_mode="complete_episodes")
remote_ev = RolloutWorker.as_remote().remote(
env_creator=lambda _: MockEnv(episode_length=10),
policy=MockPolicy,
policy_spec=MockPolicy,
batch_mode="complete_episodes")
ev.sample()
ray.get(remote_ev.sample.remote())
@ -459,7 +461,7 @@ class TestRolloutWorker(unittest.TestCase):
ev = RolloutWorker(
env_creator=lambda _: gym.make("CartPole-v0"),
sample_async=True,
policy=MockPolicy)
policy_spec=MockPolicy)
batch = ev.sample()
for key in ["obs", "actions", "rewards", "dones", "advantages"]:
self.assertIn(key, batch)
@ -469,7 +471,7 @@ class TestRolloutWorker(unittest.TestCase):
def test_auto_vectorization(self):
ev = RolloutWorker(
env_creator=lambda cfg: MockEnv(episode_length=20, config=cfg),
policy=MockPolicy,
policy_spec=MockPolicy,
batch_mode="truncate_episodes",
rollout_fragment_length=2,
num_envs=8)
@ -493,7 +495,7 @@ class TestRolloutWorker(unittest.TestCase):
def test_batches_larger_when_vectorized(self):
ev = RolloutWorker(
env_creator=lambda _: MockEnv(episode_length=8),
policy=MockPolicy,
policy_spec=MockPolicy,
batch_mode="truncate_episodes",
rollout_fragment_length=4,
num_envs=4)
@ -509,7 +511,7 @@ class TestRolloutWorker(unittest.TestCase):
def test_vector_env_support(self):
ev = RolloutWorker(
env_creator=lambda _: MockVectorEnv(episode_length=20, num_envs=8),
policy=MockPolicy,
policy_spec=MockPolicy,
batch_mode="truncate_episodes",
rollout_fragment_length=10)
for _ in range(8):
@ -527,7 +529,7 @@ class TestRolloutWorker(unittest.TestCase):
def test_truncate_episodes(self):
ev = RolloutWorker(
env_creator=lambda _: MockEnv(10),
policy=MockPolicy,
policy_spec=MockPolicy,
rollout_fragment_length=15,
batch_mode="truncate_episodes")
batch = ev.sample()
@ -537,7 +539,7 @@ class TestRolloutWorker(unittest.TestCase):
def test_complete_episodes(self):
ev = RolloutWorker(
env_creator=lambda _: MockEnv(10),
policy=MockPolicy,
policy_spec=MockPolicy,
rollout_fragment_length=5,
batch_mode="complete_episodes")
batch = ev.sample()
@ -547,7 +549,7 @@ class TestRolloutWorker(unittest.TestCase):
def test_complete_episodes_packing(self):
ev = RolloutWorker(
env_creator=lambda _: MockEnv(10),
policy=MockPolicy,
policy_spec=MockPolicy,
rollout_fragment_length=15,
batch_mode="complete_episodes")
batch = ev.sample()
@ -560,7 +562,7 @@ class TestRolloutWorker(unittest.TestCase):
def test_filter_sync(self):
ev = RolloutWorker(
env_creator=lambda _: gym.make("CartPole-v0"),
policy=MockPolicy,
policy_spec=MockPolicy,
sample_async=True,
observation_filter="ConcurrentMeanStdFilter")
time.sleep(2)
@ -574,7 +576,7 @@ class TestRolloutWorker(unittest.TestCase):
def test_get_filters(self):
ev = RolloutWorker(
env_creator=lambda _: gym.make("CartPole-v0"),
policy=MockPolicy,
policy_spec=MockPolicy,
sample_async=True,
observation_filter="ConcurrentMeanStdFilter")
self.sample_and_flush(ev)
@ -590,7 +592,7 @@ class TestRolloutWorker(unittest.TestCase):
def test_sync_filter(self):
ev = RolloutWorker(
env_creator=lambda _: gym.make("CartPole-v0"),
policy=MockPolicy,
policy_spec=MockPolicy,
sample_async=True,
observation_filter="ConcurrentMeanStdFilter")
obs_f = self.sample_and_flush(ev)
@ -616,7 +618,7 @@ class TestRolloutWorker(unittest.TestCase):
self.assertFalse("env_key_2" in os.environ)
ev = RolloutWorker(
env_creator=lambda _: MockEnv(10),
policy=MockPolicy,
policy_spec=MockPolicy,
extra_python_environs=extra_envs)
self.assertTrue("env_key_1" in os.environ)
self.assertTrue("env_key_2" in os.environ)
@ -629,7 +631,7 @@ class TestRolloutWorker(unittest.TestCase):
def test_no_env_seed(self):
ev = RolloutWorker(
env_creator=lambda _: MockVectorEnv(episode_length=20, num_envs=8),
policy=MockPolicy,
policy_spec=MockPolicy,
seed=1)
assert not hasattr(ev.env, "seed")
ev.stop()

View file

@ -7,6 +7,7 @@ import unittest
import ray
from ray import tune
from ray.rllib.agents.callbacks import DefaultCallbacks
import ray.rllib.agents.ppo as ppo
from ray.rllib.utils.test_utils import check_learning_achieved, \
framework_iterator
@ -14,6 +15,23 @@ from ray.rllib.utils.numpy import one_hot
from ray.tune import register_env
class MyCallBack(DefaultCallbacks):
def __init__(self):
super().__init__()
self.deltas = []
def on_postprocess_trajectory(self, *, worker, episode, agent_id,
policy_id, policies, postprocessed_batch,
original_batches, **kwargs):
pos = np.argmax(postprocessed_batch["obs"], -1)
x, y = pos % 10, pos // 10
self.deltas.extend((x**2 + y**2)**0.5)
def on_sample_end(self, *, worker, samples, **kwargs):
print("mean. distance from origin={}".format(np.mean(self.deltas)))
self.deltas = []
class OneHotWrapper(gym.core.ObservationWrapper):
def __init__(self, env, vector_index, framestack):
super().__init__(env)
@ -114,38 +132,35 @@ class TestCuriosity(unittest.TestCase):
config["env"] = "FrozenLake-v0"
config["env_config"] = {
"desc": [
"SFFFFFFFFFFFFFFF",
"FFFFFFFFFFFFFFFF",
"FFFFFFFFFFFFFFFF",
"FFFFFFFFFFFFFFFF",
"FFFFFFFFFFFFFFFF",
"FFFFFFFFFFFFFFFF",
"FFFFFFFFFFFFFFFF",
"FFFFFFFFFFFFFFFF",
"FFFFFFFFFFFFFFFF",
"FFFFFFFFFFFFFFFF",
"FFFFFFFFFFFFFFFF",
"FFFFFFFFFFFFFFFF",
"FFFFFFFFFFFFFFFF",
"FFFFFFFFFFFFFFFF",
"FFFFFFFFFFFFFFFF",
"FFFFFFFFFFFFFFFG",
"SFFFFFFFFF",
"FFFFFFFFFF",
"FFFFFFFFFF",
"FFFFFFFFFF",
"FFFFFFFFFF",
"FFFFFFFFFF",
"FFFFFFFFFF",
"FFFFFFFFFF",
"FFFFFFFFFF",
"FFFFFFFFFG",
],
"is_slippery": False
}
# Print out observations to see how far we already get inside the Env.
config["callbacks"] = MyCallBack
# Limit horizon to make it really hard for non-curious agent to reach
# the goal state.
config["horizon"] = 40
# config["train_batch_size"] = 2048
# config["num_sgd_iter"] = 15
config["num_workers"] = 0 # local only
config["horizon"] = 23
# Local only.
config["num_workers"] = 0
config["lr"] = 0.001
num_iterations = 40
num_iterations = 10
for _ in framework_iterator(config, frameworks="torch"):
# W/ Curiosity. Expect to learn something.
config["exploration_config"] = {
"type": "Curiosity",
"lr": 0.0003,
"eta": 0.2,
"lr": 0.001,
"feature_dim": 128,
"feature_net_config": {
"fcnet_hiddens": [],
@ -157,29 +172,28 @@ class TestCuriosity(unittest.TestCase):
}
trainer = ppo.PPOTrainer(config=config)
learnt = False
for _ in range(num_iterations):
for i in range(num_iterations):
result = trainer.train()
print(result)
if result["episode_reward_mean"] >= 0.001:
print("Learnt something!")
if result["episode_reward_max"] > 0.0:
print("Reached goal after {} iters!".format(i))
learnt = True
break
trainer.stop()
self.assertTrue(learnt)
# # W/o Curiosity. Expect to learn nothing.
# config["exploration_config"] = {
# "type": "StochasticSampling",
# }
# trainer = ppo.PPOTrainer(config=config)
# rewards_wo = 0.0
# for _ in range(num_iterations):
# result = trainer.train()
# rewards_wo += result["episode_reward_mean"]
# print(result)
# trainer.stop()
# self.assertTrue(rewards_wo == 0.0)
# W/o Curiosity. Expect to learn nothing.
config["exploration_config"] = {
"type": "StochasticSampling",
}
trainer = ppo.PPOTrainer(config=config)
rewards_wo = 0.0
for _ in range(num_iterations):
result = trainer.train()
rewards_wo += result["episode_reward_mean"]
print(result)
trainer.stop()
self.assertTrue(rewards_wo == 0.0)
def test_curiosity_on_partially_observable_domain(self):
config = ppo.DEFAULT_CONFIG.copy()

View file

@ -298,7 +298,8 @@ def check_compute_single_action(trainer,
except AttributeError:
pass
else:
obs_space = worker_set.local_worker().env.observation_space
obs_space = worker_set.local_worker().for_policy(
lambda p: p.observation_space)
else:
method_to_test = pol.compute_single_action
obs_space = pol.observation_space

View file

@ -34,8 +34,8 @@ AgentID = Any
PolicyID = str
# Type of the config["multiagent"]["policies"] dict for multi-agent training.
MultiAgentPolicyConfigDict = Dict[PolicyID, Tuple[type, gym.Space, gym.Space,
PartialTrainerConfigDict]]
MultiAgentPolicyConfigDict = Dict[PolicyID, Tuple[Union[
type, None], gym.Space, gym.Space, PartialTrainerConfigDict]]
# Represents an environment id.
EnvID = int