mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib; Docs overhaul] Docstring cleanup: Offline. (#19808)
This commit is contained in:
parent
7a2e9e00e8
commit
ea2bea7e30
11 changed files with 216 additions and 133 deletions
|
@ -664,11 +664,12 @@ class RolloutWorker(ParallelIteratorWorker):
|
||||||
"will discard all sampler outputs and keep only metrics.")
|
"will discard all sampler outputs and keep only metrics.")
|
||||||
sample_async = True
|
sample_async = True
|
||||||
elif method == "is":
|
elif method == "is":
|
||||||
ise = ImportanceSamplingEstimator.create(self.io_context)
|
ise = ImportanceSamplingEstimator.\
|
||||||
|
create_from_io_context(self.io_context)
|
||||||
self.reward_estimators.append(ise)
|
self.reward_estimators.append(ise)
|
||||||
elif method == "wis":
|
elif method == "wis":
|
||||||
wise = WeightedImportanceSamplingEstimator.create(
|
wise = WeightedImportanceSamplingEstimator.\
|
||||||
self.io_context)
|
create_from_io_context(self.io_context)
|
||||||
self.reward_estimators.append(wise)
|
self.reward_estimators.append(wise)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
|
@ -17,11 +17,11 @@ class D4RLReader(InputReader):
|
||||||
|
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
def __init__(self, inputs: str, ioctx: IOContext = None):
|
def __init__(self, inputs: str, ioctx: IOContext = None):
|
||||||
"""Initialize a D4RLReader.
|
"""Initializes a D4RLReader instance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inputs (str): String corresponding to D4RL environment name
|
inputs: String corresponding to the D4RL environment name.
|
||||||
ioctx (IOContext): Current IO context object.
|
ioctx: Current IO context object.
|
||||||
"""
|
"""
|
||||||
import d4rl
|
import d4rl
|
||||||
self.env = gym.make(inputs)
|
self.env = gym.make(inputs)
|
||||||
|
|
|
@ -16,15 +16,16 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
class InputReader(metaclass=ABCMeta):
|
class InputReader(metaclass=ABCMeta):
|
||||||
"""Input object for loading experiences in policy evaluation."""
|
"""API for collecting and returning experiences during policy evaluation.
|
||||||
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
def next(self):
|
def next(self) -> SampleBatchType:
|
||||||
"""Returns the next batch of experiences read.
|
"""Returns the next batch of read experiences.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Union[SampleBatch, MultiAgentBatch]: The experience read.
|
The experience read (SampleBatch or MultiAgentBatch).
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -40,7 +41,7 @@ class InputReader(metaclass=ABCMeta):
|
||||||
reader repeatedly to feed the TensorFlow queue.
|
reader repeatedly to feed the TensorFlow queue.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
queue_size (int): Max elements to allow in the TF queue.
|
queue_size: Max elements to allow in the TF queue.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> class MyModel(rllib.model.Model):
|
>>> class MyModel(rllib.model.Model):
|
||||||
|
@ -56,7 +57,7 @@ class InputReader(metaclass=ABCMeta):
|
||||||
You can find a runnable version of this in examples/custom_loss.py.
|
You can find a runnable version of this in examples/custom_loss.py.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict of Tensors, one for each column of the read SampleBatch.
|
Dict of Tensors, one for each column of the read SampleBatch.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if hasattr(self, "_queue_runner"):
|
if hasattr(self, "_queue_runner"):
|
||||||
|
|
|
@ -1,37 +1,53 @@
|
||||||
import os
|
import os
|
||||||
|
from typing import Any, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
from ray.rllib.utils.annotations import PublicAPI
|
from ray.rllib.utils.annotations import PublicAPI
|
||||||
from typing import Any
|
from ray.rllib.utils.typing import TrainerConfigDict
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ray.rllib.evaluation.sampler import SamplerInput
|
||||||
|
|
||||||
|
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
class IOContext:
|
class IOContext:
|
||||||
"""Attributes to pass to input / output class constructors.
|
"""Class containing attributes to pass to input/output class constructors.
|
||||||
|
|
||||||
RLlib auto-sets these attributes when constructing input / output classes.
|
RLlib auto-sets these attributes when constructing input/output classes,
|
||||||
|
such as InputReaders and OutputWriters.
|
||||||
Attributes:
|
|
||||||
log_dir (str): Default logging directory.
|
|
||||||
config (dict): Configuration of the agent.
|
|
||||||
worker_index (int): When there are multiple workers created, this
|
|
||||||
uniquely identifies the current worker.
|
|
||||||
worker (RolloutWorker): RolloutWorker object reference.
|
|
||||||
input_config (dict): The input configuration for custom input.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
log_dir: str = None,
|
log_dir: Optional[str] = None,
|
||||||
config: dict = None,
|
config: Optional[TrainerConfigDict] = None,
|
||||||
worker_index: int = 0,
|
worker_index: int = 0,
|
||||||
worker: Any = None):
|
worker: Optional[Any] = None):
|
||||||
|
"""Initializes a IOContext object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log_dir: The logging directory to read from/write to.
|
||||||
|
config: The Trainer's main config dict.
|
||||||
|
worker_index (int): When there are multiple workers created, this
|
||||||
|
uniquely identifies the current worker. 0 for the local
|
||||||
|
worker, >0 for any of the remote workers.
|
||||||
|
worker (RolloutWorker): The RolloutWorker object reference.
|
||||||
|
"""
|
||||||
self.log_dir = log_dir or os.getcwd()
|
self.log_dir = log_dir or os.getcwd()
|
||||||
self.config = config or {}
|
self.config = config or {}
|
||||||
self.worker_index = worker_index
|
self.worker_index = worker_index
|
||||||
self.worker = worker
|
self.worker = worker
|
||||||
|
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
def default_sampler_input(self) -> Any:
|
def default_sampler_input(self) -> Optional["SamplerInput"]:
|
||||||
|
"""Returns the RolloutWorker's SamplerInput object, if any.
|
||||||
|
|
||||||
|
Returns None if the RolloutWorker has no SamplerInput. Note that local
|
||||||
|
workers in case there are also one or more remote workers by default
|
||||||
|
do not create a SamplerInput object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The RolloutWorkers' SamplerInput object or None if none exists.
|
||||||
|
"""
|
||||||
return self.worker.sampler
|
return self.worker.sampler
|
||||||
|
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
|
|
|
@ -14,7 +14,7 @@ class ImportanceSamplingEstimator(OffPolicyEstimator):
|
||||||
self.check_can_estimate_for(batch)
|
self.check_can_estimate_for(batch)
|
||||||
|
|
||||||
rewards, old_prob = batch["rewards"], batch["action_prob"]
|
rewards, old_prob = batch["rewards"], batch["action_prob"]
|
||||||
new_prob = self.action_prob(batch)
|
new_prob = self.action_log_likelihood(batch)
|
||||||
|
|
||||||
# calculate importance ratios
|
# calculate importance ratios
|
||||||
p = []
|
p = []
|
||||||
|
|
|
@ -5,7 +5,7 @@ import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Union
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
import zipfile
|
import zipfile
|
||||||
|
|
||||||
|
@ -32,17 +32,20 @@ WINDOWS_DRIVES = [chr(i) for i in range(ord("c"), ord("z") + 1)]
|
||||||
class JsonReader(InputReader):
|
class JsonReader(InputReader):
|
||||||
"""Reader object that loads experiences from JSON file chunks.
|
"""Reader object that loads experiences from JSON file chunks.
|
||||||
|
|
||||||
The input files will be read from in an random order."""
|
The input files will be read from in random order.
|
||||||
|
"""
|
||||||
|
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
def __init__(self, inputs: List[str], ioctx: IOContext = None):
|
def __init__(self,
|
||||||
"""Initialize a JsonReader.
|
inputs: Union[str, List[str]],
|
||||||
|
ioctx: Optional[IOContext] = None):
|
||||||
|
"""Initializes a JsonReader instance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inputs (str|list): Either a glob expression for files, e.g.,
|
inputs: Either a glob expression for files, e.g. `/tmp/**/*.json`,
|
||||||
"/tmp/**/*.json", or a list of single file paths or URIs, e.g.,
|
or a list of single file paths or URIs, e.g.,
|
||||||
["s3://bucket/file.json", "s3://bucket/file2.json"].
|
["s3://bucket/file.json", "s3://bucket/file2.json"].
|
||||||
ioctx (IOContext): Current IO context object.
|
ioctx: Current IO context object or None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.ioctx = ioctx or IOContext()
|
self.ioctx = ioctx or IOContext()
|
||||||
|
@ -72,8 +75,8 @@ class JsonReader(InputReader):
|
||||||
self.files = []
|
self.files = []
|
||||||
for i in inputs:
|
for i in inputs:
|
||||||
self.files.extend(glob.glob(i))
|
self.files.extend(glob.glob(i))
|
||||||
elif type(inputs) is list:
|
elif isinstance(inputs, (list, tuple)):
|
||||||
self.files = inputs
|
self.files = list(inputs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"type of inputs must be list or str, not {}".format(inputs))
|
"type of inputs must be list or str, not {}".format(inputs))
|
||||||
|
@ -98,6 +101,26 @@ class JsonReader(InputReader):
|
||||||
|
|
||||||
return self._postprocess_if_needed(batch)
|
return self._postprocess_if_needed(batch)
|
||||||
|
|
||||||
|
def read_all_files(self) -> SampleBatchType:
|
||||||
|
"""Reads through all files and yields one SampleBatchType per line.
|
||||||
|
|
||||||
|
When reaching the end of the last file, will start from the beginning
|
||||||
|
again.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
One SampleBatch or MultiAgentBatch per line in all input files.
|
||||||
|
"""
|
||||||
|
for path in self.files:
|
||||||
|
file = self._try_open_file(path)
|
||||||
|
while True:
|
||||||
|
line = file.readline()
|
||||||
|
if not line:
|
||||||
|
break
|
||||||
|
batch = self._try_parse(line)
|
||||||
|
if batch is None:
|
||||||
|
break
|
||||||
|
yield batch
|
||||||
|
|
||||||
def _postprocess_if_needed(self,
|
def _postprocess_if_needed(self,
|
||||||
batch: SampleBatchType) -> SampleBatchType:
|
batch: SampleBatchType) -> SampleBatchType:
|
||||||
if not self.ioctx.config.get("postprocess_inputs"):
|
if not self.ioctx.config.get("postprocess_inputs"):
|
||||||
|
@ -182,18 +205,6 @@ class JsonReader(InputReader):
|
||||||
self.ioctx.worker.policy_map[pid].action_space_struct)
|
self.ioctx.worker.policy_map[pid].action_space_struct)
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def read_all_files(self):
|
|
||||||
for path in self.files:
|
|
||||||
file = self._try_open_file(path)
|
|
||||||
while True:
|
|
||||||
line = file.readline()
|
|
||||||
if not line:
|
|
||||||
break
|
|
||||||
batch = self._try_parse(line)
|
|
||||||
if batch is None:
|
|
||||||
break
|
|
||||||
yield batch
|
|
||||||
|
|
||||||
def _next_line(self) -> str:
|
def _next_line(self) -> str:
|
||||||
if not self.cur_file:
|
if not self.cur_file:
|
||||||
self.cur_file = self._next_file()
|
self.cur_file = self._next_file()
|
||||||
|
|
|
@ -34,15 +34,14 @@ class JsonWriter(OutputWriter):
|
||||||
ioctx: IOContext = None,
|
ioctx: IOContext = None,
|
||||||
max_file_size: int = 64 * 1024 * 1024,
|
max_file_size: int = 64 * 1024 * 1024,
|
||||||
compress_columns: List[str] = frozenset(["obs", "new_obs"])):
|
compress_columns: List[str] = frozenset(["obs", "new_obs"])):
|
||||||
"""Initialize a JsonWriter.
|
"""Initializes a JsonWriter instance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
path (str): a path/URI of the output directory to save files in.
|
path: a path/URI of the output directory to save files in.
|
||||||
ioctx (IOContext): current IO context object.
|
ioctx: current IO context object.
|
||||||
max_file_size (int): max size of single files before rolling over.
|
max_file_size: max size of single files before rolling over.
|
||||||
compress_columns (list): list of sample batch columns to compress.
|
compress_columns: list of sample batch columns to compress.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.ioctx = ioctx or IOContext()
|
self.ioctx = ioctx or IOContext()
|
||||||
self.max_file_size = max_file_size
|
self.max_file_size = max_file_size
|
||||||
self.compress_columns = compress_columns
|
self.compress_columns = compress_columns
|
||||||
|
|
|
@ -7,6 +7,7 @@ from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
|
||||||
from ray.rllib.policy import Policy
|
from ray.rllib.policy import Policy
|
||||||
from ray.rllib.utils.annotations import DeveloperAPI
|
from ray.rllib.utils.annotations import DeveloperAPI
|
||||||
from ray.rllib.offline.io_context import IOContext
|
from ray.rllib.offline.io_context import IOContext
|
||||||
|
from ray.rllib.utils.annotations import Deprecated
|
||||||
from ray.rllib.utils.numpy import convert_to_numpy
|
from ray.rllib.utils.numpy import convert_to_numpy
|
||||||
from ray.rllib.utils.typing import TensorType, SampleBatchType
|
from ray.rllib.utils.typing import TensorType, SampleBatchType
|
||||||
from typing import List
|
from typing import List
|
||||||
|
@ -23,19 +24,30 @@ class OffPolicyEstimator:
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def __init__(self, policy: Policy, gamma: float):
|
def __init__(self, policy: Policy, gamma: float):
|
||||||
"""Creates an off-policy estimator.
|
"""Initializes an OffPolicyEstimator instance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
policy (Policy): Policy to evaluate.
|
policy: Policy to evaluate.
|
||||||
gamma (float): Discount of the MDP.
|
gamma: Discount factor of the environment.
|
||||||
"""
|
"""
|
||||||
self.policy = policy
|
self.policy = policy
|
||||||
self.gamma = gamma
|
self.gamma = gamma
|
||||||
self.new_estimates = []
|
self.new_estimates = []
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(cls, ioctx: IOContext) -> "OffPolicyEstimator":
|
def create_from_io_context(cls, ioctx: IOContext) -> "OffPolicyEstimator":
|
||||||
"""Create an off-policy estimator from a IOContext."""
|
"""Creates an off-policy estimator from an IOContext object.
|
||||||
|
|
||||||
|
Extracts Policy and gamma (discount factor) information from the
|
||||||
|
IOContext.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ioctx: The IOContext object to create the OffPolicyEstimator
|
||||||
|
from.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The OffPolicyEstimator object created from the IOContext object.
|
||||||
|
"""
|
||||||
gamma = ioctx.worker.policy_config["gamma"]
|
gamma = ioctx.worker.policy_config["gamma"]
|
||||||
# Grab a reference to the current model
|
# Grab a reference to the current model
|
||||||
keys = list(ioctx.worker.policy_map.keys())
|
keys = list(ioctx.worker.policy_map.keys())
|
||||||
|
@ -47,18 +59,36 @@ class OffPolicyEstimator:
|
||||||
return cls(policy, gamma)
|
return cls(policy, gamma)
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def estimate(self, batch: SampleBatchType):
|
def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate:
|
||||||
"""Returns an estimate for the given batch of experiences.
|
"""Returns an off policy estimate for the given batch of experiences.
|
||||||
|
|
||||||
The batch will only contain data from one episode, but it may only be
|
The batch will at most only contain data from one episode,
|
||||||
a fragment of an episode.
|
but it may also only be a fragment of an episode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: The batch to calculate the off policy estimate (OPE) on.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The off-policy estimates (OPE) calculated on the given batch.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def action_prob(self, batch: SampleBatchType) -> np.ndarray:
|
def action_log_likelihood(self, batch: SampleBatchType) -> TensorType:
|
||||||
"""Returns the probs for the batch actions for the current policy."""
|
"""Returns log likelihoods for actions in given batch for policy.
|
||||||
|
|
||||||
|
Computes likelihoods by passing the observations through the current
|
||||||
|
policy's `compute_log_likelihoods()` method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: The SampleBatch or MultiAgentBatch to calculate action
|
||||||
|
log likelihoods from. This batch/batches must contain OBS
|
||||||
|
and ACTIONS keys.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The log likelihoods of the actions in the batch, given the
|
||||||
|
observations and the policy.
|
||||||
|
"""
|
||||||
num_state_inputs = 0
|
num_state_inputs = 0
|
||||||
for k in batch.keys():
|
for k in batch.keys():
|
||||||
if k.startswith("state_in_"):
|
if k.startswith("state_in_"):
|
||||||
|
@ -66,7 +96,7 @@ class OffPolicyEstimator:
|
||||||
state_keys = ["state_in_{}".format(i) for i in range(num_state_inputs)]
|
state_keys = ["state_in_{}".format(i) for i in range(num_state_inputs)]
|
||||||
log_likelihoods: TensorType = self.policy.compute_log_likelihoods(
|
log_likelihoods: TensorType = self.policy.compute_log_likelihoods(
|
||||||
actions=batch[SampleBatch.ACTIONS],
|
actions=batch[SampleBatch.ACTIONS],
|
||||||
obs_batch=batch[SampleBatch.CUR_OBS],
|
obs_batch=batch[SampleBatch.OBS],
|
||||||
state_batches=[batch[k] for k in state_keys],
|
state_batches=[batch[k] for k in state_keys],
|
||||||
prev_action_batch=batch.get(SampleBatch.PREV_ACTIONS),
|
prev_action_batch=batch.get(SampleBatch.PREV_ACTIONS),
|
||||||
prev_reward_batch=batch.get(SampleBatch.PREV_REWARDS),
|
prev_reward_batch=batch.get(SampleBatch.PREV_REWARDS),
|
||||||
|
@ -76,12 +106,29 @@ class OffPolicyEstimator:
|
||||||
return np.exp(log_likelihoods)
|
return np.exp(log_likelihoods)
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def process(self, batch: SampleBatchType):
|
def process(self, batch: SampleBatchType) -> None:
|
||||||
|
"""Computes off policy estimates (OPE) on batch and stores results.
|
||||||
|
|
||||||
|
Thus-far collected results can be retrieved then by calling
|
||||||
|
`self.get_metrics` (which flushes the internal results storage).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: The batch to process (call `self.estimate()` on) and
|
||||||
|
store results (OPEs) for.
|
||||||
|
"""
|
||||||
self.new_estimates.append(self.estimate(batch))
|
self.new_estimates.append(self.estimate(batch))
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def check_can_estimate_for(self, batch: SampleBatchType):
|
def check_can_estimate_for(self, batch: SampleBatchType) -> None:
|
||||||
"""Returns whether we can support OPE for this batch."""
|
"""Checks if we support off policy estimation (OPE) on given batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: The batch to check.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: In case `action_prob` key is not in batch OR batch
|
||||||
|
is a MultiAgentBatch.
|
||||||
|
"""
|
||||||
|
|
||||||
if isinstance(batch, MultiAgentBatch):
|
if isinstance(batch, MultiAgentBatch):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -98,11 +145,19 @@ class OffPolicyEstimator:
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def get_metrics(self) -> List[OffPolicyEstimate]:
|
def get_metrics(self) -> List[OffPolicyEstimate]:
|
||||||
"""Return a list of new episode metric estimates since the last call.
|
"""Returns list of new episode metric estimates since the last call.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list of OffPolicyEstimate objects.
|
List of OffPolicyEstimate objects.
|
||||||
"""
|
"""
|
||||||
out = self.new_estimates
|
out = self.new_estimates
|
||||||
self.new_estimates = []
|
self.new_estimates = []
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
@Deprecated(new="OffPolicyEstimator.create_from_io_context", error=False)
|
||||||
|
def create(self, *args, **kwargs):
|
||||||
|
return self.create_from_io_context(*args, **kwargs)
|
||||||
|
|
||||||
|
@Deprecated(new="OffPolicyEstimator.action_log_likelihood", error=False)
|
||||||
|
def action_prob(self, *args, **kwargs):
|
||||||
|
return self.action_log_likelihood(*args, **kwargs)
|
||||||
|
|
|
@ -1,15 +1,14 @@
|
||||||
from ray.rllib.utils.annotations import override
|
from ray.rllib.utils.annotations import override, PublicAPI
|
||||||
from ray.rllib.utils.annotations import PublicAPI
|
|
||||||
from ray.rllib.utils.typing import SampleBatchType
|
from ray.rllib.utils.typing import SampleBatchType
|
||||||
|
|
||||||
|
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
class OutputWriter:
|
class OutputWriter:
|
||||||
"""Writer object for saving experiences from policy evaluation."""
|
"""Writer API for saving experiences from policy evaluation."""
|
||||||
|
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
def write(self, sample_batch: SampleBatchType):
|
def write(self, sample_batch: SampleBatchType):
|
||||||
"""Save a batch of experiences.
|
"""Saves a batch of experiences.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sample_batch: SampleBatch or MultiAgentBatch to save.
|
sample_batch: SampleBatch or MultiAgentBatch to save.
|
||||||
|
@ -22,4 +21,5 @@ class NoopOutput(OutputWriter):
|
||||||
|
|
||||||
@override(OutputWriter)
|
@override(OutputWriter)
|
||||||
def write(self, sample_batch: SampleBatchType):
|
def write(self, sample_batch: SampleBatchType):
|
||||||
|
# Do nothing.
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -18,11 +18,11 @@ class ShuffledInput(InputReader):
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
def __init__(self, child: InputReader, n: int = 0):
|
def __init__(self, child: InputReader, n: int = 0):
|
||||||
"""Initialize a MixedInput.
|
"""Initializes a ShuffledInput instance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
child (InputReader): child input reader to shuffle.
|
child: child input reader to shuffle.
|
||||||
n (int): if positive, shuffle input over this many batches.
|
n: If positive, shuffle input over this many batches.
|
||||||
"""
|
"""
|
||||||
self.n = n
|
self.n = n
|
||||||
self.child = child
|
self.child = child
|
||||||
|
|
|
@ -20,7 +20,7 @@ class WeightedImportanceSamplingEstimator(OffPolicyEstimator):
|
||||||
self.check_can_estimate_for(batch)
|
self.check_can_estimate_for(batch)
|
||||||
|
|
||||||
rewards, old_prob = batch["rewards"], batch["action_prob"]
|
rewards, old_prob = batch["rewards"], batch["action_prob"]
|
||||||
new_prob = self.action_prob(batch)
|
new_prob = self.action_log_likelihood(batch)
|
||||||
|
|
||||||
# calculate importance ratios
|
# calculate importance ratios
|
||||||
p = []
|
p = []
|
||||||
|
|
Loading…
Add table
Reference in a new issue