ray/rllib/offline/io_context.py
Julius Frost a88b217d3f
[rllib] Enhancements to Input API for customizing offline datasets (#16957)
Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
2021-07-10 15:05:25 -07:00

40 lines
1.2 KiB
Python

import os
from ray.rllib.utils.annotations import PublicAPI
from typing import Any
@PublicAPI
class IOContext:
"""Attributes to pass to input / output class constructors.
RLlib auto-sets these attributes when constructing input / output classes.
Attributes:
log_dir (str): Default logging directory.
config (dict): Configuration of the agent.
worker_index (int): When there are multiple workers created, this
uniquely identifies the current worker.
worker (RolloutWorker): RolloutWorker object reference.
input_config (dict): The input configuration for custom input.
"""
@PublicAPI
def __init__(self,
log_dir: str = None,
config: dict = None,
worker_index: int = 0,
worker: Any = None):
self.log_dir = log_dir or os.getcwd()
self.config = config or {}
self.worker_index = worker_index
self.worker = worker
@PublicAPI
def default_sampler_input(self) -> Any:
return self.worker.sampler
@PublicAPI
@property
def input_config(self):
return self.config.get("input_config", {})