[docs] Fix the remaining style violations in docstrings and add lint rule (#27033)

This commit is contained in:
Eric Liang 2022-07-27 22:24:20 -07:00 committed by GitHub
parent 0dbb18a87d
commit a4434fac7f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
31 changed files with 125 additions and 121 deletions

View file

@ -38,13 +38,14 @@ check_python_command_exist() {
check_docstyle() {
echo "Checking docstyle..."
violations=$(git ls-files | grep '.py$' | xargs grep -E '^[ a-z_]+ \([a-zA-Z]+\): ' || true)
violations=$(git ls-files | grep '.py$' | xargs grep -E '^[ ]+[a-z_]+ ?\([a-zA-Z]+\): ' | grep -v 'str(' | grep -v noqa || true)
if [[ -n "$violations" ]]; then
echo
echo "=== Found Ray docstyle violations ==="
echo "$violations"
echo
echo "Per the Google pydoc style, omit types from pydoc args as they are redundant: https://docs.ray.io/en/latest/ray-contribute/getting-involved.html#code-style "
echo "If this is a false positive, you can add a '# noqa' comment to the line to ignore."
exit 1
fi
return 0

View file

@ -31,11 +31,11 @@ class Dashboard:
which polls said API for display purposes.
Args:
host(str): Host address of dashboard aiohttp server.
port(int): Port number of dashboard aiohttp server.
port_retries(int): The retry times to select a valid port.
gcs_address(str): GCS address of the cluster
log_dir(str): Log directory of dashboard.
host: Host address of dashboard aiohttp server.
port: Port number of dashboard aiohttp server.
port_retries: The retry times to select a valid port.
gcs_address: GCS address of the cluster
log_dir: Log directory of dashboard.
"""
def __init__(

View file

@ -309,14 +309,14 @@ class RuntimeEnvAgent(
Create runtime env with retry times. This function won't raise exceptions.
Args:
runtime_env(RuntimeEnv): The instance of RuntimeEnv class.
serialized_runtime_env(str): The serialized runtime env.
serialized_allocated_resource_instances(str): The serialized allocated
runtime_env: The instance of RuntimeEnv class.
serialized_runtime_env: The serialized runtime env.
serialized_allocated_resource_instances: The serialized allocated
resource instances.
setup_timeout_seconds(int): The timeout of runtime environment creation.
setup_timeout_seconds: The timeout of runtime environment creation.
Returns:
a tuple which contains result(bool), runtime env context(str), error
a tuple which contains result (bool), runtime env context (str), error
message(str).
"""

View file

@ -33,10 +33,10 @@ def create_url_with_offset(*, url: str, offset: int, size: int) -> str:
Example) file://path/to/file?offset=""&size=""
Args:
url(str): url to the object stored in the external storage.
offset(int): Offset from the beginning of the file to
url: url to the object stored in the external storage.
offset: Offset from the beginning of the file to
the first bytes of this object.
size(int): Size of the object that is stored in the url.
size: Size of the object that is stored in the url.
It is used to calculate the last offset.
Returns:
@ -53,7 +53,7 @@ def parse_url_with_offset(url_with_offset: str) -> Tuple[str, int, int]:
is stored in the external storage.
Args:
url_with_offset(str): url created by create_url_with_offset.
url_with_offset: url created by create_url_with_offset.
Returns:
named tuple of base_url, offset, and size.
@ -110,10 +110,10 @@ class ExternalStorage(metaclass=abc.ABCMeta):
"""Fuse all given objects into a given file handle.
Args:
f(IO): File handle to fusion all given object refs.
object_refs(list): Object references to fusion to a single file.
owner_addresses(list): Owner addresses for the provided objects.
url(str): url where the object ref is stored
f: File handle to fusion all given object refs.
object_refs: Object references to fusion to a single file.
owner_addresses: Owner addresses for the provided objects.
url: url where the object ref is stored
in the external storage.
Return:
@ -161,9 +161,9 @@ class ExternalStorage(metaclass=abc.ABCMeta):
"""Check whether or not the obtained_data_size is as expected.
Args:
metadata_len(int): Actual metadata length of the object.
buffer_len(int): Actual buffer length of the object.
obtained_data_size(int): Data size specified in the
metadata_len: Actual metadata length of the object.
buffer_len: Actual buffer length of the object.
obtained_data_size: Data size specified in the
url_with_offset.
Raises:
@ -188,7 +188,7 @@ class ExternalStorage(metaclass=abc.ABCMeta):
Args:
object_refs: The list of the refs of the objects to be spilled.
owner_addresses(list): Owner addresses for the provided objects.
owner_addresses: Owner addresses for the provided objects.
Returns:
A list of internal URLs with object offset.
"""
@ -442,9 +442,9 @@ class ExternalStorageSmartOpenImpl(ExternalStorage):
the directory.
Args:
uri(str): Storage URI used for smart open.
prefix(str): Prefix of objects that are stored.
override_transport_params(dict): Overriding the default value of
uri: Storage URI used for smart open.
prefix: Prefix of objects that are stored.
override_transport_params: Overriding the default value of
transport_params for smart-open library.
Raises:

View file

@ -201,8 +201,8 @@ class PrometheusServiceDiscoveryWriter(threading.Thread):
https://prometheus.io/docs/guides/file-sd/ for more details.
Args:
gcs_address(str): Gcs address for this cluster.
temp_dir(str): Temporary directory used by
gcs_address: Gcs address for this cluster.
temp_dir: Temporary directory used by
Ray to store logs and metadata.
"""

View file

@ -103,10 +103,10 @@ class RayParams:
monitor the log files for all processes on this node and push their
contents to Redis.
autoscaling_config: path to autoscaling config file.
metrics_agent_port(int): The port to bind metrics agent.
metrics_export_port(int): The port at which metrics are exposed
metrics_agent_port: The port to bind metrics agent.
metrics_export_port: The port at which metrics are exposed
through a Prometheus endpoint.
no_monitor(bool): If True, the ray autoscaler monitor for this cluster
no_monitor: If True, the ray autoscaler monitor for this cluster
will not be started.
_system_config: Configuration for overriding RayConfig
defaults. Used to set system configuration and for experimental Ray

View file

@ -43,18 +43,18 @@ def setup_component_logger(
The only exception is workers. They use the different logging config.
Args:
logging_level(str | int): Logging level in string or logging enum.
logging_format(str): Logging format string.
log_dir(str): Log directory path. If empty, logs will go to
logging_level: Logging level in string or logging enum.
logging_format: Logging format string.
log_dir: Log directory path. If empty, logs will go to
stderr.
filename(str): Name of the file to write logs. If empty, logs will go
filename: Name of the file to write logs. If empty, logs will go
to stderr.
max_bytes(int): Same argument as RotatingFileHandler's maxBytes.
backup_count(int): Same argument as RotatingFileHandler's backupCount.
logger_name(str, optional): used to create or get the correspoding
max_bytes: Same argument as RotatingFileHandler's maxBytes.
backup_count: Same argument as RotatingFileHandler's backupCount.
logger_name: used to create or get the correspoding
logger in getLogger call. It will get the root logger by default.
Returns:
logger (logging.Logger): the created or modified logger.
the created or modified logger.
"""
logger = logging.getLogger(logger_name)
if type(logging_level) is str:
@ -97,7 +97,7 @@ class StandardStreamInterceptor:
Args:
logger: Python logger that will receive messages streamed to
the standard out/err and delegate writes.
intercept_stdout(bool): True if the class intercepts stdout. False
intercept_stdout: True if the class intercepts stdout. False
if stderr is intercepted.
"""
@ -247,9 +247,9 @@ def setup_and_get_worker_interceptor_logger(
Args:
args: args received from default_worker.py.
max_bytes(int): maxBytes argument of RotatingFileHandler.
backup_count(int): backupCount argument of RotatingFileHandler.
is_for_stdout(bool): True if logger will be used to intercept stdout.
max_bytes: maxBytes argument of RotatingFileHandler.
backup_count: backupCount argument of RotatingFileHandler.
is_for_stdout: True if logger will be used to intercept stdout.
False otherwise.
"""
file_extension = "out" if is_for_stdout else "err"

View file

@ -33,6 +33,7 @@ class RuntimeEnvPlugin(ABC):
Args:
runtime_env_dict: the user-supplied runtime environment dict.
Raises:
ValueError: if the validation fails.
"""

View file

@ -117,7 +117,7 @@ def propagate_jemalloc_env_var(
Params:
jemalloc_path: The path to the jemalloc shared library.
jemalloc_conf: `,` separated string of jemalloc config.
jemalloc_comps List(str): The list of Ray components
jemalloc_comps: The list of Ray components
that we will profile.
process_type: The process type that needs jemalloc
env var for memory profiling. If it doesn't match one of
@ -1390,7 +1390,7 @@ def start_gcs_server(
config: Optional configuration that will
override defaults in RayConfig.
gcs_server_port: Port number of the gcs server.
metrics_agent_port(int): The port where metrics agent is bound to.
metrics_agent_port: The port where metrics agent is bound to.
node_ip_address: IP Address of a node where gcs server starts.
Returns:
@ -1470,7 +1470,7 @@ def start_raylet(
redis_address: The address of the primary Redis server.
gcs_address: The address of GCS server.
node_ip_address: The IP address of this node.
node_manager_port(int): The port to use for the node manager. If it's
node_manager_port: The port to use for the node manager. If it's
0, a random port will be used.
raylet_name: The name of the raylet socket to create.
plasma_store_name: The name of the plasma store socket to connect
@ -1482,7 +1482,7 @@ def start_raylet(
storage: The persistent storage URI.
temp_dir: The path of the temporary directory Ray will use.
session_dir: The path of this session.
resource_dir(str): The path of resource of this session .
resource_dir: The path of resource of this session .
log_dir: The path of the dir where log files are created.
resource_spec: Resources for this raylet.
object_manager_port: The port to use for the object manager. If this is
@ -1973,7 +1973,7 @@ def start_monitor(
Args:
redis_address: The address that the Redis server is listening on.
gcs_address: The address of GCS server.
logs_dir(str): The path to the log directory.
logs_dir: The path to the log directory.
stdout_file: A file handle opened for writing to redirect stdout to. If
no redirection should happen, then this should be None.
stderr_file: A file handle opened for writing to redirect stderr to. If

View file

@ -477,7 +477,7 @@ def wait_until_succeeded_without_exception(
Args:
func: A function to run.
exceptions(tuple): Exceptions that are supposed to occur.
exceptions: Exceptions that are supposed to occur.
args: arguments to pass for a given func
timeout_ms: Maximum timeout in milliseconds.
retry_interval_ms: Retry interval in milliseconds.

View file

@ -852,7 +852,7 @@ def generate_write_data(
Params:
usage_stats: The usage stats that were reported.
error(str): The error message of failed reports.
error: The error message of failed reports.
Returns:
UsageStatsToWrite

View file

@ -1138,7 +1138,7 @@ def init(
_temp_dir: If provided, specifies the root temporary
directory for the Ray process. Defaults to an OS-specific
conventional location, e.g., "/tmp/ray".
_metrics_export_port(int): Port number Ray exposes system metrics
_metrics_export_port: Port number Ray exposes system metrics
through a Prometheus endpoint. It is currently under active
development, and the API is subject to change.
_system_config: Configuration for overriding

View file

@ -199,17 +199,17 @@ class NodeProvider:
"""Returns the CommandRunner class used to perform SSH commands.
Args:
log_prefix(str): stores "NodeUpdater: {}: ".format(<node_id>). Used
log_prefix: stores "NodeUpdater: {}: ".format(<node_id>). Used
to print progress in the CommandRunner.
node_id(str): the node ID.
auth_config(dict): the authentication configs from the autoscaler
node_id: the node ID.
auth_config: the authentication configs from the autoscaler
yaml file.
cluster_name(str): the name of the cluster.
process_runner(module): the module to use to run the commands
cluster_name: the name of the cluster.
process_runner: the module to use to run the commands
in the CommandRunner. E.g., subprocess.
use_internal_ip(bool): whether the node_id belongs to an internal ip
use_internal_ip: whether the node_id belongs to an internal ip
or external ip.
docker_config(dict): If set, the docker information of the docker
docker_config: If set, the docker information of the docker
container that commands should be run on.
"""
common_args = {

View file

@ -34,7 +34,7 @@ class RuntimeEnvConfig(dict):
timeout logic, except `-1`, `setup_timeout_seconds` cannot be
less than or equal to 0. The default value of `setup_timeout_seconds`
is 600 seconds.
eager_install(bool): Indicates whether to install the runtime environment
eager_install: Indicates whether to install the runtime environment
on the cluster at `ray.init()` time, before the workers are leased.
This flag is set to `True` by default.
"""

View file

@ -33,7 +33,7 @@ def get_deployment(name: str):
"""Dynamically fetch a handle to a Deployment object.
Args:
name(str): name of the deployment. This must have already been
name: name of the deployment. This must have already been
deployed.
Returns:

View file

@ -30,7 +30,7 @@ def start_metrics_pusher(
is garbage collected or when the Serve application shuts down.
Args:
interval_s(float): the push interval.
interval_s: the push interval.
collection_callback: a callable that returns the metric data points to
be sent to the the controller. The collection callback should take
no argument and returns a dictionary of str_key -> float_value.
@ -99,10 +99,10 @@ class InMemoryMetricsStore:
"""Push new data points to the store.
Args:
data_points(dict): dictionary containing the metrics values. The
data_points: dictionary containing the metrics values. The
key should be a string that uniquely identifies this time series
and to be used to perform aggregation.
timestamp(float): the unix epoch timestamp the metrics are
timestamp: the unix epoch timestamp the metrics are
collected at.
"""
for name, value in data_points.items():
@ -128,11 +128,11 @@ class InMemoryMetricsStore:
"""Perform a window average operation for metric `key`
Args:
key(str): the metric name.
window_start_timestamp_s(float): the unix epoch timestamp for the
key: the metric name.
window_start_timestamp_s: the unix epoch timestamp for the
start of the window. The computed average will use all datapoints
from this timestamp until now.
do_compact(bool): whether or not to delete the datapoints that's
do_compact: whether or not to delete the datapoints that's
before `window_start_timestamp_s` to save memory. Default is
true.
Returns:
@ -152,11 +152,11 @@ class InMemoryMetricsStore:
"""Perform a max operation for metric `key`.
Args:
key(str): the metric name.
window_start_timestamp_s(float): the unix epoch timestamp for the
key: the metric name.
window_start_timestamp_s: the unix epoch timestamp for the
start of the window. The computed average will use all datapoints
from this timestamp until now.
do_compact(bool): whether or not to delete the datapoints that's
do_compact: whether or not to delete the datapoints that's
before `window_start_timestamp_s` to save memory. Default is
true.
Returns:

View file

@ -389,7 +389,7 @@ class ActorReplicaWrapper:
- replica __init__() failed.
SUCCEEDED:
- replica __init__() and reconfigure() succeeded.
version (DeploymentVersion):
version:
None:
- replica reconfigure() haven't returned OR
- replica __init__() failed.
@ -851,9 +851,9 @@ class ReplicaStateContainer:
"""Get the total count of replicas of the given states.
Args:
exclude_version(DeploymentVersion): version to exclude. If not
exclude_version: version to exclude. If not
specified, all versions are considered.
version(DeploymentVersion): version to filter to. If not specified,
version: version to filter to. If not specified,
all versions are considered.
states: states to consider. If not specified, all replicas
are considered.

View file

@ -262,7 +262,7 @@ def set_socket_reuse_port(sock: socket.socket) -> bool:
"""Mutate a socket object to allow multiple process listening on the same port.
Returns:
success(bool): whether the setting was successful.
success: whether the setting was successful.
"""
try:
# These two socket options will allow multiple process to bind the the

View file

@ -64,10 +64,10 @@ class LongPollClient:
"""The asynchronous long polling client.
Args:
host_actor(ray.ActorHandle): handle to actor embedding LongPollHost.
key_listeners(Dict[str, AsyncCallable]): a dictionary mapping keys to
host_actor: handle to actor embedding LongPollHost.
key_listeners: a dictionary mapping keys to
callbacks to be called on state update for the corresponding keys.
call_in_event_loop(AbstractEventLoop): an asyncio event loop
call_in_event_loop: an asyncio event loop
to post the callback into.
"""

View file

@ -214,7 +214,7 @@ class Router:
"""Router process incoming queries: assign a replica.
Args:
controller_handle(ActorHandle): The controller handle.
controller_handle: The controller handle.
"""
self._event_loop = event_loop
self._replica_set = ReplicaSet(deployment_name, event_loop)

View file

@ -73,14 +73,14 @@ def start(
for HTTP proxy. You can pass in a dictionary or HTTPOptions object
with fields:
- host(str, None): Host for HTTP servers to listen on. Defaults to
- host: Host for HTTP servers to listen on. Defaults to
"127.0.0.1". To expose Serve publicly, you probably want to set
this to "0.0.0.0".
- port(int): Port for HTTP server. Defaults to 8000.
- root_path(str): Root path to mount the serve application
- port: Port for HTTP server. Defaults to 8000.
- root_path: Root path to mount the serve application
(for example, "/serve"). All deployment routes will be prefixed
with this path. Defaults to "".
- middlewares(list): A list of Starlette middlewares that will be
- middlewares: A list of Starlette middlewares that will be
applied to the HTTP servers in the cluster. Defaults to [].
- location(str, serve.config.DeploymentMode): The deployment
location of HTTP servers:
@ -90,7 +90,7 @@ def start(
on. This is the default.
- "EveryNode": start one HTTP server per node.
- "NoServer" or None: disable HTTP server.
- num_cpus (int): The number of CPU cores to reserve for each
- num_cpus: The number of CPU cores to reserve for each
internal Serve HTTP proxy actor. Defaults to 0.
dedicated_cpu: Whether to reserve a CPU core for the internal
Serve controller actor. Defaults to False.
@ -396,7 +396,7 @@ def get_deployment(name: str) -> Deployment:
>>> MyDeployment.options(num_replicas=10).deploy() # doctest: +SKIP
Args:
name(str): name of the deployment. This must have already been
name: name of the deployment. This must have already been
deployed.
Returns:

View file

@ -440,7 +440,7 @@ class ServeController:
"""Get the current information about a deployment.
Args:
name(str): the name of the deployment.
name: the name of the deployment.
Returns:
DeploymentRoute's protobuf serialized bytes
@ -467,7 +467,7 @@ class ServeController:
"""Gets the current information about all deployments.
Args:
include_deleted(bool): Whether to include information about
include_deleted: Whether to include information about
deployments that have been deleted.
Returns:
@ -492,7 +492,7 @@ class ServeController:
"""Gets the current information about all deployments.
Args:
include_deleted(bool): Whether to include information about
include_deleted: Whether to include information about
deployments that have been deleted.
Returns:

View file

@ -167,7 +167,7 @@ class RayServeHandle:
"""Set options for this handle.
Args:
method_name(str): The method to invoke.
method_name: The method to invoke.
"""
new_options_dict = self.handle_options.__dict__.copy()
user_modified_options_dict = {

View file

@ -21,7 +21,9 @@ def test_gpu_ids(shutdown_only):
def get_gpu_ids(num_gpus_per_worker):
gpu_ids = ray.get_gpu_ids()
assert len(gpu_ids) == num_gpus_per_worker
assert os.environ["CUDA_VISIBLE_DEVICES"] == ",".join([str(i) for i in gpu_ids])
assert os.environ["CUDA_VISIBLE_DEVICES"] == ",".join(
[str(i) for i in gpu_ids] # noqa
)
for gpu_id in gpu_ids:
assert gpu_id in range(num_gpus)
return gpu_ids
@ -59,7 +61,7 @@ def test_gpu_ids(shutdown_only):
gpu_ids = ray.get_gpu_ids()
assert len(gpu_ids) == 0
assert os.environ["CUDA_VISIBLE_DEVICES"] == ",".join(
[str(i) for i in gpu_ids]
[str(i) for i in gpu_ids] # noqa
)
# Set self.x to make sure that we got here.
self.x = 1
@ -431,7 +433,7 @@ def test_many_custom_resources(shutdown_only):
else:
num_custom_resources = 10000
total_resources = {
str(i): np.random.randint(1, 7) for i in range(num_custom_resources)
str(i): np.random.randint(1, 7) for i in range(num_custom_resources) # noqa
}
ray.init(num_cpus=5, resources=total_resources)

View file

@ -124,9 +124,9 @@ class PlacementGroupFactory:
tuner.fit()
Args:
bundles(List[Dict]): A list of bundles which
bundles: A list of bundles which
represent the resources requirements.
strategy(str): The strategy to create the placement group.
strategy: The strategy to create the placement group.
- "PACK": Packs Bundles into as few nodes as possible.
- "SPREAD": Places Bundles across distinct nodes as even as possible.

View file

@ -349,7 +349,7 @@ class HyperbandSuite(unittest.TestCase):
"""Default statistics for HyperBand."""
sched = HyperBandScheduler()
res = {
str(s): {"n": sched._get_n0(s), "r": sched._get_r0(s)}
str(s): {"n": sched._get_n0(s), "r": sched._get_r0(s)} # noqa
for s in range(sched._s_max_1)
}
res["max_trials"] = sum(v["n"] for v in res.values())

View file

@ -781,7 +781,7 @@ def _get_comm_key_from_devices(devices):
then the key would be "0,1,2,3".
Args:
devices(list): a list of GPU device indices
devices: a list of GPU device indices
Returns:
str: a string represents the key to query the communicator cache.

View file

@ -275,7 +275,7 @@ def get_tensor_device_list(tensors):
"""Returns the gpu devices of the list of input tensors.
Args:
tensors(list): a list of tensors, each locates on a GPU.
tensors: a list of tensors, each locates on a GPU.
Returns:
list: the list of GPU devices.

View file

@ -58,7 +58,7 @@ class Metric:
>>> counter = Counter("name").set_default_tags({"a": "b"})
Args:
default_tags(dict): Default tags that are
default_tags: Default tags that are
used for every record method.
Returns:
@ -81,7 +81,7 @@ class Metric:
Tags passed in will take precedence over the metric's default tags.
Args:
value(float): The value to be recorded as a metric point.
value: The value to be recorded as a metric point.
"""
assert self._metric is not None
if isinstance(self._metric, CythonCount) and not _internal:
@ -159,9 +159,9 @@ class Counter(Metric):
https://prometheus.io/docs/concepts/metric_types/#counter
Args:
name(str): Name of the metric.
description(str): Description of the metric.
tag_keys(tuple): Tag keys of the metric.
name: Name of the metric.
description: Description of the metric.
tag_keys: Tag keys of the metric.
"""
def __init__(
@ -201,9 +201,9 @@ class Count(Counter):
This class is DEPRECATED, please use ray.util.metrics.Counter instead.
Args:
name(str): Name of the metric.
description(str): Description of the metric.
tag_keys(tuple): Tag keys of the metric.
name: Name of the metric.
description: Description of the metric.
tag_keys: Tag keys of the metric.
"""
def __init__(
@ -227,10 +227,10 @@ class Histogram(Metric):
https://prometheus.io/docs/concepts/metric_types/#histogram
Args:
name(str): Name of the metric.
description(str): Description of the metric.
boundaries(list): Boundaries of histogram buckets.
tag_keys(tuple): Tag keys of the metric.
name: Name of the metric.
description: Description of the metric.
boundaries: Boundaries of histogram buckets.
tag_keys: Tag keys of the metric.
"""
def __init__(
@ -301,9 +301,9 @@ class Gauge(Metric):
https://prometheus.io/docs/concepts/metric_types/#gauge
Args:
name(str): Name of the metric.
description(str): Description of the metric.
tag_keys(tuple): Tag keys of the metric.
name: Name of the metric.
description: Description of the metric.
tag_keys: Tag keys of the metric.
"""
def __init__(

View file

@ -83,14 +83,14 @@ class SlimConv2d(nn.Module):
"""Creates a standard Conv2d layer, similar to torch.nn.Conv2d
Args:
in_channels(int): Number of input channels
in_channels: Number of input channels
out_channels: Number of output channels
kernel (Union[int, Tuple[int, int]]): If int, the kernel is
kernel: If int, the kernel is
a tuple(x,x). Elsewise, the tuple can be specified
stride (Union[int, Tuple[int, int]]): Controls the stride
stride: Controls the stride
for the cross-correlation. If int, the stride is a
tuple(x,x). Elsewise, the tuple can be specified
padding (Union[int, Tuple[int, int]]): Controls the amount
padding: Controls the amount
of implicit zero-paddings during the conv operation
initializer: Initializer function for kernel weights
activation_fn: Activation function at the end of layer
@ -140,7 +140,7 @@ class SlimFC(nn.Module):
"""Creates a standard FC layer, similar to torch.nn.Linear
Args:
in_size(int): Input size for FC Layer
in_size: Input size for FC Layer
out_size: Output size for FC Layer
initializer: Initializer function for FC layer weights
activation_fn: Activation function at the end of layer

View file

@ -13,7 +13,7 @@ def flatten_space(space: gym.Space) -> List[gym.Space]:
Primitive components are any non Tuple/Dict spaces.
Args:
space (gym.Space): The gym.Space to flatten. This may be any
space: The gym.Space to flatten. This may be any
supported type (including nested Tuples and Dicts).
Returns:
@ -43,7 +43,7 @@ def get_base_struct_from_space(space):
"""Returns a Tuple/Dict Space as native (equally structured) py tuple/dict.
Args:
space (gym.Space): The Space to get the python struct for.
space: The Space to get the python struct for.
Returns:
Union[dict,tuple,gym.Space]: The struct equivalent to the given Space.
@ -83,13 +83,13 @@ def get_dummy_batch_for_space(
as an additional batch dimension has to be added as dim=0.
Args:
space (gym.Space): The space to get a dummy batch for.
batch_size(int): The required batch size (B). Note that this can also
space: The space to get a dummy batch for.
batch_size: The required batch size (B). Note that this can also
be 0 (only if `time_size` is None!), which will result in a
non-batched sample for the given space (no batch dim).
fill_value (Union[float, int, str]): The value to fill the batch with
fill_value: The value to fill the batch with
or "random" for random values.
time_size (Optional[int]): If not None, add an optional time axis
time_size: If not None, add an optional time axis
of `time_size` size to the returned batch.
time_major: If True AND `time_size` is not None, return batch
as shape [T x B x ...], otherwise as [B x T x ...]. If `time_size`