mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[RLlib; Docs overhaul] Docstring cleanup: Offline. (#19808)
This commit is contained in:
parent
7a2e9e00e8
commit
ea2bea7e30
11 changed files with 216 additions and 133 deletions
|
@ -664,11 +664,12 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
"will discard all sampler outputs and keep only metrics.")
|
||||
sample_async = True
|
||||
elif method == "is":
|
||||
ise = ImportanceSamplingEstimator.create(self.io_context)
|
||||
ise = ImportanceSamplingEstimator.\
|
||||
create_from_io_context(self.io_context)
|
||||
self.reward_estimators.append(ise)
|
||||
elif method == "wis":
|
||||
wise = WeightedImportanceSamplingEstimator.create(
|
||||
self.io_context)
|
||||
wise = WeightedImportanceSamplingEstimator.\
|
||||
create_from_io_context(self.io_context)
|
||||
self.reward_estimators.append(wise)
|
||||
else:
|
||||
raise ValueError(
|
||||
|
|
|
@ -17,11 +17,11 @@ class D4RLReader(InputReader):
|
|||
|
||||
@PublicAPI
|
||||
def __init__(self, inputs: str, ioctx: IOContext = None):
|
||||
"""Initialize a D4RLReader.
|
||||
"""Initializes a D4RLReader instance.
|
||||
|
||||
Args:
|
||||
inputs (str): String corresponding to D4RL environment name
|
||||
ioctx (IOContext): Current IO context object.
|
||||
inputs: String corresponding to the D4RL environment name.
|
||||
ioctx: Current IO context object.
|
||||
"""
|
||||
import d4rl
|
||||
self.env = gym.make(inputs)
|
||||
|
|
|
@ -16,15 +16,16 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
@PublicAPI
|
||||
class InputReader(metaclass=ABCMeta):
|
||||
"""Input object for loading experiences in policy evaluation."""
|
||||
"""API for collecting and returning experiences during policy evaluation.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
@PublicAPI
|
||||
def next(self):
|
||||
"""Returns the next batch of experiences read.
|
||||
def next(self) -> SampleBatchType:
|
||||
"""Returns the next batch of read experiences.
|
||||
|
||||
Returns:
|
||||
Union[SampleBatch, MultiAgentBatch]: The experience read.
|
||||
The experience read (SampleBatch or MultiAgentBatch).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -40,7 +41,7 @@ class InputReader(metaclass=ABCMeta):
|
|||
reader repeatedly to feed the TensorFlow queue.
|
||||
|
||||
Args:
|
||||
queue_size (int): Max elements to allow in the TF queue.
|
||||
queue_size: Max elements to allow in the TF queue.
|
||||
|
||||
Example:
|
||||
>>> class MyModel(rllib.model.Model):
|
||||
|
@ -56,7 +57,7 @@ class InputReader(metaclass=ABCMeta):
|
|||
You can find a runnable version of this in examples/custom_loss.py.
|
||||
|
||||
Returns:
|
||||
dict of Tensors, one for each column of the read SampleBatch.
|
||||
Dict of Tensors, one for each column of the read SampleBatch.
|
||||
"""
|
||||
|
||||
if hasattr(self, "_queue_runner"):
|
||||
|
|
|
@ -1,37 +1,53 @@
|
|||
import os
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
from typing import Any
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.rllib.evaluation.sampler import SamplerInput
|
||||
|
||||
|
||||
@PublicAPI
|
||||
class IOContext:
|
||||
"""Attributes to pass to input / output class constructors.
|
||||
"""Class containing attributes to pass to input/output class constructors.
|
||||
|
||||
RLlib auto-sets these attributes when constructing input / output classes.
|
||||
|
||||
Attributes:
|
||||
log_dir (str): Default logging directory.
|
||||
config (dict): Configuration of the agent.
|
||||
worker_index (int): When there are multiple workers created, this
|
||||
uniquely identifies the current worker.
|
||||
worker (RolloutWorker): RolloutWorker object reference.
|
||||
input_config (dict): The input configuration for custom input.
|
||||
RLlib auto-sets these attributes when constructing input/output classes,
|
||||
such as InputReaders and OutputWriters.
|
||||
"""
|
||||
|
||||
@PublicAPI
|
||||
def __init__(self,
|
||||
log_dir: str = None,
|
||||
config: dict = None,
|
||||
log_dir: Optional[str] = None,
|
||||
config: Optional[TrainerConfigDict] = None,
|
||||
worker_index: int = 0,
|
||||
worker: Any = None):
|
||||
worker: Optional[Any] = None):
|
||||
"""Initializes a IOContext object.
|
||||
|
||||
Args:
|
||||
log_dir: The logging directory to read from/write to.
|
||||
config: The Trainer's main config dict.
|
||||
worker_index (int): When there are multiple workers created, this
|
||||
uniquely identifies the current worker. 0 for the local
|
||||
worker, >0 for any of the remote workers.
|
||||
worker (RolloutWorker): The RolloutWorker object reference.
|
||||
"""
|
||||
self.log_dir = log_dir or os.getcwd()
|
||||
self.config = config or {}
|
||||
self.worker_index = worker_index
|
||||
self.worker = worker
|
||||
|
||||
@PublicAPI
|
||||
def default_sampler_input(self) -> Any:
|
||||
def default_sampler_input(self) -> Optional["SamplerInput"]:
|
||||
"""Returns the RolloutWorker's SamplerInput object, if any.
|
||||
|
||||
Returns None if the RolloutWorker has no SamplerInput. Note that local
|
||||
workers in case there are also one or more remote workers by default
|
||||
do not create a SamplerInput object.
|
||||
|
||||
Returns:
|
||||
The RolloutWorkers' SamplerInput object or None if none exists.
|
||||
"""
|
||||
return self.worker.sampler
|
||||
|
||||
@PublicAPI
|
||||
|
|
|
@ -14,7 +14,7 @@ class ImportanceSamplingEstimator(OffPolicyEstimator):
|
|||
self.check_can_estimate_for(batch)
|
||||
|
||||
rewards, old_prob = batch["rewards"], batch["action_prob"]
|
||||
new_prob = self.action_prob(batch)
|
||||
new_prob = self.action_log_likelihood(batch)
|
||||
|
||||
# calculate importance ratios
|
||||
p = []
|
||||
|
|
|
@ -5,7 +5,7 @@ import os
|
|||
from pathlib import Path
|
||||
import random
|
||||
import re
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
import zipfile
|
||||
|
||||
|
@ -32,17 +32,20 @@ WINDOWS_DRIVES = [chr(i) for i in range(ord("c"), ord("z") + 1)]
|
|||
class JsonReader(InputReader):
|
||||
"""Reader object that loads experiences from JSON file chunks.
|
||||
|
||||
The input files will be read from in an random order."""
|
||||
The input files will be read from in random order.
|
||||
"""
|
||||
|
||||
@PublicAPI
|
||||
def __init__(self, inputs: List[str], ioctx: IOContext = None):
|
||||
"""Initialize a JsonReader.
|
||||
def __init__(self,
|
||||
inputs: Union[str, List[str]],
|
||||
ioctx: Optional[IOContext] = None):
|
||||
"""Initializes a JsonReader instance.
|
||||
|
||||
Args:
|
||||
inputs (str|list): Either a glob expression for files, e.g.,
|
||||
"/tmp/**/*.json", or a list of single file paths or URIs, e.g.,
|
||||
inputs: Either a glob expression for files, e.g. `/tmp/**/*.json`,
|
||||
or a list of single file paths or URIs, e.g.,
|
||||
["s3://bucket/file.json", "s3://bucket/file2.json"].
|
||||
ioctx (IOContext): Current IO context object.
|
||||
ioctx: Current IO context object or None.
|
||||
"""
|
||||
|
||||
self.ioctx = ioctx or IOContext()
|
||||
|
@ -72,8 +75,8 @@ class JsonReader(InputReader):
|
|||
self.files = []
|
||||
for i in inputs:
|
||||
self.files.extend(glob.glob(i))
|
||||
elif type(inputs) is list:
|
||||
self.files = inputs
|
||||
elif isinstance(inputs, (list, tuple)):
|
||||
self.files = list(inputs)
|
||||
else:
|
||||
raise ValueError(
|
||||
"type of inputs must be list or str, not {}".format(inputs))
|
||||
|
@ -98,6 +101,26 @@ class JsonReader(InputReader):
|
|||
|
||||
return self._postprocess_if_needed(batch)
|
||||
|
||||
def read_all_files(self) -> SampleBatchType:
|
||||
"""Reads through all files and yields one SampleBatchType per line.
|
||||
|
||||
When reaching the end of the last file, will start from the beginning
|
||||
again.
|
||||
|
||||
Yields:
|
||||
One SampleBatch or MultiAgentBatch per line in all input files.
|
||||
"""
|
||||
for path in self.files:
|
||||
file = self._try_open_file(path)
|
||||
while True:
|
||||
line = file.readline()
|
||||
if not line:
|
||||
break
|
||||
batch = self._try_parse(line)
|
||||
if batch is None:
|
||||
break
|
||||
yield batch
|
||||
|
||||
def _postprocess_if_needed(self,
|
||||
batch: SampleBatchType) -> SampleBatchType:
|
||||
if not self.ioctx.config.get("postprocess_inputs"):
|
||||
|
@ -182,18 +205,6 @@ class JsonReader(InputReader):
|
|||
self.ioctx.worker.policy_map[pid].action_space_struct)
|
||||
return batch
|
||||
|
||||
def read_all_files(self):
|
||||
for path in self.files:
|
||||
file = self._try_open_file(path)
|
||||
while True:
|
||||
line = file.readline()
|
||||
if not line:
|
||||
break
|
||||
batch = self._try_parse(line)
|
||||
if batch is None:
|
||||
break
|
||||
yield batch
|
||||
|
||||
def _next_line(self) -> str:
|
||||
if not self.cur_file:
|
||||
self.cur_file = self._next_file()
|
||||
|
|
|
@ -34,15 +34,14 @@ class JsonWriter(OutputWriter):
|
|||
ioctx: IOContext = None,
|
||||
max_file_size: int = 64 * 1024 * 1024,
|
||||
compress_columns: List[str] = frozenset(["obs", "new_obs"])):
|
||||
"""Initialize a JsonWriter.
|
||||
"""Initializes a JsonWriter instance.
|
||||
|
||||
Args:
|
||||
path (str): a path/URI of the output directory to save files in.
|
||||
ioctx (IOContext): current IO context object.
|
||||
max_file_size (int): max size of single files before rolling over.
|
||||
compress_columns (list): list of sample batch columns to compress.
|
||||
path: a path/URI of the output directory to save files in.
|
||||
ioctx: current IO context object.
|
||||
max_file_size: max size of single files before rolling over.
|
||||
compress_columns: list of sample batch columns to compress.
|
||||
"""
|
||||
|
||||
self.ioctx = ioctx or IOContext()
|
||||
self.max_file_size = max_file_size
|
||||
self.compress_columns = compress_columns
|
||||
|
|
|
@ -7,6 +7,7 @@ from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
|
|||
from ray.rllib.policy import Policy
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.offline.io_context import IOContext
|
||||
from ray.rllib.utils.annotations import Deprecated
|
||||
from ray.rllib.utils.numpy import convert_to_numpy
|
||||
from ray.rllib.utils.typing import TensorType, SampleBatchType
|
||||
from typing import List
|
||||
|
@ -23,19 +24,30 @@ class OffPolicyEstimator:
|
|||
|
||||
@DeveloperAPI
|
||||
def __init__(self, policy: Policy, gamma: float):
|
||||
"""Creates an off-policy estimator.
|
||||
"""Initializes an OffPolicyEstimator instance.
|
||||
|
||||
Args:
|
||||
policy (Policy): Policy to evaluate.
|
||||
gamma (float): Discount of the MDP.
|
||||
policy: Policy to evaluate.
|
||||
gamma: Discount factor of the environment.
|
||||
"""
|
||||
self.policy = policy
|
||||
self.gamma = gamma
|
||||
self.new_estimates = []
|
||||
|
||||
@classmethod
|
||||
def create(cls, ioctx: IOContext) -> "OffPolicyEstimator":
|
||||
"""Create an off-policy estimator from a IOContext."""
|
||||
def create_from_io_context(cls, ioctx: IOContext) -> "OffPolicyEstimator":
|
||||
"""Creates an off-policy estimator from an IOContext object.
|
||||
|
||||
Extracts Policy and gamma (discount factor) information from the
|
||||
IOContext.
|
||||
|
||||
Args:
|
||||
ioctx: The IOContext object to create the OffPolicyEstimator
|
||||
from.
|
||||
|
||||
Returns:
|
||||
The OffPolicyEstimator object created from the IOContext object.
|
||||
"""
|
||||
gamma = ioctx.worker.policy_config["gamma"]
|
||||
# Grab a reference to the current model
|
||||
keys = list(ioctx.worker.policy_map.keys())
|
||||
|
@ -47,18 +59,36 @@ class OffPolicyEstimator:
|
|||
return cls(policy, gamma)
|
||||
|
||||
@DeveloperAPI
|
||||
def estimate(self, batch: SampleBatchType):
|
||||
"""Returns an estimate for the given batch of experiences.
|
||||
def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate:
|
||||
"""Returns an off policy estimate for the given batch of experiences.
|
||||
|
||||
The batch will only contain data from one episode, but it may only be
|
||||
a fragment of an episode.
|
||||
The batch will at most only contain data from one episode,
|
||||
but it may also only be a fragment of an episode.
|
||||
|
||||
Args:
|
||||
batch: The batch to calculate the off policy estimate (OPE) on.
|
||||
|
||||
Returns:
|
||||
The off-policy estimates (OPE) calculated on the given batch.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def action_prob(self, batch: SampleBatchType) -> np.ndarray:
|
||||
"""Returns the probs for the batch actions for the current policy."""
|
||||
def action_log_likelihood(self, batch: SampleBatchType) -> TensorType:
|
||||
"""Returns log likelihoods for actions in given batch for policy.
|
||||
|
||||
Computes likelihoods by passing the observations through the current
|
||||
policy's `compute_log_likelihoods()` method.
|
||||
|
||||
Args:
|
||||
batch: The SampleBatch or MultiAgentBatch to calculate action
|
||||
log likelihoods from. This batch/batches must contain OBS
|
||||
and ACTIONS keys.
|
||||
|
||||
Returns:
|
||||
The log likelihoods of the actions in the batch, given the
|
||||
observations and the policy.
|
||||
"""
|
||||
num_state_inputs = 0
|
||||
for k in batch.keys():
|
||||
if k.startswith("state_in_"):
|
||||
|
@ -66,7 +96,7 @@ class OffPolicyEstimator:
|
|||
state_keys = ["state_in_{}".format(i) for i in range(num_state_inputs)]
|
||||
log_likelihoods: TensorType = self.policy.compute_log_likelihoods(
|
||||
actions=batch[SampleBatch.ACTIONS],
|
||||
obs_batch=batch[SampleBatch.CUR_OBS],
|
||||
obs_batch=batch[SampleBatch.OBS],
|
||||
state_batches=[batch[k] for k in state_keys],
|
||||
prev_action_batch=batch.get(SampleBatch.PREV_ACTIONS),
|
||||
prev_reward_batch=batch.get(SampleBatch.PREV_REWARDS),
|
||||
|
@ -76,12 +106,29 @@ class OffPolicyEstimator:
|
|||
return np.exp(log_likelihoods)
|
||||
|
||||
@DeveloperAPI
|
||||
def process(self, batch: SampleBatchType):
|
||||
def process(self, batch: SampleBatchType) -> None:
|
||||
"""Computes off policy estimates (OPE) on batch and stores results.
|
||||
|
||||
Thus-far collected results can be retrieved then by calling
|
||||
`self.get_metrics` (which flushes the internal results storage).
|
||||
|
||||
Args:
|
||||
batch: The batch to process (call `self.estimate()` on) and
|
||||
store results (OPEs) for.
|
||||
"""
|
||||
self.new_estimates.append(self.estimate(batch))
|
||||
|
||||
@DeveloperAPI
|
||||
def check_can_estimate_for(self, batch: SampleBatchType):
|
||||
"""Returns whether we can support OPE for this batch."""
|
||||
def check_can_estimate_for(self, batch: SampleBatchType) -> None:
|
||||
"""Checks if we support off policy estimation (OPE) on given batch.
|
||||
|
||||
Args:
|
||||
batch: The batch to check.
|
||||
|
||||
Raises:
|
||||
ValueError: In case `action_prob` key is not in batch OR batch
|
||||
is a MultiAgentBatch.
|
||||
"""
|
||||
|
||||
if isinstance(batch, MultiAgentBatch):
|
||||
raise ValueError(
|
||||
|
@ -98,11 +145,19 @@ class OffPolicyEstimator:
|
|||
|
||||
@DeveloperAPI
|
||||
def get_metrics(self) -> List[OffPolicyEstimate]:
|
||||
"""Return a list of new episode metric estimates since the last call.
|
||||
"""Returns list of new episode metric estimates since the last call.
|
||||
|
||||
Returns:
|
||||
list of OffPolicyEstimate objects.
|
||||
List of OffPolicyEstimate objects.
|
||||
"""
|
||||
out = self.new_estimates
|
||||
self.new_estimates = []
|
||||
return out
|
||||
|
||||
@Deprecated(new="OffPolicyEstimator.create_from_io_context", error=False)
|
||||
def create(self, *args, **kwargs):
|
||||
return self.create_from_io_context(*args, **kwargs)
|
||||
|
||||
@Deprecated(new="OffPolicyEstimator.action_log_likelihood", error=False)
|
||||
def action_prob(self, *args, **kwargs):
|
||||
return self.action_log_likelihood(*args, **kwargs)
|
||||
|
|
|
@ -1,15 +1,14 @@
|
|||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
from ray.rllib.utils.annotations import override, PublicAPI
|
||||
from ray.rllib.utils.typing import SampleBatchType
|
||||
|
||||
|
||||
@PublicAPI
|
||||
class OutputWriter:
|
||||
"""Writer object for saving experiences from policy evaluation."""
|
||||
"""Writer API for saving experiences from policy evaluation."""
|
||||
|
||||
@PublicAPI
|
||||
def write(self, sample_batch: SampleBatchType):
|
||||
"""Save a batch of experiences.
|
||||
"""Saves a batch of experiences.
|
||||
|
||||
Args:
|
||||
sample_batch: SampleBatch or MultiAgentBatch to save.
|
||||
|
@ -22,4 +21,5 @@ class NoopOutput(OutputWriter):
|
|||
|
||||
@override(OutputWriter)
|
||||
def write(self, sample_batch: SampleBatchType):
|
||||
# Do nothing.
|
||||
pass
|
||||
|
|
|
@ -18,11 +18,11 @@ class ShuffledInput(InputReader):
|
|||
|
||||
@DeveloperAPI
|
||||
def __init__(self, child: InputReader, n: int = 0):
|
||||
"""Initialize a MixedInput.
|
||||
"""Initializes a ShuffledInput instance.
|
||||
|
||||
Args:
|
||||
child (InputReader): child input reader to shuffle.
|
||||
n (int): if positive, shuffle input over this many batches.
|
||||
child: child input reader to shuffle.
|
||||
n: If positive, shuffle input over this many batches.
|
||||
"""
|
||||
self.n = n
|
||||
self.child = child
|
||||
|
|
|
@ -20,7 +20,7 @@ class WeightedImportanceSamplingEstimator(OffPolicyEstimator):
|
|||
self.check_can_estimate_for(batch)
|
||||
|
||||
rewards, old_prob = batch["rewards"], batch["action_prob"]
|
||||
new_prob = self.action_prob(batch)
|
||||
new_prob = self.action_log_likelihood(batch)
|
||||
|
||||
# calculate importance ratios
|
||||
p = []
|
||||
|
|
Loading…
Add table
Reference in a new issue