[CI] Replace YAPF disables with Black disables (#21982)

This commit is contained in:
Balaji Veeramani 2022-02-08 16:29:25 -08:00 committed by GitHub
parent dcd96ca348
commit 31ed9e5d02
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
55 changed files with 114 additions and 114 deletions

View file

@ -1,5 +1,5 @@
# flake8: noqa
# yapf: disable
# fmt: off
# __data_setup_begin__

View file

@ -8,7 +8,7 @@ in the documentation.
"""
import ray
# yapf: disable
# fmt: off
# __runtime_env_conda_def_start__
runtime_env = {

View file

@ -7,7 +7,7 @@ but we put comments right after code blocks to prevent large white spaces
in the documentation.
"""
# yapf: disable
# fmt: off
# __tf_model_start__
@ -28,9 +28,9 @@ def create_keras_model():
metrics=[keras.metrics.categorical_accuracy])
return model
# __tf_model_end__
# yapf: enable
# fmt: on
# yapf: disable
# fmt: off
# __ray_start__
import ray
import numpy as np
@ -65,17 +65,17 @@ class Network(object):
# Note that for simplicity this does not handle the optimizer state.
self.model.set_weights(weights)
# __ray_end__
# yapf: enable
# fmt: on
# yapf: disable
# fmt: off
# __actor_start__
NetworkActor = Network.remote()
result_object_ref = NetworkActor.train.remote()
ray.get(result_object_ref)
# __actor_end__
# yapf: enable
# fmt: on
# yapf: disable
# fmt: off
# __weight_average_start__
NetworkActor2 = Network.remote()
NetworkActor2.train.remote()

View file

