[RLlib; Docs overhaul] Docstring cleanup: Offline. (#19808)

This commit is contained in:
Sven Mika 2021-11-01 10:59:53 +01:00 committed by GitHub
parent 7a2e9e00e8
commit ea2bea7e30
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 216 additions and 133 deletions

View file

@ -664,11 +664,12 @@ class RolloutWorker(ParallelIteratorWorker):
"will discard all sampler outputs and keep only metrics.")
sample_async = True
elif method == "is":
ise = ImportanceSamplingEstimator.create(self.io_context)
ise = ImportanceSamplingEstimator.\
create_from_io_context(self.io_context)
self.reward_estimators.append(ise)
elif method == "wis":
wise = WeightedImportanceSamplingEstimator.create(
self.io_context)
wise = WeightedImportanceSamplingEstimator.\
create_from_io_context(self.io_context)
self.reward_estimators.append(wise)
else:
raise ValueError(

View file

@ -17,11 +17,11 @@ class D4RLReader(InputReader):
@PublicAPI
def __init__(self, inputs: str, ioctx: IOContext = None):
"""Initialize a D4RLReader.
"""Initializes a D4RLReader instance.
Args:
inputs (str): String corresponding to D4RL environment name
ioctx (IOContext): Current IO context object.
inputs: String corresponding to the D4RL environment name.
ioctx: Current IO context object.
"""
import d4rl
self.env = gym.make(inputs)

View file

@ -16,15 +16,16 @@ logger = logging.getLogger(__name__)
@PublicAPI
class InputReader(metaclass=ABCMeta):
"""Input object for loading experiences in policy evaluation."""
"""API for collecting and returning experiences during policy evaluation.
"""
@abstractmethod
@PublicAPI
def next(self):
"""Returns the next batch of experiences read.
def next(self) -> SampleBatchType:
"""Returns the next batch of read experiences.
Returns:
Union[SampleBatch, MultiAgentBatch]: The experience read.
The experience read (SampleBatch or MultiAgentBatch).
"""
raise NotImplementedError
@ -40,7 +41,7 @@ class InputReader(metaclass=ABCMeta):
reader repeatedly to feed the TensorFlow queue.
Args:
queue_size (int): Max elements to allow in the TF queue.
queue_size: Max elements to allow in the TF queue.
Example:
>>> class MyModel(rllib.model.Model):
@ -56,7 +57,7 @@ class InputReader(metaclass=ABCMeta):
You can find a runnable version of this in examples/custom_loss.py.
Returns:
dict of Tensors, one for each column of the read SampleBatch.
Dict of Tensors, one for each column of the read SampleBatch.
"""
if hasattr(self, "_queue_runner"):

View file

@ -1,37 +1,53 @@
import os
from typing import Any, Optional, TYPE_CHECKING
from ray.rllib.utils.annotations import PublicAPI
from typing import Any
from ray.rllib.utils.typing import TrainerConfigDict
if TYPE_CHECKING:
from ray.rllib.evaluation.sampler import SamplerInput
@PublicAPI
class IOContext:
"""Attributes to pass to input / output class constructors.
"""Class containing attributes to pass to input/output class constructors.
RLlib auto-sets these attributes when constructing input / output classes.
Attributes:
log_dir (str): Default logging directory.
config (dict): Configuration of the agent.
worker_index (int): When there are multiple workers created, this
uniquely identifies the current worker.
worker (RolloutWorker): RolloutWorker object reference.
input_config (dict): The input configuration for custom input.
RLlib auto-sets these attributes when constructing input/output classes,
such as InputReaders and OutputWriters.
"""
@PublicAPI
def __init__(self,
log_dir: str = None,
config: dict = None,
log_dir: Optional[str] = None,
config: Optional[TrainerConfigDict] = None,
worker_index: int = 0,
worker: Any = None):
worker: Optional[Any] = None):
"""Initializes a IOContext object.
Args:
log_dir: The logging directory to read from/write to.
config: The Trainer's main config dict.
worker_index (int): When there are multiple workers created, this
uniquely identifies the current worker. 0 for the local
worker, >0 for any of the remote workers.
worker (RolloutWorker): The RolloutWorker object reference.
"""
self.log_dir = log_dir or os.getcwd()
self.config = config or {}
self.worker_index = worker_index
self.worker = worker
@PublicAPI
def default_sampler_input(self) -> Any:
def default_sampler_input(self) -> Optional["SamplerInput"]:
"""Returns the RolloutWorker's SamplerInput object, if any.
Returns None if the RolloutWorker has no SamplerInput. Note that local
workers in case there are also one or more remote workers by default
do not create a SamplerInput object.
Returns:
The RolloutWorkers' SamplerInput object or None if none exists.
"""
return self.worker.sampler
@PublicAPI

