[RLlib; Docs overhaul] Docstring cleanup: Offline. (#19808)

This commit is contained in:
Sven Mika 2021-11-01 10:59:53 +01:00 committed by GitHub
parent 7a2e9e00e8
commit ea2bea7e30
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 216 additions and 133 deletions

View file

@ -664,11 +664,12 @@ class RolloutWorker(ParallelIteratorWorker):
"will discard all sampler outputs and keep only metrics.") "will discard all sampler outputs and keep only metrics.")
sample_async = True sample_async = True
elif method == "is": elif method == "is":
ise = ImportanceSamplingEstimator.create(self.io_context) ise = ImportanceSamplingEstimator.\
create_from_io_context(self.io_context)
self.reward_estimators.append(ise) self.reward_estimators.append(ise)
elif method == "wis": elif method == "wis":
wise = WeightedImportanceSamplingEstimator.create( wise = WeightedImportanceSamplingEstimator.\
self.io_context) create_from_io_context(self.io_context)
self.reward_estimators.append(wise) self.reward_estimators.append(wise)
else: else:
raise ValueError( raise ValueError(

View file

@ -17,11 +17,11 @@ class D4RLReader(InputReader):
@PublicAPI @PublicAPI
def __init__(self, inputs: str, ioctx: IOContext = None): def __init__(self, inputs: str, ioctx: IOContext = None):
"""Initialize a D4RLReader. """Initializes a D4RLReader instance.
Args: Args:
inputs (str): String corresponding to D4RL environment name inputs: String corresponding to the D4RL environment name.
ioctx (IOContext): Current IO context object. ioctx: Current IO context object.
""" """
import d4rl import d4rl
self.env = gym.make(inputs) self.env = gym.make(inputs)

View file

@ -16,15 +16,16 @@ logger = logging.getLogger(__name__)
@PublicAPI @PublicAPI
class InputReader(metaclass=ABCMeta): class InputReader(metaclass=ABCMeta):
"""Input object for loading experiences in policy evaluation.""" """API for collecting and returning experiences during policy evaluation.
"""
@abstractmethod @abstractmethod
@PublicAPI @PublicAPI
def next(self): def next(self) -> SampleBatchType:
"""Returns the next batch of experiences read. """Returns the next batch of read experiences.
Returns: Returns:
Union[SampleBatch, MultiAgentBatch]: The experience read. The experience read (SampleBatch or MultiAgentBatch).
""" """
raise NotImplementedError raise NotImplementedError
@ -40,7 +41,7 @@ class InputReader(metaclass=ABCMeta):
reader repeatedly to feed the TensorFlow queue. reader repeatedly to feed the TensorFlow queue.
Args: Args:
queue_size (int): Max elements to allow in the TF queue. queue_size: Max elements to allow in the TF queue.
Example: Example:
>>> class MyModel(rllib.model.Model): >>> class MyModel(rllib.model.Model):
@ -56,7 +57,7 @@ class InputReader(metaclass=ABCMeta):
You can find a runnable version of this in examples/custom_loss.py. You can find a runnable version of this in examples/custom_loss.py.
Returns: Returns:
dict of Tensors, one for each column of the read SampleBatch. Dict of Tensors, one for each column of the read SampleBatch.
""" """
if hasattr(self, "_queue_runner"): if hasattr(self, "_queue_runner"):

View file

@ -1,37 +1,53 @@
import os import os
from typing import Any, Optional, TYPE_CHECKING
from ray.rllib.utils.annotations import PublicAPI from ray.rllib.utils.annotations import PublicAPI
from typing import Any from ray.rllib.utils.typing import TrainerConfigDict
if TYPE_CHECKING:
from ray.rllib.evaluation.sampler import SamplerInput
@PublicAPI @PublicAPI
class IOContext: class IOContext:
"""Attributes to pass to input / output class constructors. """Class containing attributes to pass to input/output class constructors.
RLlib auto-sets these attributes when constructing input / output classes. RLlib auto-sets these attributes when constructing input/output classes,
such as InputReaders and OutputWriters.
Attributes:
log_dir (str): Default logging directory.
config (dict): Configuration of the agent.
worker_index (int): When there are multiple workers created, this
uniquely identifies the current worker.
worker (RolloutWorker): RolloutWorker object reference.
input_config (dict): The input configuration for custom input.
""" """
@PublicAPI @PublicAPI
def __init__(self, def __init__(self,
log_dir: str = None, log_dir: Optional[str] = None,
config: dict = None, config: Optional[TrainerConfigDict] = None,
worker_index: int = 0, worker_index: int = 0,
worker: Any = None): worker: Optional[Any] = None):
"""Initializes a IOContext object.
Args:
log_dir: The logging directory to read from/write to.
config: The Trainer's main config dict.
worker_index (int): When there are multiple workers created, this
uniquely identifies the current worker. 0 for the local
worker, >0 for any of the remote workers.
worker (RolloutWorker): The RolloutWorker object reference.
"""
self.log_dir = log_dir or os.getcwd() self.log_dir = log_dir or os.getcwd()
self.config = config or {} self.config = config or {}
self.worker_index = worker_index self.worker_index = worker_index
self.worker = worker self.worker = worker
@PublicAPI @PublicAPI
def default_sampler_input(self) -> Any: def default_sampler_input(self) -> Optional["SamplerInput"]:
"""Returns the RolloutWorker's SamplerInput object, if any.
Returns None if the RolloutWorker has no SamplerInput. Note that local
workers in case there are also one or more remote workers by default
do not create a SamplerInput object.
Returns:
The RolloutWorkers' SamplerInput object or None if none exists.
"""
return self.worker.sampler return self.worker.sampler
@PublicAPI @PublicAPI

View file

@ -14,7 +14,7 @@ class ImportanceSamplingEstimator(OffPolicyEstimator):
self.check_can_estimate_for(batch) self.check_can_estimate_for(batch)
rewards, old_prob = batch["rewards"], batch["action_prob"] rewards, old_prob = batch["rewards"], batch["action_prob"]
new_prob = self.action_prob(batch) new_prob = self.action_log_likelihood(batch)
# calculate importance ratios # calculate importance ratios
p = [] p = []

View file

@ -5,7 +5,7 @@ import os
from pathlib import Path from pathlib import Path
import random import random
import re import re
from typing import List, Optional from typing import List, Optional, Union
from urllib.parse import urlparse from urllib.parse import urlparse
import zipfile import zipfile
@ -32,17 +32,20 @@ WINDOWS_DRIVES = [chr(i) for i in range(ord("c"), ord("z") + 1)]
class JsonReader(InputReader): class JsonReader(InputReader):
"""Reader object that loads experiences from JSON file chunks. """Reader object that loads experiences from JSON file chunks.
The input files will be read from in an random order.""" The input files will be read from in random order.
"""
@PublicAPI @PublicAPI
def __init__(self, inputs: List[str], ioctx: IOContext = None): def __init__(self,
"""Initialize a JsonReader. inputs: Union[str, List[str]],
ioctx: Optional[IOContext] = None):
"""Initializes a JsonReader instance.
Args: Args:
inputs (str|list): Either a glob expression for files, e.g., inputs: Either a glob expression for files, e.g. `/tmp/**/*.json`,
"/tmp/**/*.json", or a list of single file paths or URIs, e.g., or a list of single file paths or URIs, e.g.,
["s3://bucket/file.json", "s3://bucket/file2.json"]. ["s3://bucket/file.json", "s3://bucket/file2.json"].
ioctx (IOContext): Current IO context object. ioctx: Current IO context object or None.
""" """
self.ioctx = ioctx or IOContext() self.ioctx = ioctx or IOContext()
@ -72,8 +75,8 @@ class JsonReader(InputReader):
self.files = [] self.files = []
for i in inputs: for i in inputs:
self.files.extend(glob.glob(i)) self.files.extend(glob.glob(i))
elif type(inputs) is list: elif isinstance(inputs, (list, tuple)):
self.files = inputs self.files = list(inputs)
else: else:
raise ValueError( raise ValueError(
"type of inputs must be list or str, not {}".format(inputs)) "type of inputs must be list or str, not {}".format(inputs))
@ -98,6 +101,26 @@ class JsonReader(InputReader):
return self._postprocess_if_needed(batch) return self._postprocess_if_needed(batch)
def read_all_files(self) -> SampleBatchType:
"""Reads through all files and yields one SampleBatchType per line.
When reaching the end of the last file, will start from the beginning
again.
Yields:
One SampleBatch or MultiAgentBatch per line in all input files.
"""
for path in self.files:
file = self._try_open_file(path)
while True:
line = file.readline()
if not line:
break
batch = self._try_parse(line)
if batch is None:
break
yield batch
def _postprocess_if_needed(self, def _postprocess_if_needed(self,
batch: SampleBatchType) -> SampleBatchType: batch: SampleBatchType) -> SampleBatchType:
if not self.ioctx.config.get("postprocess_inputs"): if not self.ioctx.config.get("postprocess_inputs"):
@ -182,18 +205,6 @@ class JsonReader(InputReader):
self.ioctx.worker.policy_map[pid].action_space_struct) self.ioctx.worker.policy_map[pid].action_space_struct)
return batch return batch
def read_all_files(self):
for path in self.files:
file = self._try_open_file(path)
while True:
line = file.readline()
if not line:
break
batch = self._try_parse(line)
if batch is None:
break
yield batch
def _next_line(self) -> str: def _next_line(self) -> str:
if not self.cur_file: if not self.cur_file:
self.cur_file = self._next_file() self.cur_file = self._next_file()

View file

@ -34,15 +34,14 @@ class JsonWriter(OutputWriter):
ioctx: IOContext = None, ioctx: IOContext = None,
max_file_size: int = 64 * 1024 * 1024, max_file_size: int = 64 * 1024 * 1024,
compress_columns: List[str] = frozenset(["obs", "new_obs"])): compress_columns: List[str] = frozenset(["obs", "new_obs"])):
"""Initialize a JsonWriter. """Initializes a JsonWriter instance.
Args: Args:
path (str): a path/URI of the output directory to save files in. path: a path/URI of the output directory to save files in.
ioctx (IOContext): current IO context object. ioctx: current IO context object.
max_file_size (int): max size of single files before rolling over. max_file_size: max size of single files before rolling over.
compress_columns (list): list of sample batch columns to compress. compress_columns: list of sample batch columns to compress.
""" """
self.ioctx = ioctx or IOContext() self.ioctx = ioctx or IOContext()
self.max_file_size = max_file_size self.max_file_size = max_file_size
self.compress_columns = compress_columns self.compress_columns = compress_columns

