mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[RLlib] Initial code/comment cleanups in preparation for decentralized multi-agent learner. (#21420)
This commit is contained in:
parent
4eaf70942d
commit
92f030331e
17 changed files with 128 additions and 55 deletions
|
@ -366,7 +366,7 @@ class ImpalaTrainer(Trainer):
|
|||
{
|
||||
# Evaluation (remote) workers.
|
||||
# Note: The local eval worker is located on the driver
|
||||
# CPU.
|
||||
# CPU or not even created iff >0 eval workers.
|
||||
"CPU": eval_config.get("num_cpus_per_worker",
|
||||
cf["num_cpus_per_worker"]),
|
||||
"GPU": eval_config.get("num_gpus_per_worker",
|
||||
|
|
|
@ -120,12 +120,6 @@ class APPOTrainer(impala.ImpalaTrainer):
|
|||
self.workers.local_worker().foreach_trainable_policy(
|
||||
lambda p, _: p.update_target())
|
||||
|
||||
# TODO: Remove this once ImpalaTrainer directly inherits from Trainer
|
||||
# (instead of being created by `build_trainer()` utility).
|
||||
@override(impala.ImpalaTrainer)
|
||||
def _init(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@override(Trainer)
|
||||
def get_default_config(cls) -> TrainerConfigDict:
|
||||
|
|
|
@ -29,6 +29,7 @@ from ray.rllib.evaluation.worker_set import WorkerSet
|
|||
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
||||
from ray.rllib.execution.buffers.multi_agent_replay_buffer import \
|
||||
MultiAgentReplayBuffer
|
||||
from ray.rllib.execution.common import WORKER_UPDATE_TIMER
|
||||
from ray.rllib.execution.rollout_ops import ConcatBatches, ParallelRollouts, \
|
||||
synchronous_parallel_sample
|
||||
from ray.rllib.execution.train_ops import TrainOneStep, MultiGPUTrainOneStep, \
|
||||
|
@ -775,8 +776,8 @@ class Trainer(Trainable):
|
|||
|
||||
# Set Trainer's seed after we have - if necessary - enabled
|
||||
# tf eager-execution.
|
||||
update_global_seed_if_necessary(
|
||||
config.get("framework"), config.get("seed"))
|
||||
update_global_seed_if_necessary(self.config["framework"],
|
||||
self.config["seed"])
|
||||
|
||||
self.validate_config(self.config)
|
||||
if not callable(self.config["callbacks"]):
|
||||
|
@ -844,6 +845,14 @@ class Trainer(Trainable):
|
|||
self.workers, self.config,
|
||||
**self._kwargs_for_execution_plan())
|
||||
|
||||
# TODO: Now that workers have been created, update our policy
|
||||
# specs in the config[multiagent] dict with the correct spaces.
|
||||
# However, this leads to a problem with the evaluation
|
||||
# workers' observation one-hot preprocessor in
|
||||
# `examples/documentation/rllib_in_6sec.py` script.
|
||||
# self.config["multiagent"]["policies"] = \
|
||||
# self.workers.local_worker().policy_map.policy_specs
|
||||
|
||||
# Evaluation WorkerSet setup.
|
||||
# User would like to setup a separate evaluation worker set.
|
||||
|
||||
|
@ -1295,6 +1304,12 @@ class Trainer(Trainable):
|
|||
else:
|
||||
train_results = multi_gpu_train_one_step(self, train_batch)
|
||||
|
||||
# Update weights - after learning on the local worker - on all remote
|
||||
# workers.
|
||||
if self.workers.remote_workers():
|
||||
with self._timers[WORKER_UPDATE_TIMER]:
|
||||
self.workers.sync_weights()
|
||||
|
||||
return train_results
|
||||
|
||||
@DeveloperAPI
|
||||
|
@ -1976,6 +1991,22 @@ class Trainer(Trainable):
|
|||
config2: PartialTrainerConfigDict,
|
||||
_allow_unknown_configs: Optional[bool] = None
|
||||
) -> TrainerConfigDict:
|
||||
"""Merges a complete Trainer config with a partial override dict.
|
||||
|
||||
Respects nested structures within the config dicts. The values in the
|
||||
partial override dict take priority.
|
||||
|
||||
Args:
|
||||
config1: The complete Trainer's dict to be merged (overridden)
|
||||
with `config2`.
|
||||
config2: The partial override config dict to merge on top of
|
||||
`config1`.
|
||||
_allow_unknown_configs: If True, keys in `config2` that don't exist
|
||||
in `config1` are allowed and will be added to the final config.
|
||||
|
||||
Returns:
|
||||
The merged full trainer config dict.
|
||||
"""
|
||||
config1 = copy.deepcopy(config1)
|
||||
if "callbacks" in config2 and type(config2["callbacks"]) is dict:
|
||||
legacy_callbacks_dict = config2["callbacks"]
|
||||
|
|
|
@ -6,7 +6,7 @@ import numpy as np
|
|||
import platform
|
||||
import os
|
||||
import tree # pip install dm_tree
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, \
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, \
|
||||
TYPE_CHECKING, Union
|
||||
|
||||
import ray
|
||||
|
@ -46,7 +46,7 @@ from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
|||
from ray.rllib.utils.typing import AgentID, EnvConfigDict, EnvType, \
|
||||
ModelConfigDict, ModelGradients, ModelWeights, \
|
||||
MultiAgentPolicyConfigDict, PartialTrainerConfigDict, PolicyID, \
|
||||
SampleBatchType
|
||||
SampleBatchType, T
|
||||
from ray.util.debug import log_once, disable_log_once_globally, \
|
||||
enable_periodic_logging
|
||||
from ray.util.iter import ParallelIteratorWorker
|
||||
|
@ -56,9 +56,6 @@ if TYPE_CHECKING:
|
|||
from ray.rllib.evaluation.observation_function import ObservationFunction
|
||||
from ray.rllib.agents.callbacks import DefaultCallbacks # noqa
|
||||
|
||||
# Generic type var for foreach_* methods.
|
||||
T = TypeVar("T")
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
torch, _ = try_import_torch()
|
||||
|
||||
|
@ -1436,19 +1433,26 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
sess.close()
|
||||
|
||||
@DeveloperAPI
|
||||
def apply(self, func: Callable[["RolloutWorker", Optional[Any]], T],
|
||||
*args) -> T:
|
||||
def apply(
|
||||
self,
|
||||
func: Callable[["RolloutWorker", Optional[Any], Optional[Any]], T],
|
||||
*args, **kwargs) -> T:
|
||||
"""Calls the given function with this rollout worker instance.
|
||||
|
||||
Useful for when the RolloutWorker class has been converted into a
|
||||
ActorHandle and the user needs to execute some functionality (e.g.
|
||||
add a property) on the underlying policy object.
|
||||
|
||||
Args:
|
||||
func: The function to call with this RolloutWorker as first
|
||||
argument.
|
||||
func: The function to call, with this RolloutWorker as first
|
||||
argument, followed by args, and kwargs.
|
||||
args: Optional additional args to pass to the function call.
|
||||
kwargs: Optional additional kwargs to pass to the function call.
|
||||
|
||||
Returns:
|
||||
The return value of the function call.
|
||||
"""
|
||||
return func(self, *args)
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
def setup_torch_data_parallel(self, url: str, world_rank: int,
|
||||
world_size: int, backend: str) -> None:
|
||||
|
|
|
@ -159,13 +159,15 @@ class WorkerSet:
|
|||
if self.remote_workers() or from_worker is not None:
|
||||
weights = (from_worker
|
||||
or self.local_worker()).get_weights(policies)
|
||||
# Put weights only once into object store and use same object
|
||||
# ref to synch to all workers.
|
||||
weights_ref = ray.put(weights)
|
||||
# Sync to all remote workers in this WorkerSet.
|
||||
for to_worker in self.remote_workers():
|
||||
to_worker.set_weights.remote(weights_ref)
|
||||
|
||||
# If from_worker is provided, also sync to this WorkerSet's local
|
||||
# worker.
|
||||
# If `from_worker` is provided, also sync to this WorkerSet's
|
||||
# local worker.
|
||||
if from_worker is not None and self.local_worker() is not None:
|
||||
self.local_worker().set_weights(weights)
|
||||
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
import collections
|
||||
import platform
|
||||
from typing import Dict, Any
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
import ray
|
||||
from ray.rllib import SampleBatch
|
||||
from ray.rllib.execution import PrioritizedReplayBuffer
|
||||
from ray.rllib.execution import PrioritizedReplayBuffer, ReplayBuffer
|
||||
from ray.rllib.execution.buffers.replay_buffer import logger, _ALL_POLICIES
|
||||
from ray.rllib.policy.rnn_sequencing import \
|
||||
timeslice_along_seq_lens_with_overlap
|
||||
|
@ -54,7 +54,7 @@ class MultiAgentReplayBuffer(ParallelIteratorWorker):
|
|||
`self.replay_batch_size` will be set to the number of
|
||||
sequences sampled (B).
|
||||
prioritized_replay_alpha (float): Alpha parameter for a prioritized
|
||||
replay buffer.
|
||||
replay buffer. Use 0.0 for no prioritization.
|
||||
prioritized_replay_beta (float): Beta parameter for a prioritized
|
||||
replay buffer.
|
||||
prioritized_replay_eps (float): Epsilon parameter for a prioritized
|
||||
|
@ -108,8 +108,11 @@ class MultiAgentReplayBuffer(ParallelIteratorWorker):
|
|||
ParallelIteratorWorker.__init__(self, gen_replay, False)
|
||||
|
||||
def new_buffer():
|
||||
return PrioritizedReplayBuffer(
|
||||
self.capacity, alpha=prioritized_replay_alpha)
|
||||
if prioritized_replay_alpha == 0.0:
|
||||
return ReplayBuffer(self.capacity)
|
||||
else:
|
||||
return PrioritizedReplayBuffer(
|
||||
self.capacity, alpha=prioritized_replay_alpha)
|
||||
|
||||
self.replay_buffers = collections.defaultdict(new_buffer)
|
||||
|
||||
|
|
|
@ -52,7 +52,7 @@ class ReplayBuffer:
|
|||
def __init__(self,
|
||||
capacity: int = 10000,
|
||||
size: Optional[int] = DEPRECATED_VALUE):
|
||||
"""Initializes a Replaybuffer instance.
|
||||
"""Initializes a ReplayBuffer instance.
|
||||
|
||||
Args:
|
||||
capacity: Max number of timesteps to store in the FIFO
|
||||
|
@ -84,6 +84,7 @@ class ReplayBuffer:
|
|||
self._est_size_bytes = 0
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Returns the number of items currently stored in this buffer."""
|
||||
return len(self._storage)
|
||||
|
||||
@DeveloperAPI
|
||||
|
@ -147,7 +148,7 @@ class ReplayBuffer:
|
|||
"""Returns the stats of this buffer.
|
||||
|
||||
Args:
|
||||
debug: If true, adds sample eviction statistics to the returned
|
||||
debug: If True, adds sample eviction statistics to the returned
|
||||
stats dict.
|
||||
|
||||
Returns:
|
||||
|
@ -253,7 +254,11 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|||
@DeveloperAPI
|
||||
@override(ReplayBuffer)
|
||||
def sample(self, num_items: int, beta: float) -> SampleBatchType:
|
||||
"""Sample a batch of experiences and return priority weights, indices.
|
||||
"""Sample `num_items` items from this buffer, including prio. weights.
|
||||
|
||||
If less than `num_items` records are in this buffer, some samples in
|
||||
the results may be repeated to fulfil the batch size (`num_items`)
|
||||
request.
|
||||
|
||||
Args:
|
||||
num_items: Number of items to sample from this buffer.
|
||||
|
@ -272,11 +277,11 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|||
weights = []
|
||||
batch_indexes = []
|
||||
p_min = self._it_min.min() / self._it_sum.sum()
|
||||
max_weight = (p_min * len(self._storage))**(-beta)
|
||||
max_weight = (p_min * len(self))**(-beta)
|
||||
|
||||
for idx in idxes:
|
||||
p_sample = self._it_sum[idx] / self._it_sum.sum()
|
||||
weight = (p_sample * len(self._storage))**(-beta)
|
||||
weight = (p_sample * len(self))**(-beta)
|
||||
count = self._storage[idx].count
|
||||
# If zero-padded, count will not be the actual batch size of the
|
||||
# data.
|
||||
|
|
|
@ -55,13 +55,6 @@ def train_one_step(trainer, train_batch) -> Dict:
|
|||
trainer._counters[NUM_ENV_STEPS_TRAINED] += train_batch.count
|
||||
trainer._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps()
|
||||
|
||||
# Update weights - after learning on the local worker - on all remote
|
||||
# workers.
|
||||
if workers.remote_workers():
|
||||
with trainer._timers[WORKER_UPDATE_TIMER]:
|
||||
weights = ray.put(workers.local_worker().get_weights(policies))
|
||||
for e in workers.remote_workers():
|
||||
e.set_weights.remote(weights)
|
||||
return info
|
||||
|
||||
|
||||
|
|
|
@ -4,8 +4,9 @@ import gym
|
|||
from gym.spaces import Box
|
||||
import logging
|
||||
import numpy as np
|
||||
import platform
|
||||
import tree # pip install dm_tree
|
||||
from typing import Dict, List, Optional, Type, TYPE_CHECKING
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING
|
||||
|
||||
from ray.rllib.models.action_dist import ActionDistribution
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
|
@ -21,7 +22,7 @@ from ray.rllib.utils.from_config import from_config
|
|||
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space, \
|
||||
get_dummy_batch_for_space, unbatch
|
||||
from ray.rllib.utils.typing import AgentID, ModelGradients, ModelWeights, \
|
||||
TensorType, TensorStructType, TrainerConfigDict, Tuple, Union
|
||||
T, TensorType, TensorStructType, TrainerConfigDict, Tuple, Union
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
torch, _ = try_import_torch()
|
||||
|
@ -638,6 +639,27 @@ class Policy(metaclass=ABCMeta):
|
|||
self.set_weights(state["weights"])
|
||||
self.global_timestep = state["global_timestep"]
|
||||
|
||||
@ExperimentalAPI
|
||||
def apply(self,
|
||||
func: Callable[["Policy", Optional[Any], Optional[Any]], T],
|
||||
*args, **kwargs) -> T:
|
||||
"""Calls the given function with this Policy instance.
|
||||
|
||||
Useful for when the Policy class has been converted into a ActorHandle
|
||||
and the user needs to execute some functionality (e.g. add a property)
|
||||
on the underlying policy object.
|
||||
|
||||
Args:
|
||||
func: The function to call, with this Policy as first
|
||||
argument, followed by args, and kwargs.
|
||||
args: Optional additional args to pass to the function call.
|
||||
kwargs: Optional additional kwargs to pass to the function call.
|
||||
|
||||
Returns:
|
||||
The return value of the function call.
|
||||
"""
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
@DeveloperAPI
|
||||
def on_global_var_update(self, global_vars: Dict[str, TensorType]) -> None:
|
||||
"""Called on an update to global vars.
|
||||
|
@ -697,6 +719,15 @@ class Policy(metaclass=ABCMeta):
|
|||
"""
|
||||
return None
|
||||
|
||||
def get_host(self) -> str:
|
||||
"""Returns the computer's network name.
|
||||
|
||||
Returns:
|
||||
The computer's networks name or an empty string, if the network
|
||||
name could not be determined.
|
||||
"""
|
||||
return platform.node()
|
||||
|
||||
def _create_exploration(self) -> Exploration:
|
||||
"""Creates the Policy's Exploration object.
|
||||
|
||||
|
|
|
@ -294,7 +294,7 @@ def chop_into_sequences(
|
|||
f = np.array(f)
|
||||
|
||||
length = len(seq_lens) * max_seq_len
|
||||
if f.dtype == np.object or f.dtype.type is np.str_:
|
||||
if f.dtype == object or f.dtype.type is np.str_:
|
||||
f_pad = [None] * length
|
||||
else:
|
||||
# Make sure type doesn't change.
|
||||
|
|
|
@ -620,7 +620,7 @@ class SampleBatch(dict):
|
|||
or path[0] == SampleBatch.SEQ_LENS:
|
||||
return
|
||||
# Generate zero-filled primer of len=max_seq_len.
|
||||
if value.dtype == np.object or value.dtype.type is np.str_:
|
||||
if value.dtype == object or value.dtype.type is np.str_:
|
||||
f_pad = [None] * length
|
||||
else:
|
||||
# Make sure type doesn't change.
|
||||
|
@ -651,13 +651,13 @@ class SampleBatch(dict):
|
|||
|
||||
return self
|
||||
|
||||
# Experimental method.
|
||||
@ExperimentalAPI
|
||||
def to_device(self, device, framework="torch"):
|
||||
"""TODO: transfer batch to given device as framework tensor."""
|
||||
if framework == "torch":
|
||||
assert torch is not None
|
||||
for k, v in self.items():
|
||||
if isinstance(v, np.ndarray) and v.dtype != np.object:
|
||||
if isinstance(v, np.ndarray) and v.dtype != object:
|
||||
self[k] = torch.from_numpy(v).to(device)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -1170,7 +1170,7 @@ class EntropyCoeffSchedule:
|
|||
|
||||
@DeveloperAPI
|
||||
class DirectStepOptimizer:
|
||||
"""Typesafe method for indicating apply gradients can directly step the
|
||||
"""Typesafe method for indicating `apply_gradients` can directly step the
|
||||
optimizers with in-place gradients.
|
||||
"""
|
||||
_instance = None
|
||||
|
|
|
@ -42,7 +42,7 @@ def _summarize(obj):
|
|||
if obj.size == 0:
|
||||
return _StringValue("np.ndarray({}, dtype={})".format(
|
||||
obj.shape, obj.dtype))
|
||||
elif obj.dtype == np.object or obj.dtype.type is np.str_:
|
||||
elif obj.dtype == object or obj.dtype.type is np.str_:
|
||||
return _StringValue("np.ndarray({}, dtype={}, head={})".format(
|
||||
obj.shape, obj.dtype, _summarize(obj[0])))
|
||||
else:
|
||||
|
|
|
@ -214,7 +214,7 @@ def check(x, y, decimals=5, atol=None, rtol=None, false=False):
|
|||
"ERROR: x ({}) is not the same as y ({})!".format(x, y)
|
||||
# String/byte comparisons.
|
||||
elif hasattr(x, "dtype") and \
|
||||
(x.dtype == np.object or str(x.dtype).startswith("<U")):
|
||||
(x.dtype == object or str(x.dtype).startswith("<U")):
|
||||
try:
|
||||
np.testing.assert_array_equal(x, y)
|
||||
if false is True:
|
||||
|
@ -307,11 +307,15 @@ def check_compute_single_action(trainer,
|
|||
ValueError: If anything unexpected happens.
|
||||
"""
|
||||
# Have to import this here to avoid circular dependency.
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
|
||||
|
||||
# Some Trainers may not abide to the standard API.
|
||||
pid = DEFAULT_POLICY_ID
|
||||
try:
|
||||
pol = trainer.get_policy()
|
||||
# Multi-agent: Pick any policy (or DEFAULT_POLICY if it's the only
|
||||
# one).
|
||||
pid = next(iter(trainer.workers.local_worker().policy_map))
|
||||
pol = trainer.get_policy(pid)
|
||||
except AttributeError:
|
||||
pol = trainer.policy
|
||||
# Get the policy's model.
|
||||
|
@ -324,6 +328,7 @@ def check_compute_single_action(trainer,
|
|||
call_kwargs = {}
|
||||
if what is trainer:
|
||||
call_kwargs["full_fetch"] = full_fetch
|
||||
call_kwargs["policy_id"] = pid
|
||||
|
||||
obs = obs_space.sample()
|
||||
if isinstance(obs_space, Box):
|
||||
|
@ -429,10 +434,10 @@ def check_compute_single_action(trainer,
|
|||
worker_set = getattr(trainer, "_workers", None)
|
||||
assert worker_set
|
||||
if isinstance(worker_set, list):
|
||||
obs_space = trainer.get_policy().observation_space
|
||||
obs_space = trainer.get_policy(pid).observation_space
|
||||
else:
|
||||
obs_space = worker_set.local_worker().for_policy(
|
||||
lambda p: p.observation_space)
|
||||
lambda p: p.observation_space, policy_id=pid)
|
||||
obs_space = getattr(obs_space, "original_space", obs_space)
|
||||
else:
|
||||
obs_space = pol.observation_space
|
||||
|
|
|
@ -255,7 +255,8 @@ def get_tf_eager_cls_if_necessary(
|
|||
cls = orig_cls.as_eager()
|
||||
if config.get("eager_tracing"):
|
||||
cls = cls.with_tracing()
|
||||
# Could be some other type of policy.
|
||||
# Could be some other type of policy or already
|
||||
# eager-ized.
|
||||
elif not issubclass(orig_cls, TFPolicy):
|
||||
pass
|
||||
else:
|
||||
|
|
|
@ -136,8 +136,8 @@ def convert_to_torch_tensor(x: TensorStructType, device: Optional[str] = None):
|
|||
item.max_len)
|
||||
# Numpy arrays.
|
||||
if isinstance(item, np.ndarray):
|
||||
# np.object_ type (e.g. info dicts in train batch): leave as-is.
|
||||
if item.dtype == np.object_:
|
||||
# Object type (e.g. info dicts in train batch): leave as-is.
|
||||
if item.dtype == object:
|
||||
return item
|
||||
# Non-writable numpy-arrays will cause PyTorch warning.
|
||||
elif item.flags.writeable is False:
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import gym
|
||||
from typing import Any, Dict, List, Tuple, Union, TYPE_CHECKING
|
||||
from typing import Any, Dict, List, Tuple, Union, TypeVar, \
|
||||
TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.rllib.utils import try_import_tf, try_import_torch
|
||||
|
@ -123,3 +124,6 @@ TensorShape = Union[Tuple[int], List[int]]
|
|||
# A (possibly nested) space struct: Either a gym.spaces.Space or a
|
||||
# (possibly nested) dict|tuple of gym.space.Spaces.
|
||||
SpaceStruct = Union[gym.spaces.Space, dict, tuple]
|
||||
|
||||
# Generic type var.
|
||||
T = TypeVar("T")
|
||||
|
|
Loading…
Add table
Reference in a new issue