View file

@ -14,7 +14,7 @@ class ImportanceSamplingEstimator(OffPolicyEstimator):
self.check_can_estimate_for(batch)
rewards, old_prob = batch["rewards"], batch["action_prob"]
new_prob = self.action_prob(batch)
new_prob = self.action_log_likelihood(batch)
# calculate importance ratios
p = []

View file

@ -5,7 +5,7 @@ import os
from pathlib import Path
import random
import re
from typing import List, Optional
from typing import List, Optional, Union
from urllib.parse import urlparse
import zipfile
@ -32,17 +32,20 @@ WINDOWS_DRIVES = [chr(i) for i in range(ord("c"), ord("z") + 1)]
class JsonReader(InputReader):
"""Reader object that loads experiences from JSON file chunks.
The input files will be read from in an random order."""
The input files will be read from in random order.
"""
@PublicAPI
def __init__(self, inputs: List[str], ioctx: IOContext = None):
"""Initialize a JsonReader.
def __init__(self,
inputs: Union[str, List[str]],
ioctx: Optional[IOContext] = None):
"""Initializes a JsonReader instance.
Args:
inputs (str|list): Either a glob expression for files, e.g.,
"/tmp/**/*.json", or a list of single file paths or URIs, e.g.,
inputs: Either a glob expression for files, e.g. `/tmp/**/*.json`,
or a list of single file paths or URIs, e.g.,
["s3://bucket/file.json", "s3://bucket/file2.json"].
ioctx (IOContext): Current IO context object.
ioctx: Current IO context object or None.
"""
self.ioctx = ioctx or IOContext()
@ -72,8 +75,8 @@ class JsonReader(InputReader):
self.files = []
for i in inputs:
self.files.extend(glob.glob(i))
elif type(inputs) is list:
self.files = inputs
elif isinstance(inputs, (list, tuple)):
self.files = list(inputs)
else:
raise ValueError(
"type of inputs must be list or str, not {}".format(inputs))
@ -98,6 +101,26 @@ class JsonReader(InputReader):
return self._postprocess_if_needed(batch)
def read_all_files(self) -> SampleBatchType:
"""Reads through all files and yields one SampleBatchType per line.
When reaching the end of the last file, will start from the beginning
again.
Yields:
One SampleBatch or MultiAgentBatch per line in all input files.
"""
for path in self.files:
file = self._try_open_file(path)
while True:
line = file.readline()
if not line:
break
batch = self._try_parse(line)
if batch is None:
break
yield batch
def _postprocess_if_needed(self,
batch: SampleBatchType) -> SampleBatchType:
if not self.ioctx.config.get("postprocess_inputs"):
@ -182,18 +205,6 @@ class JsonReader(InputReader):
self.ioctx.worker.policy_map[pid].action_space_struct)
return batch
def read_all_files(self):
for path in self.files:
file = self._try_open_file(path)
while True:
line = file.readline()
if not line:
break
batch = self._try_parse(line)
if batch is None:
break
yield batch
def _next_line(self) -> str:
if not self.cur_file:
self.cur_file = self._next_file()

View file

