2020-10-15 18:21:30 +02:00
|
|
|
import gym
|
2019-06-03 06:49:24 +08:00
|
|
|
import logging
|
|
|
|
from types import FunctionType
|
2020-10-15 18:21:30 +02:00
|
|
|
from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
2019-06-03 06:49:24 +08:00
|
|
|
|
2020-03-07 14:47:58 -08:00
|
|
|
import ray
|
2019-06-03 06:49:24 +08:00
|
|
|
from ray.rllib.utils.annotations import DeveloperAPI
|
|
|
|
from ray.rllib.evaluation.rollout_worker import RolloutWorker, \
|
|
|
|
_validate_multiagent_config
|
|
|
|
from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter, \
|
|
|
|
ShuffledInput
|
2020-06-19 13:09:05 -07:00
|
|
|
from ray.rllib.env.env_context import EnvContext
|
|
|
|
from ray.rllib.policy import Policy
|
2020-06-11 14:29:57 +02:00
|
|
|
from ray.rllib.utils import merge_dicts
|
|
|
|
from ray.rllib.utils.framework import try_import_tf
|
2020-08-15 13:24:22 +02:00
|
|
|
from ray.rllib.utils.typing import PolicyID, TrainerConfigDict, EnvType
|
2019-06-03 06:49:24 +08:00
|
|
|
|
2020-06-30 10:13:20 +02:00
|
|
|
tf1, tf, tfv = try_import_tf()
|
2019-06-03 06:49:24 +08:00
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
2020-06-19 13:09:05 -07:00
|
|
|
# Generic type var for foreach_* methods.
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
2019-06-03 06:49:24 +08:00
|
|
|
|
|
|
|
@DeveloperAPI
|
2020-01-02 17:42:13 -08:00
|
|
|
class WorkerSet:
|
2019-06-03 06:49:24 +08:00
|
|
|
"""Represents a set of RolloutWorkers.
|
|
|
|
|
|
|
|
There must be one local worker copy, and zero or more remote workers.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self,
|
2020-08-20 17:05:57 +02:00
|
|
|
*,
|
|
|
|
env_creator: Optional[Callable[[EnvContext], EnvType]] = None,
|
2020-10-06 20:28:16 +02:00
|
|
|
validate_env: Optional[Callable[[EnvType], None]] = None,
|
2020-08-20 17:05:57 +02:00
|
|
|
policy_class: Optional[Type[Policy]] = None,
|
|
|
|
trainer_config: Optional[TrainerConfigDict] = None,
|
2020-06-19 13:09:05 -07:00
|
|
|
num_workers: int = 0,
|
2020-08-20 17:05:57 +02:00
|
|
|
logdir: Optional[str] = None,
|
2020-06-19 13:09:05 -07:00
|
|
|
_setup: bool = True):
|
2019-06-03 06:49:24 +08:00
|
|
|
"""Create a new WorkerSet and initialize its workers.
|
|
|
|
|
2020-09-20 11:27:02 +02:00
|
|
|
Args:
|
2020-08-20 17:05:57 +02:00
|
|
|
env_creator (Optional[Callable[[EnvContext], EnvType]]): Function
|
|
|
|
that returns env given env config.
|
2020-10-06 20:28:16 +02:00
|
|
|
validate_env (Optional[Callable[[EnvType], None]]): Optional
|
|
|
|
callable to validate the generated environment (only on
|
|
|
|
worker=0).
|
2020-08-20 17:05:57 +02:00
|
|
|
policy (Optional[Type[Policy]]): A rllib.policy.Policy class.
|
|
|
|
trainer_config (Optional[TrainerConfigDict]): Optional dict that
|
|
|
|
extends the common config of the Trainer class.
|
2019-06-03 06:49:24 +08:00
|
|
|
num_workers (int): Number of remote rollout workers to create.
|
2020-08-20 17:05:57 +02:00
|
|
|
logdir (Optional[str]): Optional logging directory for workers.
|
2019-06-03 06:49:24 +08:00
|
|
|
_setup (bool): Whether to setup workers. This is only for testing.
|
|
|
|
"""
|
|
|
|
|
|
|
|
if not trainer_config:
|
|
|
|
from ray.rllib.agents.trainer import COMMON_CONFIG
|
|
|
|
trainer_config = COMMON_CONFIG
|
|
|
|
|
|
|
|
self._env_creator = env_creator
|
2020-08-20 17:05:57 +02:00
|
|
|
self._policy_class = policy_class
|
2019-06-03 06:49:24 +08:00
|
|
|
self._remote_config = trainer_config
|
|
|
|
self._logdir = logdir
|
|
|
|
|
|
|
|
if _setup:
|
|
|
|
self._local_config = merge_dicts(
|
|
|
|
trainer_config,
|
|
|
|
{"tf_session_args": trainer_config["local_tf_session_args"]})
|
|
|
|
|
2020-09-03 17:27:05 +02:00
|
|
|
# Create a number of remote workers.
|
2020-08-24 15:29:55 -04:00
|
|
|
self._remote_workers = []
|
|
|
|
self.add_workers(num_workers)
|
|
|
|
|
2020-10-15 18:21:30 +02:00
|
|
|
# If num_workers > 0, get the action_spaces and observation_spaces
|
|
|
|
# to not be forced to create an Env on the driver.
|
|
|
|
if self._remote_workers:
|
|
|
|
remote_spaces = ray.get(self.remote_workers(
|
|
|
|
)[0].foreach_policy.remote(
|
|
|
|
lambda p, pid: (pid, p.observation_space, p.action_space)))
|
|
|
|
spaces = {e[0]: (e[1], e[2]) for e in remote_spaces}
|
|
|
|
else:
|
|
|
|
spaces = None
|
|
|
|
|
2020-09-03 17:27:05 +02:00
|
|
|
# Always create a local worker.
|
2020-10-06 20:28:16 +02:00
|
|
|
self._local_worker = self._make_worker(
|
|
|
|
cls=RolloutWorker,
|
|
|
|
env_creator=env_creator,
|
|
|
|
validate_env=validate_env,
|
2020-10-15 18:21:30 +02:00
|
|
|
policy_cls=self._policy_class,
|
2020-10-06 20:28:16 +02:00
|
|
|
worker_index=0,
|
2020-10-15 18:21:30 +02:00
|
|
|
num_workers=num_workers,
|
|
|
|
config=self._local_config,
|
|
|
|
spaces=spaces,
|
|
|
|
)
|
2019-06-03 06:49:24 +08:00
|
|
|
|
2020-06-19 13:09:05 -07:00
|
|
|
def local_worker(self) -> RolloutWorker:
|
2019-06-03 06:49:24 +08:00
|
|
|
"""Return the local rollout worker."""
|
|
|
|
return self._local_worker
|
|
|
|
|
2020-06-19 13:09:05 -07:00
|
|
|
def remote_workers(self) -> List["ActorHandle"]:
|
2019-06-03 06:49:24 +08:00
|
|
|
"""Return a list of remote rollout workers."""
|
|
|
|
return self._remote_workers
|
|
|
|
|
2020-06-19 13:09:05 -07:00
|
|
|
def sync_weights(self) -> None:
|
2020-03-07 14:47:58 -08:00
|
|
|
"""Syncs weights of remote workers with the local worker."""
|
2020-03-13 18:48:41 -07:00
|
|
|
if self.remote_workers():
|
|
|
|
weights = ray.put(self.local_worker().get_weights())
|
|
|
|
for e in self.remote_workers():
|
|
|
|
e.set_weights.remote(weights)
|
2020-03-07 14:47:58 -08:00
|
|
|
|
2020-06-19 13:09:05 -07:00
|
|
|
def add_workers(self, num_workers: int) -> None:
|
2020-02-11 00:22:07 +01:00
|
|
|
"""Creates and add a number of remote workers to this worker set.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
num_workers (int): The number of remote Workers to add to this
|
|
|
|
WorkerSet.
|
|
|
|
"""
|
2019-06-03 06:49:24 +08:00
|
|
|
remote_args = {
|
|
|
|
"num_cpus": self._remote_config["num_cpus_per_worker"],
|
|
|
|
"num_gpus": self._remote_config["num_gpus_per_worker"],
|
2020-09-06 20:56:48 -07:00
|
|
|
# memory=0 is an error, but memory=None means no limits.
|
|
|
|
"memory": self._remote_config["memory_per_worker"] or None,
|
|
|
|
"object_store_memory": self.
|
|
|
|
_remote_config["object_store_memory_per_worker"] or None,
|
2019-06-03 06:49:24 +08:00
|
|
|
"resources": self._remote_config["custom_resources_per_worker"],
|
|
|
|
}
|
|
|
|
cls = RolloutWorker.as_remote(**remote_args).remote
|
|
|
|
self._remote_workers.extend([
|
2020-10-06 20:28:16 +02:00
|
|
|
self._make_worker(
|
|
|
|
cls=cls,
|
|
|
|
env_creator=self._env_creator,
|
|
|
|
validate_env=None,
|
2020-10-15 18:21:30 +02:00
|
|
|
policy_cls=self._policy_class,
|
2020-10-06 20:28:16 +02:00
|
|
|
worker_index=i + 1,
|
2020-10-15 18:21:30 +02:00
|
|
|
num_workers=num_workers,
|
2020-10-06 20:28:16 +02:00
|
|
|
config=self._remote_config) for i in range(num_workers)
|
2019-06-03 06:49:24 +08:00
|
|
|
])
|
|
|
|
|
2020-06-19 13:09:05 -07:00
|
|
|
def reset(self, new_remote_workers: List["ActorHandle"]) -> None:
|
2019-06-03 06:49:24 +08:00
|
|
|
"""Called to change the set of remote workers."""
|
|
|
|
self._remote_workers = new_remote_workers
|
|
|
|
|
2020-06-19 13:09:05 -07:00
|
|
|
def stop(self) -> None:
|
2019-06-03 06:49:24 +08:00
|
|
|
"""Stop all rollout workers."""
|
2020-10-28 17:23:06 -04:00
|
|
|
try:
|
|
|
|
self.local_worker().stop()
|
|
|
|
tids = [w.stop.remote() for w in self.remote_workers()]
|
|
|
|
ray.get(tids)
|
|
|
|
except Exception:
|
|
|
|
logger.exception("Failed to stop workers")
|
|
|
|
finally:
|
|
|
|
for w in self.remote_workers():
|
|
|
|
w.__ray_terminate__.remote()
|
2019-06-03 06:49:24 +08:00
|
|
|
|
|
|
|
@DeveloperAPI
|
2020-06-19 13:09:05 -07:00
|
|
|
def foreach_worker(self, func: Callable[[RolloutWorker], T]) -> List[T]:
|
2019-06-03 06:49:24 +08:00
|
|
|
"""Apply the given function to each worker instance."""
|
|
|
|
|
|
|
|
local_result = [func(self.local_worker())]
|
2020-05-21 10:16:18 -07:00
|
|
|
remote_results = ray.get(
|
2019-06-03 06:49:24 +08:00
|
|
|
[w.apply.remote(func) for w in self.remote_workers()])
|
|
|
|
return local_result + remote_results
|
|
|
|
|
|
|
|
@DeveloperAPI
|
2020-06-19 13:09:05 -07:00
|
|
|
def foreach_worker_with_index(
|
|
|
|
self, func: Callable[[RolloutWorker, int], T]) -> List[T]:
|
2019-06-03 06:49:24 +08:00
|
|
|
"""Apply the given function to each worker instance.
|
|
|
|
|
|
|
|
The index will be passed as the second arg to the given function.
|
|
|
|
"""
|
|
|
|
local_result = [func(self.local_worker(), 0)]
|
2020-05-21 10:16:18 -07:00
|
|
|
remote_results = ray.get([
|
2019-06-03 06:49:24 +08:00
|
|
|
w.apply.remote(func, i + 1)
|
|
|
|
for i, w in enumerate(self.remote_workers())
|
|
|
|
])
|
|
|
|
return local_result + remote_results
|
|
|
|
|
2020-02-11 00:22:07 +01:00
|
|
|
@DeveloperAPI
|
2020-06-19 13:09:05 -07:00
|
|
|
def foreach_policy(self, func: Callable[[Policy, PolicyID], T]) -> List[T]:
|
2020-02-11 00:22:07 +01:00
|
|
|
"""Apply the given function to each worker's (policy, policy_id) tuple.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
func (callable): A function - taking a Policy and its ID - that is
|
|
|
|
called on all workers' Policies.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
List[any]: The list of return values of func over all workers'
|
|
|
|
policies.
|
|
|
|
"""
|
|
|
|
local_results = self.local_worker().foreach_policy(func)
|
|
|
|
remote_results = []
|
|
|
|
for worker in self.remote_workers():
|
2020-05-21 10:16:18 -07:00
|
|
|
res = ray.get(
|
2020-02-11 00:22:07 +01:00
|
|
|
worker.apply.remote(lambda w: w.foreach_policy(func)))
|
|
|
|
remote_results.extend(res)
|
|
|
|
return local_results + remote_results
|
|
|
|
|
2020-04-30 01:18:09 -07:00
|
|
|
@DeveloperAPI
|
2020-06-19 13:09:05 -07:00
|
|
|
def trainable_policies(self) -> List[PolicyID]:
|
2020-04-30 01:18:09 -07:00
|
|
|
"""Return the list of trainable policy ids."""
|
|
|
|
return self.local_worker().foreach_trainable_policy(lambda _, pid: pid)
|
|
|
|
|
2020-02-11 00:22:07 +01:00
|
|
|
@DeveloperAPI
|
2020-06-19 13:09:05 -07:00
|
|
|
def foreach_trainable_policy(
|
|
|
|
self, func: Callable[[Policy, PolicyID], T]) -> List[T]:
|
2020-02-11 00:22:07 +01:00
|
|
|
"""Apply `func` to all workers' Policies iff in `policies_to_train`.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
func (callable): A function - taking a Policy and its ID - that is
|
|
|
|
called on all workers' Policies in `worker.policies_to_train`.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
List[any]: The list of n return values of all
|
|
|
|
`func([trainable policy], [ID])`-calls.
|
|
|
|
"""
|
|
|
|
local_results = self.local_worker().foreach_trainable_policy(func)
|
|
|
|
remote_results = []
|
|
|
|
for worker in self.remote_workers():
|
2020-05-21 10:16:18 -07:00
|
|
|
res = ray.get(
|
2020-02-11 00:22:07 +01:00
|
|
|
worker.apply.remote(
|
|
|
|
lambda w: w.foreach_trainable_policy(func)))
|
|
|
|
remote_results.extend(res)
|
|
|
|
return local_results + remote_results
|
|
|
|
|
2019-06-03 06:49:24 +08:00
|
|
|
@staticmethod
|
2020-06-19 13:09:05 -07:00
|
|
|
def _from_existing(local_worker: RolloutWorker,
|
|
|
|
remote_workers: List["ActorHandle"] = None):
|
2020-08-20 17:05:57 +02:00
|
|
|
workers = WorkerSet(
|
|
|
|
env_creator=None,
|
|
|
|
policy_class=None,
|
|
|
|
trainer_config={},
|
|
|
|
_setup=False)
|
2019-06-03 06:49:24 +08:00
|
|
|
workers._local_worker = local_worker
|
|
|
|
workers._remote_workers = remote_workers or []
|
|
|
|
return workers
|
|
|
|
|
2020-06-19 13:09:05 -07:00
|
|
|
def _make_worker(
|
2020-10-15 18:21:30 +02:00
|
|
|
self,
|
|
|
|
*,
|
|
|
|
cls: Callable,
|
2020-10-06 20:28:16 +02:00
|
|
|
env_creator: Callable[[EnvContext], EnvType],
|
|
|
|
validate_env: Optional[Callable[[EnvType], None]],
|
2020-10-15 18:21:30 +02:00
|
|
|
policy_cls: Type[Policy],
|
|
|
|
worker_index: int,
|
|
|
|
num_workers: int,
|
|
|
|
config: TrainerConfigDict,
|
|
|
|
spaces: Optional[Dict[PolicyID, Tuple[gym.spaces.Space,
|
|
|
|
gym.spaces.Space]]] = None,
|
|
|
|
) -> Union[RolloutWorker, "ActorHandle"]:
|
2019-06-03 06:49:24 +08:00
|
|
|
def session_creator():
|
|
|
|
logger.debug("Creating TF session {}".format(
|
|
|
|
config["tf_session_args"]))
|
2020-06-30 10:13:20 +02:00
|
|
|
return tf1.Session(
|
|
|
|
config=tf1.ConfigProto(**config["tf_session_args"]))
|
2019-06-03 06:49:24 +08:00
|
|
|
|
|
|
|
if isinstance(config["input"], FunctionType):
|
|
|
|
input_creator = config["input"]
|
|
|
|
elif config["input"] == "sampler":
|
|
|
|
input_creator = (lambda ioctx: ioctx.default_sampler_input())
|
|
|
|
elif isinstance(config["input"], dict):
|
|
|
|
input_creator = (lambda ioctx: ShuffledInput(
|
|
|
|
MixedInput(config["input"], ioctx), config[
|
|
|
|
"shuffle_buffer_size"]))
|
|
|
|
else:
|
|
|
|
input_creator = (lambda ioctx: ShuffledInput(
|
|
|
|
JsonReader(config["input"], ioctx), config[
|
|
|
|
"shuffle_buffer_size"]))
|
|
|
|
|
|
|
|
if isinstance(config["output"], FunctionType):
|
|
|
|
output_creator = config["output"]
|
|
|
|
elif config["output"] is None:
|
|
|
|
output_creator = (lambda ioctx: NoopOutput())
|
|
|
|
elif config["output"] == "logdir":
|
|
|
|
output_creator = (lambda ioctx: JsonWriter(
|
|
|
|
ioctx.log_dir,
|
|
|
|
ioctx,
|
|
|
|
max_file_size=config["output_max_file_size"],
|
|
|
|
compress_columns=config["output_compress_columns"]))
|
|
|
|
else:
|
|
|
|
output_creator = (lambda ioctx: JsonWriter(
|
|
|
|
config["output"],
|
|
|
|
ioctx,
|
|
|
|
max_file_size=config["output_max_file_size"],
|
|
|
|
compress_columns=config["output_compress_columns"]))
|
|
|
|
|
|
|
|
if config["input"] == "sampler":
|
|
|
|
input_evaluation = []
|
|
|
|
else:
|
|
|
|
input_evaluation = config["input_evaluation"]
|
|
|
|
|
2020-10-15 18:21:30 +02:00
|
|
|
# Fill in the default policy_cls if 'None' is specified in multiagent.
|
2019-06-03 06:49:24 +08:00
|
|
|
if config["multiagent"]["policies"]:
|
|
|
|
tmp = config["multiagent"]["policies"]
|
|
|
|
_validate_multiagent_config(tmp, allow_none_graph=True)
|
2020-10-15 18:21:30 +02:00
|
|
|
# TODO: (sven) Allow for setting observation and action spaces to
|
|
|
|
# None as well, in which case, spaces are taken from env.
|
|
|
|
# It's tedious to have to provide these in a multi-agent config.
|
2019-06-03 06:49:24 +08:00
|
|
|
for k, v in tmp.items():
|
|
|
|
if v[0] is None:
|
2020-10-15 18:21:30 +02:00
|
|
|
tmp[k] = (policy_cls, v[1], v[2], v[3])
|
|
|
|
policy_spec = tmp
|
|
|
|
# Otherwise, policy spec is simply the policy class itself.
|
|
|
|
else:
|
|
|
|
policy_spec = policy_cls
|
2019-06-03 06:49:24 +08:00
|
|
|
|
2020-04-16 16:13:45 +08:00
|
|
|
if worker_index == 0:
|
|
|
|
extra_python_environs = config.get(
|
|
|
|
"extra_python_environs_for_driver", None)
|
|
|
|
else:
|
|
|
|
extra_python_environs = config.get(
|
|
|
|
"extra_python_environs_for_worker", None)
|
|
|
|
|
2020-03-30 23:03:29 +02:00
|
|
|
worker = cls(
|
2020-10-06 20:28:16 +02:00
|
|
|
env_creator=env_creator,
|
|
|
|
validate_env=validate_env,
|
2020-10-15 18:21:30 +02:00
|
|
|
policy_spec=policy_spec,
|
2019-06-03 06:49:24 +08:00
|
|
|
policy_mapping_fn=config["multiagent"]["policy_mapping_fn"],
|
|
|
|
policies_to_train=config["multiagent"]["policies_to_train"],
|
|
|
|
tf_session_creator=(session_creator
|
|
|
|
if config["tf_session_args"] else None),
|
2020-03-14 12:05:04 -07:00
|
|
|
rollout_fragment_length=config["rollout_fragment_length"],
|
2019-06-03 06:49:24 +08:00
|
|
|
batch_mode=config["batch_mode"],
|
|
|
|
episode_horizon=config["horizon"],
|
|
|
|
preprocessor_pref=config["preprocessor_pref"],
|
|
|
|
sample_async=config["sample_async"],
|
|
|
|
compress_observations=config["compress_observations"],
|
|
|
|
num_envs=config["num_envs_per_worker"],
|
2020-05-04 22:13:49 -07:00
|
|
|
observation_fn=config["multiagent"]["observation_fn"],
|
2019-06-03 06:49:24 +08:00
|
|
|
observation_filter=config["observation_filter"],
|
|
|
|
clip_rewards=config["clip_rewards"],
|
|
|
|
clip_actions=config["clip_actions"],
|
|
|
|
env_config=config["env_config"],
|
|
|
|
model_config=config["model"],
|
|
|
|
policy_config=config,
|
|
|
|
worker_index=worker_index,
|
2020-10-15 18:21:30 +02:00
|
|
|
num_workers=num_workers,
|
2019-06-03 06:49:24 +08:00
|
|
|
monitor_path=self._logdir if config["monitor"] else None,
|
|
|
|
log_dir=self._logdir,
|
|
|
|
log_level=config["log_level"],
|
|
|
|
callbacks=config["callbacks"],
|
|
|
|
input_creator=input_creator,
|
|
|
|
input_evaluation=input_evaluation,
|
|
|
|
output_creator=output_creator,
|
|
|
|
remote_worker_envs=config["remote_worker_envs"],
|
|
|
|
remote_env_batch_wait_ms=config["remote_env_batch_wait_ms"],
|
|
|
|
soft_horizon=config["soft_horizon"],
|
2019-08-01 23:37:36 -07:00
|
|
|
no_done_at_end=config["no_done_at_end"],
|
2019-07-18 14:31:34 +08:00
|
|
|
seed=(config["seed"] + worker_index)
|
|
|
|
if config["seed"] is not None else None,
|
2020-05-11 20:24:43 -07:00
|
|
|
fake_sampler=config["fake_sampler"],
|
2020-10-15 18:21:30 +02:00
|
|
|
extra_python_environs=extra_python_environs,
|
|
|
|
spaces=spaces,
|
|
|
|
)
|
2020-03-30 23:03:29 +02:00
|
|
|
|
|
|
|
return worker
|