@ -6,7 +6,7 @@ It ignores yapf because yapf doesn't allow comments right after code blocks,
but we put comments right after code blocks to prevent large white spaces
in the documentation.
"""
# yapf: disable
# fmt: off
# __torch_model_start__
import argparse
@ -35,9 +35,9 @@ class Model(nn.Module):
# __torch_model_end__
# yapf: enable
# fmt: on
# yapf: disable
# fmt: off
# __torch_helper_start__
from filelock import FileLock
from torchvision import datasets, transforms
@ -112,9 +112,9 @@ def dataset_creator(use_cuda, data_dir):
# __torch_helper_end__
# yapf: enable
# fmt: on
# yapf: disable
# fmt: off
# __torch_net_start__
import torch.optim as optim
@ -155,9 +155,9 @@ args = parser.parse_args()
net = Network(data_dir=args.data_dir)
net.train()
# __torch_net_end__
# yapf: enable
# fmt: on
# yapf: disable
# fmt: off
# __torch_ray_start__
import ray
@ -167,18 +167,18 @@ RemoteNetwork = ray.remote(Network)
# Use the below instead of `ray.remote(network)` to leverage the GPU.
# RemoteNetwork = ray.remote(num_gpus=1)(Network)
# __torch_ray_end__
# yapf: enable
# fmt: on
# yapf: disable
# fmt: off
# __torch_actor_start__
NetworkActor = RemoteNetwork.remote()
NetworkActor2 = RemoteNetwork.remote()
ray.get([NetworkActor.train.remote(), NetworkActor2.train.remote()])
# __torch_actor_end__
# yapf: enable
# fmt: on
# yapf: disable
# fmt: off
# __weight_average_start__
weights = ray.get(
[NetworkActor.get_weights.remote(),

View file

@ -1,5 +1,5 @@
# flake8: noqa
# yapf: disable
# fmt: off
# __serve_example_begin__
import requests

View file

@ -1,4 +1,4 @@
# yapf: disable
# fmt: off
# __doc_import_begin__
from typing import List
import time
@ -10,7 +10,7 @@ from starlette.requests import Request
import ray
from ray import serve
# __doc_import_end__
# yapf: enable
# fmt: on
# __doc_define_servable_begin__

View file

@ -1,4 +1,4 @@
# yapf: disable
# fmt: off
import ray
# __doc_import_begin__
from ray import serve
@ -11,7 +11,7 @@ import torch
from torchvision import transforms
from torchvision.models import resnet18
# __doc_import_end__
# yapf: enable
# fmt: on
# __doc_define_servable_begin__

View file

@ -1,4 +1,4 @@
# yapf: disable
# fmt: off
import ray
# __doc_import_begin__
from ray import serve
@ -14,7 +14,7 @@ from sklearn.datasets import load_iris
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.metrics import mean_squared_error
# __doc_import_end__
# yapf: enable
# fmt: on
# __doc_train_model_begin__
# Load data

View file

@ -1,4 +1,4 @@
# yapf: disable
# fmt: off
import ray
# __doc_import_begin__
from ray import serve
@ -8,7 +8,7 @@ import tempfile
import numpy as np
import requests
# __doc_import_end__
# yapf: enable
# fmt: on
# __doc_train_model_begin__
TRAINED_MODEL_PATH = os.path.join(tempfile.gettempdir(), "mnist_model.h5")

View file

@ -289,7 +289,7 @@ def test_add_min_workers_nodes():
# Formatting is disabled to prevent Black from erroring while formatting
# this file. See https://github.com/ray-project/ray/issues/21313 for more
# information.
# yapf: disable
# fmt: off
assert _add_min_workers_nodes([],
{},
types, None, None, None) == \
@ -336,7 +336,7 @@ def test_add_min_workers_nodes():
}, {
"gpubla": 10
})
# yapf: enable
# fmt: on
def test_get_nodes_to_launch_with_min_workers():

View file

@ -1,5 +1,5 @@
# flake8: noqa
# yapf: disable
# fmt: off
# __tf_setup_begin__

View file

@ -1,5 +1,5 @@
# flake8: noqa
# yapf: disable
# fmt: off
# __torch_setup_begin__
import torch

View file

@ -1,5 +1,5 @@
# flake8: noqa
# yapf: disable
# fmt: off
# __import_begin__
from functools import partial

View file

@ -1,5 +1,5 @@
# flake8: noqa
# yapf: disable
# fmt: off
# __import_lightning_begin__
import math

View file

@ -28,7 +28,7 @@ parser.add_argument(
# Below comments are for documentation purposes only.
# yapf: disable
# fmt: off
# __trainable_example_begin__
class TrainMNIST(tune.Trainable):
def setup(self, config):
@ -57,7 +57,7 @@ class TrainMNIST(tune.Trainable):
# __trainable_example_end__
# yapf: enable
# fmt: on
if __name__ == "__main__":
args = parser.parse_args()

View file

@ -1,7 +1,7 @@
#!/usr/bin/env python
# flake8: noqa
# yapf: disable
# fmt: off
# __tutorial_imports_begin__
import argparse

View file

@ -1,6 +1,6 @@
import os
# yapf: disable
# fmt: off
# __sphinx_doc_begin__
# (Optional/Auto-filled) training is terminated. Filled only if not provided.
DONE = "done"
@ -60,7 +60,7 @@ TIME_TOTAL_S = "time_total_s"
# (Auto-filled) The index of this training iteration.
TRAINING_ITERATION = "training_iteration"
# __sphinx_doc_end__
# yapf: enable
# fmt: on
DEFAULT_EXPERIMENT_INFO_KEYS = ("trainable_name", EXPERIMENT_TAG, TRIAL_ID)

View file

@ -1,5 +1,5 @@
# flake8: noqa
# yapf: disable
# fmt: off
# External PyTorch tutorial (https://github.com/pytorch/tutorials/pull/1066)
# If this script fails, fix it and submit a PR to pytorch/tutorials.

View file

@ -1,7 +1,7 @@
# flake8: noqa
# Original Code: https://github.com/pytorch/examples/blob/master/mnist/main.py
# yapf: disable
# fmt: off
# __tutorial_imports_begin__
import numpy as np
import torch
@ -14,10 +14,10 @@ import torch.nn.functional as F
from ray import tune
from ray.tune.schedulers import ASHAScheduler
# __tutorial_imports_end__
# yapf: enable
# fmt: on
# yapf: disable
# fmt: off
# __model_def_begin__
class ConvNet(nn.Module):
def __init__(self):
@ -33,9 +33,9 @@ class ConvNet(nn.Module):
x = self.fc(x)
return F.log_softmax(x, dim=1)
# __model_def_end__
# yapf: enable
# fmt: on
# yapf: disable
# fmt: off
# __train_def_begin__
# Change these values if you want the training to run quicker or slower.
@ -111,7 +111,7 @@ def train_mnist(config):
# This saves the model to the trial directory
torch.save(model.state_dict(), "./model.pth")
# __train_func_end__
# yapf: enable
# fmt: on
# __eval_func_begin__
search_space = {
@ -145,14 +145,14 @@ analysis = tune.run(
dfs = analysis.trial_dataframes
# __run_scheduler_end__
# yapf: disable
# fmt: off
# __plot_scheduler_begin__
# Plot by epoch
ax = None # This plots everything on the same plot
for d in dfs.values():
ax = d.mean_accuracy.plot(ax=ax, legend=False)
# __plot_scheduler_end__
# yapf: enable
# fmt: on
# __run_searchalg_begin__
from hyperopt import hp

View file

@ -5,7 +5,7 @@ It ignores yapf because yapf doesn't allow comments right after code blocks,
but we put comments right after code blocks to prevent large white spaces
in the documentation.
"""
# yapf: disable
# fmt: off
# __torch_operator_start__
import torch