@ -34,15 +34,14 @@ class JsonWriter(OutputWriter):
ioctx: IOContext = None,
max_file_size: int = 64 * 1024 * 1024,
compress_columns: List[str] = frozenset(["obs", "new_obs"])):
"""Initialize a JsonWriter.
"""Initializes a JsonWriter instance.
Args:
path (str): a path/URI of the output directory to save files in.
ioctx (IOContext): current IO context object.
max_file_size (int): max size of single files before rolling over.
compress_columns (list): list of sample batch columns to compress.
path: a path/URI of the output directory to save files in.
ioctx: current IO context object.
max_file_size: max size of single files before rolling over.
compress_columns: list of sample batch columns to compress.
"""
self.ioctx = ioctx or IOContext()
self.max_file_size = max_file_size
self.compress_columns = compress_columns

View file

@ -7,6 +7,7 @@ from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
from ray.rllib.policy import Policy
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.offline.io_context import IOContext
from ray.rllib.utils.annotations import Deprecated
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.typing import TensorType, SampleBatchType
from typing import List
@ -23,19 +24,30 @@ class OffPolicyEstimator:
@DeveloperAPI
def __init__(self, policy: Policy, gamma: float):
"""Creates an off-policy estimator.
"""Initializes an OffPolicyEstimator instance.
Args:
policy (Policy): Policy to evaluate.
gamma (float): Discount of the MDP.
policy: Policy to evaluate.
gamma: Discount factor of the environment.
"""
self.policy = policy
self.gamma = gamma
self.new_estimates = []
@classmethod
def create(cls, ioctx: IOContext) -> "OffPolicyEstimator":
"""Create an off-policy estimator from a IOContext."""
def create_from_io_context(cls, ioctx: IOContext) -> "OffPolicyEstimator":
"""Creates an off-policy estimator from an IOContext object.
Extracts Policy and gamma (discount factor) information from the
IOContext.
Args:
ioctx: The IOContext object to create the OffPolicyEstimator
from.
Returns:
The OffPolicyEstimator object created from the IOContext object.
"""
gamma = ioctx.worker.policy_config["gamma"]
# Grab a reference to the current model
keys = list(ioctx.worker.policy_map.keys())
@ -47,18 +59,36 @@ class OffPolicyEstimator:
return cls(policy, gamma)
@DeveloperAPI
def estimate(self, batch: SampleBatchType):
"""Returns an estimate for the given batch of experiences.
def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate:
"""Returns an off policy estimate for the given batch of experiences.
The batch will only contain data from one episode, but it may only be
a fragment of an episode.
The batch will at most only contain data from one episode,
but it may also only be a fragment of an episode.
Args:
batch: The batch to calculate the off policy estimate (OPE) on.
Returns:
The off-policy estimates (OPE) calculated on the given batch.
"""
raise NotImplementedError
@DeveloperAPI
def action_prob(self, batch: SampleBatchType) -> np.ndarray:
"""Returns the probs for the batch actions for the current policy."""
def action_log_likelihood(self, batch: SampleBatchType) -> TensorType:
"""Returns log likelihoods for actions in given batch for policy.
Computes likelihoods by passing the observations through the current
policy's `compute_log_likelihoods()` method.
Args:
batch: The SampleBatch or MultiAgentBatch to calculate action
log likelihoods from. This batch/batches must contain OBS
and ACTIONS keys.
Returns:
The log likelihoods of the actions in the batch, given the
observations and the policy.
"""
num_state_inputs = 0
for k in batch.keys():
if k.startswith("state_in_"):
@ -66,7 +96,7 @@ class OffPolicyEstimator:
state_keys = ["state_in_{}".format(i) for i in range(num_state_inputs)]
log_likelihoods: TensorType = self.policy.compute_log_likelihoods(
actions=batch[SampleBatch.ACTIONS],
obs_batch=batch[SampleBatch.CUR_OBS],
obs_batch=batch[SampleBatch.OBS],
state_batches=[batch[k] for k in state_keys],
prev_action_batch=batch.get(SampleBatch.PREV_ACTIONS),
prev_reward_batch=batch.get(SampleBatch.PREV_REWARDS),
@ -76,12 +106,29 @@ class OffPolicyEstimator:
return np.exp(log_likelihoods)
@DeveloperAPI
def process(self, batch: SampleBatchType):
def process(self, batch: SampleBatchType) -> None:
"""Computes off policy estimates (OPE) on batch and stores results.
Thus-far collected results can be retrieved then by calling
`self.get_metrics` (which flushes the internal results storage).
Args:
batch: The batch to process (call `self.estimate()` on) and
store results (OPEs) for.
"""
self.new_estimates.append(self.estimate(batch))
@DeveloperAPI
def check_can_estimate_for(self, batch: SampleBatchType):
"""Returns whether we can support OPE for this batch."""
def check_can_estimate_for(self, batch: SampleBatchType) -> None:
"""Checks if we support off policy estimation (OPE) on given batch.
Args:
batch: The batch to check.
Raises:
ValueError: In case `action_prob` key is not in batch OR batch
is a MultiAgentBatch.
"""
if isinstance(batch, MultiAgentBatch):
raise ValueError(
@ -98,11 +145,19 @@ class OffPolicyEstimator:
@DeveloperAPI
def get_metrics(self) -> List[OffPolicyEstimate]:
"""Return a list of new episode metric estimates since the last call.
"""Returns list of new episode metric estimates since the last call.
Returns:
list of OffPolicyEstimate objects.
List of OffPolicyEstimate objects.
"""
out = self.new_estimates
self.new_estimates = []
return out
@Deprecated(new="OffPolicyEstimator.create_from_io_context", error=False)
def create(self, *args, **kwargs):
return self.create_from_io_context(*args, **kwargs)
@Deprecated(new="OffPolicyEstimator.action_log_likelihood", error=False)
def action_prob(self, *args, **kwargs):
return self.action_log_likelihood(*args, **kwargs)

