mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[Docs] Update Train user guide to use the new APIs (#26091)
This commit is contained in:
parent
2c5c0f6cee
commit
65ea710e30
7 changed files with 415 additions and 524 deletions
|
@ -29,6 +29,7 @@ tabulate
|
|||
uvicorn==0.16.0
|
||||
werkzeug
|
||||
wandb
|
||||
tensorflow
|
||||
transformers
|
||||
|
||||
# Ray libraries
|
||||
|
|
|
@ -138,32 +138,13 @@ MOCK_MODULES = [
|
|||
"scipy.stats",
|
||||
"setproctitle",
|
||||
"tensorflow_probability",
|
||||
"tensorflow",
|
||||
"tensorflow.contrib",
|
||||
"tensorflow.contrib.all_reduce",
|
||||
"transformers",
|
||||
"transformers.modeling_utils",
|
||||
"transformers.models",
|
||||
"transformers.models.auto",
|
||||
"transformers.pipelines",
|
||||
"transformers.pipelines.table_question_answering",
|
||||
"transformers.trainer",
|
||||
"transformers.training_args",
|
||||
"transformers.trainer_callback",
|
||||
"transformers.utils",
|
||||
"transformers.utils.logging",
|
||||
"transformers.utils.versions",
|
||||
"tree",
|
||||
"tensorflow.contrib.all_reduce.python",
|
||||
"tensorflow.contrib.layers",
|
||||
"tensorflow.contrib.rnn",
|
||||
"tensorflow.contrib.slim",
|
||||
"tensorflow.core",
|
||||
"tensorflow.core.util",
|
||||
"tensorflow.keras.callbacks",
|
||||
"tensorflow.python",
|
||||
"tensorflow.python.client",
|
||||
"tensorflow.python.util",
|
||||
"tree",
|
||||
"wandb",
|
||||
"zoopt",
|
||||
]
|
||||
|
@ -188,8 +169,6 @@ def mock_modules():
|
|||
|
||||
sys.modules["ray._raylet"].ObjectRef = make_typing_mock("ray", "ObjectRef")
|
||||
|
||||
sys.modules["tensorflow"].VERSION = "9.9.9"
|
||||
|
||||
|
||||
# Add doc files from external repositories to be downloaded during build here
|
||||
# (repo, ref, path to get, path to save on disk)
|
||||
|
|
|
@ -126,6 +126,33 @@ Configs
|
|||
|
||||
.. autoclass:: ray.air.config.CheckpointConfig
|
||||
|
||||
|
||||
.. _air-builtin-callbacks:
|
||||
|
||||
Callbacks
|
||||
~~~~~~~~~
|
||||
|
||||
Comet
|
||||
#####
|
||||
|
||||
.. autoclass:: ray.air.callbacks.comet.CometLoggerCallback
|
||||
|
||||
Keras
|
||||
#####
|
||||
|
||||
.. autoclass:: ray.air.callbacks.keras.Callback
|
||||
:members:
|
||||
|
||||
MLflow
|
||||
######
|
||||
|
||||
.. autoclass:: ray.air.callbacks.mlflow.MLflowLoggerCallback
|
||||
|
||||
Weights and Biases
|
||||
##################
|
||||
|
||||
.. autoclass:: ray.air.callbacks.wandb.WandbLoggerCallback
|
||||
|
||||
.. _air-session-ref:
|
||||
|
||||
Session
|
||||
|
|
|
@ -4,11 +4,13 @@
|
|||
Ray Train API
|
||||
=============
|
||||
|
||||
.. _train-api-trainer:
|
||||
|
||||
Trainer
|
||||
-------
|
||||
|
||||
.. warning::
|
||||
This Trainer API is deprecated and no longer supported. For an overview of the new :ref:`air` Trainer API,
|
||||
see :ref:`air-trainer-ref`.
|
||||
|
||||
.. autoclass:: ray.train.Trainer
|
||||
:members:
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -8,6 +8,14 @@ External library integrations (tune.integration)
|
|||
:depth: 1
|
||||
|
||||
|
||||
Comet (tune.integration.comet)
|
||||
-------------------------------------------
|
||||
|
||||
:ref:`See also here <tune-comet-ref>`.
|
||||
|
||||
.. autoclass:: ray.air.callbacks.comet.CometLoggerCallback
|
||||
:noindex:
|
||||
|
||||
.. _tune-integration-keras:
|
||||
|
||||
Keras (tune.integration.keras)
|
||||
|
@ -26,6 +34,7 @@ MLflow (tune.integration.mlflow)
|
|||
:ref:`See also here <tune-mlflow-ref>`.
|
||||
|
||||
.. autoclass:: ray.air.callbacks.mlflow.MLflowLoggerCallback
|
||||
:noindex:
|
||||
|
||||
.. autofunction:: ray.tune.integration.mlflow.mlflow_mixin
|
||||
|
||||
|
@ -57,6 +66,7 @@ Weights and Biases (tune.integration.wandb)
|
|||
:ref:`See also here <tune-wandb-ref>`.
|
||||
|
||||
.. autoclass:: ray.air.callbacks.wandb.WandbLoggerCallback
|
||||
:noindex:
|
||||
|
||||
.. autofunction:: ray.tune.integration.wandb.wandb_mixin
|
||||
|
||||
|
|
|
@ -114,15 +114,25 @@ class _Callback(KerasCallback):
|
|||
|
||||
@PublicAPI(stability="beta")
|
||||
class Callback(_Callback):
|
||||
def __init__(
|
||||
self,
|
||||
metrics: Optional[Union[str, List[str], Dict[str, str]]] = None,
|
||||
on: Union[str, List[str]] = "epoch_end",
|
||||
frequency: Union[int, List[int]] = 1,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
metrics: Metrics to report. If this is a list, each item describes
|
||||
"""
|
||||
Keras callback for Ray AIR reporting and checkpointing.
|
||||
|
||||
You can use this in both TuneSession and TrainSession.
|
||||
|
||||
Example:
|
||||
.. code-block: python
|
||||
|
||||
############# Using it in TrainSession ###############
|
||||
from ray.air.callbacks.keras import Callback
|
||||
def train_loop_per_worker():
|
||||
strategy = tf.distribute.MultiWorkerMirroredStrategy()
|
||||
with strategy.scope():
|
||||
model = build_model()
|
||||
#model.compile(...)
|
||||
model.fit(dataset_shard, callbacks=[Callback()])
|
||||
|
||||
Args:
|
||||
metrics: Metrics to report. If this is a list, each item describes
|
||||
the metric key reported to Keras, and it will reported under the
|
||||
same name. If this is a dict, each key will be the name reported
|
||||
and the respective value will be the metric key reported to Keras.
|
||||
|
@ -135,20 +145,14 @@ class Callback(_Callback):
|
|||
this is a list, it specifies the checkpoint frequencies for each
|
||||
hook individually.
|
||||
|
||||
You can use this in both TuneSession and TrainSession.
|
||||
"""
|
||||
|
||||
Example:
|
||||
.. code-block: python
|
||||
|
||||
############# Using it in TrainSession ###############
|
||||
from ray.air.callbacks.keras import Callback
|
||||
def train_loop_per_worker():
|
||||
strategy = tf.distribute.MultiWorkerMirroredStrategy()
|
||||
with strategy.scope():
|
||||
model = build_model()
|
||||
#model.compile(...)
|
||||
model.fit(dataset_shard, callbacks=[Callback()])
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
metrics: Optional[Union[str, List[str], Dict[str, str]]] = None,
|
||||
on: Union[str, List[str]] = "epoch_end",
|
||||
frequency: Union[int, List[int]] = 1,
|
||||
):
|
||||
if isinstance(frequency, list):
|
||||
if not isinstance(on, list) or len(frequency) != len(on):
|
||||
raise ValueError(
|
||||
|
|
Loading…
Add table
Reference in a new issue