From 92f030331e29549e9d1ed5edd43161313c83622b Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Mon, 10 Jan 2022 11:22:55 +0100 Subject: [PATCH] [RLlib] Initial code/comment cleanups in preparation for decentralized multi-agent learner. (#21420) --- rllib/agents/impala/impala.py | 2 +- rllib/agents/ppo/appo.py | 6 ---- rllib/agents/trainer.py | 35 +++++++++++++++++-- rllib/evaluation/rollout_worker.py | 24 +++++++------ rllib/evaluation/worker_set.py | 6 ++-- .../buffers/multi_agent_replay_buffer.py | 13 ++++--- rllib/execution/buffers/replay_buffer.py | 15 +++++--- rllib/execution/train_ops.py | 7 ---- rllib/policy/policy.py | 35 +++++++++++++++++-- rllib/policy/rnn_sequencing.py | 2 +- rllib/policy/sample_batch.py | 6 ++-- rllib/policy/torch_policy.py | 2 +- rllib/utils/debug.py | 2 +- rllib/utils/test_utils.py | 15 +++++--- rllib/utils/tf_utils.py | 3 +- rllib/utils/torch_utils.py | 4 +-- rllib/utils/typing.py | 6 +++- 17 files changed, 128 insertions(+), 55 deletions(-) diff --git a/rllib/agents/impala/impala.py b/rllib/agents/impala/impala.py index a0d31e609..12f3b7959 100644 --- a/rllib/agents/impala/impala.py +++ b/rllib/agents/impala/impala.py @@ -366,7 +366,7 @@ class ImpalaTrainer(Trainer): { # Evaluation (remote) workers. # Note: The local eval worker is located on the driver - # CPU. + # CPU or not even created iff >0 eval workers. "CPU": eval_config.get("num_cpus_per_worker", cf["num_cpus_per_worker"]), "GPU": eval_config.get("num_gpus_per_worker", diff --git a/rllib/agents/ppo/appo.py b/rllib/agents/ppo/appo.py index 1c14db3c2..e030de383 100644 --- a/rllib/agents/ppo/appo.py +++ b/rllib/agents/ppo/appo.py @@ -120,12 +120,6 @@ class APPOTrainer(impala.ImpalaTrainer): self.workers.local_worker().foreach_trainable_policy( lambda p, _: p.update_target()) - # TODO: Remove this once ImpalaTrainer directly inherits from Trainer - # (instead of being created by `build_trainer()` utility). - @override(impala.ImpalaTrainer) - def _init(self, *args, **kwargs): - raise NotImplementedError - @classmethod @override(Trainer) def get_default_config(cls) -> TrainerConfigDict: diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index e55ab358e..5b5ed2738 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -29,6 +29,7 @@ from ray.rllib.evaluation.worker_set import WorkerSet from ray.rllib.execution.metric_ops import StandardMetricsReporting from ray.rllib.execution.buffers.multi_agent_replay_buffer import \ MultiAgentReplayBuffer +from ray.rllib.execution.common import WORKER_UPDATE_TIMER from ray.rllib.execution.rollout_ops import ConcatBatches, ParallelRollouts, \ synchronous_parallel_sample from ray.rllib.execution.train_ops import TrainOneStep, MultiGPUTrainOneStep, \ @@ -775,8 +776,8 @@ class Trainer(Trainable): # Set Trainer's seed after we have - if necessary - enabled # tf eager-execution. - update_global_seed_if_necessary( - config.get("framework"), config.get("seed")) + update_global_seed_if_necessary(self.config["framework"], + self.config["seed"]) self.validate_config(self.config) if not callable(self.config["callbacks"]): @@ -844,6 +845,14 @@ class Trainer(Trainable): self.workers, self.config, **self._kwargs_for_execution_plan()) + # TODO: Now that workers have been created, update our policy + # specs in the config[multiagent] dict with the correct spaces. + # However, this leads to a problem with the evaluation + # workers' observation one-hot preprocessor in + # `examples/documentation/rllib_in_6sec.py` script. + # self.config["multiagent"]["policies"] = \ + # self.workers.local_worker().policy_map.policy_specs + # Evaluation WorkerSet setup. # User would like to setup a separate evaluation worker set. @@ -1295,6 +1304,12 @@ class Trainer(Trainable): else: train_results = multi_gpu_train_one_step(self, train_batch) + # Update weights - after learning on the local worker - on all remote + # workers. + if self.workers.remote_workers(): + with self._timers[WORKER_UPDATE_TIMER]: + self.workers.sync_weights() + return train_results @DeveloperAPI @@ -1976,6 +1991,22 @@ class Trainer(Trainable): config2: PartialTrainerConfigDict, _allow_unknown_configs: Optional[bool] = None ) -> TrainerConfigDict: + """Merges a complete Trainer config with a partial override dict. + + Respects nested structures within the config dicts. The values in the + partial override dict take priority. + + Args: + config1: The complete Trainer's dict to be merged (overridden) + with `config2`. + config2: The partial override config dict to merge on top of + `config1`. + _allow_unknown_configs: If True, keys in `config2` that don't exist + in `config1` are allowed and will be added to the final config. + + Returns: + The merged full trainer config dict. + """ config1 = copy.deepcopy(config1) if "callbacks" in config2 and type(config2["callbacks"]) is dict: legacy_callbacks_dict = config2["callbacks"] diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index ef3f4d27e..0531c9c8e 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -6,7 +6,7 @@ import numpy as np import platform import os import tree # pip install dm_tree -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, \ +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, \ TYPE_CHECKING, Union import ray @@ -46,7 +46,7 @@ from ray.rllib.utils.tf_run_builder import TFRunBuilder from ray.rllib.utils.typing import AgentID, EnvConfigDict, EnvType, \ ModelConfigDict, ModelGradients, ModelWeights, \ MultiAgentPolicyConfigDict, PartialTrainerConfigDict, PolicyID, \ - SampleBatchType + SampleBatchType, T from ray.util.debug import log_once, disable_log_once_globally, \ enable_periodic_logging from ray.util.iter import ParallelIteratorWorker @@ -56,9 +56,6 @@ if TYPE_CHECKING: from ray.rllib.evaluation.observation_function import ObservationFunction from ray.rllib.agents.callbacks import DefaultCallbacks # noqa -# Generic type var for foreach_* methods. -T = TypeVar("T") - tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() @@ -1436,19 +1433,26 @@ class RolloutWorker(ParallelIteratorWorker): sess.close() @DeveloperAPI - def apply(self, func: Callable[["RolloutWorker", Optional[Any]], T], - *args) -> T: + def apply( + self, + func: Callable[["RolloutWorker", Optional[Any], Optional[Any]], T], + *args, **kwargs) -> T: """Calls the given function with this rollout worker instance. + Useful for when the RolloutWorker class has been converted into a + ActorHandle and the user needs to execute some functionality (e.g. + add a property) on the underlying policy object. + Args: - func: The function to call with this RolloutWorker as first - argument. + func: The function to call, with this RolloutWorker as first + argument, followed by args, and kwargs. args: Optional additional args to pass to the function call. + kwargs: Optional additional kwargs to pass to the function call. Returns: The return value of the function call. """ - return func(self, *args) + return func(self, *args, **kwargs) def setup_torch_data_parallel(self, url: str, world_rank: int, world_size: int, backend: str) -> None: diff --git a/rllib/evaluation/worker_set.py b/rllib/evaluation/worker_set.py index 341355c67..ddcd21c41 100644 --- a/rllib/evaluation/worker_set.py +++ b/rllib/evaluation/worker_set.py @@ -159,13 +159,15 @@ class WorkerSet: if self.remote_workers() or from_worker is not None: weights = (from_worker or self.local_worker()).get_weights(policies) + # Put weights only once into object store and use same object + # ref to synch to all workers. weights_ref = ray.put(weights) # Sync to all remote workers in this WorkerSet. for to_worker in self.remote_workers(): to_worker.set_weights.remote(weights_ref) - # If from_worker is provided, also sync to this WorkerSet's local - # worker. + # If `from_worker` is provided, also sync to this WorkerSet's + # local worker. if from_worker is not None and self.local_worker() is not None: self.local_worker().set_weights(weights) diff --git a/rllib/execution/buffers/multi_agent_replay_buffer.py b/rllib/execution/buffers/multi_agent_replay_buffer.py index 70997409c..4e7d7d8f8 100644 --- a/rllib/execution/buffers/multi_agent_replay_buffer.py +++ b/rllib/execution/buffers/multi_agent_replay_buffer.py @@ -1,11 +1,11 @@ import collections import platform -from typing import Dict, Any +from typing import Any, Dict import numpy as np import ray from ray.rllib import SampleBatch -from ray.rllib.execution import PrioritizedReplayBuffer +from ray.rllib.execution import PrioritizedReplayBuffer, ReplayBuffer from ray.rllib.execution.buffers.replay_buffer import logger, _ALL_POLICIES from ray.rllib.policy.rnn_sequencing import \ timeslice_along_seq_lens_with_overlap @@ -54,7 +54,7 @@ class MultiAgentReplayBuffer(ParallelIteratorWorker): `self.replay_batch_size` will be set to the number of sequences sampled (B). prioritized_replay_alpha (float): Alpha parameter for a prioritized - replay buffer. + replay buffer. Use 0.0 for no prioritization. prioritized_replay_beta (float): Beta parameter for a prioritized replay buffer. prioritized_replay_eps (float): Epsilon parameter for a prioritized @@ -108,8 +108,11 @@ class MultiAgentReplayBuffer(ParallelIteratorWorker): ParallelIteratorWorker.__init__(self, gen_replay, False) def new_buffer(): - return PrioritizedReplayBuffer( - self.capacity, alpha=prioritized_replay_alpha) + if prioritized_replay_alpha == 0.0: + return ReplayBuffer(self.capacity) + else: + return PrioritizedReplayBuffer( + self.capacity, alpha=prioritized_replay_alpha) self.replay_buffers = collections.defaultdict(new_buffer) diff --git a/rllib/execution/buffers/replay_buffer.py b/rllib/execution/buffers/replay_buffer.py index f4dc9ee58..fc77d8369 100644 --- a/rllib/execution/buffers/replay_buffer.py +++ b/rllib/execution/buffers/replay_buffer.py @@ -52,7 +52,7 @@ class ReplayBuffer: def __init__(self, capacity: int = 10000, size: Optional[int] = DEPRECATED_VALUE): - """Initializes a Replaybuffer instance. + """Initializes a ReplayBuffer instance. Args: capacity: Max number of timesteps to store in the FIFO @@ -84,6 +84,7 @@ class ReplayBuffer: self._est_size_bytes = 0 def __len__(self) -> int: + """Returns the number of items currently stored in this buffer.""" return len(self._storage) @DeveloperAPI @@ -147,7 +148,7 @@ class ReplayBuffer: """Returns the stats of this buffer. Args: - debug: If true, adds sample eviction statistics to the returned + debug: If True, adds sample eviction statistics to the returned stats dict. Returns: @@ -253,7 +254,11 @@ class PrioritizedReplayBuffer(ReplayBuffer): @DeveloperAPI @override(ReplayBuffer) def sample(self, num_items: int, beta: float) -> SampleBatchType: - """Sample a batch of experiences and return priority weights, indices. + """Sample `num_items` items from this buffer, including prio. weights. + + If less than `num_items` records are in this buffer, some samples in + the results may be repeated to fulfil the batch size (`num_items`) + request. Args: num_items: Number of items to sample from this buffer. @@ -272,11 +277,11 @@ class PrioritizedReplayBuffer(ReplayBuffer): weights = [] batch_indexes = [] p_min = self._it_min.min() / self._it_sum.sum() - max_weight = (p_min * len(self._storage))**(-beta) + max_weight = (p_min * len(self))**(-beta) for idx in idxes: p_sample = self._it_sum[idx] / self._it_sum.sum() - weight = (p_sample * len(self._storage))**(-beta) + weight = (p_sample * len(self))**(-beta) count = self._storage[idx].count # If zero-padded, count will not be the actual batch size of the # data. diff --git a/rllib/execution/train_ops.py b/rllib/execution/train_ops.py index 3f448d8d1..8fd57e438 100644 --- a/rllib/execution/train_ops.py +++ b/rllib/execution/train_ops.py @@ -55,13 +55,6 @@ def train_one_step(trainer, train_batch) -> Dict: trainer._counters[NUM_ENV_STEPS_TRAINED] += train_batch.count trainer._counters[NUM_AGENT_STEPS_TRAINED] += train_batch.agent_steps() - # Update weights - after learning on the local worker - on all remote - # workers. - if workers.remote_workers(): - with trainer._timers[WORKER_UPDATE_TIMER]: - weights = ray.put(workers.local_worker().get_weights(policies)) - for e in workers.remote_workers(): - e.set_weights.remote(weights) return info diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index c01f3f051..3419c62da 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -4,8 +4,9 @@ import gym from gym.spaces import Box import logging import numpy as np +import platform import tree # pip install dm_tree -from typing import Dict, List, Optional, Type, TYPE_CHECKING +from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.models.catalog import ModelCatalog @@ -21,7 +22,7 @@ from ray.rllib.utils.from_config import from_config from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space, \ get_dummy_batch_for_space, unbatch from ray.rllib.utils.typing import AgentID, ModelGradients, ModelWeights, \ - TensorType, TensorStructType, TrainerConfigDict, Tuple, Union + T, TensorType, TensorStructType, TrainerConfigDict, Tuple, Union tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() @@ -638,6 +639,27 @@ class Policy(metaclass=ABCMeta): self.set_weights(state["weights"]) self.global_timestep = state["global_timestep"] + @ExperimentalAPI + def apply(self, + func: Callable[["Policy", Optional[Any], Optional[Any]], T], + *args, **kwargs) -> T: + """Calls the given function with this Policy instance. + + Useful for when the Policy class has been converted into a ActorHandle + and the user needs to execute some functionality (e.g. add a property) + on the underlying policy object. + + Args: + func: The function to call, with this Policy as first + argument, followed by args, and kwargs. + args: Optional additional args to pass to the function call. + kwargs: Optional additional kwargs to pass to the function call. + + Returns: + The return value of the function call. + """ + return func(self, *args, **kwargs) + @DeveloperAPI def on_global_var_update(self, global_vars: Dict[str, TensorType]) -> None: """Called on an update to global vars. @@ -697,6 +719,15 @@ class Policy(metaclass=ABCMeta): """ return None + def get_host(self) -> str: + """Returns the computer's network name. + + Returns: + The computer's networks name or an empty string, if the network + name could not be determined. + """ + return platform.node() + def _create_exploration(self) -> Exploration: """Creates the Policy's Exploration object. diff --git a/rllib/policy/rnn_sequencing.py b/rllib/policy/rnn_sequencing.py index 0e4c36570..41884e017 100644 --- a/rllib/policy/rnn_sequencing.py +++ b/rllib/policy/rnn_sequencing.py @@ -294,7 +294,7 @@ def chop_into_sequences( f = np.array(f) length = len(seq_lens) * max_seq_len - if f.dtype == np.object or f.dtype.type is np.str_: + if f.dtype == object or f.dtype.type is np.str_: f_pad = [None] * length else: # Make sure type doesn't change. diff --git a/rllib/policy/sample_batch.py b/rllib/policy/sample_batch.py index a17710ca3..87315d98f 100644 --- a/rllib/policy/sample_batch.py +++ b/rllib/policy/sample_batch.py @@ -620,7 +620,7 @@ class SampleBatch(dict): or path[0] == SampleBatch.SEQ_LENS: return # Generate zero-filled primer of len=max_seq_len. - if value.dtype == np.object or value.dtype.type is np.str_: + if value.dtype == object or value.dtype.type is np.str_: f_pad = [None] * length else: # Make sure type doesn't change. @@ -651,13 +651,13 @@ class SampleBatch(dict): return self - # Experimental method. + @ExperimentalAPI def to_device(self, device, framework="torch"): """TODO: transfer batch to given device as framework tensor.""" if framework == "torch": assert torch is not None for k, v in self.items(): - if isinstance(v, np.ndarray) and v.dtype != np.object: + if isinstance(v, np.ndarray) and v.dtype != object: self[k] = torch.from_numpy(v).to(device) else: raise NotImplementedError diff --git a/rllib/policy/torch_policy.py b/rllib/policy/torch_policy.py index a022147e4..7631ee799 100644 --- a/rllib/policy/torch_policy.py +++ b/rllib/policy/torch_policy.py @@ -1170,7 +1170,7 @@ class EntropyCoeffSchedule: @DeveloperAPI class DirectStepOptimizer: - """Typesafe method for indicating apply gradients can directly step the + """Typesafe method for indicating `apply_gradients` can directly step the optimizers with in-place gradients. """ _instance = None diff --git a/rllib/utils/debug.py b/rllib/utils/debug.py index 90d475cdf..02080e6ba 100644 --- a/rllib/utils/debug.py +++ b/rllib/utils/debug.py @@ -42,7 +42,7 @@ def _summarize(obj): if obj.size == 0: return _StringValue("np.ndarray({}, dtype={})".format( obj.shape, obj.dtype)) - elif obj.dtype == np.object or obj.dtype.type is np.str_: + elif obj.dtype == object or obj.dtype.type is np.str_: return _StringValue("np.ndarray({}, dtype={}, head={})".format( obj.shape, obj.dtype, _summarize(obj[0]))) else: diff --git a/rllib/utils/test_utils.py b/rllib/utils/test_utils.py index df1f608b7..dd1f6e12c 100644 --- a/rllib/utils/test_utils.py +++ b/rllib/utils/test_utils.py @@ -214,7 +214,7 @@ def check(x, y, decimals=5, atol=None, rtol=None, false=False): "ERROR: x ({}) is not the same as y ({})!".format(x, y) # String/byte comparisons. elif hasattr(x, "dtype") and \ - (x.dtype == np.object or str(x.dtype).startswith("