View file

@ -1,15 +1,14 @@
from ray.rllib.utils.annotations import override
from ray.rllib.utils.annotations import PublicAPI
from ray.rllib.utils.annotations import override, PublicAPI
from ray.rllib.utils.typing import SampleBatchType
@PublicAPI
class OutputWriter:
"""Writer object for saving experiences from policy evaluation."""
"""Writer API for saving experiences from policy evaluation."""
@PublicAPI
def write(self, sample_batch: SampleBatchType):
"""Save a batch of experiences.
"""Saves a batch of experiences.
Args:
sample_batch: SampleBatch or MultiAgentBatch to save.
@ -22,4 +21,5 @@ class NoopOutput(OutputWriter):
@override(OutputWriter)
def write(self, sample_batch: SampleBatchType):
# Do nothing.
pass

View file

@ -18,11 +18,11 @@ class ShuffledInput(InputReader):
@DeveloperAPI
def __init__(self, child: InputReader, n: int = 0):
"""Initialize a MixedInput.
"""Initializes a ShuffledInput instance.
Args:
child (InputReader): child input reader to shuffle.
n (int): if positive, shuffle input over this many batches.
child: child input reader to shuffle.
n: If positive, shuffle input over this many batches.
"""
self.n = n
self.child = child

View file

@ -1,54 +1,54 @@
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, \
OffPolicyEstimate
from ray.rllib.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import SampleBatchType
class WeightedImportanceSamplingEstimator(OffPolicyEstimator):
"""The weighted step-wise IS estimator.
Step-wise WIS estimator in https://arxiv.org/pdf/1511.03722.pdf"""
def __init__(self, policy: Policy, gamma: float):
super().__init__(policy, gamma)
self.filter_values = []
self.filter_counts = []
@override(OffPolicyEstimator)
def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate:
self.check_can_estimate_for(batch)
rewards, old_prob = batch["rewards"], batch["action_prob"]
new_prob = self.action_prob(batch)
# calculate importance ratios
p = []
for t in range(batch.count):
if t == 0:
pt_prev = 1.0
else:
pt_prev = p[t - 1]
p.append(pt_prev * new_prob[t] / old_prob[t])
for t, v in enumerate(p):
if t >= len(self.filter_values):
self.filter_values.append(v)
self.filter_counts.append(1.0)
else:
self.filter_values[t] += v
self.filter_counts[t] += 1.0
# calculate stepwise weighted IS estimate
V_prev, V_step_WIS = 0.0, 0.0
for t in range(batch.count):
V_prev += rewards[t] * self.gamma**t
w_t = self.filter_values[t] / self.filter_counts[t]
V_step_WIS += p[t] / w_t * rewards[t] * self.gamma**t
estimation = OffPolicyEstimate(
"wis", {
"V_prev": V_prev,
"V_step_WIS": V_step_WIS,
"V_gain_est": V_step_WIS / max(1e-8, V_prev),
})
return estimation
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, \
OffPolicyEstimate
from ray.rllib.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import SampleBatchType
class WeightedImportanceSamplingEstimator(OffPolicyEstimator):
"""The weighted step-wise IS estimator.
Step-wise WIS estimator in https://arxiv.org/pdf/1511.03722.pdf"""
def __init__(self, policy: Policy, gamma: float):
super().__init__(policy, gamma)
self.filter_values = []
self.filter_counts = []
@override(OffPolicyEstimator)
def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate:
self.check_can_estimate_for(batch)
rewards, old_prob = batch["rewards"], batch["action_prob"]
new_prob = self.action_log_likelihood(batch)
# calculate importance ratios
p = []
for t in range(batch.count):
if t == 0:
pt_prev = 1.0
else:
pt_prev = p[t - 1]
p.append(pt_prev * new_prob[t] / old_prob[t])
for t, v in enumerate(p):
if t >= len(self.filter_values):
self.filter_values.append(v)
self.filter_counts.append(1.0)
else:
self.filter_values[t] += v
self.filter_counts[t] += 1.0
# calculate stepwise weighted IS estimate
V_prev, V_step_WIS = 0.0, 0.0
for t in range(batch.count):
V_prev += rewards[t] * self.gamma**t
w_t = self.filter_values[t] / self.filter_counts[t]
V_step_WIS += p[t] / w_t * rewards[t] * self.gamma**t
estimation = OffPolicyEstimate(
"wis", {
"V_prev": V_prev,
"V_step_WIS": V_step_WIS,
"V_gain_est": V_step_WIS / max(1e-8, V_prev),
})
return estimation