mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
767 lines
30 KiB
Python
767 lines
30 KiB
Python
import gym
|
|
import logging
|
|
import importlib.util
|
|
import os
|
|
from types import FunctionType
|
|
from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
|
|
|
import ray
|
|
from ray.actor import ActorHandle
|
|
from ray.exceptions import RayError
|
|
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
|
from ray.rllib.env.base_env import BaseEnv
|
|
from ray.rllib.env.env_context import EnvContext
|
|
from ray.rllib.offline import (
|
|
NoopOutput,
|
|
JsonReader,
|
|
MixedInput,
|
|
JsonWriter,
|
|
ShuffledInput,
|
|
D4RLReader,
|
|
DatasetReader,
|
|
DatasetWriter,
|
|
get_dataset_and_shards,
|
|
)
|
|
from ray.rllib.policy.policy import Policy, PolicySpec
|
|
from ray.rllib.utils import merge_dicts
|
|
from ray.rllib.utils.annotations import DeveloperAPI
|
|
from ray.rllib.utils.deprecation import Deprecated
|
|
from ray.rllib.utils.framework import try_import_tf
|
|
from ray.rllib.utils.from_config import from_config
|
|
from ray.rllib.utils.typing import (
|
|
EnvCreator,
|
|
EnvType,
|
|
PolicyID,
|
|
SampleBatchType,
|
|
TensorType,
|
|
AlgorithmConfigDict,
|
|
)
|
|
from ray.tune.registry import registry_contains_input, registry_get_input
|
|
|
|
tf1, tf, tfv = try_import_tf()
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Generic type var for foreach_* methods.
|
|
T = TypeVar("T")
|
|
|
|
|
|
@DeveloperAPI
|
|
class WorkerSet:
|
|
"""Set of RolloutWorkers with n @ray.remote workers and zero or one local worker.
|
|
|
|
Where: n >= 0.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
env_creator: Optional[EnvCreator] = None,
|
|
validate_env: Optional[Callable[[EnvType], None]] = None,
|
|
policy_class: Optional[Type[Policy]] = None,
|
|
trainer_config: Optional[AlgorithmConfigDict] = None,
|
|
num_workers: int = 0,
|
|
local_worker: bool = True,
|
|
logdir: Optional[str] = None,
|
|
_setup: bool = True,
|
|
):
|
|
"""Initializes a WorkerSet instance.
|
|
|
|
Args:
|
|
env_creator: Function that returns env given env config.
|
|
validate_env: Optional callable to validate the generated
|
|
environment (only on worker=0).
|
|
policy_class: An optional Policy class. If None, PolicySpecs can be
|
|
generated automatically by using the Algorithm's default class
|
|
of via a given multi-agent policy config dict.
|
|
trainer_config: Optional dict that extends the common config of
|
|
the Algorithm class.
|
|
num_workers: Number of remote rollout workers to create.
|
|
local_worker: Whether to create a local (non @ray.remote) worker
|
|
in the returned set as well (default: True). If `num_workers`
|
|
is 0, always create a local worker.
|
|
logdir: Optional logging directory for workers.
|
|
_setup: Whether to setup workers. This is only for testing.
|
|
"""
|
|
|
|
if not trainer_config:
|
|
from ray.rllib.algorithms.algorithm import COMMON_CONFIG
|
|
|
|
trainer_config = COMMON_CONFIG
|
|
|
|
self._env_creator = env_creator
|
|
self._policy_class = policy_class
|
|
self._remote_config = trainer_config
|
|
self._remote_args = {
|
|
"num_cpus": self._remote_config["num_cpus_per_worker"],
|
|
"num_gpus": self._remote_config["num_gpus_per_worker"],
|
|
"resources": self._remote_config["custom_resources_per_worker"],
|
|
}
|
|
self._cls = RolloutWorker.as_remote(**self._remote_args).remote
|
|
self._logdir = logdir
|
|
if _setup:
|
|
# Force a local worker if num_workers == 0 (no remote workers).
|
|
# Otherwise, this WorkerSet would be empty.
|
|
self._local_worker = None
|
|
if num_workers == 0:
|
|
local_worker = True
|
|
self._local_config = merge_dicts(
|
|
trainer_config,
|
|
{"tf_session_args": trainer_config["local_tf_session_args"]},
|
|
)
|
|
|
|
if trainer_config["input"] == "dataset":
|
|
# Create the set of dataset readers to be shared by all the
|
|
# rollout workers.
|
|
self._ds, self._ds_shards = get_dataset_and_shards(
|
|
trainer_config, num_workers
|
|
)
|
|
else:
|
|
self._ds = None
|
|
self._ds_shards = None
|
|
|
|
# Create a number of @ray.remote workers.
|
|
self._remote_workers = []
|
|
self.add_workers(
|
|
num_workers,
|
|
validate=trainer_config.get("validate_workers_after_construction"),
|
|
)
|
|
|
|
# Create a local worker, if needed.
|
|
# If num_workers > 0 and we don't have an env on the local worker,
|
|
# get the observation- and action spaces for each policy from
|
|
# the first remote worker (which does have an env).
|
|
if (
|
|
local_worker
|
|
and self._remote_workers
|
|
and not trainer_config.get("create_env_on_driver")
|
|
and (
|
|
not trainer_config.get("observation_space")
|
|
or not trainer_config.get("action_space")
|
|
)
|
|
):
|
|
remote_spaces = ray.get(
|
|
self.remote_workers()[0].foreach_policy.remote(
|
|
lambda p, pid: (pid, p.observation_space, p.action_space)
|
|
)
|
|
)
|
|
spaces = {
|
|
e[0]: (getattr(e[1], "original_space", e[1]), e[2])
|
|
for e in remote_spaces
|
|
}
|
|
# Try to add the actual env's obs/action spaces.
|
|
try:
|
|
env_spaces = ray.get(
|
|
self.remote_workers()[0].foreach_env.remote(
|
|
lambda env: (env.observation_space, env.action_space)
|
|
)
|
|
)[0]
|
|
spaces["__env__"] = env_spaces
|
|
except Exception:
|
|
pass
|
|
|
|
logger.info(
|
|
"Inferred observation/action spaces from remote "
|
|
f"worker (local worker has no env): {spaces}"
|
|
)
|
|
else:
|
|
spaces = None
|
|
|
|
if local_worker:
|
|
self._local_worker = self._make_worker(
|
|
cls=RolloutWorker,
|
|
env_creator=env_creator,
|
|
validate_env=validate_env,
|
|
policy_cls=self._policy_class,
|
|
worker_index=0,
|
|
num_workers=num_workers,
|
|
config=self._local_config,
|
|
spaces=spaces,
|
|
)
|
|
|
|
def local_worker(self) -> RolloutWorker:
|
|
"""Returns the local rollout worker."""
|
|
return self._local_worker
|
|
|
|
def remote_workers(self) -> List[ActorHandle]:
|
|
"""Returns a list of remote rollout workers."""
|
|
return self._remote_workers
|
|
|
|
def sync_weights(
|
|
self,
|
|
policies: Optional[List[PolicyID]] = None,
|
|
from_worker: Optional[RolloutWorker] = None,
|
|
global_vars: Optional[Dict[str, TensorType]] = None,
|
|
) -> None:
|
|
"""Syncs model weights from the local worker to all remote workers.
|
|
|
|
Args:
|
|
policies: Optional list of PolicyIDs to sync weights for.
|
|
If None (default), sync weights to/from all policies.
|
|
from_worker: Optional RolloutWorker instance to sync from.
|
|
If None (default), sync from this WorkerSet's local worker.
|
|
global_vars: An optional global vars dict to set this
|
|
worker to. If None, do not update the global_vars.
|
|
"""
|
|
if self.local_worker() is None and from_worker is None:
|
|
raise TypeError(
|
|
"No `local_worker` in WorkerSet, must provide `from_worker` "
|
|
"arg in `sync_weights()`!"
|
|
)
|
|
|
|
# Only sync if we have remote workers or `from_worker` is provided.
|
|
weights = None
|
|
if self.remote_workers() or from_worker is not None:
|
|
weights = (from_worker or self.local_worker()).get_weights(policies)
|
|
# Put weights only once into object store and use same object
|
|
# ref to synch to all workers.
|
|
weights_ref = ray.put(weights)
|
|
# Sync to all remote workers in this WorkerSet.
|
|
for to_worker in self.remote_workers():
|
|
to_worker.set_weights.remote(weights_ref, global_vars=global_vars)
|
|
|
|
# If `from_worker` is provided, also sync to this WorkerSet's
|
|
# local worker.
|
|
if from_worker is not None and self.local_worker() is not None:
|
|
self.local_worker().set_weights(weights, global_vars=global_vars)
|
|
# If `global_vars` is provided and local worker exists -> Update its
|
|
# global_vars.
|
|
elif self.local_worker() is not None and global_vars is not None:
|
|
self.local_worker().set_global_vars(global_vars)
|
|
|
|
def add_workers(self, num_workers: int, validate: bool = False) -> None:
|
|
"""Creates and adds a number of remote workers to this worker set.
|
|
|
|
Can be called several times on the same WorkerSet to add more
|
|
RolloutWorkers to the set.
|
|
|
|
Args:
|
|
num_workers: The number of remote Workers to add to this
|
|
WorkerSet.
|
|
validate: Whether to validate remote workers after their construction
|
|
process.
|
|
|
|
Raises:
|
|
RayError: If any of the constructed remote workers is not up and running
|
|
properly.
|
|
"""
|
|
old_num_workers = len(self._remote_workers)
|
|
self._remote_workers.extend(
|
|
[
|
|
self._make_worker(
|
|
cls=self._cls,
|
|
env_creator=self._env_creator,
|
|
validate_env=None,
|
|
policy_cls=self._policy_class,
|
|
worker_index=old_num_workers + i + 1,
|
|
num_workers=old_num_workers + num_workers,
|
|
config=self._remote_config,
|
|
)
|
|
for i in range(num_workers)
|
|
]
|
|
)
|
|
|
|
# Validate here, whether all remote workers have been constructed properly
|
|
# and are "up and running". If not, the following will throw a RayError
|
|
# which needs to be handled by this WorkerSet's owner (usually
|
|
# a RLlib Algorithm instance).
|
|
if validate:
|
|
self.foreach_worker(lambda w: w.assert_healthy())
|
|
|
|
def reset(self, new_remote_workers: List[ActorHandle]) -> None:
|
|
"""Hard overrides the remote workers in this set with the given one.
|
|
|
|
Args:
|
|
new_remote_workers: A list of new RolloutWorkers
|
|
(as `ActorHandles`) to use as remote workers.
|
|
"""
|
|
self._remote_workers = new_remote_workers
|
|
|
|
def remove_failed_workers(self):
|
|
faulty_indices = self._worker_health_check()
|
|
removed_workers = []
|
|
# Terminate faulty workers.
|
|
for worker_index in faulty_indices:
|
|
worker = self.remote_workers()[worker_index - 1]
|
|
logger.info(f"Trying to terminate faulty worker {worker_index}.")
|
|
try:
|
|
worker.__ray_terminate__.remote()
|
|
removed_workers.append(worker)
|
|
except Exception:
|
|
logger.exception("Error terminating faulty worker.")
|
|
|
|
# Remove all faulty workers from self._remote_workers.
|
|
for worker_index in reversed(faulty_indices):
|
|
del self._remote_workers[worker_index - 1]
|
|
# TODO: Should we also change each healthy worker's num_workers counter and
|
|
# worker_index property?
|
|
|
|
if len(self.remote_workers()) == 0:
|
|
raise RuntimeError(
|
|
f"No healthy workers remaining (worker indices {faulty_indices} have "
|
|
f"died)! Can't continue training."
|
|
)
|
|
return removed_workers
|
|
|
|
def recreate_failed_workers(
|
|
self, local_worker_for_synching: RolloutWorker
|
|
) -> Tuple[List[ActorHandle], List[ActorHandle]]:
|
|
"""Recreates any failed workers (after health check).
|
|
|
|
Args:
|
|
local_worker_for_synching: RolloutWorker to use to synchronize the weights
|
|
after recreation.
|
|
|
|
Returns:
|
|
A tuple consisting of two items: The list of removed workers and the list of
|
|
newly added ones.
|
|
"""
|
|
faulty_indices = self._worker_health_check()
|
|
removed_workers = []
|
|
new_workers = []
|
|
for worker_index in faulty_indices:
|
|
worker = self.remote_workers()[worker_index - 1]
|
|
removed_workers.append(worker)
|
|
logger.info(f"Trying to recreate faulty worker {worker_index}")
|
|
try:
|
|
worker.__ray_terminate__.remote()
|
|
except Exception:
|
|
logger.exception("Error terminating faulty worker.")
|
|
# Try to recreate the failed worker (start a new one).
|
|
new_worker = self._make_worker(
|
|
cls=self._cls,
|
|
env_creator=self._env_creator,
|
|
validate_env=None,
|
|
policy_cls=self._policy_class,
|
|
worker_index=worker_index,
|
|
num_workers=len(self._remote_workers),
|
|
recreated_worker=True,
|
|
config=self._remote_config,
|
|
)
|
|
|
|
# Sync new worker from provided one (or local one).
|
|
new_worker.set_weights.remote(
|
|
weights=local_worker_for_synching.get_weights(),
|
|
global_vars=local_worker_for_synching.get_global_vars(),
|
|
)
|
|
|
|
# Add new worker to list of remote workers.
|
|
self._remote_workers[worker_index - 1] = new_worker
|
|
new_workers.append(new_worker)
|
|
|
|
return removed_workers, new_workers
|
|
|
|
def stop(self) -> None:
|
|
"""Calls `stop` on all rollout workers (including the local one)."""
|
|
try:
|
|
if self.local_worker():
|
|
self.local_worker().stop()
|
|
tids = [w.stop.remote() for w in self.remote_workers()]
|
|
ray.get(tids)
|
|
except Exception:
|
|
logger.exception("Failed to stop workers!")
|
|
finally:
|
|
for w in self.remote_workers():
|
|
w.__ray_terminate__.remote()
|
|
|
|
@DeveloperAPI
|
|
def is_policy_to_train(
|
|
self, policy_id: PolicyID, batch: Optional[SampleBatchType] = None
|
|
) -> bool:
|
|
"""Whether given PolicyID (optionally inside some batch) is trainable."""
|
|
local_worker = self.local_worker()
|
|
if local_worker:
|
|
return local_worker.is_policy_to_train(policy_id, batch)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
@DeveloperAPI
|
|
def foreach_worker(self, func: Callable[[RolloutWorker], T]) -> List[T]:
|
|
"""Calls the given function with each worker instance as arg.
|
|
|
|
Args:
|
|
func: The function to call for each worker (as only arg).
|
|
|
|
Returns:
|
|
The list of return values of all calls to `func([worker])`.
|
|
"""
|
|
local_result = []
|
|
if self.local_worker() is not None:
|
|
local_result = [func(self.local_worker())]
|
|
remote_results = ray.get([w.apply.remote(func) for w in self.remote_workers()])
|
|
return local_result + remote_results
|
|
|
|
@DeveloperAPI
|
|
def foreach_worker_with_index(
|
|
self, func: Callable[[RolloutWorker, int], T]
|
|
) -> List[T]:
|
|
"""Calls `func` with each worker instance and worker idx as args.
|
|
|
|
The index will be passed as the second arg to the given function.
|
|
|
|
Args:
|
|
func: The function to call for each worker and its index
|
|
(as args). The local worker has index 0, all remote workers
|
|
have indices > 0.
|
|
|
|
Returns:
|
|
The list of return values of all calls to `func([worker, idx])`.
|
|
The first entry in this list are the results of the local
|
|
worker, followed by all remote workers' results.
|
|
"""
|
|
local_result = []
|
|
# Local worker: Index=0.
|
|
if self.local_worker() is not None:
|
|
local_result = [func(self.local_worker(), 0)]
|
|
# Remote workers: Index > 0.
|
|
remote_results = ray.get(
|
|
[w.apply.remote(func, i + 1) for i, w in enumerate(self.remote_workers())]
|
|
)
|
|
return local_result + remote_results
|
|
|
|
@DeveloperAPI
|
|
def foreach_policy(self, func: Callable[[Policy, PolicyID], T]) -> List[T]:
|
|
"""Calls `func` with each worker's (policy, PolicyID) tuple.
|
|
|
|
Note that in the multi-agent case, each worker may have more than one
|
|
policy.
|
|
|
|
Args:
|
|
func: A function - taking a Policy and its ID - that is
|
|
called on all workers' Policies.
|
|
|
|
Returns:
|
|
The list of return values of func over all workers' policies. The
|
|
length of this list is:
|
|
(num_workers + 1 (local-worker)) *
|
|
[num policies in the multi-agent config dict].
|
|
The local workers' results are first, followed by all remote
|
|
workers' results
|
|
"""
|
|
results = []
|
|
if self.local_worker() is not None:
|
|
results = self.local_worker().foreach_policy(func)
|
|
ray_gets = []
|
|
for worker in self.remote_workers():
|
|
ray_gets.append(worker.apply.remote(lambda w: w.foreach_policy(func)))
|
|
remote_results = ray.get(ray_gets)
|
|
for r in remote_results:
|
|
results.extend(r)
|
|
return results
|
|
|
|
@DeveloperAPI
|
|
def foreach_policy_to_train(self, func: Callable[[Policy, PolicyID], T]) -> List[T]:
|
|
"""Apply `func` to all workers' Policies iff in `policies_to_train`.
|
|
|
|
Args:
|
|
func: A function - taking a Policy and its ID - that is
|
|
called on all workers' Policies, for which
|
|
`worker.is_policy_to_train()` returns True.
|
|
|
|
Returns:
|
|
List[any]: The list of n return values of all
|
|
`func([trainable policy], [ID])`-calls.
|
|
"""
|
|
results = []
|
|
if self.local_worker() is not None:
|
|
results = self.local_worker().foreach_policy_to_train(func)
|
|
ray_gets = []
|
|
for worker in self.remote_workers():
|
|
ray_gets.append(
|
|
worker.apply.remote(lambda w: w.foreach_policy_to_train(func))
|
|
)
|
|
remote_results = ray.get(ray_gets)
|
|
for r in remote_results:
|
|
results.extend(r)
|
|
return results
|
|
|
|
@DeveloperAPI
|
|
def foreach_env(self, func: Callable[[EnvType], List[T]]) -> List[List[T]]:
|
|
"""Calls `func` with all workers' sub-environments as args.
|
|
|
|
An "underlying sub environment" is a single clone of an env within
|
|
a vectorized environment.
|
|
`func` takes a single underlying sub environment as arg, e.g. a
|
|
gym.Env object.
|
|
|
|
Args:
|
|
func: A function - taking an EnvType (normally a gym.Env object)
|
|
as arg and returning a list of lists of return values, one
|
|
value per underlying sub-environment per each worker.
|
|
|
|
Returns:
|
|
The list (workers) of lists (sub environments) of results.
|
|
"""
|
|
local_results = []
|
|
if self.local_worker() is not None:
|
|
local_results = [self.local_worker().foreach_env(func)]
|
|
ray_gets = []
|
|
for worker in self.remote_workers():
|
|
ray_gets.append(worker.foreach_env.remote(func))
|
|
return local_results + ray.get(ray_gets)
|
|
|
|
@DeveloperAPI
|
|
def foreach_env_with_context(
|
|
self, func: Callable[[BaseEnv, EnvContext], List[T]]
|
|
) -> List[List[T]]:
|
|
"""Calls `func` with all workers' sub-environments and env_ctx as args.
|
|
|
|
An "underlying sub environment" is a single clone of an env within
|
|
a vectorized environment.
|
|
`func` takes a single underlying sub environment and the env_context
|
|
as args.
|
|
|
|
Args:
|
|
func: A function - taking a BaseEnv object and an EnvContext as
|
|
arg - and returning a list of lists of return values over envs
|
|
of the worker.
|
|
|
|
Returns:
|
|
The list (1 item per workers) of lists (1 item per sub-environment)
|
|
of results.
|
|
"""
|
|
local_results = []
|
|
if self.local_worker() is not None:
|
|
local_results = [self.local_worker().foreach_env_with_context(func)]
|
|
ray_gets = []
|
|
for worker in self.remote_workers():
|
|
ray_gets.append(worker.foreach_env_with_context.remote(func))
|
|
return local_results + ray.get(ray_gets)
|
|
|
|
@staticmethod
|
|
def _from_existing(
|
|
local_worker: RolloutWorker, remote_workers: List[ActorHandle] = None
|
|
):
|
|
workers = WorkerSet(
|
|
env_creator=None, policy_class=None, trainer_config={}, _setup=False
|
|
)
|
|
workers._local_worker = local_worker
|
|
workers._remote_workers = remote_workers or []
|
|
return workers
|
|
|
|
def _make_worker(
|
|
self,
|
|
*,
|
|
cls: Callable,
|
|
env_creator: EnvCreator,
|
|
validate_env: Optional[Callable[[EnvType], None]],
|
|
policy_cls: Type[Policy],
|
|
worker_index: int,
|
|
num_workers: int,
|
|
recreated_worker: bool = False,
|
|
config: AlgorithmConfigDict,
|
|
spaces: Optional[
|
|
Dict[PolicyID, Tuple[gym.spaces.Space, gym.spaces.Space]]
|
|
] = None,
|
|
) -> Union[RolloutWorker, ActorHandle]:
|
|
def session_creator():
|
|
logger.debug("Creating TF session {}".format(config["tf_session_args"]))
|
|
return tf1.Session(config=tf1.ConfigProto(**config["tf_session_args"]))
|
|
|
|
def valid_module(class_path):
|
|
if (
|
|
isinstance(class_path, str)
|
|
and not os.path.isfile(class_path)
|
|
and "." in class_path
|
|
):
|
|
module_path, class_name = class_path.rsplit(".", 1)
|
|
try:
|
|
spec = importlib.util.find_spec(module_path)
|
|
if spec is not None:
|
|
return True
|
|
except (ModuleNotFoundError, ValueError):
|
|
print(
|
|
f"module {module_path} not found while trying to get "
|
|
f"input {class_path}"
|
|
)
|
|
return False
|
|
|
|
# A callable returning an InputReader object to use.
|
|
if isinstance(config["input"], FunctionType):
|
|
input_creator = config["input"]
|
|
# Use RLlib's Sampler classes (SyncSampler or AsynchSampler, depending
|
|
# on `config.sample_async` setting).
|
|
elif config["input"] == "sampler":
|
|
input_creator = lambda ioctx: ioctx.default_sampler_input()
|
|
# Ray Dataset input -> Use `config.input_config` to construct DatasetReader.
|
|
elif config["input"] == "dataset":
|
|
# Input dataset shards should have already been prepared.
|
|
# We just need to take the proper shard here.
|
|
input_creator = lambda ioctx: DatasetReader(
|
|
self._ds_shards[worker_index], ioctx
|
|
)
|
|
# Dict: Mix of different input methods with different ratios.
|
|
elif isinstance(config["input"], dict):
|
|
input_creator = lambda ioctx: ShuffledInput(
|
|
MixedInput(config["input"], ioctx), config["shuffle_buffer_size"]
|
|
)
|
|
# A pre-registered input descriptor (str).
|
|
elif isinstance(config["input"], str) and registry_contains_input(
|
|
config["input"]
|
|
):
|
|
input_creator = registry_get_input(config["input"])
|
|
# D4RL input.
|
|
elif "d4rl" in config["input"]:
|
|
env_name = config["input"].split(".")[-1]
|
|
input_creator = lambda ioctx: D4RLReader(env_name, ioctx)
|
|
# Valid python module (class path) -> Create using `from_config`.
|
|
elif valid_module(config["input"]):
|
|
input_creator = lambda ioctx: ShuffledInput(
|
|
from_config(config["input"], ioctx=ioctx)
|
|
)
|
|
# JSON file or list of JSON files -> Use JsonReader (shuffled).
|
|
else:
|
|
input_creator = lambda ioctx: ShuffledInput(
|
|
JsonReader(config["input"], ioctx), config["shuffle_buffer_size"]
|
|
)
|
|
|
|
if isinstance(config["output"], FunctionType):
|
|
output_creator = config["output"]
|
|
elif config["output"] is None:
|
|
output_creator = lambda ioctx: NoopOutput()
|
|
elif config["output"] == "dataset":
|
|
output_creator = lambda ioctx: DatasetWriter(
|
|
ioctx, compress_columns=config["output_compress_columns"]
|
|
)
|
|
elif config["output"] == "logdir":
|
|
output_creator = lambda ioctx: JsonWriter(
|
|
ioctx.log_dir,
|
|
ioctx,
|
|
max_file_size=config["output_max_file_size"],
|
|
compress_columns=config["output_compress_columns"],
|
|
)
|
|
else:
|
|
output_creator = lambda ioctx: JsonWriter(
|
|
config["output"],
|
|
ioctx,
|
|
max_file_size=config["output_max_file_size"],
|
|
compress_columns=config["output_compress_columns"],
|
|
)
|
|
|
|
# Assert everything is correct in "multiagent" config dict (if given).
|
|
ma_policies = config["multiagent"]["policies"]
|
|
if ma_policies:
|
|
for pid, policy_spec in ma_policies.copy().items():
|
|
assert isinstance(policy_spec, PolicySpec)
|
|
# Class is None -> Use `policy_cls`.
|
|
if policy_spec.policy_class is None:
|
|
ma_policies[pid].policy_class = policy_cls
|
|
policies = ma_policies
|
|
|
|
# Create a policy_spec (MultiAgentPolicyConfigDict),
|
|
# even if no "multiagent" setup given by user.
|
|
else:
|
|
policies = policy_cls
|
|
|
|
if worker_index == 0:
|
|
extra_python_environs = config.get("extra_python_environs_for_driver", None)
|
|
else:
|
|
extra_python_environs = config.get("extra_python_environs_for_worker", None)
|
|
|
|
worker = cls(
|
|
env_creator=env_creator,
|
|
validate_env=validate_env,
|
|
policy_spec=policies,
|
|
policy_mapping_fn=config["multiagent"]["policy_mapping_fn"],
|
|
policies_to_train=config["multiagent"]["policies_to_train"],
|
|
tf_session_creator=(session_creator if config["tf_session_args"] else None),
|
|
rollout_fragment_length=config["rollout_fragment_length"],
|
|
count_steps_by=config["multiagent"]["count_steps_by"],
|
|
batch_mode=config["batch_mode"],
|
|
episode_horizon=config["horizon"],
|
|
preprocessor_pref=config["preprocessor_pref"],
|
|
sample_async=config["sample_async"],
|
|
compress_observations=config["compress_observations"],
|
|
num_envs=config["num_envs_per_worker"],
|
|
observation_fn=config["multiagent"]["observation_fn"],
|
|
observation_filter=config["observation_filter"],
|
|
clip_rewards=config["clip_rewards"],
|
|
normalize_actions=config["normalize_actions"],
|
|
clip_actions=config["clip_actions"],
|
|
env_config=config["env_config"],
|
|
policy_config=config,
|
|
worker_index=worker_index,
|
|
num_workers=num_workers,
|
|
recreated_worker=recreated_worker,
|
|
log_dir=self._logdir,
|
|
log_level=config["log_level"],
|
|
callbacks=config["callbacks"],
|
|
input_creator=input_creator,
|
|
output_creator=output_creator,
|
|
remote_worker_envs=config["remote_worker_envs"],
|
|
remote_env_batch_wait_ms=config["remote_env_batch_wait_ms"],
|
|
soft_horizon=config["soft_horizon"],
|
|
no_done_at_end=config["no_done_at_end"],
|
|
seed=(config["seed"] + worker_index)
|
|
if config["seed"] is not None
|
|
else None,
|
|
fake_sampler=config["fake_sampler"],
|
|
extra_python_environs=extra_python_environs,
|
|
spaces=spaces,
|
|
disable_env_checking=config["disable_env_checking"],
|
|
)
|
|
|
|
return worker
|
|
|
|
def _worker_health_check(self) -> List[int]:
|
|
"""Performs a health-check on each remote worker.
|
|
|
|
Returns:
|
|
List of indices (into `self._remote_workers` list) of faulty workers.
|
|
Note that index=1 is the 0th item in `self._remote_workers`.
|
|
"""
|
|
logger.info("Health checking all workers ...")
|
|
checks = []
|
|
for worker in self.remote_workers():
|
|
# TODO: Maybe find a better way to probe for healthiness. Performing an
|
|
# entire `sample()` step may be costly. Then again, we only do this
|
|
# upon any worker failure during the `step_attempt()`, not regularly.
|
|
_, obj_ref = worker.sample_with_count.remote()
|
|
checks.append(obj_ref)
|
|
|
|
faulty_worker_indices = []
|
|
for i, obj_ref in enumerate(checks):
|
|
try:
|
|
ray.get(obj_ref)
|
|
logger.info("Worker {} looks healthy.".format(i + 1))
|
|
except RayError:
|
|
logger.exception("Worker {} is faulty.".format(i + 1))
|
|
faulty_worker_indices.append(i + 1)
|
|
|
|
return faulty_worker_indices
|
|
|
|
@classmethod
|
|
def _valid_module(cls, class_path):
|
|
del cls
|
|
if (
|
|
isinstance(class_path, str)
|
|
and not os.path.isfile(class_path)
|
|
and "." in class_path
|
|
):
|
|
module_path, class_name = class_path.rsplit(".", 1)
|
|
try:
|
|
spec = importlib.util.find_spec(module_path)
|
|
if spec is not None:
|
|
return True
|
|
except (ModuleNotFoundError, ValueError):
|
|
print(
|
|
f"module {module_path} not found while trying to get "
|
|
f"input {class_path}"
|
|
)
|
|
return False
|
|
|
|
@Deprecated(new="WorkerSet.foreach_policy_to_train", error=False)
|
|
def foreach_trainable_policy(self, func):
|
|
return self.foreach_policy_to_train(func)
|
|
|
|
@Deprecated(new="WorkerSet.is_policy_to_train([pid], [batch]?)", error=False)
|
|
def trainable_policies(self):
|
|
local_worker = self.local_worker()
|
|
if local_worker is not None:
|
|
return [
|
|
pid
|
|
for pid in local_worker.policy_map.keys()
|
|
if local_worker.is_policy_to_train(pid, None)
|
|
]
|
|
else:
|
|
raise NotImplementedError
|