View file

@ -7,6 +7,7 @@ from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
from ray.rllib.policy import Policy from ray.rllib.policy import Policy
from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.offline.io_context import IOContext from ray.rllib.offline.io_context import IOContext
from ray.rllib.utils.annotations import Deprecated
from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.typing import TensorType, SampleBatchType from ray.rllib.utils.typing import TensorType, SampleBatchType
from typing import List from typing import List
@ -23,19 +24,30 @@ class OffPolicyEstimator:
@DeveloperAPI @DeveloperAPI
def __init__(self, policy: Policy, gamma: float): def __init__(self, policy: Policy, gamma: float):
"""Creates an off-policy estimator. """Initializes an OffPolicyEstimator instance.
Args: Args:
policy (Policy): Policy to evaluate. policy: Policy to evaluate.
gamma (float): Discount of the MDP. gamma: Discount factor of the environment.
""" """
self.policy = policy self.policy = policy
self.gamma = gamma self.gamma = gamma
self.new_estimates = [] self.new_estimates = []
@classmethod @classmethod
def create(cls, ioctx: IOContext) -> "OffPolicyEstimator": def create_from_io_context(cls, ioctx: IOContext) -> "OffPolicyEstimator":
"""Create an off-policy estimator from a IOContext.""" """Creates an off-policy estimator from an IOContext object.
Extracts Policy and gamma (discount factor) information from the
IOContext.
Args:
ioctx: The IOContext object to create the OffPolicyEstimator
from.
Returns:
The OffPolicyEstimator object created from the IOContext object.
"""
gamma = ioctx.worker.policy_config["gamma"] gamma = ioctx.worker.policy_config["gamma"]
# Grab a reference to the current model # Grab a reference to the current model
keys = list(ioctx.worker.policy_map.keys()) keys = list(ioctx.worker.policy_map.keys())
@ -47,18 +59,36 @@ class OffPolicyEstimator:
return cls(policy, gamma) return cls(policy, gamma)
@DeveloperAPI @DeveloperAPI
def estimate(self, batch: SampleBatchType): def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate:
"""Returns an estimate for the given batch of experiences. """Returns an off policy estimate for the given batch of experiences.
The batch will only contain data from one episode, but it may only be The batch will at most only contain data from one episode,
a fragment of an episode. but it may also only be a fragment of an episode.
Args:
batch: The batch to calculate the off policy estimate (OPE) on.
Returns:
The off-policy estimates (OPE) calculated on the given batch.
""" """
raise NotImplementedError raise NotImplementedError
@DeveloperAPI @DeveloperAPI
def action_prob(self, batch: SampleBatchType) -> np.ndarray: def action_log_likelihood(self, batch: SampleBatchType) -> TensorType:
"""Returns the probs for the batch actions for the current policy.""" """Returns log likelihoods for actions in given batch for policy.
Computes likelihoods by passing the observations through the current
policy's `compute_log_likelihoods()` method.
Args:
batch: The SampleBatch or MultiAgentBatch to calculate action
log likelihoods from. This batch/batches must contain OBS
and ACTIONS keys.
Returns:
The log likelihoods of the actions in the batch, given the
observations and the policy.
"""
num_state_inputs = 0 num_state_inputs = 0
for k in batch.keys(): for k in batch.keys():
if k.startswith("state_in_"): if k.startswith("state_in_"):
@ -66,7 +96,7 @@ class OffPolicyEstimator:
state_keys = ["state_in_{}".format(i) for i in range(num_state_inputs)] state_keys = ["state_in_{}".format(i) for i in range(num_state_inputs)]
log_likelihoods: TensorType = self.policy.compute_log_likelihoods( log_likelihoods: TensorType = self.policy.compute_log_likelihoods(
actions=batch[SampleBatch.ACTIONS], actions=batch[SampleBatch.ACTIONS],
obs_batch=batch[SampleBatch.CUR_OBS], obs_batch=batch[SampleBatch.OBS],
state_batches=[batch[k] for k in state_keys], state_batches=[batch[k] for k in state_keys],
prev_action_batch=batch.get(SampleBatch.PREV_ACTIONS), prev_action_batch=batch.get(SampleBatch.PREV_ACTIONS),
prev_reward_batch=batch.get(SampleBatch.PREV_REWARDS), prev_reward_batch=batch.get(SampleBatch.PREV_REWARDS),
@ -76,12 +106,29 @@ class OffPolicyEstimator:
return np.exp(log_likelihoods) return np.exp(log_likelihoods)
@DeveloperAPI @DeveloperAPI
def process(self, batch: SampleBatchType): def process(self, batch: SampleBatchType) -> None:
"""Computes off policy estimates (OPE) on batch and stores results.
Thus-far collected results can be retrieved then by calling
`self.get_metrics` (which flushes the internal results storage).
Args:
batch: The batch to process (call `self.estimate()` on) and
store results (OPEs) for.
"""
self.new_estimates.append(self.estimate(batch)) self.new_estimates.append(self.estimate(batch))
@DeveloperAPI @DeveloperAPI
def check_can_estimate_for(self, batch: SampleBatchType): def check_can_estimate_for(self, batch: SampleBatchType) -> None:
"""Returns whether we can support OPE for this batch.""" """Checks if we support off policy estimation (OPE) on given batch.
Args:
batch: The batch to check.
Raises:
ValueError: In case `action_prob` key is not in batch OR batch
is a MultiAgentBatch.
"""
if isinstance(batch, MultiAgentBatch): if isinstance(batch, MultiAgentBatch):
raise ValueError( raise ValueError(
@ -98,11 +145,19 @@ class OffPolicyEstimator:
@DeveloperAPI @DeveloperAPI
def get_metrics(self) -> List[OffPolicyEstimate]: def get_metrics(self) -> List[OffPolicyEstimate]:
"""Return a list of new episode metric estimates since the last call. """Returns list of new episode metric estimates since the last call.
Returns: Returns:
list of OffPolicyEstimate objects. List of OffPolicyEstimate objects.
""" """
out = self.new_estimates out = self.new_estimates
self.new_estimates = [] self.new_estimates = []
return out return out
@Deprecated(new="OffPolicyEstimator.create_from_io_context", error=False)
def create(self, *args, **kwargs):
return self.create_from_io_context(*args, **kwargs)
@Deprecated(new="OffPolicyEstimator.action_log_likelihood", error=False)
def action_prob(self, *args, **kwargs):
return self.action_log_likelihood(*args, **kwargs)

View file

@ -1,15 +1,14 @@
from ray.rllib.utils.annotations import override from ray.rllib.utils.annotations import override, PublicAPI
from ray.rllib.utils.annotations import PublicAPI
from ray.rllib.utils.typing import SampleBatchType from ray.rllib.utils.typing import SampleBatchType
@PublicAPI @PublicAPI
class OutputWriter: class OutputWriter:
"""Writer object for saving experiences from policy evaluation.""" """Writer API for saving experiences from policy evaluation."""
@PublicAPI @PublicAPI
def write(self, sample_batch: SampleBatchType): def write(self, sample_batch: SampleBatchType):
"""Save a batch of experiences. """Saves a batch of experiences.
Args: Args:
sample_batch: SampleBatch or MultiAgentBatch to save. sample_batch: SampleBatch or MultiAgentBatch to save.
@ -22,4 +21,5 @@ class NoopOutput(OutputWriter):
@override(OutputWriter) @override(OutputWriter)
def write(self, sample_batch: SampleBatchType): def write(self, sample_batch: SampleBatchType):
# Do nothing.
pass pass

View file

@ -18,11 +18,11 @@ class ShuffledInput(InputReader):
@DeveloperAPI @DeveloperAPI
def __init__(self, child: InputReader, n: int = 0): def __init__(self, child: InputReader, n: int = 0):
"""Initialize a MixedInput. """Initializes a ShuffledInput instance.
Args: Args:
child (InputReader): child input reader to shuffle. child: child input reader to shuffle.
n (int): if positive, shuffle input over this many batches. n: If positive, shuffle input over this many batches.
""" """
self.n = n self.n = n
self.child = child self.child = child

View file

@ -20,7 +20,7 @@ class WeightedImportanceSamplingEstimator(OffPolicyEstimator):
self.check_can_estimate_for(batch) self.check_can_estimate_for(batch)
rewards, old_prob = batch["rewards"], batch["action_prob"] rewards, old_prob = batch["rewards"], batch["action_prob"]
new_prob = self.action_prob(batch) new_prob = self.action_log_likelihood(batch)
# calculate importance ratios # calculate importance ratios
p = [] p = []