View file

@ -5,7 +5,7 @@ but we put comments right after code blocks to prevent large white spaces
in the documentation.
"""
# yapf: disable
# fmt: off
# __torch_train_example__
import argparse
import numpy as np

View file

@ -1,4 +1,4 @@
# yapf: disable
# fmt: off
"""
This file holds code for a Distributed Pytorch + Tune page in the docs.

View file

@ -1,4 +1,4 @@
# yapf: disable
# fmt: off
from typing import Any, Callable, Generic, Optional, TypeVar, Union, overload, Sequence, List
from ray._raylet import ObjectRef

View file

@ -1,4 +1,4 @@
# yapf: disable
# fmt: off
from typing import Callable, Generic, Optional, TypeVar, Union, overload, Any
from types import FunctionType

View file

@ -27,7 +27,7 @@ from ray.util.iter import LocalIterator
logger = logging.getLogger(__name__)
# yapf: disable
# fmt: off
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
# Should use a critic as a baseline (otherwise don't use value baseline;
@ -68,7 +68,7 @@ DEFAULT_CONFIG = with_common_config({
"_disable_execution_plan_api": True,
})
# __sphinx_doc_end__
# yapf: enable
# fmt: on
class A3CTrainer(Trainer):

View file

@ -35,7 +35,7 @@ Result = namedtuple(
],
)
# yapf: disable
# fmt: off
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
"action_noise_std": 0.0,
@ -59,7 +59,7 @@ DEFAULT_CONFIG = with_common_config({
},
})
# __sphinx_doc_end__
# yapf: enable
# fmt: on
@ray.remote

View file

@ -9,7 +9,7 @@ from ray.rllib.utils.typing import TrainerConfigDict
logger = logging.getLogger(__name__)
# yapf: disable
# fmt: off
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
# No remote workers by default.
@ -26,7 +26,7 @@ DEFAULT_CONFIG = with_common_config({
"timesteps_per_iteration": 100,
})
# __sphinx_doc_end__
# yapf: enable
# fmt: on
class BanditLinTSTrainer(Trainer):

View file

@ -26,7 +26,7 @@ tf1, tf, tfv = try_import_tf()
tfp = try_import_tfp()
logger = logging.getLogger(__name__)
# yapf: disable
# fmt: off
# __sphinx_doc_begin__
CQL_DEFAULT_CONFIG = merge_dicts(
SAC_CONFIG, {
@ -55,7 +55,7 @@ CQL_DEFAULT_CONFIG = merge_dicts(
},
})
# __sphinx_doc_end__
# yapf: enable
# fmt: on
class CQLTrainer(SACTrainer):

View file

@ -11,7 +11,7 @@ from ray.rllib.utils.typing import TrainerConfigDict
logger = logging.getLogger(__name__)
# yapf: disable
# fmt: off
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
# === Twin Delayed DDPG (TD3) and Soft Actor-Critic (SAC) tricks ===
@ -175,7 +175,7 @@ DEFAULT_CONFIG = with_common_config({
"min_time_s_per_reporting": 1,
})
# __sphinx_doc_end__
# yapf: enable
# fmt: on
class DDPGTrainer(SimpleQTrainer):

