Training can be done with either a **Class API** (``tune.Trainable``) < or **function-based API** (``track.log``).
You can use the **function-based API** for fast prototyping. On the other hand, the ``tune.Trainable`` interface supports checkpoint/restore functionality and provides more control for advanced algorithms.
Function-based API
------------------
..code-block:: python
def trainable(config):
"""
Args:
config (dict): Parameters provided from the search algorithm
or variant generation.
"""
while True:
# ...
tune.track.log(**kwargs)
..tip:: Do not use ``tune.track.log`` within a ``Trainable`` class.
Tune will run this function on a separate thread in a Ray actor process. Note that this API is not checkpointable, since the thread will never return control back to its caller.
..note:: If you have a lambda function that you want to train, you will need to first register the function: ``tune.register_trainable("lambda_id", lambda x: ...)``. You can then use ``lambda_id`` in place of ``my_trainable``.
Trainable API
-------------
..caution:: Do not use ``tune.track.log`` within a ``Trainable`` class.
The Trainable **class API** will require users to subclass ``ray.tune.Trainable``. Here's a naive example of this API:
..code-block:: python
from ray import tune
class Guesser(tune.Trainable):
"""Randomly picks 10 number from [1, 10000) to find the password."""
As a subclass of ``tune.Trainable``, Tune will create a ``Guesser`` object on a separate process (using the Ray Actor API).
1.``_setup`` function is invoked once training starts.
2.``_train`` is invoked **multiple times**. Each time, the Guesser object executes one logical iteration of training in the tuning process, which may include one or more iterations of actual training.
3.``_stop`` is invoked when training is finished.
..tip:: As a rule of thumb, the execution time of ``_train`` should be large enough to avoid overheads (i.e. more than a few seconds), but short enough to report progress periodically (i.e. at most a few minutes).
In this example, we only implemented the ``_setup`` and ``_train`` methods for simplification. Next, we'll implement ``_save`` and ``_restore`` for checkpoint and fault tolerance.
Save and Restore
~~~~~~~~~~~~~~~~
Many Tune features rely on ``_save``, and ``_restore``, including the usage of certain Trial Schedulers, fault tolerance, and checkpointing.
Checkpoints will be saved by training iteration to ``local_dir/exp_name/trial_name/checkpoint_<iter>``. You can restore a single trial checkpoint by using ``tune.run(restore=<checkpoint_dir>)``.
Tune also generates temporary checkpoints for pausing and switching between trials. For this purpose, it is important not to depend on absolute paths in the implementation of ``save``.
Use ``validate_save_restore`` to catch ``_save``/``_restore`` errors before execution.
Your Trainable can often take a long time to start. To avoid this, you can do ``tune.run(reuse_actors=True)`` to reuse the same Trainable Python process and object for multiple hyperparameters.
This requires you to implement ``Trainable.reset_config``, which provides a new set of hyperparameters. It is up to the user to correctly update the hyperparameters of your trainable.