mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Do not create env on driver iff num_workers > 0. (#11307)
This commit is contained in:
parent
60a4be4a59
commit
414041c6dd
17 changed files with 308 additions and 189 deletions
|
@ -28,6 +28,8 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"kl_coeff": 0.0005,
|
||||
# Size of batches collected from each worker
|
||||
"rollout_fragment_length": 200,
|
||||
# Do create an actual env on the local worker (worker-idx=0).
|
||||
"create_env_on_driver": True,
|
||||
# Stepsize of SGD
|
||||
"lr": 1e-3,
|
||||
# Share layers for value function
|
||||
|
@ -209,15 +211,18 @@ def get_policy_class(config):
|
|||
|
||||
def validate_config(config):
|
||||
if config["inner_adaptation_steps"] <= 0:
|
||||
raise ValueError("Inner Adaptation Steps must be >=1.")
|
||||
raise ValueError("Inner Adaptation Steps must be >=1!")
|
||||
if config["maml_optimizer_steps"] <= 0:
|
||||
raise ValueError("PPO steps for meta-update needs to be >=0")
|
||||
raise ValueError("PPO steps for meta-update needs to be >=0!")
|
||||
if config["entropy_coeff"] < 0:
|
||||
raise ValueError("entropy_coeff must be >=0")
|
||||
raise ValueError("`entropy_coeff` must be >=0.0!")
|
||||
if config["batch_mode"] != "complete_episodes":
|
||||
raise ValueError("truncate_episodes not supported")
|
||||
raise ValueError("`batch_mode`=truncate_episodes not supported!")
|
||||
if config["num_workers"] <= 0:
|
||||
raise ValueError("Must have at least 1 worker/task.")
|
||||
raise ValueError("Must have at least 1 worker/task!")
|
||||
if config["create_env_on_driver"] is False:
|
||||
raise ValueError("Must have an actual Env created on the driver "
|
||||
"(local) worker! Set `create_env_on_driver` to True.")
|
||||
|
||||
|
||||
MAMLTrainer = build_trainer(
|
||||
|
|
|
@ -51,6 +51,8 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"kl_coeff": 0.0005,
|
||||
# Size of batches collected from each worker.
|
||||
"rollout_fragment_length": 200,
|
||||
# Do create an actual env on the local worker (worker-idx=0).
|
||||
"create_env_on_driver": True,
|
||||
# Step size of SGD.
|
||||
"lr": 1e-3,
|
||||
# Share layers for value function.
|
||||
|
@ -414,15 +416,18 @@ def validate_config(config):
|
|||
"`framework=torch`.")
|
||||
config["framework"] = "torch"
|
||||
if config["inner_adaptation_steps"] <= 0:
|
||||
raise ValueError("Inner Adaptation Steps must be >=1.")
|
||||
raise ValueError("Inner adaptation steps must be >=1!")
|
||||
if config["maml_optimizer_steps"] <= 0:
|
||||
raise ValueError("PPO steps for meta-update needs to be >=0")
|
||||
raise ValueError("PPO steps for meta-update needs to be >=0!")
|
||||
if config["entropy_coeff"] < 0:
|
||||
raise ValueError("`entropy_coeff` must be >=0.")
|
||||
raise ValueError("`entropy_coeff` must be >=0.0!")
|
||||
if config["batch_mode"] != "complete_episodes":
|
||||
raise ValueError("`batch_mode=truncate_episodes` not supported.")
|
||||
raise ValueError("`batch_mode=truncate_episodes` not supported!")
|
||||
if config["num_workers"] <= 0:
|
||||
raise ValueError("Must have at least 1 worker/task.")
|
||||
if config["create_env_on_driver"] is False:
|
||||
raise ValueError("Must have an actual Env created on the driver "
|
||||
"(local) worker! Set `create_env_on_driver` to True.")
|
||||
|
||||
|
||||
def validate_env(env: EnvType, env_context: EnvContext):
|
||||
|
|
|
@ -53,6 +53,11 @@ COMMON_CONFIG: TrainerConfigDict = {
|
|||
# model inference batching, which can improve performance for inference
|
||||
# bottlenecked workloads.
|
||||
"num_envs_per_worker": 1,
|
||||
# When `num_workers` > 0, the driver (local_worker; worker-idx=0) does not
|
||||
# need an environment. This is because it doesn't have to sample (done by
|
||||
# remote_workers; worker_indices > 0) nor evaluate (done by evaluation
|
||||
# workers; see below).
|
||||
"create_env_on_driver": False,
|
||||
# Divide episodes into fragments of this many steps each during rollouts.
|
||||
# Sample batches of this size are collected from rollout workers and
|
||||
# combined into a larger batch of `train_batch_size` for learning.
|
||||
|
@ -308,16 +313,17 @@ COMMON_CONFIG: TrainerConfigDict = {
|
|||
|
||||
# === Offline Datasets ===
|
||||
# Specify how to generate experiences:
|
||||
# - "sampler": generate experiences via online simulation (default)
|
||||
# - a local directory or file glob expression (e.g., "/tmp/*.json")
|
||||
# - a list of individual file paths/URIs (e.g., ["/tmp/1.json",
|
||||
# "s3://bucket/2.json"])
|
||||
# - a dict with string keys and sampling probabilities as values (e.g.,
|
||||
# - "sampler": Generate experiences via online (env) simulation (default).
|
||||
# - A local directory or file glob expression (e.g., "/tmp/*.json").
|
||||
# - A list of individual file paths/URIs (e.g., ["/tmp/1.json",
|
||||
# "s3://bucket/2.json"]).
|
||||
# - A dict with string keys and sampling probabilities as values (e.g.,
|
||||
# {"sampler": 0.4, "/tmp/*.json": 0.4, "s3://bucket/expert.json": 0.2}).
|
||||
# - a function that returns a rllib.offline.InputReader
|
||||
# - A callable that returns a ray.rllib.offline.InputReader.
|
||||
"input": "sampler",
|
||||
# Specify how to evaluate the current policy. This only has an effect when
|
||||
# reading offline experiences. Available options:
|
||||
# reading offline experiences ("input" is not "sampler").
|
||||
# Available options:
|
||||
# - "wis": the weighted step-wise importance sampling estimator.
|
||||
# - "is": the step-wise importance sampling estimator.
|
||||
# - "simulation": run the environment in the background, but use
|
||||
|
@ -557,12 +563,12 @@ class Trainer(Trainable):
|
|||
# A class specifier.
|
||||
elif "." in env:
|
||||
self.env_creator = \
|
||||
lambda env_config: from_config(env, env_config)
|
||||
lambda env_context: from_config(env, env_context)
|
||||
# Try gym.
|
||||
else:
|
||||
import gym # soft dependency
|
||||
self.env_creator = \
|
||||
lambda env_config: gym.make(env, **env_config)
|
||||
lambda env_context: gym.make(env, **env_context)
|
||||
else:
|
||||
self.env_creator = lambda env_config: None
|
||||
|
||||
|
|
|
@ -32,6 +32,7 @@ from ray.rllib.policy.torch_policy import TorchPolicy
|
|||
from ray.rllib.utils import merge_dicts
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils.debug import summarize
|
||||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
from ray.rllib.utils.filter import get_filter, Filter
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.sgd import do_minibatch_sgd
|
||||
|
@ -84,7 +85,7 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
>>> # Create a rollout worker and using it to collect experiences.
|
||||
>>> worker = RolloutWorker(
|
||||
... env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
... policy=PGTFPolicy)
|
||||
... policy_spec=PGTFPolicy)
|
||||
>>> print(worker.sample())
|
||||
SampleBatch({
|
||||
"obs": [[...]], "actions": [[...]], "rewards": [[...]],
|
||||
|
@ -93,7 +94,7 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
>>> # Creating a multi-agent rollout worker
|
||||
>>> worker = RolloutWorker(
|
||||
... env_creator=lambda _: MultiAgentTrafficGrid(num_cars=25),
|
||||
... policies={
|
||||
... policy_spec={
|
||||
... # Use an ensemble of two policies for car agents
|
||||
... "car_policy1":
|
||||
... (PGTFPolicy, Box(...), Discrete(...), {"gamma": 0.99}),
|
||||
|
@ -135,9 +136,10 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
env_creator: Callable[[EnvContext], EnvType],
|
||||
validate_env: Optional[Callable[[EnvType, EnvContext],
|
||||
None]] = None,
|
||||
policy: Union[type, Dict[str, Tuple[Optional[
|
||||
type], gym.Space, gym.Space, PartialTrainerConfigDict]]],
|
||||
policy_mapping_fn: Callable[[AgentID], PolicyID] = None,
|
||||
policy_spec: Union[type, Dict[
|
||||
str, Tuple[Optional[type], gym.Space, gym.Space,
|
||||
PartialTrainerConfigDict]]] = None,
|
||||
policy_mapping_fn: Optional[Callable[[AgentID], PolicyID]] = None,
|
||||
policies_to_train: Optional[List[PolicyID]] = None,
|
||||
tf_session_creator: Optional[Callable[[], "tf1.Session"]] = None,
|
||||
rollout_fragment_length: int = 100,
|
||||
|
@ -172,7 +174,13 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
no_done_at_end: bool = False,
|
||||
seed: int = None,
|
||||
extra_python_environs: dict = None,
|
||||
fake_sampler: bool = False):
|
||||
fake_sampler: bool = False,
|
||||
spaces: Optional[Dict[PolicyID, Tuple[gym.spaces.Space,
|
||||
gym.spaces.Space]]] = None,
|
||||
policy: Union[type, Dict[
|
||||
str, Tuple[Optional[type], gym.Space, gym.Space,
|
||||
PartialTrainerConfigDict]]] = None,
|
||||
):
|
||||
"""Initialize a rollout worker.
|
||||
|
||||
Args:
|
||||
|
@ -181,16 +189,19 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
validate_env (Optional[Callable[[EnvType, EnvContext], None]]):
|
||||
Optional callable to validate the generated environment (only
|
||||
on worker=0).
|
||||
policy (Union[type, Dict[str, Tuple[Optional[type], gym.Space,
|
||||
policy_spec (Union[type, Dict[str, Tuple[Type[Policy], gym.Space,
|
||||
gym.Space, PartialTrainerConfigDict]]]): Either a Policy class
|
||||
or a dict of policy id strings to
|
||||
(Policy (None for default), obs_space, action_space,
|
||||
config)-tuples. If a dict is specified, then we are in
|
||||
multi-agent mode and a policy_mapping_fn should also be set.
|
||||
policy_mapping_fn (Callable[[AgentID], PolicyID]): A function that
|
||||
maps agent ids to policy ids in multi-agent mode. This function
|
||||
will be called each time a new agent appears in an episode, to
|
||||
bind that agent to a policy for the duration of the episode.
|
||||
(Policy class, obs_space, action_space, config)-tuples. If a
|
||||
dict is specified, then we are in multi-agent mode and a
|
||||
policy_mapping_fn can also be set (if not, will map all agents
|
||||
to DEFAULT_POLICY_ID).
|
||||
policy_mapping_fn (Optional[Callable[[AgentID], PolicyID]]): A
|
||||
callable that maps agent ids to policy ids in multi-agent mode.
|
||||
This function will be called each time a new agent appears in
|
||||
an episode, to bind that agent to a policy for the duration of
|
||||
the episode. If not provided, will map all agents to
|
||||
DEFAULT_POLICY_ID.
|
||||
policies_to_train (Optional[List[PolicyID]]): Optional list of
|
||||
policies to train, or None for all policies.
|
||||
tf_session_creator (Optional[Callable[[], tf1.Session]]): A
|
||||
|
@ -236,7 +247,7 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
policy model.
|
||||
policy_config (TrainerConfigDict): Config to pass to the policy.
|
||||
In the multi-agent case, this config will be merged with the
|
||||
per-policy configs specified by `policy`.
|
||||
per-policy configs specified by `policy_spec`.
|
||||
worker_index (int): For remote workers, this should be set to a
|
||||
non-zero and unique value. This index is passed to created envs
|
||||
through EnvContext so that envs can be configured per worker.
|
||||
|
@ -277,7 +288,19 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
extra_python_environs (dict): Extra python environments need to
|
||||
be set.
|
||||
fake_sampler (bool): Use a fake (inf speed) sampler for testing.
|
||||
spaces (Optional[Dict[PolicyID, Tuple[gym.spaces.Space,
|
||||
gym.spaces.Space]]]): An optional space dict mapping policy IDs
|
||||
to (obs_space, action_space)-tuples. This is used in case no
|
||||
Env is created on this RolloutWorker.
|
||||
policy: Obsoleted arg. Use `policy_spec` instead.
|
||||
"""
|
||||
# Deprecated arg.
|
||||
if policy is not None:
|
||||
deprecation_warning("policy", "policy_spec", error=False)
|
||||
policy_spec = policy
|
||||
assert policy_spec is not None, "Must provide `policy_spec` when " \
|
||||
"creating RolloutWorker!"
|
||||
|
||||
self._original_kwargs: dict = locals().copy()
|
||||
del self._original_kwargs["self"]
|
||||
|
||||
|
@ -334,6 +357,12 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
self.global_vars: dict = None
|
||||
self.fake_sampler: bool = fake_sampler
|
||||
|
||||
# No Env will be used in this particular worker (not needed).
|
||||
if worker_index == 0 and num_workers > 0 and \
|
||||
policy_config["create_env_on_driver"] is False:
|
||||
self.env = None
|
||||
# Create an env for this worker.
|
||||
else:
|
||||
self.env = _validate_env(env_creator(env_context))
|
||||
if validate_env is not None:
|
||||
validate_env(self.env, self.env_context)
|
||||
|
@ -385,7 +414,8 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
self.make_env_fn = make_env
|
||||
|
||||
self.tf_sess = None
|
||||
policy_dict = _validate_and_canonicalize(policy, self.env)
|
||||
policy_dict = _validate_and_canonicalize(
|
||||
policy_spec, self.env, spaces=spaces)
|
||||
self.policies_to_train: List[PolicyID] = policies_to_train or list(
|
||||
policy_dict.keys())
|
||||
self.policy_map: Dict[PolicyID, Policy] = None
|
||||
|
@ -446,7 +476,7 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
|
||||
self.multiagent: bool = set(
|
||||
self.policy_map.keys()) != {DEFAULT_POLICY_ID}
|
||||
if self.multiagent:
|
||||
if self.multiagent and self.env is not None:
|
||||
if not ((isinstance(self.env, MultiAgentEnv)
|
||||
or isinstance(self.env, ExternalMultiAgentEnv))
|
||||
or isinstance(self.env, BaseEnv)):
|
||||
|
@ -466,7 +496,9 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
|
||||
self.num_envs: int = num_envs
|
||||
|
||||
if "custom_vector_env" in policy_config:
|
||||
if self.env is None:
|
||||
self.async_env = None
|
||||
elif "custom_vector_env" in policy_config:
|
||||
custom_vec_wrapper = policy_config["custom_vector_env"]
|
||||
self.async_env = custom_vec_wrapper(self.env)
|
||||
else:
|
||||
|
@ -494,7 +526,7 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
|
||||
self.io_context: IOContext = IOContext(log_dir, policy_config,
|
||||
worker_index, self)
|
||||
self.reward_estimators: OffPolicyEstimator = []
|
||||
self.reward_estimators: List[OffPolicyEstimator] = []
|
||||
for method in input_evaluation:
|
||||
if method == "simulation":
|
||||
logger.warning(
|
||||
|
@ -512,7 +544,9 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
raise ValueError(
|
||||
"Unknown evaluation method: {}".format(method))
|
||||
|
||||
if sample_async:
|
||||
if self.env is None:
|
||||
self.sampler = None
|
||||
elif sample_async:
|
||||
self.sampler = AsyncSampler(
|
||||
worker=self,
|
||||
env=self.async_env,
|
||||
|
@ -816,7 +850,12 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
def get_metrics(self) -> List[Union[RolloutMetrics, OffPolicyEstimate]]:
|
||||
"""Returns a list of new RolloutMetric objects from evaluation."""
|
||||
|
||||
# Get metrics from sampler (if any).
|
||||
if self.sampler is not None:
|
||||
out = self.sampler.get_metrics()
|
||||
else:
|
||||
out = []
|
||||
# Get metrics from our reward-estimators (if any).
|
||||
for m in self.reward_estimators:
|
||||
out.extend(m.get_metrics())
|
||||
return out
|
||||
|
@ -825,6 +864,9 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
def foreach_env(self, func: Callable[[BaseEnv], T]) -> List[T]:
|
||||
"""Apply the given function to each underlying env instance."""
|
||||
|
||||
if self.async_env is None:
|
||||
return []
|
||||
|
||||
envs = self.async_env.get_unwrapped()
|
||||
if not envs:
|
||||
return [func(self.async_env)]
|
||||
|
@ -957,6 +999,7 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
|
||||
@DeveloperAPI
|
||||
def stop(self) -> None:
|
||||
if self.env:
|
||||
self.async_env.stop()
|
||||
|
||||
@DeveloperAPI
|
||||
|
@ -1054,8 +1097,12 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
self.sampler.shutdown = True
|
||||
|
||||
|
||||
def _validate_and_canonicalize(policy: Policy,
|
||||
env: EnvType) -> MultiAgentPolicyConfigDict:
|
||||
def _validate_and_canonicalize(
|
||||
policy: Union[Type[Policy], MultiAgentPolicyConfigDict],
|
||||
env: Optional[EnvType] = None,
|
||||
spaces: Optional[Dict[PolicyID, Tuple[
|
||||
gym.spaces.Space, gym.spaces.Space]]] = None) -> \
|
||||
MultiAgentPolicyConfigDict:
|
||||
if isinstance(policy, dict):
|
||||
_validate_multiagent_config(policy)
|
||||
return policy
|
||||
|
@ -1067,14 +1114,20 @@ def _validate_and_canonicalize(policy: Policy,
|
|||
raise ValueError(
|
||||
"MultiAgentEnv must have observation_space defined if run "
|
||||
"in a single-agent configuration.")
|
||||
if env is not None:
|
||||
return {
|
||||
DEFAULT_POLICY_ID: (policy, env.observation_space,
|
||||
env.action_space, {})
|
||||
}
|
||||
else:
|
||||
return {
|
||||
DEFAULT_POLICY_ID: (policy, spaces[DEFAULT_POLICY_ID][0],
|
||||
spaces[DEFAULT_POLICY_ID][1], {})
|
||||
}
|
||||
|
||||
|
||||
def _validate_multiagent_config(policy: MultiAgentPolicyConfigDict,
|
||||
allow_none_graph: bool = False):
|
||||
allow_none_graph: bool = False) -> None:
|
||||
for k, v in policy.items():
|
||||
if not isinstance(k, str):
|
||||
raise ValueError("policy keys must be strs, got {}".format(
|
||||
|
|
|
@ -33,7 +33,7 @@ class TestTrajectoryViewAPI(unittest.TestCase):
|
|||
view_req_model = policy.model.inference_view_requirements
|
||||
view_req_policy = policy.view_requirements
|
||||
assert len(view_req_model) == 1, view_req_model
|
||||
assert len(view_req_policy) == 11, view_req_policy
|
||||
assert len(view_req_policy) == 12, view_req_policy
|
||||
for key in [
|
||||
SampleBatch.OBS, SampleBatch.ACTIONS, SampleBatch.REWARDS,
|
||||
SampleBatch.DONES, SampleBatch.NEXT_OBS,
|
||||
|
@ -64,8 +64,8 @@ class TestTrajectoryViewAPI(unittest.TestCase):
|
|||
policy = trainer.get_policy()
|
||||
view_req_model = policy.model.inference_view_requirements
|
||||
view_req_policy = policy.view_requirements
|
||||
assert len(view_req_model) == 7, view_req_model
|
||||
assert len(view_req_policy) == 17, view_req_policy
|
||||
assert len(view_req_model) == 5, view_req_model
|
||||
assert len(view_req_policy) == 18, view_req_policy
|
||||
for key in [
|
||||
SampleBatch.OBS, SampleBatch.ACTIONS, SampleBatch.REWARDS,
|
||||
SampleBatch.DONES, SampleBatch.NEXT_OBS,
|
||||
|
@ -220,7 +220,7 @@ class TestTrajectoryViewAPI(unittest.TestCase):
|
|||
env_creator=lambda _: MultiAgentDebugCounterEnv({"num_agents": 4}),
|
||||
policy_config=dict(config, **{"_use_trajectory_view_api": True}),
|
||||
rollout_fragment_length=rollout_fragment_length,
|
||||
policy=policies,
|
||||
policy_spec=policies,
|
||||
policy_mapping_fn=policy_fn,
|
||||
num_envs=1,
|
||||
)
|
||||
|
@ -228,7 +228,7 @@ class TestTrajectoryViewAPI(unittest.TestCase):
|
|||
env_creator=lambda _: MultiAgentDebugCounterEnv({"num_agents": 4}),
|
||||
policy_config=dict(config, **{"_use_trajectory_view_api": False}),
|
||||
rollout_fragment_length=rollout_fragment_length,
|
||||
policy=policies,
|
||||
policy_spec=policies,
|
||||
policy_mapping_fn=policy_fn,
|
||||
num_envs=1,
|
||||
)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import gym
|
||||
import logging
|
||||
from types import FunctionType
|
||||
from typing import Callable, List, Optional, Type, TypeVar, Union
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
import ray
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
|
@ -72,14 +73,27 @@ class WorkerSet:
|
|||
self._remote_workers = []
|
||||
self.add_workers(num_workers)
|
||||
|
||||
# If num_workers > 0, get the action_spaces and observation_spaces
|
||||
# to not be forced to create an Env on the driver.
|
||||
if self._remote_workers:
|
||||
remote_spaces = ray.get(self.remote_workers(
|
||||
)[0].foreach_policy.remote(
|
||||
lambda p, pid: (pid, p.observation_space, p.action_space)))
|
||||
spaces = {e[0]: (e[1], e[2]) for e in remote_spaces}
|
||||
else:
|
||||
spaces = None
|
||||
|
||||
# Always create a local worker.
|
||||
self._local_worker = self._make_worker(
|
||||
cls=RolloutWorker,
|
||||
env_creator=env_creator,
|
||||
validate_env=validate_env,
|
||||
policy=self._policy_class,
|
||||
policy_cls=self._policy_class,
|
||||
worker_index=0,
|
||||
config=self._local_config)
|
||||
num_workers=num_workers,
|
||||
config=self._local_config,
|
||||
spaces=spaces,
|
||||
)
|
||||
|
||||
def local_worker(self) -> RolloutWorker:
|
||||
"""Return the local rollout worker."""
|
||||
|
@ -118,8 +132,9 @@ class WorkerSet:
|
|||
cls=cls,
|
||||
env_creator=self._env_creator,
|
||||
validate_env=None,
|
||||
policy=self._policy_class,
|
||||
policy_cls=self._policy_class,
|
||||
worker_index=i + 1,
|
||||
num_workers=num_workers,
|
||||
config=self._remote_config) for i in range(num_workers)
|
||||
])
|
||||
|
||||
|
@ -217,11 +232,18 @@ class WorkerSet:
|
|||
return workers
|
||||
|
||||
def _make_worker(
|
||||
self, *, cls: Callable,
|
||||
self,
|
||||
*,
|
||||
cls: Callable,
|
||||
env_creator: Callable[[EnvContext], EnvType],
|
||||
validate_env: Optional[Callable[[EnvType], None]],
|
||||
policy: Type[Policy], worker_index: int,
|
||||
config: TrainerConfigDict) -> Union[RolloutWorker, "ActorHandle"]:
|
||||
policy_cls: Type[Policy],
|
||||
worker_index: int,
|
||||
num_workers: int,
|
||||
config: TrainerConfigDict,
|
||||
spaces: Optional[Dict[PolicyID, Tuple[gym.spaces.Space,
|
||||
gym.spaces.Space]]] = None,
|
||||
) -> Union[RolloutWorker, "ActorHandle"]:
|
||||
def session_creator():
|
||||
logger.debug("Creating TF session {}".format(
|
||||
config["tf_session_args"]))
|
||||
|
@ -263,14 +285,20 @@ class WorkerSet:
|
|||
else:
|
||||
input_evaluation = config["input_evaluation"]
|
||||
|
||||
# Fill in the default policy if 'None' is specified in multiagent.
|
||||
# Fill in the default policy_cls if 'None' is specified in multiagent.
|
||||
if config["multiagent"]["policies"]:
|
||||
tmp = config["multiagent"]["policies"]
|
||||
_validate_multiagent_config(tmp, allow_none_graph=True)
|
||||
# TODO: (sven) Allow for setting observation and action spaces to
|
||||
# None as well, in which case, spaces are taken from env.
|
||||
# It's tedious to have to provide these in a multi-agent config.
|
||||
for k, v in tmp.items():
|
||||
if v[0] is None:
|
||||
tmp[k] = (policy, v[1], v[2], v[3])
|
||||
policy = tmp
|
||||
tmp[k] = (policy_cls, v[1], v[2], v[3])
|
||||
policy_spec = tmp
|
||||
# Otherwise, policy spec is simply the policy class itself.
|
||||
else:
|
||||
policy_spec = policy_cls
|
||||
|
||||
if worker_index == 0:
|
||||
extra_python_environs = config.get(
|
||||
|
@ -282,7 +310,7 @@ class WorkerSet:
|
|||
worker = cls(
|
||||
env_creator=env_creator,
|
||||
validate_env=validate_env,
|
||||
policy=policy,
|
||||
policy_spec=policy_spec,
|
||||
policy_mapping_fn=config["multiagent"]["policy_mapping_fn"],
|
||||
policies_to_train=config["multiagent"]["policies_to_train"],
|
||||
tf_session_creator=(session_creator
|
||||
|
@ -302,7 +330,7 @@ class WorkerSet:
|
|||
model_config=config["model"],
|
||||
policy_config=config,
|
||||
worker_index=worker_index,
|
||||
num_workers=config["num_workers"],
|
||||
num_workers=num_workers,
|
||||
monitor_path=self._logdir if config["monitor"] else None,
|
||||
log_dir=self._logdir,
|
||||
log_level=config["log_level"],
|
||||
|
@ -317,6 +345,8 @@ class WorkerSet:
|
|||
seed=(config["seed"] + worker_index)
|
||||
if config["seed"] is not None else None,
|
||||
fake_sampler=config["fake_sampler"],
|
||||
extra_python_environs=extra_python_environs)
|
||||
extra_python_environs=extra_python_environs,
|
||||
spaces=spaces,
|
||||
)
|
||||
|
||||
return worker
|
||||
|
|
|
@ -267,6 +267,9 @@ def run(args, parser):
|
|||
if "num_workers" in config:
|
||||
config["num_workers"] = min(2, config["num_workers"])
|
||||
|
||||
# Make sure worker 0 has an Env.
|
||||
config["create_env_on_driver"] = True
|
||||
|
||||
# Merge with `evaluation_config` (first try from command line, then from
|
||||
# pkl file).
|
||||
evaluation_config = copy.deepcopy(
|
||||
|
|
|
@ -28,12 +28,12 @@ def iter_list(values):
|
|||
def make_workers(n):
|
||||
local = RolloutWorker(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy=PPOTFPolicy,
|
||||
policy_spec=PPOTFPolicy,
|
||||
rollout_fragment_length=100)
|
||||
remotes = [
|
||||
RolloutWorker.as_remote().remote(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy=PPOTFPolicy,
|
||||
policy_spec=PPOTFPolicy,
|
||||
rollout_fragment_length=100) for _ in range(n)
|
||||
]
|
||||
workers = WorkerSet._from_existing(local, remotes)
|
||||
|
|
|
@ -126,7 +126,7 @@ class TestExternalEnv(unittest.TestCase):
|
|||
def test_external_env_complete_episodes(self):
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: SimpleServing(MockEnv(25)),
|
||||
policy=MockPolicy,
|
||||
policy_spec=MockPolicy,
|
||||
rollout_fragment_length=40,
|
||||
batch_mode="complete_episodes")
|
||||
for _ in range(3):
|
||||
|
@ -136,7 +136,7 @@ class TestExternalEnv(unittest.TestCase):
|
|||
def test_external_env_truncate_episodes(self):
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: SimpleServing(MockEnv(25)),
|
||||
policy=MockPolicy,
|
||||
policy_spec=MockPolicy,
|
||||
rollout_fragment_length=40,
|
||||
batch_mode="truncate_episodes")
|
||||
for _ in range(3):
|
||||
|
@ -146,7 +146,7 @@ class TestExternalEnv(unittest.TestCase):
|
|||
def test_external_env_off_policy(self):
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: SimpleOffPolicyServing(MockEnv(25), 42),
|
||||
policy=MockPolicy,
|
||||
policy_spec=MockPolicy,
|
||||
rollout_fragment_length=40,
|
||||
batch_mode="complete_episodes")
|
||||
for _ in range(3):
|
||||
|
@ -158,7 +158,7 @@ class TestExternalEnv(unittest.TestCase):
|
|||
def test_external_env_bad_actions(self):
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: SimpleServing(MockEnv(25)),
|
||||
policy=BadPolicy,
|
||||
policy_spec=BadPolicy,
|
||||
sample_async=True,
|
||||
rollout_fragment_length=40,
|
||||
batch_mode="truncate_episodes")
|
||||
|
@ -226,7 +226,7 @@ class TestExternalEnv(unittest.TestCase):
|
|||
def test_external_env_horizon_not_supported(self):
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: SimpleServing(MockEnv(25)),
|
||||
policy=MockPolicy,
|
||||
policy_spec=MockPolicy,
|
||||
episode_horizon=20,
|
||||
rollout_fragment_length=10,
|
||||
batch_mode="complete_episodes")
|
||||
|
|
|
@ -25,7 +25,7 @@ class TestExternalMultiAgentEnv(unittest.TestCase):
|
|||
agents = 4
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)),
|
||||
policy=MockPolicy,
|
||||
policy_spec=MockPolicy,
|
||||
rollout_fragment_length=40,
|
||||
batch_mode="complete_episodes")
|
||||
for _ in range(3):
|
||||
|
@ -37,7 +37,7 @@ class TestExternalMultiAgentEnv(unittest.TestCase):
|
|||
agents = 4
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)),
|
||||
policy=MockPolicy,
|
||||
policy_spec=MockPolicy,
|
||||
rollout_fragment_length=40,
|
||||
batch_mode="truncate_episodes")
|
||||
for _ in range(3):
|
||||
|
@ -51,7 +51,7 @@ class TestExternalMultiAgentEnv(unittest.TestCase):
|
|||
obs_space = gym.spaces.Discrete(2)
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)),
|
||||
policy={
|
||||
policy_spec={
|
||||
"p0": (MockPolicy, obs_space, act_space, {}),
|
||||
"p1": (MockPolicy, obs_space, act_space, {}),
|
||||
},
|
||||
|
|
|
@ -172,7 +172,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
|||
obs_space = gym.spaces.Discrete(2)
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: BasicMultiAgent(5),
|
||||
policy={
|
||||
policy_spec={
|
||||
"p0": (MockPolicy, obs_space, act_space, {}),
|
||||
"p1": (MockPolicy, obs_space, act_space, {}),
|
||||
},
|
||||
|
@ -192,7 +192,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
|||
obs_space = gym.spaces.Discrete(2)
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: BasicMultiAgent(5),
|
||||
policy={
|
||||
policy_spec={
|
||||
"p0": (MockPolicy, obs_space, act_space, {}),
|
||||
"p1": (MockPolicy, obs_space, act_space, {}),
|
||||
},
|
||||
|
@ -211,7 +211,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
|||
obs_space = gym.spaces.Discrete(2)
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: BasicMultiAgent(5),
|
||||
policy={
|
||||
policy_spec={
|
||||
"p0": (MockPolicy, obs_space, act_space, {}),
|
||||
"p1": (MockPolicy, obs_space, act_space, {}),
|
||||
},
|
||||
|
@ -227,7 +227,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
|||
obs_space = gym.spaces.Discrete(2)
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: BasicMultiAgent(5),
|
||||
policy={
|
||||
policy_spec={
|
||||
"p0": (MockPolicy, obs_space, act_space, {}),
|
||||
"p1": (MockPolicy, obs_space, act_space, {}),
|
||||
},
|
||||
|
@ -242,7 +242,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
|||
obs_space = gym.spaces.Discrete(2)
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: EarlyDoneMultiAgent(),
|
||||
policy={
|
||||
policy_spec={
|
||||
"p0": (MockPolicy, obs_space, act_space, {}),
|
||||
"p1": (MockPolicy, obs_space, act_space, {}),
|
||||
},
|
||||
|
@ -258,7 +258,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
|||
obs_space = gym.spaces.Discrete(10)
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: RoundRobinMultiAgent(5, increment_obs=True),
|
||||
policy={
|
||||
policy_spec={
|
||||
"p0": (MockPolicy, obs_space, act_space, {}),
|
||||
},
|
||||
policy_mapping_fn=lambda agent_id: "p0",
|
||||
|
@ -309,7 +309,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
|||
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy=StatefulPolicy,
|
||||
policy_spec=StatefulPolicy,
|
||||
rollout_fragment_length=5)
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 5)
|
||||
|
@ -354,7 +354,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
|||
act_space = single_env.action_space
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: MultiAgentCartPole({"num_agents": 2}),
|
||||
policy={
|
||||
policy_spec={
|
||||
"p0": (ModelBasedPolicy, obs_space, act_space, {}),
|
||||
"p1": (ModelBasedPolicy, obs_space, act_space, {}),
|
||||
},
|
||||
|
|
|
@ -23,7 +23,7 @@ class TestPerf(unittest.TestCase):
|
|||
for _ in range(20):
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy=MockPolicy,
|
||||
policy_spec=MockPolicy,
|
||||
rollout_fragment_length=100)
|
||||
start = time.time()
|
||||
count = 0
|
||||
|
|
|
@ -6,7 +6,7 @@ from ray.tune.registry import register_env
|
|||
from ray.rllib.env import PettingZooEnv
|
||||
from ray.rllib.agents.registry import get_agent_class
|
||||
|
||||
from pettingzoo.mpe import simple_spread_v0
|
||||
from pettingzoo.mpe import simple_spread_v1
|
||||
|
||||
|
||||
class TestPettingZooEnv(unittest.TestCase):
|
||||
|
@ -17,13 +17,13 @@ class TestPettingZooEnv(unittest.TestCase):
|
|||
ray.shutdown()
|
||||
|
||||
def test_pettingzoo_env(self):
|
||||
register_env("prison", lambda _: PettingZooEnv(simple_spread_v0.env()))
|
||||
register_env("prison", lambda _: PettingZooEnv(simple_spread_v1.env()))
|
||||
|
||||
agent_class = get_agent_class("PPO")
|
||||
|
||||
config = deepcopy(agent_class._default_config)
|
||||
|
||||
test_env = PettingZooEnv(simple_spread_v0.env())
|
||||
test_env = PettingZooEnv(simple_spread_v1.env())
|
||||
obs_space = test_env.observation_space
|
||||
act_space = test_env.action_space
|
||||
test_env.close()
|
||||
|
|
|
@ -147,7 +147,8 @@ class TestRolloutWorker(unittest.TestCase):
|
|||
|
||||
def test_basic(self):
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"), policy=MockPolicy)
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy_spec=MockPolicy)
|
||||
batch = ev.sample()
|
||||
for key in [
|
||||
"obs", "actions", "rewards", "dones", "advantages",
|
||||
|
@ -173,7 +174,7 @@ class TestRolloutWorker(unittest.TestCase):
|
|||
def test_batch_ids(self):
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy=MockPolicy,
|
||||
policy_spec=MockPolicy,
|
||||
rollout_fragment_length=1)
|
||||
batch1 = ev.sample()
|
||||
batch2 = ev.sample()
|
||||
|
@ -264,6 +265,7 @@ class TestRolloutWorker(unittest.TestCase):
|
|||
"rollout_fragment_length": 5,
|
||||
"num_envs_per_worker": 2,
|
||||
"framework": fw,
|
||||
"create_env_on_driver": True,
|
||||
})
|
||||
results = pg.workers.foreach_worker(
|
||||
lambda ev: ev.rollout_fragment_length)
|
||||
|
@ -288,7 +290,7 @@ class TestRolloutWorker(unittest.TestCase):
|
|||
p_done=0.0,
|
||||
check_action_bounds=True,
|
||||
)),
|
||||
policy=RandomPolicy,
|
||||
policy_spec=RandomPolicy,
|
||||
policy_config=dict(
|
||||
action_space=action_space,
|
||||
ignore_action_bounds=True,
|
||||
|
@ -312,7 +314,7 @@ class TestRolloutWorker(unittest.TestCase):
|
|||
p_done=0.0,
|
||||
check_action_bounds=True,
|
||||
)),
|
||||
policy=RandomPolicy,
|
||||
policy_spec=RandomPolicy,
|
||||
policy_config=dict(
|
||||
action_space=action_space,
|
||||
ignore_action_bounds=True,
|
||||
|
@ -331,7 +333,7 @@ class TestRolloutWorker(unittest.TestCase):
|
|||
p_done=0.0,
|
||||
check_action_bounds=True,
|
||||
)),
|
||||
policy=RandomPolicy,
|
||||
policy_spec=RandomPolicy,
|
||||
policy_config=dict(action_space=action_space),
|
||||
# Should not be a problem as RandomPolicy abides to bounds.
|
||||
clip_actions=False,
|
||||
|
@ -345,7 +347,7 @@ class TestRolloutWorker(unittest.TestCase):
|
|||
# Clipping: True (clip between -1.0 and 1.0).
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: MockEnv2(episode_length=10),
|
||||
policy=MockPolicy,
|
||||
policy_spec=MockPolicy,
|
||||
clip_rewards=True,
|
||||
batch_mode="complete_episodes")
|
||||
self.assertEqual(max(ev.sample()["rewards"]), 1)
|
||||
|
@ -363,7 +365,7 @@ class TestRolloutWorker(unittest.TestCase):
|
|||
p_done=0.0,
|
||||
max_episode_len=10,
|
||||
)),
|
||||
policy=MockPolicy,
|
||||
policy_spec=MockPolicy,
|
||||
clip_rewards=2.0,
|
||||
batch_mode="complete_episodes")
|
||||
sample = ev2.sample()
|
||||
|
@ -376,7 +378,7 @@ class TestRolloutWorker(unittest.TestCase):
|
|||
# Clipping: Off.
|
||||
ev2 = RolloutWorker(
|
||||
env_creator=lambda _: MockEnv2(episode_length=10),
|
||||
policy=MockPolicy,
|
||||
policy_spec=MockPolicy,
|
||||
clip_rewards=False,
|
||||
batch_mode="complete_episodes")
|
||||
self.assertEqual(max(ev2.sample()["rewards"]), 100)
|
||||
|
@ -387,7 +389,7 @@ class TestRolloutWorker(unittest.TestCase):
|
|||
def test_hard_horizon(self):
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: MockEnv2(episode_length=10),
|
||||
policy=MockPolicy,
|
||||
policy_spec=MockPolicy,
|
||||
batch_mode="complete_episodes",
|
||||
rollout_fragment_length=10,
|
||||
episode_horizon=4,
|
||||
|
@ -406,7 +408,7 @@ class TestRolloutWorker(unittest.TestCase):
|
|||
# A gym env's max_episode_steps is smaller than Trainer's horizon.
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy=MockPolicy,
|
||||
policy_spec=MockPolicy,
|
||||
batch_mode="complete_episodes",
|
||||
rollout_fragment_length=10,
|
||||
episode_horizon=6,
|
||||
|
@ -427,7 +429,7 @@ class TestRolloutWorker(unittest.TestCase):
|
|||
def test_soft_horizon(self):
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: MockEnv(episode_length=10),
|
||||
policy=MockPolicy,
|
||||
policy_spec=MockPolicy,
|
||||
batch_mode="complete_episodes",
|
||||
rollout_fragment_length=10,
|
||||
episode_horizon=4,
|
||||
|
@ -442,11 +444,11 @@ class TestRolloutWorker(unittest.TestCase):
|
|||
def test_metrics(self):
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: MockEnv(episode_length=10),
|
||||
policy=MockPolicy,
|
||||
policy_spec=MockPolicy,
|
||||
batch_mode="complete_episodes")
|
||||
remote_ev = RolloutWorker.as_remote().remote(
|
||||
env_creator=lambda _: MockEnv(episode_length=10),
|
||||
policy=MockPolicy,
|
||||
policy_spec=MockPolicy,
|
||||
batch_mode="complete_episodes")
|
||||
ev.sample()
|
||||
ray.get(remote_ev.sample.remote())
|
||||
|
@ -459,7 +461,7 @@ class TestRolloutWorker(unittest.TestCase):
|
|||
ev = RolloutWorker(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
sample_async=True,
|
||||
policy=MockPolicy)
|
||||
policy_spec=MockPolicy)
|
||||
batch = ev.sample()
|
||||
for key in ["obs", "actions", "rewards", "dones", "advantages"]:
|
||||
self.assertIn(key, batch)
|
||||
|
@ -469,7 +471,7 @@ class TestRolloutWorker(unittest.TestCase):
|
|||
def test_auto_vectorization(self):
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda cfg: MockEnv(episode_length=20, config=cfg),
|
||||
policy=MockPolicy,
|
||||
policy_spec=MockPolicy,
|
||||
batch_mode="truncate_episodes",
|
||||
rollout_fragment_length=2,
|
||||
num_envs=8)
|
||||
|
@ -493,7 +495,7 @@ class TestRolloutWorker(unittest.TestCase):
|
|||
def test_batches_larger_when_vectorized(self):
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: MockEnv(episode_length=8),
|
||||
policy=MockPolicy,
|
||||
policy_spec=MockPolicy,
|
||||
batch_mode="truncate_episodes",
|
||||
rollout_fragment_length=4,
|
||||
num_envs=4)
|
||||
|
@ -509,7 +511,7 @@ class TestRolloutWorker(unittest.TestCase):
|
|||
def test_vector_env_support(self):
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: MockVectorEnv(episode_length=20, num_envs=8),
|
||||
policy=MockPolicy,
|
||||
policy_spec=MockPolicy,
|
||||
batch_mode="truncate_episodes",
|
||||
rollout_fragment_length=10)
|
||||
for _ in range(8):
|
||||
|
@ -527,7 +529,7 @@ class TestRolloutWorker(unittest.TestCase):
|
|||
def test_truncate_episodes(self):
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: MockEnv(10),
|
||||
policy=MockPolicy,
|
||||
policy_spec=MockPolicy,
|
||||
rollout_fragment_length=15,
|
||||
batch_mode="truncate_episodes")
|
||||
batch = ev.sample()
|
||||
|
@ -537,7 +539,7 @@ class TestRolloutWorker(unittest.TestCase):
|
|||
def test_complete_episodes(self):
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: MockEnv(10),
|
||||
policy=MockPolicy,
|
||||
policy_spec=MockPolicy,
|
||||
rollout_fragment_length=5,
|
||||
batch_mode="complete_episodes")
|
||||
batch = ev.sample()
|
||||
|
@ -547,7 +549,7 @@ class TestRolloutWorker(unittest.TestCase):
|
|||
def test_complete_episodes_packing(self):
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: MockEnv(10),
|
||||
policy=MockPolicy,
|
||||
policy_spec=MockPolicy,
|
||||
rollout_fragment_length=15,
|
||||
batch_mode="complete_episodes")
|
||||
batch = ev.sample()
|
||||
|
@ -560,7 +562,7 @@ class TestRolloutWorker(unittest.TestCase):
|
|||
def test_filter_sync(self):
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy=MockPolicy,
|
||||
policy_spec=MockPolicy,
|
||||
sample_async=True,
|
||||
observation_filter="ConcurrentMeanStdFilter")
|
||||
time.sleep(2)
|
||||
|
@ -574,7 +576,7 @@ class TestRolloutWorker(unittest.TestCase):
|
|||
def test_get_filters(self):
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy=MockPolicy,
|
||||
policy_spec=MockPolicy,
|
||||
sample_async=True,
|
||||
observation_filter="ConcurrentMeanStdFilter")
|
||||
self.sample_and_flush(ev)
|
||||
|
@ -590,7 +592,7 @@ class TestRolloutWorker(unittest.TestCase):
|
|||
def test_sync_filter(self):
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy=MockPolicy,
|
||||
policy_spec=MockPolicy,
|
||||
sample_async=True,
|
||||
observation_filter="ConcurrentMeanStdFilter")
|
||||
obs_f = self.sample_and_flush(ev)
|
||||
|
@ -616,7 +618,7 @@ class TestRolloutWorker(unittest.TestCase):
|
|||
self.assertFalse("env_key_2" in os.environ)
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: MockEnv(10),
|
||||
policy=MockPolicy,
|
||||
policy_spec=MockPolicy,
|
||||
extra_python_environs=extra_envs)
|
||||
self.assertTrue("env_key_1" in os.environ)
|
||||
self.assertTrue("env_key_2" in os.environ)
|
||||
|
@ -629,7 +631,7 @@ class TestRolloutWorker(unittest.TestCase):
|
|||
def test_no_env_seed(self):
|
||||
ev = RolloutWorker(
|
||||
env_creator=lambda _: MockVectorEnv(episode_length=20, num_envs=8),
|
||||
policy=MockPolicy,
|
||||
policy_spec=MockPolicy,
|
||||
seed=1)
|
||||
assert not hasattr(ev.env, "seed")
|
||||
ev.stop()
|
||||
|
|
|
@ -7,6 +7,7 @@ import unittest
|
|||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.agents.callbacks import DefaultCallbacks
|
||||
import ray.rllib.agents.ppo as ppo
|
||||
from ray.rllib.utils.test_utils import check_learning_achieved, \
|
||||
framework_iterator
|
||||
|
@ -14,6 +15,23 @@ from ray.rllib.utils.numpy import one_hot
|
|||
from ray.tune import register_env
|
||||
|
||||
|
||||
class MyCallBack(DefaultCallbacks):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.deltas = []
|
||||
|
||||
def on_postprocess_trajectory(self, *, worker, episode, agent_id,
|
||||
policy_id, policies, postprocessed_batch,
|
||||
original_batches, **kwargs):
|
||||
pos = np.argmax(postprocessed_batch["obs"], -1)
|
||||
x, y = pos % 10, pos // 10
|
||||
self.deltas.extend((x**2 + y**2)**0.5)
|
||||
|
||||
def on_sample_end(self, *, worker, samples, **kwargs):
|
||||
print("mean. distance from origin={}".format(np.mean(self.deltas)))
|
||||
self.deltas = []
|
||||
|
||||
|
||||
class OneHotWrapper(gym.core.ObservationWrapper):
|
||||
def __init__(self, env, vector_index, framestack):
|
||||
super().__init__(env)
|
||||
|
@ -114,38 +132,35 @@ class TestCuriosity(unittest.TestCase):
|
|||
config["env"] = "FrozenLake-v0"
|
||||
config["env_config"] = {
|
||||
"desc": [
|
||||
"SFFFFFFFFFFFFFFF",
|
||||
"FFFFFFFFFFFFFFFF",
|
||||
"FFFFFFFFFFFFFFFF",
|
||||
"FFFFFFFFFFFFFFFF",
|
||||
"FFFFFFFFFFFFFFFF",
|
||||
"FFFFFFFFFFFFFFFF",
|
||||
"FFFFFFFFFFFFFFFF",
|
||||
"FFFFFFFFFFFFFFFF",
|
||||
"FFFFFFFFFFFFFFFF",
|
||||
"FFFFFFFFFFFFFFFF",
|
||||
"FFFFFFFFFFFFFFFF",
|
||||
"FFFFFFFFFFFFFFFF",
|
||||
"FFFFFFFFFFFFFFFF",
|
||||
"FFFFFFFFFFFFFFFF",
|
||||
"FFFFFFFFFFFFFFFF",
|
||||
"FFFFFFFFFFFFFFFG",
|
||||
"SFFFFFFFFF",
|
||||
"FFFFFFFFFF",
|
||||
"FFFFFFFFFF",
|
||||
"FFFFFFFFFF",
|
||||
"FFFFFFFFFF",
|
||||
"FFFFFFFFFF",
|
||||
"FFFFFFFFFF",
|
||||
"FFFFFFFFFF",
|
||||
"FFFFFFFFFF",
|
||||
"FFFFFFFFFG",
|
||||
],
|
||||
"is_slippery": False
|
||||
}
|
||||
# Print out observations to see how far we already get inside the Env.
|
||||
config["callbacks"] = MyCallBack
|
||||
# Limit horizon to make it really hard for non-curious agent to reach
|
||||
# the goal state.
|
||||
config["horizon"] = 40
|
||||
# config["train_batch_size"] = 2048
|
||||
# config["num_sgd_iter"] = 15
|
||||
config["num_workers"] = 0 # local only
|
||||
config["horizon"] = 23
|
||||
# Local only.
|
||||
config["num_workers"] = 0
|
||||
config["lr"] = 0.001
|
||||
|
||||
num_iterations = 40
|
||||
num_iterations = 10
|
||||
for _ in framework_iterator(config, frameworks="torch"):
|
||||
# W/ Curiosity. Expect to learn something.
|
||||
config["exploration_config"] = {
|
||||
"type": "Curiosity",
|
||||
"lr": 0.0003,
|
||||
"eta": 0.2,
|
||||
"lr": 0.001,
|
||||
"feature_dim": 128,
|
||||
"feature_net_config": {
|
||||
"fcnet_hiddens": [],
|
||||
|
@ -157,29 +172,28 @@ class TestCuriosity(unittest.TestCase):
|
|||
}
|
||||
trainer = ppo.PPOTrainer(config=config)
|
||||
learnt = False
|
||||
for _ in range(num_iterations):
|
||||
for i in range(num_iterations):
|
||||
result = trainer.train()
|
||||
print(result)
|
||||
if result["episode_reward_mean"] >= 0.001:
|
||||
print("Learnt something!")
|
||||
if result["episode_reward_max"] > 0.0:
|
||||
print("Reached goal after {} iters!".format(i))
|
||||
learnt = True
|
||||
break
|
||||
trainer.stop()
|
||||
self.assertTrue(learnt)
|
||||
|
||||
# # W/o Curiosity. Expect to learn nothing.
|
||||
# config["exploration_config"] = {
|
||||
# "type": "StochasticSampling",
|
||||
# }
|
||||
# trainer = ppo.PPOTrainer(config=config)
|
||||
# rewards_wo = 0.0
|
||||
# for _ in range(num_iterations):
|
||||
# result = trainer.train()
|
||||
# rewards_wo += result["episode_reward_mean"]
|
||||
# print(result)
|
||||
# trainer.stop()
|
||||
|
||||
# self.assertTrue(rewards_wo == 0.0)
|
||||
# W/o Curiosity. Expect to learn nothing.
|
||||
config["exploration_config"] = {
|
||||
"type": "StochasticSampling",
|
||||
}
|
||||
trainer = ppo.PPOTrainer(config=config)
|
||||
rewards_wo = 0.0
|
||||
for _ in range(num_iterations):
|
||||
result = trainer.train()
|
||||
rewards_wo += result["episode_reward_mean"]
|
||||
print(result)
|
||||
trainer.stop()
|
||||
self.assertTrue(rewards_wo == 0.0)
|
||||
|
||||
def test_curiosity_on_partially_observable_domain(self):
|
||||
config = ppo.DEFAULT_CONFIG.copy()
|
||||
|
|
|
@ -298,7 +298,8 @@ def check_compute_single_action(trainer,
|
|||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
obs_space = worker_set.local_worker().env.observation_space
|
||||
obs_space = worker_set.local_worker().for_policy(
|
||||
lambda p: p.observation_space)
|
||||
else:
|
||||
method_to_test = pol.compute_single_action
|
||||
obs_space = pol.observation_space
|
||||
|
|
|
@ -34,8 +34,8 @@ AgentID = Any
|
|||
PolicyID = str
|
||||
|
||||
# Type of the config["multiagent"]["policies"] dict for multi-agent training.
|
||||
MultiAgentPolicyConfigDict = Dict[PolicyID, Tuple[type, gym.Space, gym.Space,
|
||||
PartialTrainerConfigDict]]
|
||||
MultiAgentPolicyConfigDict = Dict[PolicyID, Tuple[Union[
|
||||
type, None], gym.Space, gym.Space, PartialTrainerConfigDict]]
|
||||
|
||||
# Represents an environment id.
|
||||
EnvID = int
|
||||
|
|
Loading…
Add table
Reference in a new issue