View file

@ -47,7 +47,7 @@ from ray.tune.trainable import Trainable
from ray.tune.utils.placement_groups import PlacementGroupFactory
from ray.util.iter import LocalIterator
# yapf: disable
# fmt: off
# __sphinx_doc_begin__
APEX_DEFAULT_CONFIG = merge_dicts(
# See also the options in dqn.py, which are also supported.
@ -92,7 +92,7 @@ APEX_DEFAULT_CONFIG = merge_dicts(
},
)
# __sphinx_doc_end__
# yapf: enable
# fmt: on
# Update worker weights as they finish generating experiences.

View file

@ -38,7 +38,7 @@ from ray.util.iter import LocalIterator
logger = logging.getLogger(__name__)
# yapf: disable
# fmt: off
# __sphinx_doc_begin__
DEFAULT_CONFIG = Trainer.merge_trainer_configs(
SIMPLEQ_DEFAULT_CONFIG,
@ -106,7 +106,7 @@ DEFAULT_CONFIG = Trainer.merge_trainer_configs(
_allow_unknown_configs=True,
)
# __sphinx_doc_end__
# yapf: enable
# fmt: on
def calculate_rr_weights(config: TrainerConfigDict) -> List[float]:

View file

@ -11,7 +11,7 @@ from ray.rllib.utils.typing import TrainerConfigDict
logger = logging.getLogger(__name__)
# yapf: disable
# fmt: off
# __sphinx_doc_begin__
R2D2_DEFAULT_CONFIG = Trainer.merge_trainer_configs(
DQN_DEFAULT_CONFIG, # See keys in impala.py, which are also supported.
@ -70,7 +70,7 @@ R2D2_DEFAULT_CONFIG = Trainer.merge_trainer_configs(
_allow_unknown_configs=True,
)
# __sphinx_doc_end__
# yapf: enable
# fmt: on
# Build an R2D2 trainer, which uses the framework specific Policy

View file

@ -31,7 +31,7 @@ from ray.rllib.utils.typing import TrainerConfigDict
logger = logging.getLogger(__name__)
# yapf: disable
# fmt: off
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
# === Exploration Settings ===
@ -110,7 +110,7 @@ DEFAULT_CONFIG = with_common_config({
"min_time_s_per_reporting": 1,
})
# __sphinx_doc_end__
# yapf: enable
# fmt: on
class SimpleQTrainer(Trainer):

View file

@ -17,7 +17,7 @@ from ray.rllib.utils.typing import SampleBatchType, TrainerConfigDict
logger = logging.getLogger(__name__)
# yapf: disable
# fmt: off
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
# PlaNET Model LR
@ -78,7 +78,7 @@ DEFAULT_CONFIG = with_common_config({
}
})
# __sphinx_doc_end__
# yapf: enable
# fmt: on
class EpisodicBuffer(object):

View file

@ -33,7 +33,7 @@ Result = namedtuple(
],
)
# yapf: disable
# fmt: off
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
"action_noise_std": 0.01,
@ -58,7 +58,7 @@ DEFAULT_CONFIG = with_common_config({
},
})
# __sphinx_doc_end__
# yapf: enable
# fmt: on
@ray.remote

View file

@ -25,7 +25,7 @@ from ray.tune.utils.placement_groups import PlacementGroupFactory
logger = logging.getLogger(__name__)
# yapf: disable
# fmt: off
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
# V-trace params (see vtrace_tf/torch.py).
@ -127,7 +127,7 @@ DEFAULT_CONFIG = with_common_config({
"num_data_loader_buffers": DEPRECATED_VALUE,
})
# __sphinx_doc_end__
# yapf: enable
# fmt: on
def make_learner_thread(local_worker, config):

View file

@ -27,7 +27,7 @@ from ray.util.iter import from_actors, LocalIterator
logger = logging.getLogger(__name__)
# yapf: disable
# fmt: off
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
# If true, use the Generalized Advantage Estimator (GAE)
@ -80,7 +80,7 @@ DEFAULT_CONFIG = with_common_config({
"vf_share_layers": DEPRECATED_VALUE,
})
# __sphinx_doc_end__
# yapf: enable
# fmt: on
# @mluo: TODO

