[RLlib] Initial code/comment cleanups in preparation for decentralized multi-agent learner. (#21420)

This commit is contained in:
Sven Mika 2022-01-10 11:22:55 +01:00 committed by GitHub
parent 4eaf70942d
commit 92f030331e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 128 additions and 55 deletions

View file

@ -366,7 +366,7 @@ class ImpalaTrainer(Trainer):
{
# Evaluation (remote) workers.
# Note: The local eval worker is located on the driver
# CPU.
# CPU or not even created iff >0 eval workers.
"CPU": eval_config.get("num_cpus_per_worker",
cf["num_cpus_per_worker"]),
"GPU": eval_config.get("num_gpus_per_worker",

View file

@ -120,12 +120,6 @@ class APPOTrainer(impala.ImpalaTrainer):
self.workers.local_worker().foreach_trainable_policy(
lambda p, _: p.update_target())
# TODO: Remove this once ImpalaTrainer directly inherits from Trainer
# (instead of being created by `build_trainer()` utility).
@override(impala.ImpalaTrainer)
def _init(self, *args, **kwargs):
raise NotImplementedError
@classmethod
@override(Trainer)
def get_default_config(cls) -> TrainerConfigDict:

View file

@ -29,6 +29,7 @@ from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.buffers.multi_agent_replay_buffer import \
MultiAgentReplayBuffer
from ray.rllib.execution.common import WORKER_UPDATE_TIMER
from ray.rllib.execution.rollout_ops import ConcatBatches, ParallelRollouts, \
synchronous_parallel_sample
from ray.rllib.execution.train_ops import TrainOneStep, MultiGPUTrainOneStep, \
@ -775,8 +776,8 @@ class Trainer(Trainable):
# Set Trainer's seed after we have - if necessary - enabled
# tf eager-execution.
update_global_seed_if_necessary(
config.get("framework"), config.get("seed"))
update_global_seed_if_necessary(self.config["framework"],
self.config["seed"])
self.validate_config(self.config)
if not callable(self.config["callbacks"]):
@ -844,6 +845,14 @@ class Trainer(Trainable):
self.workers, self.config,
**self._kwargs_for_execution_plan())
# TODO: Now that workers have been created, update our policy
# specs in the config[multiagent] dict with the correct spaces.
# However, this leads to a problem with the evaluation
# workers' observation one-hot preprocessor in
# `examples/documentation/rllib_in_6sec.py` script.
# self.config["multiagent"]["policies"] = \
# self.workers.local_worker().policy_map.policy_specs
# Evaluation WorkerSet setup.
# User would like to setup a separate evaluation worker set.
@ -1295,6 +1304,12 @@ class Trainer(Trainable):
else:
train_results = multi_gpu_train_one_step(self, train_batch)
# Update weights - after learning on the local worker - on all remote
# workers.
if self.workers.remote_workers():
with self._timers[WORKER_UPDATE_TIMER]:
self.workers.sync_weights()
return train_results
@DeveloperAPI
@ -1976,6 +1991,22 @@ class Trainer(Trainable):
config2: PartialTrainerConfigDict,
_allow_unknown_configs: Optional[bool] = None
) -> TrainerConfigDict:
"""Merges a complete Trainer config with a partial override dict.
Respects nested structures within the config dicts. The values in the
partial override dict take priority.
Args:
config1: The complete Trainer's dict to be merged (overridden)
with `config2`.
config2: The partial override config dict to merge on top of
`config1`.
_allow_unknown_configs: If True, keys in `config2` that don't exist
in `config1` are allowed and will be added to the final config.
Returns:
The merged full trainer config dict.
"""
config1 = copy.deepcopy(config1)
if "callbacks" in config2 and type(config2["callbacks"]) is dict:
legacy_callbacks_dict = config2["callbacks"]

View file

@ -6,7 +6,7 @@ import numpy as np
import platform
import os
import tree # pip install dm_tree
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, \
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, \
TYPE_CHECKING, Union
import ray
@ -46,7 +46,7 @@ from ray.rllib.utils.tf_run_builder import TFRunBuilder
from ray.rllib.utils.typing import AgentID, EnvConfigDict, EnvType, \
ModelConfigDict, ModelGradients, ModelWeights, \
MultiAgentPolicyConfigDict, PartialTrainerConfigDict, PolicyID, \
SampleBatchType
SampleBatchType, T
from ray.util.debug import log_once, disable_log_once_globally, \
enable_periodic_logging
from ray.util.iter import ParallelIteratorWorker
@ -56,9 +56,6 @@ if TYPE_CHECKING:
from ray.rllib.evaluation.observation_function import ObservationFunction
from ray.rllib.agents.callbacks import DefaultCallbacks # noqa
# Generic type var for foreach_* methods.
T = TypeVar("T")
tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()
@ -1436,19 +1433,26 @@ class RolloutWorker(ParallelIteratorWorker):
sess.close()
@DeveloperAPI
def apply(self, func: Callable[["RolloutWorker", Optional[Any]], T],
*args) -> T:
def apply(
self,
func: Callable[["RolloutWorker", Optional[Any], Optional[Any]], T],
*args, **kwargs) -> T:
"""Calls the given function with this rollout worker instance.
Useful for when the RolloutWorker class has been converted into a
ActorHandle and the user needs to execute some functionality (e.g.
add a property) on the underlying policy object.
Args:
func: The function to call with this RolloutWorker as first
argument.
func: The function to call, with this RolloutWorker as first
argument, followed by args, and kwargs.
args: Optional additional args to pass to the function call.
kwargs: Optional additional kwargs to pass to the function call.
Returns:
The return value of the function call.
"""
return func(self, *args)
return func(self, *args, **kwargs)
def setup_torch_data_parallel(self, url: str, world_rank: int,
world_size: int, backend: str) -> None:

View file

@ -159,13 +159,15 @@ class WorkerSet:
if self.remote_workers() or from_worker is not None:
weights = (from_worker
or self.local_worker()).get_weights(policies)
# Put weights only once into object store and use same object
# ref to synch to all workers.
weights_ref = ray.put(weights)
# Sync to all remote workers in this WorkerSet.
for to_worker in self.remote_workers():
to_worker.set_weights.remote(weights_ref)
# If from_worker is provided, also sync to this WorkerSet's local
# worker.
# If `from_worker` is provided, also sync to this WorkerSet's
# local worker.
if from_worker is not None and self.local_worker() is not None:
self.local_worker().set_weights(weights)

View file

@ -1,11 +1,11 @@
import collections
import platform
from typing import Dict, Any
from typing import Any, Dict
import numpy as np
import ray
from ray.rllib import SampleBatch
from ray.rllib.execution import PrioritizedReplayBuffer
from ray.rllib.execution import PrioritizedReplayBuffer, ReplayBuffer
from ray.rllib.execution.buffers.replay_buffer import logger, _ALL_POLICIES
from ray.rllib.policy.rnn_sequencing import \
timeslice_along_seq_lens_with_overlap
@ -54,7 +54,7 @@ class MultiAgentReplayBuffer(ParallelIteratorWorker):
`self.replay_batch_size` will be set to the number of
sequences sampled (B).
prioritized_replay_alpha (float): Alpha parameter for a prioritized
replay buffer.
replay buffer. Use 0.0 for no prioritization.
prioritized_replay_beta (float): Beta parameter for a prioritized
replay buffer.
prioritized_replay_eps (float): Epsilon parameter for a prioritized
@ -108,8 +108,11 @@ class MultiAgentReplayBuffer(ParallelIteratorWorker):
ParallelIteratorWorker.__init__(self, gen_replay, False)
def new_buffer():
return PrioritizedReplayBuffer(
self.capacity, alpha=prioritized_replay_alpha)
if prioritized_replay_alpha == 0.0:
return ReplayBuffer(self.capacity)
else:
return PrioritizedReplayBuffer(
self.capacity, alpha=prioritized_replay_alpha)
self.replay_buffers = collections.defaultdict(new_buffer)

View file

@ -52,7 +52,7 @@ class ReplayBuffer:
def __init__(self,
capacity: int = 10000,
size: Optional[int] = DEPRECATED_VALUE):
"""Initializes a Replaybuffer instance.
"""Initializes a ReplayBuffer instance.
Args:
capacity: Max number of timesteps to store in the FIFO
@ -84,6 +84,7 @@ class ReplayBuffer:
self._est_size_bytes = 0
def __len__(self) -> int:
"""Returns the number of items currently stored in this buffer."""
return len(self._storage)
@DeveloperAPI
@ -147,7 +148,7 @@ class ReplayBuffer:
"""Returns the stats of this buffer.
Args:
debug: If true, adds sample eviction statistics to the returned
debug: If True, adds sample eviction statistics to the returned
stats dict.
Returns:
@ -253,7 +254,11 @@ class PrioritizedReplayBuffer(ReplayBuffer):
@DeveloperAPI
@override(ReplayBuffer)
def sample(self, num_items: int, beta: float) -> SampleBatchType:
"""Sample a batch of experiences and return priority weights, indices.
"""Sample `num_items` items from this buffer, including prio. weights.
If less than `num_items` records are in this buffer, some samples in
the results may be repeated to fulfil the batch size (`num_items`)
request.
Args:
num_items: Number of items to sample from this buffer.
@ -272,11 +277,11 @@ class PrioritizedReplayBuffer(ReplayBuffer):
weights = []
batch_indexes = []
p_min = self._it_min.min() / self._it_sum.sum()
max_weight = (p_min * len(self._storage))**(-beta)
max_weight = (p_min * len(self))**(-beta)
for idx in idxes:
p_sample = self._it_sum[idx] / self._it_sum.sum()
weight = (p_sample * len(self._storage))**(-beta)
weight = (p_sample * len(self))**(-beta)
count = self._storage[idx].count
# If zero-padded, count will not be the actual batch size of the
# data.

View file

@ -55,13 +55,6 @@ def train_one_step(trainer, train_batch) -> Dict:
trainer._counters[NUM_ENV_STEPS_TRAINED] += train_batch.count
trainer._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps()
# Update weights - after learning on the local worker - on all remote
# workers.
if workers.remote_workers():
with trainer._timers[WORKER_UPDATE_TIMER]:
weights = ray.put(workers.local_worker().get_weights(policies))
for e in workers.remote_workers():
e.set_weights.remote(weights)
return info

View file

@ -4,8 +4,9 @@ import gym
from gym.spaces import Box
import logging
import numpy as np
import platform
import tree # pip install dm_tree
from typing import Dict, List, Optional, Type, TYPE_CHECKING
from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.models.catalog import ModelCatalog
@ -21,7 +22,7 @@ from ray.rllib.utils.from_config import from_config
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space, \
get_dummy_batch_for_space, unbatch
from ray.rllib.utils.typing import AgentID, ModelGradients, ModelWeights, \
TensorType, TensorStructType, TrainerConfigDict, Tuple, Union
T, TensorType, TensorStructType, TrainerConfigDict, Tuple, Union
tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()
@ -638,6 +639,27 @@ class Policy(metaclass=ABCMeta):
self.set_weights(state["weights"])
self.global_timestep = state["global_timestep"]
@ExperimentalAPI
def apply(self,
func: Callable[["Policy", Optional[Any], Optional[Any]], T],
*args, **kwargs) -> T:
"""Calls the given function with this Policy instance.
Useful for when the Policy class has been converted into a ActorHandle
and the user needs to execute some functionality (e.g. add a property)
on the underlying policy object.
Args:
func: The function to call, with this Policy as first
argument, followed by args, and kwargs.
args: Optional additional args to pass to the function call.
kwargs: Optional additional kwargs to pass to the function call.
Returns:
The return value of the function call.
"""
return func(self, *args, **kwargs)
@DeveloperAPI
def on_global_var_update(self, global_vars: Dict[str, TensorType]) -> None:
"""Called on an update to global vars.
@ -697,6 +719,15 @@ class Policy(metaclass=ABCMeta):
"""
return None
def get_host(self) -> str:
"""Returns the computer's network name.
Returns:
The computer's networks name or an empty string, if the network
name could not be determined.
"""
return platform.node()
def _create_exploration(self) -> Exploration:
"""Creates the Policy's Exploration object.

View file

@ -294,7 +294,7 @@ def chop_into_sequences(
f = np.array(f)
length = len(seq_lens) * max_seq_len
if f.dtype == np.object or f.dtype.type is np.str_:
if f.dtype == object or f.dtype.type is np.str_:
f_pad = [None] * length
else:
# Make sure type doesn't change.

View file

@ -620,7 +620,7 @@ class SampleBatch(dict):
or path[0] == SampleBatch.SEQ_LENS:
return
# Generate zero-filled primer of len=max_seq_len.
if value.dtype == np.object or value.dtype.type is np.str_:
if value.dtype == object or value.dtype.type is np.str_:
f_pad = [None] * length
else:
# Make sure type doesn't change.
@ -651,13 +651,13 @@ class SampleBatch(dict):
return self
# Experimental method.
@ExperimentalAPI
def to_device(self, device, framework="torch"):
"""TODO: transfer batch to given device as framework tensor."""
if framework == "torch":
assert torch is not None
for k, v in self.items():
if isinstance(v, np.ndarray) and v.dtype != np.object:
if isinstance(v, np.ndarray) and v.dtype != object:
self[k] = torch.from_numpy(v).to(device)
else:
raise NotImplementedError

View file

@ -1170,7 +1170,7 @@ class EntropyCoeffSchedule:
@DeveloperAPI
class DirectStepOptimizer:
"""Typesafe method for indicating apply gradients can directly step the
"""Typesafe method for indicating `apply_gradients` can directly step the
optimizers with in-place gradients.
"""
_instance = None

View file

@ -42,7 +42,7 @@ def _summarize(obj):
if obj.size == 0:
return _StringValue("np.ndarray({}, dtype={})".format(
obj.shape, obj.dtype))
elif obj.dtype == np.object or obj.dtype.type is np.str_:
elif obj.dtype == object or obj.dtype.type is np.str_:
return _StringValue("np.ndarray({}, dtype={}, head={})".format(
obj.shape, obj.dtype, _summarize(obj[0])))
else:

View file

@ -214,7 +214,7 @@ def check(x, y, decimals=5, atol=None, rtol=None, false=False):
"ERROR: x ({}) is not the same as y ({})!".format(x, y)
# String/byte comparisons.
elif hasattr(x, "dtype") and \
(x.dtype == np.object or str(x.dtype).startswith("<U")):
(x.dtype == object or str(x.dtype).startswith("<U")):
try:
np.testing.assert_array_equal(x, y)
if false is True:
@ -307,11 +307,15 @@ def check_compute_single_action(trainer,
ValueError: If anything unexpected happens.
"""
# Have to import this here to avoid circular dependency.
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
# Some Trainers may not abide to the standard API.
pid = DEFAULT_POLICY_ID
try:
pol = trainer.get_policy()
# Multi-agent: Pick any policy (or DEFAULT_POLICY if it's the only
# one).
pid = next(iter(trainer.workers.local_worker().policy_map))
pol = trainer.get_policy(pid)
except AttributeError:
pol = trainer.policy
# Get the policy's model.
@ -324,6 +328,7 @@ def check_compute_single_action(trainer,
call_kwargs = {}
if what is trainer:
call_kwargs["full_fetch"] = full_fetch
call_kwargs["policy_id"] = pid
obs = obs_space.sample()
if isinstance(obs_space, Box):
@ -429,10 +434,10 @@ def check_compute_single_action(trainer,
worker_set = getattr(trainer, "_workers", None)
assert worker_set
if isinstance(worker_set, list):
obs_space = trainer.get_policy().observation_space
obs_space = trainer.get_policy(pid).observation_space
else:
obs_space = worker_set.local_worker().for_policy(
lambda p: p.observation_space)
lambda p: p.observation_space, policy_id=pid)
obs_space = getattr(obs_space, "original_space", obs_space)
else:
obs_space = pol.observation_space

View file

@ -255,7 +255,8 @@ def get_tf_eager_cls_if_necessary(
cls = orig_cls.as_eager()
if config.get("eager_tracing"):
cls = cls.with_tracing()
# Could be some other type of policy.
# Could be some other type of policy or already
# eager-ized.
elif not issubclass(orig_cls, TFPolicy):
pass
else:

View file

@ -136,8 +136,8 @@ def convert_to_torch_tensor(x: TensorStructType, device: Optional[str] = None):
item.max_len)
# Numpy arrays.
if isinstance(item, np.ndarray):
# np.object_ type (e.g. info dicts in train batch): leave as-is.
if item.dtype == np.object_:
# Object type (e.g. info dicts in train batch): leave as-is.
if item.dtype == object:
return item
# Non-writable numpy-arrays will cause PyTorch warning.
elif item.flags.writeable is False:

View file

@ -1,5 +1,6 @@
import gym
from typing import Any, Dict, List, Tuple, Union, TYPE_CHECKING
from typing import Any, Dict, List, Tuple, Union, TypeVar, \
TYPE_CHECKING
if TYPE_CHECKING:
from ray.rllib.utils import try_import_tf, try_import_torch
@ -123,3 +124,6 @@ TensorShape = Union[Tuple[int], List[int]]
# A (possibly nested) space struct: Either a gym.spaces.Space or a
# (possibly nested) dict|tuple of gym.space.Spaces.
SpaceStruct = Union[gym.spaces.Space, dict, tuple]
# Generic type var.
T = TypeVar("T")