From ea2bea7e309cd60457aa0e027321be5f10fa0fe5 Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Mon, 1 Nov 2021 10:59:53 +0100 Subject: [PATCH] [RLlib; Docs overhaul] Docstring cleanup: Offline. (#19808) --- rllib/evaluation/rollout_worker.py | 7 +- rllib/offline/d4rl_reader.py | 6 +- rllib/offline/input_reader.py | 13 ++-- rllib/offline/io_context.py | 46 +++++++---- rllib/offline/is_estimator.py | 2 +- rllib/offline/json_reader.py | 53 ++++++++----- rllib/offline/json_writer.py | 11 ++- rllib/offline/off_policy_estimator.py | 89 +++++++++++++++++---- rllib/offline/output_writer.py | 8 +- rllib/offline/shuffled_input.py | 6 +- rllib/offline/wis_estimator.py | 108 +++++++++++++------------- 11 files changed, 216 insertions(+), 133 deletions(-) diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index 2ee532096..8fd1ebcf6 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -664,11 +664,12 @@ class RolloutWorker(ParallelIteratorWorker): "will discard all sampler outputs and keep only metrics.") sample_async = True elif method == "is": - ise = ImportanceSamplingEstimator.create(self.io_context) + ise = ImportanceSamplingEstimator.\ + create_from_io_context(self.io_context) self.reward_estimators.append(ise) elif method == "wis": - wise = WeightedImportanceSamplingEstimator.create( - self.io_context) + wise = WeightedImportanceSamplingEstimator.\ + create_from_io_context(self.io_context) self.reward_estimators.append(wise) else: raise ValueError( diff --git a/rllib/offline/d4rl_reader.py b/rllib/offline/d4rl_reader.py index d191d65c6..dae9cccb0 100644 --- a/rllib/offline/d4rl_reader.py +++ b/rllib/offline/d4rl_reader.py @@ -17,11 +17,11 @@ class D4RLReader(InputReader): @PublicAPI def __init__(self, inputs: str, ioctx: IOContext = None): - """Initialize a D4RLReader. + """Initializes a D4RLReader instance. Args: - inputs (str): String corresponding to D4RL environment name - ioctx (IOContext): Current IO context object. + inputs: String corresponding to the D4RL environment name. + ioctx: Current IO context object. """ import d4rl self.env = gym.make(inputs) diff --git a/rllib/offline/input_reader.py b/rllib/offline/input_reader.py index 3b05e4772..12ac65474 100644 --- a/rllib/offline/input_reader.py +++ b/rllib/offline/input_reader.py @@ -16,15 +16,16 @@ logger = logging.getLogger(__name__) @PublicAPI class InputReader(metaclass=ABCMeta): - """Input object for loading experiences in policy evaluation.""" + """API for collecting and returning experiences during policy evaluation. + """ @abstractmethod @PublicAPI - def next(self): - """Returns the next batch of experiences read. + def next(self) -> SampleBatchType: + """Returns the next batch of read experiences. Returns: - Union[SampleBatch, MultiAgentBatch]: The experience read. + The experience read (SampleBatch or MultiAgentBatch). """ raise NotImplementedError @@ -40,7 +41,7 @@ class InputReader(metaclass=ABCMeta): reader repeatedly to feed the TensorFlow queue. Args: - queue_size (int): Max elements to allow in the TF queue. + queue_size: Max elements to allow in the TF queue. Example: >>> class MyModel(rllib.model.Model): @@ -56,7 +57,7 @@ class InputReader(metaclass=ABCMeta): You can find a runnable version of this in examples/custom_loss.py. Returns: - dict of Tensors, one for each column of the read SampleBatch. + Dict of Tensors, one for each column of the read SampleBatch. """ if hasattr(self, "_queue_runner"): diff --git a/rllib/offline/io_context.py b/rllib/offline/io_context.py index f13103b7f..c74db614c 100644 --- a/rllib/offline/io_context.py +++ b/rllib/offline/io_context.py @@ -1,37 +1,53 @@ import os +from typing import Any, Optional, TYPE_CHECKING from ray.rllib.utils.annotations import PublicAPI -from typing import Any +from ray.rllib.utils.typing import TrainerConfigDict + +if TYPE_CHECKING: + from ray.rllib.evaluation.sampler import SamplerInput @PublicAPI class IOContext: - """Attributes to pass to input / output class constructors. + """Class containing attributes to pass to input/output class constructors. - RLlib auto-sets these attributes when constructing input / output classes. - - Attributes: - log_dir (str): Default logging directory. - config (dict): Configuration of the agent. - worker_index (int): When there are multiple workers created, this - uniquely identifies the current worker. - worker (RolloutWorker): RolloutWorker object reference. - input_config (dict): The input configuration for custom input. + RLlib auto-sets these attributes when constructing input/output classes, + such as InputReaders and OutputWriters. """ @PublicAPI def __init__(self, - log_dir: str = None, - config: dict = None, + log_dir: Optional[str] = None, + config: Optional[TrainerConfigDict] = None, worker_index: int = 0, - worker: Any = None): + worker: Optional[Any] = None): + """Initializes a IOContext object. + + Args: + log_dir: The logging directory to read from/write to. + config: The Trainer's main config dict. + worker_index (int): When there are multiple workers created, this + uniquely identifies the current worker. 0 for the local + worker, >0 for any of the remote workers. + worker (RolloutWorker): The RolloutWorker object reference. + """ self.log_dir = log_dir or os.getcwd() self.config = config or {} self.worker_index = worker_index self.worker = worker @PublicAPI - def default_sampler_input(self) -> Any: + def default_sampler_input(self) -> Optional["SamplerInput"]: + """Returns the RolloutWorker's SamplerInput object, if any. + + Returns None if the RolloutWorker has no SamplerInput. Note that local + workers in case there are also one or more remote workers by default + do not create a SamplerInput object. + + Returns: + The RolloutWorkers' SamplerInput object or None if none exists. + """ return self.worker.sampler @PublicAPI diff --git a/rllib/offline/is_estimator.py b/rllib/offline/is_estimator.py index 119eb2e1c..242c5f291 100644 --- a/rllib/offline/is_estimator.py +++ b/rllib/offline/is_estimator.py @@ -14,7 +14,7 @@ class ImportanceSamplingEstimator(OffPolicyEstimator): self.check_can_estimate_for(batch) rewards, old_prob = batch["rewards"], batch["action_prob"] - new_prob = self.action_prob(batch) + new_prob = self.action_log_likelihood(batch) # calculate importance ratios p = [] diff --git a/rllib/offline/json_reader.py b/rllib/offline/json_reader.py index 2177ea27e..01da73d87 100644 --- a/rllib/offline/json_reader.py +++ b/rllib/offline/json_reader.py @@ -5,7 +5,7 @@ import os from pathlib import Path import random import re -from typing import List, Optional +from typing import List, Optional, Union from urllib.parse import urlparse import zipfile @@ -32,17 +32,20 @@ WINDOWS_DRIVES = [chr(i) for i in range(ord("c"), ord("z") + 1)] class JsonReader(InputReader): """Reader object that loads experiences from JSON file chunks. - The input files will be read from in an random order.""" + The input files will be read from in random order. + """ @PublicAPI - def __init__(self, inputs: List[str], ioctx: IOContext = None): - """Initialize a JsonReader. + def __init__(self, + inputs: Union[str, List[str]], + ioctx: Optional[IOContext] = None): + """Initializes a JsonReader instance. Args: - inputs (str|list): Either a glob expression for files, e.g., - "/tmp/**/*.json", or a list of single file paths or URIs, e.g., + inputs: Either a glob expression for files, e.g. `/tmp/**/*.json`, + or a list of single file paths or URIs, e.g., ["s3://bucket/file.json", "s3://bucket/file2.json"]. - ioctx (IOContext): Current IO context object. + ioctx: Current IO context object or None. """ self.ioctx = ioctx or IOContext() @@ -72,8 +75,8 @@ class JsonReader(InputReader): self.files = [] for i in inputs: self.files.extend(glob.glob(i)) - elif type(inputs) is list: - self.files = inputs + elif isinstance(inputs, (list, tuple)): + self.files = list(inputs) else: raise ValueError( "type of inputs must be list or str, not {}".format(inputs)) @@ -98,6 +101,26 @@ class JsonReader(InputReader): return self._postprocess_if_needed(batch) + def read_all_files(self) -> SampleBatchType: + """Reads through all files and yields one SampleBatchType per line. + + When reaching the end of the last file, will start from the beginning + again. + + Yields: + One SampleBatch or MultiAgentBatch per line in all input files. + """ + for path in self.files: + file = self._try_open_file(path) + while True: + line = file.readline() + if not line: + break + batch = self._try_parse(line) + if batch is None: + break + yield batch + def _postprocess_if_needed(self, batch: SampleBatchType) -> SampleBatchType: if not self.ioctx.config.get("postprocess_inputs"): @@ -182,18 +205,6 @@ class JsonReader(InputReader): self.ioctx.worker.policy_map[pid].action_space_struct) return batch - def read_all_files(self): - for path in self.files: - file = self._try_open_file(path) - while True: - line = file.readline() - if not line: - break - batch = self._try_parse(line) - if batch is None: - break - yield batch - def _next_line(self) -> str: if not self.cur_file: self.cur_file = self._next_file() diff --git a/rllib/offline/json_writer.py b/rllib/offline/json_writer.py index d3c849684..77777872d 100644 --- a/rllib/offline/json_writer.py +++ b/rllib/offline/json_writer.py @@ -34,15 +34,14 @@ class JsonWriter(OutputWriter): ioctx: IOContext = None, max_file_size: int = 64 * 1024 * 1024, compress_columns: List[str] = frozenset(["obs", "new_obs"])): - """Initialize a JsonWriter. + """Initializes a JsonWriter instance. Args: - path (str): a path/URI of the output directory to save files in. - ioctx (IOContext): current IO context object. - max_file_size (int): max size of single files before rolling over. - compress_columns (list): list of sample batch columns to compress. + path: a path/URI of the output directory to save files in. + ioctx: current IO context object. + max_file_size: max size of single files before rolling over. + compress_columns: list of sample batch columns to compress. """ - self.ioctx = ioctx or IOContext() self.max_file_size = max_file_size self.compress_columns = compress_columns diff --git a/rllib/offline/off_policy_estimator.py b/rllib/offline/off_policy_estimator.py index 871ad67e0..e04f82386 100644 --- a/rllib/offline/off_policy_estimator.py +++ b/rllib/offline/off_policy_estimator.py @@ -7,6 +7,7 @@ from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch from ray.rllib.policy import Policy from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.offline.io_context import IOContext +from ray.rllib.utils.annotations import Deprecated from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.typing import TensorType, SampleBatchType from typing import List @@ -23,19 +24,30 @@ class OffPolicyEstimator: @DeveloperAPI def __init__(self, policy: Policy, gamma: float): - """Creates an off-policy estimator. + """Initializes an OffPolicyEstimator instance. Args: - policy (Policy): Policy to evaluate. - gamma (float): Discount of the MDP. + policy: Policy to evaluate. + gamma: Discount factor of the environment. """ self.policy = policy self.gamma = gamma self.new_estimates = [] @classmethod - def create(cls, ioctx: IOContext) -> "OffPolicyEstimator": - """Create an off-policy estimator from a IOContext.""" + def create_from_io_context(cls, ioctx: IOContext) -> "OffPolicyEstimator": + """Creates an off-policy estimator from an IOContext object. + + Extracts Policy and gamma (discount factor) information from the + IOContext. + + Args: + ioctx: The IOContext object to create the OffPolicyEstimator + from. + + Returns: + The OffPolicyEstimator object created from the IOContext object. + """ gamma = ioctx.worker.policy_config["gamma"] # Grab a reference to the current model keys = list(ioctx.worker.policy_map.keys()) @@ -47,18 +59,36 @@ class OffPolicyEstimator: return cls(policy, gamma) @DeveloperAPI - def estimate(self, batch: SampleBatchType): - """Returns an estimate for the given batch of experiences. + def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate: + """Returns an off policy estimate for the given batch of experiences. - The batch will only contain data from one episode, but it may only be - a fragment of an episode. + The batch will at most only contain data from one episode, + but it may also only be a fragment of an episode. + + Args: + batch: The batch to calculate the off policy estimate (OPE) on. + + Returns: + The off-policy estimates (OPE) calculated on the given batch. """ raise NotImplementedError @DeveloperAPI - def action_prob(self, batch: SampleBatchType) -> np.ndarray: - """Returns the probs for the batch actions for the current policy.""" + def action_log_likelihood(self, batch: SampleBatchType) -> TensorType: + """Returns log likelihoods for actions in given batch for policy. + Computes likelihoods by passing the observations through the current + policy's `compute_log_likelihoods()` method. + + Args: + batch: The SampleBatch or MultiAgentBatch to calculate action + log likelihoods from. This batch/batches must contain OBS + and ACTIONS keys. + + Returns: + The log likelihoods of the actions in the batch, given the + observations and the policy. + """ num_state_inputs = 0 for k in batch.keys(): if k.startswith("state_in_"): @@ -66,7 +96,7 @@ class OffPolicyEstimator: state_keys = ["state_in_{}".format(i) for i in range(num_state_inputs)] log_likelihoods: TensorType = self.policy.compute_log_likelihoods( actions=batch[SampleBatch.ACTIONS], - obs_batch=batch[SampleBatch.CUR_OBS], + obs_batch=batch[SampleBatch.OBS], state_batches=[batch[k] for k in state_keys], prev_action_batch=batch.get(SampleBatch.PREV_ACTIONS), prev_reward_batch=batch.get(SampleBatch.PREV_REWARDS), @@ -76,12 +106,29 @@ class OffPolicyEstimator: return np.exp(log_likelihoods) @DeveloperAPI - def process(self, batch: SampleBatchType): + def process(self, batch: SampleBatchType) -> None: + """Computes off policy estimates (OPE) on batch and stores results. + + Thus-far collected results can be retrieved then by calling + `self.get_metrics` (which flushes the internal results storage). + + Args: + batch: The batch to process (call `self.estimate()` on) and + store results (OPEs) for. + """ self.new_estimates.append(self.estimate(batch)) @DeveloperAPI - def check_can_estimate_for(self, batch: SampleBatchType): - """Returns whether we can support OPE for this batch.""" + def check_can_estimate_for(self, batch: SampleBatchType) -> None: + """Checks if we support off policy estimation (OPE) on given batch. + + Args: + batch: The batch to check. + + Raises: + ValueError: In case `action_prob` key is not in batch OR batch + is a MultiAgentBatch. + """ if isinstance(batch, MultiAgentBatch): raise ValueError( @@ -98,11 +145,19 @@ class OffPolicyEstimator: @DeveloperAPI def get_metrics(self) -> List[OffPolicyEstimate]: - """Return a list of new episode metric estimates since the last call. + """Returns list of new episode metric estimates since the last call. Returns: - list of OffPolicyEstimate objects. + List of OffPolicyEstimate objects. """ out = self.new_estimates self.new_estimates = [] return out + + @Deprecated(new="OffPolicyEstimator.create_from_io_context", error=False) + def create(self, *args, **kwargs): + return self.create_from_io_context(*args, **kwargs) + + @Deprecated(new="OffPolicyEstimator.action_log_likelihood", error=False) + def action_prob(self, *args, **kwargs): + return self.action_log_likelihood(*args, **kwargs) diff --git a/rllib/offline/output_writer.py b/rllib/offline/output_writer.py index 8d168dfb4..2389c3d74 100644 --- a/rllib/offline/output_writer.py +++ b/rllib/offline/output_writer.py @@ -1,15 +1,14 @@ -from ray.rllib.utils.annotations import override -from ray.rllib.utils.annotations import PublicAPI +from ray.rllib.utils.annotations import override, PublicAPI from ray.rllib.utils.typing import SampleBatchType @PublicAPI class OutputWriter: - """Writer object for saving experiences from policy evaluation.""" + """Writer API for saving experiences from policy evaluation.""" @PublicAPI def write(self, sample_batch: SampleBatchType): - """Save a batch of experiences. + """Saves a batch of experiences. Args: sample_batch: SampleBatch or MultiAgentBatch to save. @@ -22,4 +21,5 @@ class NoopOutput(OutputWriter): @override(OutputWriter) def write(self, sample_batch: SampleBatchType): + # Do nothing. pass diff --git a/rllib/offline/shuffled_input.py b/rllib/offline/shuffled_input.py index 24522c87a..a7c261018 100644 --- a/rllib/offline/shuffled_input.py +++ b/rllib/offline/shuffled_input.py @@ -18,11 +18,11 @@ class ShuffledInput(InputReader): @DeveloperAPI def __init__(self, child: InputReader, n: int = 0): - """Initialize a MixedInput. + """Initializes a ShuffledInput instance. Args: - child (InputReader): child input reader to shuffle. - n (int): if positive, shuffle input over this many batches. + child: child input reader to shuffle. + n: If positive, shuffle input over this many batches. """ self.n = n self.child = child diff --git a/rllib/offline/wis_estimator.py b/rllib/offline/wis_estimator.py index 74eb342a4..00bbf3145 100644 --- a/rllib/offline/wis_estimator.py +++ b/rllib/offline/wis_estimator.py @@ -1,54 +1,54 @@ -from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, \ - OffPolicyEstimate -from ray.rllib.policy import Policy -from ray.rllib.utils.annotations import override -from ray.rllib.utils.typing import SampleBatchType - - -class WeightedImportanceSamplingEstimator(OffPolicyEstimator): - """The weighted step-wise IS estimator. - - Step-wise WIS estimator in https://arxiv.org/pdf/1511.03722.pdf""" - - def __init__(self, policy: Policy, gamma: float): - super().__init__(policy, gamma) - self.filter_values = [] - self.filter_counts = [] - - @override(OffPolicyEstimator) - def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate: - self.check_can_estimate_for(batch) - - rewards, old_prob = batch["rewards"], batch["action_prob"] - new_prob = self.action_prob(batch) - - # calculate importance ratios - p = [] - for t in range(batch.count): - if t == 0: - pt_prev = 1.0 - else: - pt_prev = p[t - 1] - p.append(pt_prev * new_prob[t] / old_prob[t]) - for t, v in enumerate(p): - if t >= len(self.filter_values): - self.filter_values.append(v) - self.filter_counts.append(1.0) - else: - self.filter_values[t] += v - self.filter_counts[t] += 1.0 - - # calculate stepwise weighted IS estimate - V_prev, V_step_WIS = 0.0, 0.0 - for t in range(batch.count): - V_prev += rewards[t] * self.gamma**t - w_t = self.filter_values[t] / self.filter_counts[t] - V_step_WIS += p[t] / w_t * rewards[t] * self.gamma**t - - estimation = OffPolicyEstimate( - "wis", { - "V_prev": V_prev, - "V_step_WIS": V_step_WIS, - "V_gain_est": V_step_WIS / max(1e-8, V_prev), - }) - return estimation +from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, \ + OffPolicyEstimate +from ray.rllib.policy import Policy +from ray.rllib.utils.annotations import override +from ray.rllib.utils.typing import SampleBatchType + + +class WeightedImportanceSamplingEstimator(OffPolicyEstimator): + """The weighted step-wise IS estimator. + + Step-wise WIS estimator in https://arxiv.org/pdf/1511.03722.pdf""" + + def __init__(self, policy: Policy, gamma: float): + super().__init__(policy, gamma) + self.filter_values = [] + self.filter_counts = [] + + @override(OffPolicyEstimator) + def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate: + self.check_can_estimate_for(batch) + + rewards, old_prob = batch["rewards"], batch["action_prob"] + new_prob = self.action_log_likelihood(batch) + + # calculate importance ratios + p = [] + for t in range(batch.count): + if t == 0: + pt_prev = 1.0 + else: + pt_prev = p[t - 1] + p.append(pt_prev * new_prob[t] / old_prob[t]) + for t, v in enumerate(p): + if t >= len(self.filter_values): + self.filter_values.append(v) + self.filter_counts.append(1.0) + else: + self.filter_values[t] += v + self.filter_counts[t] += 1.0 + + # calculate stepwise weighted IS estimate + V_prev, V_step_WIS = 0.0, 0.0 + for t in range(batch.count): + V_prev += rewards[t] * self.gamma**t + w_t = self.filter_values[t] / self.filter_counts[t] + V_step_WIS += p[t] / w_t * rewards[t] * self.gamma**t + + estimation = OffPolicyEstimate( + "wis", { + "V_prev": V_prev, + "V_step_WIS": V_step_WIS, + "V_gain_est": V_step_WIS / max(1e-8, V_prev), + }) + return estimation