View file

@ -5,7 +5,7 @@ from ray.rllib.agents.marwil.marwil import (
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import TrainerConfigDict
# yapf: disable
# fmt: off
# __sphinx_doc_begin__
BC_DEFAULT_CONFIG = MARWILTrainer.merge_trainer_configs(
MARWIL_CONFIG, {
@ -19,7 +19,7 @@ BC_DEFAULT_CONFIG = MARWILTrainer.merge_trainer_configs(
"input_evaluation": [],
})
# __sphinx_doc_end__
# yapf: enable
# fmt: on
class BCTrainer(MARWILTrainer):

View file

@ -14,7 +14,7 @@ from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import TrainerConfigDict
from ray.util.iter import LocalIterator
# yapf: disable
# fmt: off
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
# === Input settings ===
@ -73,7 +73,7 @@ DEFAULT_CONFIG = with_common_config({
"num_workers": 0,
})
# __sphinx_doc_end__
# yapf: enable
# fmt: on
class MARWILTrainer(Trainer):

View file

@ -35,7 +35,7 @@ from ray.util.iter import from_actors, LocalIterator
logger = logging.getLogger(__name__)
# yapf: disable
# fmt: off
# __sphinx_doc_begin__
# Adds the following updates to the (base) `Trainer` config in
@ -115,7 +115,7 @@ DEFAULT_CONFIG = with_common_config({
"vf_share_layers": DEPRECATED_VALUE,
})
# __sphinx_doc_end__
# yapf: enable
# fmt: on
# Select Metric Keys for MAML Stats Tracing
METRICS_KEYS = ["episode_reward_mean", "episode_reward_min", "episode_reward_max"]

View file

@ -1,6 +1,6 @@
from ray.rllib.agents.trainer import with_common_config
# yapf: disable
# fmt: off
# __sphinx_doc_begin__
# Add the following (PG-specific) updates to the (base) `Trainer` config in
@ -20,4 +20,4 @@ DEFAULT_CONFIG = with_common_config({
})
# __sphinx_doc_end__
# yapf: enable
# fmt: on

View file

@ -25,7 +25,7 @@ from ray.rllib.execution.common import (
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import PartialTrainerConfigDict, TrainerConfigDict
# yapf: disable
# fmt: off
# __sphinx_doc_begin__
# Adds the following updates to the `IMPALATrainer` config in
@ -85,7 +85,7 @@ DEFAULT_CONFIG = impala.ImpalaTrainer.merge_trainer_configs(
)
# __sphinx_doc_end__
# yapf: enable
# fmt: on
class UpdateTargetAndKL:

View file

@ -43,7 +43,7 @@ from ray.util.iter import LocalIterator
logger = logging.getLogger(__name__)
# yapf: disable
# fmt: off
# __sphinx_doc_begin__
# Adds the following updates to the `PPOTrainer` config in
@ -93,7 +93,7 @@ DEFAULT_CONFIG = Trainer.merge_trainer_configs(
)
# __sphinx_doc_end__
# yapf: enable
# fmt: on
class DDPPOTrainer(PPOTrainer):

View file

@ -34,7 +34,7 @@ from ray.util.iter import LocalIterator
logger = logging.getLogger(__name__)
# yapf: disable
# fmt: off
# __sphinx_doc_begin__
# Adds the following updates to the (base) `Trainer` config in
@ -101,7 +101,7 @@ DEFAULT_CONFIG = with_common_config({
})
# __sphinx_doc_end__
# yapf: enable
# fmt: on
class UpdateKL:

View file

@ -18,7 +18,7 @@ from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import TrainerConfigDict
from ray.util.iter import LocalIterator
# yapf: disable
# fmt: off
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
# === QMix ===
@ -107,7 +107,7 @@ DEFAULT_CONFIG = with_common_config({
"framework": "torch",
})
# __sphinx_doc_end__
# yapf: enable
# fmt: on
class QMixTrainer(SimpleQTrainer):

View file

@ -26,7 +26,7 @@ OPTIMIZER_SHARED_CONFIGS = [
"learning_starts",
]
# yapf: disable
# fmt: off
# __sphinx_doc_begin__
# Adds the following updates to the (base) `Trainer` config in
@ -173,7 +173,7 @@ DEFAULT_CONFIG = with_common_config({
"_use_beta_distribution": False,
})
# __sphinx_doc_end__
# yapf: enable
# fmt: on
class SACTrainer(DQNTrainer):

View file

@ -46,7 +46,7 @@ ALL_SLATEQ_STRATEGIES = [
"QL",
]
# yapf: disable
# fmt: off
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
# === Model ===
@ -144,7 +144,7 @@ DEFAULT_CONFIG = with_common_config({
"double_q": True,
})
# __sphinx_doc_end__
# yapf: enable
# fmt: on
def calculate_round_robin_weights(config: TrainerConfigDict) -> List[float]:

View file

@ -117,7 +117,7 @@ logger = logging.getLogger(__name__)
# times in a row since that would indicate a persistent cluster issue.
MAX_WORKER_FAILURE_RETRIES = 3
# yapf: disable
# fmt: off
# __sphinx_doc_begin__
COMMON_CONFIG: TrainerConfigDict = {
# === Settings for Rollout Worker processes ===
@ -650,7 +650,7 @@ COMMON_CONFIG: TrainerConfigDict = {
"collect_metrics_timeout": DEPRECATED_VALUE,
}
# __sphinx_doc_end__
# yapf: enable
# fmt: on
@DeveloperAPI

View file

@ -48,7 +48,7 @@ class AlphaZeroDefaultCallbacks(DefaultCallbacks):
episode.user_data["initial_state"] = state
# yapf: disable
# fmt: off
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
# Size of batches collected from each worker
@ -121,7 +121,7 @@ DEFAULT_CONFIG = with_common_config({
# __sphinx_doc_end__
# yapf: enable
# fmt: on
def alpha_zero_loss(policy, model, dist_class, train_batch):

View file

@ -25,7 +25,7 @@ from ray.rllib.utils.typing import TrainerConfigDict
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# yapf: disable
# fmt: off
# __sphinx_doc_begin__
DEFAULT_CONFIG = with_common_config({
# === Framework to run the algorithm ===
@ -123,7 +123,7 @@ DEFAULT_CONFIG = with_common_config({
"min_time_s_per_reporting": 0,
})
# __sphinx_doc_end__
# yapf: enable
# fmt: on
def before_learn_on_batch(multi_agent_batch, policies, train_batch_size):

View file

@ -5,7 +5,7 @@ from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import TrainerConfigDict
# yapf: disable
# fmt: off
# __sphinx_doc_begin__
class RandomAgent(Trainer):
"""Trainer that produces random actions and never learns."""

View file

@ -233,7 +233,7 @@ class MultiAgentEnv(gym.Env):
# By default, do nothing.
pass
# yapf: disable
# fmt: off
# __grouping_doc_begin__
@ExperimentalAPI
def with_agent_groups(
@ -279,7 +279,7 @@ class MultiAgentEnv(gym.Env):
return GroupAgentsWrapper(self, groups, obs_space, act_space)
# __grouping_doc_end__
# yapf: enable
# fmt: on
@PublicAPI
def to_base_env(

View file

@ -13,7 +13,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# yapf: disable
# fmt: off
# __sphinx_doc_begin__
class SampleCollector(metaclass=ABCMeta):
"""Collects samples for all policies and agents from a multi-agent env.

View file

@ -48,7 +48,7 @@ torch, _ = try_import_torch()
logger = logging.getLogger(__name__)
# yapf: disable
# fmt: off
# __sphinx_doc_begin__
MODEL_DEFAULTS: ModelConfigDict = {
# Experimental flag.
@ -188,7 +188,7 @@ MODEL_DEFAULTS: ModelConfigDict = {
"lstm_use_prev_action_reward": DEPRECATED_VALUE,
}
# __sphinx_doc_end__
# yapf: enable
# fmt: on
@PublicAPI

View file

@ -80,7 +80,7 @@ class Exploration:
"""
pass
# yapf: disable
# fmt: off
# __sphinx_doc_begin_get_exploration_action__
@DeveloperAPI
@ -112,7 +112,7 @@ class Exploration:
pass
# __sphinx_doc_end_get_exploration_action__
# yapf: enable
# fmt: on
@DeveloperAPI
def on_episode_start(