[Docs] Update Train user guide to use the new APIs (#26091)

This commit is contained in:
Antoni Baum 2022-07-12 00:10:10 +02:00 committed by GitHub
parent 2c5c0f6cee
commit 65ea710e30
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 415 additions and 524 deletions

View file

@ -29,6 +29,7 @@ tabulate
uvicorn==0.16.0
werkzeug
wandb
tensorflow
transformers
# Ray libraries

View file

@ -138,32 +138,13 @@ MOCK_MODULES = [
"scipy.stats",
"setproctitle",
"tensorflow_probability",
"tensorflow",
"tensorflow.contrib",
"tensorflow.contrib.all_reduce",
"transformers",
"transformers.modeling_utils",
"transformers.models",
"transformers.models.auto",
"transformers.pipelines",
"transformers.pipelines.table_question_answering",
"transformers.trainer",
"transformers.training_args",
"transformers.trainer_callback",
"transformers.utils",
"transformers.utils.logging",
"transformers.utils.versions",
"tree",
"tensorflow.contrib.all_reduce.python",
"tensorflow.contrib.layers",
"tensorflow.contrib.rnn",
"tensorflow.contrib.slim",
"tensorflow.core",
"tensorflow.core.util",
"tensorflow.keras.callbacks",
"tensorflow.python",
"tensorflow.python.client",
"tensorflow.python.util",
"tree",
"wandb",
"zoopt",
]
@ -188,8 +169,6 @@ def mock_modules():
sys.modules["ray._raylet"].ObjectRef = make_typing_mock("ray", "ObjectRef")
sys.modules["tensorflow"].VERSION = "9.9.9"
# Add doc files from external repositories to be downloaded during build here
# (repo, ref, path to get, path to save on disk)

View file

@ -126,6 +126,33 @@ Configs
.. autoclass:: ray.air.config.CheckpointConfig
.. _air-builtin-callbacks:
Callbacks
~~~~~~~~~
Comet
#####
.. autoclass:: ray.air.callbacks.comet.CometLoggerCallback
Keras
#####
.. autoclass:: ray.air.callbacks.keras.Callback
:members:
MLflow
######
.. autoclass:: ray.air.callbacks.mlflow.MLflowLoggerCallback
Weights and Biases
##################
.. autoclass:: ray.air.callbacks.wandb.WandbLoggerCallback
.. _air-session-ref:
Session

View file

@ -4,11 +4,13 @@
Ray Train API
=============
.. _train-api-trainer:
Trainer
-------
.. warning::
This Trainer API is deprecated and no longer supported. For an overview of the new :ref:`air` Trainer API,
see :ref:`air-trainer-ref`.
.. autoclass:: ray.train.Trainer
:members:

File diff suppressed because it is too large Load diff

View file

@ -8,6 +8,14 @@ External library integrations (tune.integration)
:depth: 1
Comet (tune.integration.comet)
-------------------------------------------
:ref:`See also here <tune-comet-ref>`.
.. autoclass:: ray.air.callbacks.comet.CometLoggerCallback
:noindex:
.. _tune-integration-keras:
Keras (tune.integration.keras)
@ -26,6 +34,7 @@ MLflow (tune.integration.mlflow)
:ref:`See also here <tune-mlflow-ref>`.
.. autoclass:: ray.air.callbacks.mlflow.MLflowLoggerCallback
:noindex:
.. autofunction:: ray.tune.integration.mlflow.mlflow_mixin
@ -57,6 +66,7 @@ Weights and Biases (tune.integration.wandb)
:ref:`See also here <tune-wandb-ref>`.
.. autoclass:: ray.air.callbacks.wandb.WandbLoggerCallback
:noindex:
.. autofunction:: ray.tune.integration.wandb.wandb_mixin

View file

@ -114,15 +114,25 @@ class _Callback(KerasCallback):
@PublicAPI(stability="beta")
class Callback(_Callback):
def __init__(
self,
metrics: Optional[Union[str, List[str], Dict[str, str]]] = None,
on: Union[str, List[str]] = "epoch_end",
frequency: Union[int, List[int]] = 1,
):
"""
Args:
metrics: Metrics to report. If this is a list, each item describes
"""
Keras callback for Ray AIR reporting and checkpointing.
You can use this in both TuneSession and TrainSession.
Example:
.. code-block: python
############# Using it in TrainSession ###############
from ray.air.callbacks.keras import Callback
def train_loop_per_worker():
strategy = tf.distribute.MultiWorkerMirroredStrategy()
with strategy.scope():
model = build_model()
#model.compile(...)
model.fit(dataset_shard, callbacks=[Callback()])
Args:
metrics: Metrics to report. If this is a list, each item describes
the metric key reported to Keras, and it will reported under the
same name. If this is a dict, each key will be the name reported
and the respective value will be the metric key reported to Keras.
@ -135,20 +145,14 @@ class Callback(_Callback):
this is a list, it specifies the checkpoint frequencies for each
hook individually.
You can use this in both TuneSession and TrainSession.
"""
Example:
.. code-block: python
############# Using it in TrainSession ###############
from ray.air.callbacks.keras import Callback
def train_loop_per_worker():
strategy = tf.distribute.MultiWorkerMirroredStrategy()
with strategy.scope():
model = build_model()
#model.compile(...)
model.fit(dataset_shard, callbacks=[Callback()])
"""
def __init__(
self,
metrics: Optional[Union[str, List[str], Dict[str, str]]] = None,
on: Union[str, List[str]] = "epoch_end",
frequency: Union[int, List[int]] = 1,
):
if isinstance(frequency, list):
if not isinstance(on, list) or len(frequency) != len(on):
raise ValueError(