mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[CI] Replace YAPF disables with Black disables (#21982)
This commit is contained in:
parent
dcd96ca348
commit
31ed9e5d02
55 changed files with 114 additions and 114 deletions
|
@ -1,5 +1,5 @@
|
|||
# flake8: noqa
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
|
||||
# __data_setup_begin__
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ in the documentation.
|
|||
"""
|
||||
import ray
|
||||
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
|
||||
# __runtime_env_conda_def_start__
|
||||
runtime_env = {
|
||||
|
|
|
@ -7,7 +7,7 @@ but we put comments right after code blocks to prevent large white spaces
|
|||
in the documentation.
|
||||
"""
|
||||
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __tf_model_start__
|
||||
|
||||
|
||||
|
@ -28,9 +28,9 @@ def create_keras_model():
|
|||
metrics=[keras.metrics.categorical_accuracy])
|
||||
return model
|
||||
# __tf_model_end__
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __ray_start__
|
||||
import ray
|
||||
import numpy as np
|
||||
|
@ -65,17 +65,17 @@ class Network(object):
|
|||
# Note that for simplicity this does not handle the optimizer state.
|
||||
self.model.set_weights(weights)
|
||||
# __ray_end__
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __actor_start__
|
||||
NetworkActor = Network.remote()
|
||||
result_object_ref = NetworkActor.train.remote()
|
||||
ray.get(result_object_ref)
|
||||
# __actor_end__
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __weight_average_start__
|
||||
NetworkActor2 = Network.remote()
|
||||
NetworkActor2.train.remote()
|
||||
|
|
|
@ -6,7 +6,7 @@ It ignores yapf because yapf doesn't allow comments right after code blocks,
|
|||
but we put comments right after code blocks to prevent large white spaces
|
||||
in the documentation.
|
||||
"""
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __torch_model_start__
|
||||
import argparse
|
||||
|
||||
|
@ -35,9 +35,9 @@ class Model(nn.Module):
|
|||
|
||||
|
||||
# __torch_model_end__
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __torch_helper_start__
|
||||
from filelock import FileLock
|
||||
from torchvision import datasets, transforms
|
||||
|
@ -112,9 +112,9 @@ def dataset_creator(use_cuda, data_dir):
|
|||
|
||||
|
||||
# __torch_helper_end__
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __torch_net_start__
|
||||
import torch.optim as optim
|
||||
|
||||
|
@ -155,9 +155,9 @@ args = parser.parse_args()
|
|||
net = Network(data_dir=args.data_dir)
|
||||
net.train()
|
||||
# __torch_net_end__
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __torch_ray_start__
|
||||
import ray
|
||||
|
||||
|
@ -167,18 +167,18 @@ RemoteNetwork = ray.remote(Network)
|
|||
# Use the below instead of `ray.remote(network)` to leverage the GPU.
|
||||
# RemoteNetwork = ray.remote(num_gpus=1)(Network)
|
||||
# __torch_ray_end__
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __torch_actor_start__
|
||||
NetworkActor = RemoteNetwork.remote()
|
||||
NetworkActor2 = RemoteNetwork.remote()
|
||||
|
||||
ray.get([NetworkActor.train.remote(), NetworkActor2.train.remote()])
|
||||
# __torch_actor_end__
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __weight_average_start__
|
||||
weights = ray.get(
|
||||
[NetworkActor.get_weights.remote(),
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# flake8: noqa
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
|
||||
# __serve_example_begin__
|
||||
import requests
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# yapf: disable
|
||||
# fmt: off
|
||||
# __doc_import_begin__
|
||||
from typing import List
|
||||
import time
|
||||
|
@ -10,7 +10,7 @@ from starlette.requests import Request
|
|||
import ray
|
||||
from ray import serve
|
||||
# __doc_import_end__
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
||||
|
||||
# __doc_define_servable_begin__
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# yapf: disable
|
||||
# fmt: off
|
||||
import ray
|
||||
# __doc_import_begin__
|
||||
from ray import serve
|
||||
|
@ -11,7 +11,7 @@ import torch
|
|||
from torchvision import transforms
|
||||
from torchvision.models import resnet18
|
||||
# __doc_import_end__
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
||||
|
||||
# __doc_define_servable_begin__
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# yapf: disable
|
||||
# fmt: off
|
||||
import ray
|
||||
# __doc_import_begin__
|
||||
from ray import serve
|
||||
|
@ -14,7 +14,7 @@ from sklearn.datasets import load_iris
|
|||
from sklearn.ensemble import GradientBoostingClassifier
|
||||
from sklearn.metrics import mean_squared_error
|
||||
# __doc_import_end__
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
||||
# __doc_train_model_begin__
|
||||
# Load data
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# yapf: disable
|
||||
# fmt: off
|
||||
import ray
|
||||
# __doc_import_begin__
|
||||
from ray import serve
|
||||
|
@ -8,7 +8,7 @@ import tempfile
|
|||
import numpy as np
|
||||
import requests
|
||||
# __doc_import_end__
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
||||
# __doc_train_model_begin__
|
||||
TRAINED_MODEL_PATH = os.path.join(tempfile.gettempdir(), "mnist_model.h5")
|
||||
|
|
|
@ -289,7 +289,7 @@ def test_add_min_workers_nodes():
|
|||
# Formatting is disabled to prevent Black from erroring while formatting
|
||||
# this file. See https://github.com/ray-project/ray/issues/21313 for more
|
||||
# information.
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
assert _add_min_workers_nodes([],
|
||||
{},
|
||||
types, None, None, None) == \
|
||||
|
@ -336,7 +336,7 @@ def test_add_min_workers_nodes():
|
|||
}, {
|
||||
"gpubla": 10
|
||||
})
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
||||
|
||||
def test_get_nodes_to_launch_with_min_workers():
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# flake8: noqa
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
|
||||
# __tf_setup_begin__
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# flake8: noqa
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
|
||||
# __torch_setup_begin__
|
||||
import torch
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# flake8: noqa
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
|
||||
# __import_begin__
|
||||
from functools import partial
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# flake8: noqa
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
|
||||
# __import_lightning_begin__
|
||||
import math
|
||||
|
|
|
@ -28,7 +28,7 @@ parser.add_argument(
|
|||
|
||||
|
||||
# Below comments are for documentation purposes only.
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __trainable_example_begin__
|
||||
class TrainMNIST(tune.Trainable):
|
||||
def setup(self, config):
|
||||
|
@ -57,7 +57,7 @@ class TrainMNIST(tune.Trainable):
|
|||
|
||||
|
||||
# __trainable_example_end__
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# flake8: noqa
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
|
||||
# __tutorial_imports_begin__
|
||||
import argparse
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import os
|
||||
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
# (Optional/Auto-filled) training is terminated. Filled only if not provided.
|
||||
DONE = "done"
|
||||
|
@ -60,7 +60,7 @@ TIME_TOTAL_S = "time_total_s"
|
|||
# (Auto-filled) The index of this training iteration.
|
||||
TRAINING_ITERATION = "training_iteration"
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
||||
DEFAULT_EXPERIMENT_INFO_KEYS = ("trainable_name", EXPERIMENT_TAG, TRIAL_ID)
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# flake8: noqa
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
|
||||
# External PyTorch tutorial (https://github.com/pytorch/tutorials/pull/1066)
|
||||
# If this script fails, fix it and submit a PR to pytorch/tutorials.
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# flake8: noqa
|
||||
# Original Code: https://github.com/pytorch/examples/blob/master/mnist/main.py
|
||||
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __tutorial_imports_begin__
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -14,10 +14,10 @@ import torch.nn.functional as F
|
|||
from ray import tune
|
||||
from ray.tune.schedulers import ASHAScheduler
|
||||
# __tutorial_imports_end__
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
||||
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __model_def_begin__
|
||||
class ConvNet(nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -33,9 +33,9 @@ class ConvNet(nn.Module):
|
|||
x = self.fc(x)
|
||||
return F.log_softmax(x, dim=1)
|
||||
# __model_def_end__
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __train_def_begin__
|
||||
|
||||
# Change these values if you want the training to run quicker or slower.
|
||||
|
@ -111,7 +111,7 @@ def train_mnist(config):
|
|||
# This saves the model to the trial directory
|
||||
torch.save(model.state_dict(), "./model.pth")
|
||||
# __train_func_end__
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
||||
# __eval_func_begin__
|
||||
search_space = {
|
||||
|
@ -145,14 +145,14 @@ analysis = tune.run(
|
|||
dfs = analysis.trial_dataframes
|
||||
# __run_scheduler_end__
|
||||
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __plot_scheduler_begin__
|
||||
# Plot by epoch
|
||||
ax = None # This plots everything on the same plot
|
||||
for d in dfs.values():
|
||||
ax = d.mean_accuracy.plot(ax=ax, legend=False)
|
||||
# __plot_scheduler_end__
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
||||
# __run_searchalg_begin__
|
||||
from hyperopt import hp
|
||||
|
|
|
@ -5,7 +5,7 @@ It ignores yapf because yapf doesn't allow comments right after code blocks,
|
|||
but we put comments right after code blocks to prevent large white spaces
|
||||
in the documentation.
|
||||
"""
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
|
||||
# __torch_operator_start__
|
||||
import torch
|
||||
|
|
|
@ -5,7 +5,7 @@ but we put comments right after code blocks to prevent large white spaces
|
|||
in the documentation.
|
||||
"""
|
||||
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __torch_train_example__
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# yapf: disable
|
||||
# fmt: off
|
||||
"""
|
||||
This file holds code for a Distributed Pytorch + Tune page in the docs.
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# yapf: disable
|
||||
# fmt: off
|
||||
from typing import Any, Callable, Generic, Optional, TypeVar, Union, overload, Sequence, List
|
||||
|
||||
from ray._raylet import ObjectRef
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# yapf: disable
|
||||
# fmt: off
|
||||
from typing import Callable, Generic, Optional, TypeVar, Union, overload, Any
|
||||
from types import FunctionType
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ from ray.util.iter import LocalIterator
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
# Should use a critic as a baseline (otherwise don't use value baseline;
|
||||
|
@ -68,7 +68,7 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"_disable_execution_plan_api": True,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
||||
|
||||
class A3CTrainer(Trainer):
|
||||
|
|
|
@ -35,7 +35,7 @@ Result = namedtuple(
|
|||
],
|
||||
)
|
||||
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
"action_noise_std": 0.0,
|
||||
|
@ -59,7 +59,7 @@ DEFAULT_CONFIG = with_common_config({
|
|||
},
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
||||
|
||||
@ray.remote
|
||||
|
|
|
@ -9,7 +9,7 @@ from ray.rllib.utils.typing import TrainerConfigDict
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
# No remote workers by default.
|
||||
|
@ -26,7 +26,7 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"timesteps_per_iteration": 100,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
||||
|
||||
class BanditLinTSTrainer(Trainer):
|
||||
|
|
|
@ -26,7 +26,7 @@ tf1, tf, tfv = try_import_tf()
|
|||
tfp = try_import_tfp()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
CQL_DEFAULT_CONFIG = merge_dicts(
|
||||
SAC_CONFIG, {
|
||||
|
@ -55,7 +55,7 @@ CQL_DEFAULT_CONFIG = merge_dicts(
|
|||
},
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
||||
|
||||
class CQLTrainer(SACTrainer):
|
||||
|
|
|
@ -11,7 +11,7 @@ from ray.rllib.utils.typing import TrainerConfigDict
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
# === Twin Delayed DDPG (TD3) and Soft Actor-Critic (SAC) tricks ===
|
||||
|
@ -175,7 +175,7 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"min_time_s_per_reporting": 1,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
||||
|
||||
class DDPGTrainer(SimpleQTrainer):
|
||||
|
|
|
@ -47,7 +47,7 @@ from ray.tune.trainable import Trainable
|
|||
from ray.tune.utils.placement_groups import PlacementGroupFactory
|
||||
from ray.util.iter import LocalIterator
|
||||
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
APEX_DEFAULT_CONFIG = merge_dicts(
|
||||
# See also the options in dqn.py, which are also supported.
|
||||
|
@ -92,7 +92,7 @@ APEX_DEFAULT_CONFIG = merge_dicts(
|
|||
},
|
||||
)
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
||||
|
||||
# Update worker weights as they finish generating experiences.
|
||||
|
|
|
@ -38,7 +38,7 @@ from ray.util.iter import LocalIterator
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
DEFAULT_CONFIG = Trainer.merge_trainer_configs(
|
||||
SIMPLEQ_DEFAULT_CONFIG,
|
||||
|
@ -106,7 +106,7 @@ DEFAULT_CONFIG = Trainer.merge_trainer_configs(
|
|||
_allow_unknown_configs=True,
|
||||
)
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
||||
|
||||
def calculate_rr_weights(config: TrainerConfigDict) -> List[float]:
|
||||
|
|
|
@ -11,7 +11,7 @@ from ray.rllib.utils.typing import TrainerConfigDict
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
R2D2_DEFAULT_CONFIG = Trainer.merge_trainer_configs(
|
||||
DQN_DEFAULT_CONFIG, # See keys in impala.py, which are also supported.
|
||||
|
@ -70,7 +70,7 @@ R2D2_DEFAULT_CONFIG = Trainer.merge_trainer_configs(
|
|||
_allow_unknown_configs=True,
|
||||
)
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
||||
|
||||
# Build an R2D2 trainer, which uses the framework specific Policy
|
||||
|
|
|
@ -31,7 +31,7 @@ from ray.rllib.utils.typing import TrainerConfigDict
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
# === Exploration Settings ===
|
||||
|
@ -110,7 +110,7 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"min_time_s_per_reporting": 1,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
||||
|
||||
class SimpleQTrainer(Trainer):
|
||||
|
|
|
@ -17,7 +17,7 @@ from ray.rllib.utils.typing import SampleBatchType, TrainerConfigDict
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
# PlaNET Model LR
|
||||
|
@ -78,7 +78,7 @@ DEFAULT_CONFIG = with_common_config({
|
|||
}
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
||||
|
||||
class EpisodicBuffer(object):
|
||||
|
|
|
@ -33,7 +33,7 @@ Result = namedtuple(
|
|||
],
|
||||
)
|
||||
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
"action_noise_std": 0.01,
|
||||
|
@ -58,7 +58,7 @@ DEFAULT_CONFIG = with_common_config({
|
|||
},
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
||||
|
||||
@ray.remote
|
||||
|
|
|
@ -25,7 +25,7 @@ from ray.tune.utils.placement_groups import PlacementGroupFactory
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
# V-trace params (see vtrace_tf/torch.py).
|
||||
|
@ -127,7 +127,7 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"num_data_loader_buffers": DEPRECATED_VALUE,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
||||
|
||||
def make_learner_thread(local_worker, config):
|
||||
|
|
|
@ -27,7 +27,7 @@ from ray.util.iter import from_actors, LocalIterator
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
# If true, use the Generalized Advantage Estimator (GAE)
|
||||
|
@ -80,7 +80,7 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"vf_share_layers": DEPRECATED_VALUE,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
||||
|
||||
# @mluo: TODO
|
||||
|
|
|
@ -5,7 +5,7 @@ from ray.rllib.agents.marwil.marwil import (
|
|||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
BC_DEFAULT_CONFIG = MARWILTrainer.merge_trainer_configs(
|
||||
MARWIL_CONFIG, {
|
||||
|
@ -19,7 +19,7 @@ BC_DEFAULT_CONFIG = MARWILTrainer.merge_trainer_configs(
|
|||
"input_evaluation": [],
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
||||
|
||||
class BCTrainer(MARWILTrainer):
|
||||
|
|
|
@ -14,7 +14,7 @@ from ray.rllib.utils.annotations import override
|
|||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
from ray.util.iter import LocalIterator
|
||||
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
# === Input settings ===
|
||||
|
@ -73,7 +73,7 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"num_workers": 0,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
||||
|
||||
class MARWILTrainer(Trainer):
|
||||
|
|
|
@ -35,7 +35,7 @@ from ray.util.iter import from_actors, LocalIterator
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
|
||||
# Adds the following updates to the (base) `Trainer` config in
|
||||
|
@ -115,7 +115,7 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"vf_share_layers": DEPRECATED_VALUE,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
||||
# Select Metric Keys for MAML Stats Tracing
|
||||
METRICS_KEYS = ["episode_reward_mean", "episode_reward_min", "episode_reward_max"]
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from ray.rllib.agents.trainer import with_common_config
|
||||
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
|
||||
# Add the following (PG-specific) updates to the (base) `Trainer` config in
|
||||
|
@ -20,4 +20,4 @@ DEFAULT_CONFIG = with_common_config({
|
|||
})
|
||||
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
|
|
@ -25,7 +25,7 @@ from ray.rllib.execution.common import (
|
|||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.typing import PartialTrainerConfigDict, TrainerConfigDict
|
||||
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
|
||||
# Adds the following updates to the `IMPALATrainer` config in
|
||||
|
@ -85,7 +85,7 @@ DEFAULT_CONFIG = impala.ImpalaTrainer.merge_trainer_configs(
|
|||
)
|
||||
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
||||
|
||||
class UpdateTargetAndKL:
|
||||
|
|
|
@ -43,7 +43,7 @@ from ray.util.iter import LocalIterator
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
|
||||
# Adds the following updates to the `PPOTrainer` config in
|
||||
|
@ -93,7 +93,7 @@ DEFAULT_CONFIG = Trainer.merge_trainer_configs(
|
|||
)
|
||||
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
||||
|
||||
class DDPPOTrainer(PPOTrainer):
|
||||
|
|
|
@ -34,7 +34,7 @@ from ray.util.iter import LocalIterator
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
|
||||
# Adds the following updates to the (base) `Trainer` config in
|
||||
|
@ -101,7 +101,7 @@ DEFAULT_CONFIG = with_common_config({
|
|||
})
|
||||
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
||||
|
||||
class UpdateKL:
|
||||
|
|
|
@ -18,7 +18,7 @@ from ray.rllib.utils.annotations import override
|
|||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
from ray.util.iter import LocalIterator
|
||||
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
# === QMix ===
|
||||
|
@ -107,7 +107,7 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"framework": "torch",
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
||||
|
||||
class QMixTrainer(SimpleQTrainer):
|
||||
|
|
|
@ -26,7 +26,7 @@ OPTIMIZER_SHARED_CONFIGS = [
|
|||
"learning_starts",
|
||||
]
|
||||
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
|
||||
# Adds the following updates to the (base) `Trainer` config in
|
||||
|
@ -173,7 +173,7 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"_use_beta_distribution": False,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
||||
|
||||
class SACTrainer(DQNTrainer):
|
||||
|
|
|
@ -46,7 +46,7 @@ ALL_SLATEQ_STRATEGIES = [
|
|||
"QL",
|
||||
]
|
||||
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
# === Model ===
|
||||
|
@ -144,7 +144,7 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"double_q": True,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
||||
|
||||
def calculate_round_robin_weights(config: TrainerConfigDict) -> List[float]:
|
||||
|
|
|
@ -117,7 +117,7 @@ logger = logging.getLogger(__name__)
|
|||
# times in a row since that would indicate a persistent cluster issue.
|
||||
MAX_WORKER_FAILURE_RETRIES = 3
|
||||
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
COMMON_CONFIG: TrainerConfigDict = {
|
||||
# === Settings for Rollout Worker processes ===
|
||||
|
@ -650,7 +650,7 @@ COMMON_CONFIG: TrainerConfigDict = {
|
|||
"collect_metrics_timeout": DEPRECATED_VALUE,
|
||||
}
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
|
|
|
@ -48,7 +48,7 @@ class AlphaZeroDefaultCallbacks(DefaultCallbacks):
|
|||
episode.user_data["initial_state"] = state
|
||||
|
||||
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
# Size of batches collected from each worker
|
||||
|
@ -121,7 +121,7 @@ DEFAULT_CONFIG = with_common_config({
|
|||
|
||||
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
||||
|
||||
def alpha_zero_loss(policy, model, dist_class, train_batch):
|
||||
|
|
|
@ -25,7 +25,7 @@ from ray.rllib.utils.typing import TrainerConfigDict
|
|||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
# === Framework to run the algorithm ===
|
||||
|
@ -123,7 +123,7 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"min_time_s_per_reporting": 0,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
||||
|
||||
def before_learn_on_batch(multi_agent_batch, policies, train_batch_size):
|
||||
|
|
|
@ -5,7 +5,7 @@ from ray.rllib.utils.annotations import override
|
|||
from ray.rllib.utils.typing import TrainerConfigDict
|
||||
|
||||
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
class RandomAgent(Trainer):
|
||||
"""Trainer that produces random actions and never learns."""
|
||||
|
|
4
rllib/env/multi_agent_env.py
vendored
4
rllib/env/multi_agent_env.py
vendored
|
@ -233,7 +233,7 @@ class MultiAgentEnv(gym.Env):
|
|||
# By default, do nothing.
|
||||
pass
|
||||
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __grouping_doc_begin__
|
||||
@ExperimentalAPI
|
||||
def with_agent_groups(
|
||||
|
@ -279,7 +279,7 @@ class MultiAgentEnv(gym.Env):
|
|||
return GroupAgentsWrapper(self, groups, obs_space, act_space)
|
||||
|
||||
# __grouping_doc_end__
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
||||
@PublicAPI
|
||||
def to_base_env(
|
||||
|
|
|
@ -13,7 +13,7 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
class SampleCollector(metaclass=ABCMeta):
|
||||
"""Collects samples for all policies and agents from a multi-agent env.
|
||||
|
|
|
@ -48,7 +48,7 @@ torch, _ = try_import_torch()
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin__
|
||||
MODEL_DEFAULTS: ModelConfigDict = {
|
||||
# Experimental flag.
|
||||
|
@ -188,7 +188,7 @@ MODEL_DEFAULTS: ModelConfigDict = {
|
|||
"lstm_use_prev_action_reward": DEPRECATED_VALUE,
|
||||
}
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
||||
|
||||
@PublicAPI
|
||||
|
|
|
@ -80,7 +80,7 @@ class Exploration:
|
|||
"""
|
||||
pass
|
||||
|
||||
# yapf: disable
|
||||
# fmt: off
|
||||
# __sphinx_doc_begin_get_exploration_action__
|
||||
|
||||
@DeveloperAPI
|
||||
|
@ -112,7 +112,7 @@ class Exploration:
|
|||
pass
|
||||
|
||||
# __sphinx_doc_end_get_exploration_action__
|
||||
# yapf: enable
|
||||
# fmt: on
|
||||
|
||||
@DeveloperAPI
|
||||
def on_episode_start(
|
||||
|
|
Loading…
Add table
Reference in a new issue