mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[air] Consolidate Tune and Train report (#25558)
Consolidate tune/train report/checkpoint functionality by working with a unified Session interface. The goal of this PR is to establish a solid Session and Session.report path. In favor of having less merging conflict (as other folks are doing the whole package renaming) and control the scope of this PR, I have intentionally left out some migration. More PRs to follow. Feel free to comment on the ideal final state. To give an idea of the final directory structure. This is a for 2-worker DP training. ``` ├── TensorflowTrainer_ce44d_00000_0_2022-06-15_14-40-42 │ ├── checkpoint_000000 │ │ ├── _current_checkpoint_id.meta.pkl │ │ ├── _preprocessor.meta.pkl │ │ ├── _timestamp.meta.pkl │ │ ├── assets │ │ ├── keras_metadata.pb │ │ ├── saved_model.pb │ │ └── variables │ │ ├── variables.data-00000-of-00001 │ │ └── variables.index │ ├── events.out.tfevents.1655329242.xw │ ├── params.json │ ├── params.pkl │ ├── progress.csv │ ├── rank_0 │ │ └── my_model │ │ ├── assets │ │ ├── keras_metadata.pb │ │ ├── saved_model.pb │ │ └── variables │ │ ├── variables.data-00000-of-00001 │ │ └── variables.index │ ├── rank_1 │ │ └── my_model │ │ ├── assets │ │ ├── keras_metadata.pb │ │ ├── saved_model.pb │ │ └── variables │ │ ├── variables.data-00000-of-00001 │ │ └── variables.index │ └── result.json ├── basic-variant-state-2022-06-15_14-40-42.json ├── experiment_state-2022-06-15_14-40-42.json ├── trainable.pkl └── tuner.pkl ``` Update: 1. Updated a few classes to be backward compatible - while legacy ray train deprecation is ongoing. 2. Marked all places in 1 using "# TODO(xwjiang): Legacy Ray Train trainer clean up!". So we can easily clean those up once Antoni's work is landed. 3. All CI and release tests are passing. Co-authored-by: Eric Liang <ekhliang@gmail.com>
This commit is contained in:
parent
2b270fd9cb
commit
97f42425da
18 changed files with 978 additions and 252 deletions
|
@ -193,6 +193,14 @@ py_test(
|
|||
deps = [":ml_lib"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_keras_callback",
|
||||
size = "small",
|
||||
srcs = ["tests/test_keras_callback.py"],
|
||||
tags = ["team:ml", "exclusive"],
|
||||
deps = [":ml_lib"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_remote_storage",
|
||||
size = "small",
|
||||
|
|
87
python/ray/air/_internal/session.py
Normal file
87
python/ray/air/_internal/session.py
Normal file
|
@ -0,0 +1,87 @@
|
|||
import abc
|
||||
import logging
|
||||
from typing import Dict, Optional
|
||||
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Session(abc.ABC):
|
||||
"""The canonical session interface that both Tune and Train session implements.
|
||||
|
||||
User can interact with this interface to get session information,
|
||||
as well as reporting metrics and saving checkpoint.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def report(self, metrics: Dict, *, checkpoint: Optional[Checkpoint] = None) -> None:
|
||||
"""Report metrics and optionally save checkpoint.
|
||||
|
||||
Each invocation of this method will automatically increment the underlying
|
||||
iteration number. The physical meaning of this "iteration" is defined by
|
||||
user (or more specifically the way they call ``report``).
|
||||
It does not necessarily map to one epoch.
|
||||
|
||||
This API is supposed to replace the legacy ``tune.report``,
|
||||
``with tune.checkpoint_dir``, ``train.report`` and ``train.save_checkpoint``.
|
||||
Please avoid mixing them together.
|
||||
|
||||
There is no requirement on what is the underlying representation of the
|
||||
checkpoint.
|
||||
|
||||
All forms are accepted and (will be) handled by AIR in an efficient way.
|
||||
|
||||
Specifically, if you are passing in a directory checkpoint, AIR will move
|
||||
the content of the directory to AIR managed directory. By the return of this
|
||||
method, one may safely write new content to the original directory without
|
||||
interfering with AIR checkpointing flow.
|
||||
|
||||
Args:
|
||||
metrics: The metrics you want to report.
|
||||
checkpoint: The optional checkpoint you want to report.
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def loaded_checkpoint(self) -> Optional[Checkpoint]:
|
||||
"""Access the session's loaded checkpoint to resume from if applicable.
|
||||
|
||||
Returns:
|
||||
Checkpoint object if the session is currently being resumed.
|
||||
Otherwise, return None.
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def trial_name(self) -> str:
|
||||
"""Trial name for the corresponding trial."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def trial_id(self) -> str:
|
||||
"""Trial id for the corresponding trial."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def trial_resources(self) -> Dict[str, float]:
|
||||
"""Trial resources for the corresponding trial."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _get_session() -> Optional[Session]:
|
||||
from ray.train._internal.session import _session_v2 as train_session
|
||||
from ray.tune.session import _session_v2 as tune_session
|
||||
|
||||
if train_session and tune_session:
|
||||
logger.warning(
|
||||
"Expected to be either in tune session or train session but not both."
|
||||
)
|
||||
return None
|
||||
if not (train_session or tune_session):
|
||||
logger.warning("In neither tune session nor train session!")
|
||||
return None
|
||||
return train_session or tune_session
|
188
python/ray/air/callbacks/keras.py
Normal file
188
python/ray/air/callbacks/keras.py
Normal file
|
@ -0,0 +1,188 @@
|
|||
from collections import Counter
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from tensorflow.keras.callbacks import Callback as KerasCallback
|
||||
|
||||
from ray.air import session
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.util.annotations import PublicAPI
|
||||
|
||||
|
||||
class _Callback(KerasCallback):
|
||||
"""Base class for Air's Keras callbacks."""
|
||||
|
||||
_allowed = [
|
||||
"batch_begin",
|
||||
"batch_end",
|
||||
"epoch_begin",
|
||||
"epoch_end",
|
||||
"train_batch_begin",
|
||||
"train_batch_end",
|
||||
"test_batch_begin",
|
||||
"test_batch_end",
|
||||
"predict_batch_begin",
|
||||
"predict_batch_end",
|
||||
"train_begin",
|
||||
"train_end",
|
||||
"test_begin",
|
||||
"test_end",
|
||||
"predict_begin",
|
||||
"predict_end",
|
||||
]
|
||||
|
||||
def __init__(self, on: Union[str, List[str]] = "validation_end"):
|
||||
super(_Callback, self).__init__()
|
||||
|
||||
if not isinstance(on, list):
|
||||
on = [on]
|
||||
if any(w not in self._allowed for w in on):
|
||||
raise ValueError(
|
||||
"Invalid trigger time selected: {}. Must be one of {}".format(
|
||||
on, self._allowed
|
||||
)
|
||||
)
|
||||
self._on = on
|
||||
|
||||
def _handle(self, logs: Dict, when: str):
|
||||
raise NotImplementedError
|
||||
|
||||
def on_batch_begin(self, batch, logs=None):
|
||||
if "batch_begin" in self._on:
|
||||
self._handle(logs, "batch_begin")
|
||||
|
||||
def on_batch_end(self, batch, logs=None):
|
||||
if "batch_end" in self._on:
|
||||
self._handle(logs, "batch_end")
|
||||
|
||||
def on_epoch_begin(self, epoch, logs=None):
|
||||
if "epoch_begin" in self._on:
|
||||
self._handle(logs, "epoch_begin")
|
||||
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
if "epoch_end" in self._on:
|
||||
self._handle(logs, "epoch_end")
|
||||
|
||||
def on_train_batch_begin(self, batch, logs=None):
|
||||
if "train_batch_begin" in self._on:
|
||||
self._handle(logs, "train_batch_begin")
|
||||
|
||||
def on_train_batch_end(self, batch, logs=None):
|
||||
if "train_batch_end" in self._on:
|
||||
self._handle(logs, "train_batch_end")
|
||||
|
||||
def on_test_batch_begin(self, batch, logs=None):
|
||||
if "test_batch_begin" in self._on:
|
||||
self._handle(logs, "test_batch_begin")
|
||||
|
||||
def on_test_batch_end(self, batch, logs=None):
|
||||
if "test_batch_end" in self._on:
|
||||
self._handle(logs, "test_batch_end")
|
||||
|
||||
def on_predict_batch_begin(self, batch, logs=None):
|
||||
if "predict_batch_begin" in self._on:
|
||||
self._handle(logs, "predict_batch_begin")
|
||||
|
||||
def on_predict_batch_end(self, batch, logs=None):
|
||||
if "predict_batch_end" in self._on:
|
||||
self._handle(logs, "predict_batch_end")
|
||||
|
||||
def on_train_begin(self, logs=None):
|
||||
if "train_begin" in self._on:
|
||||
self._handle(logs, "train_begin")
|
||||
|
||||
def on_train_end(self, logs=None):
|
||||
if "train_end" in self._on:
|
||||
self._handle(logs, "train_end")
|
||||
|
||||
def on_test_begin(self, logs=None):
|
||||
if "test_begin" in self._on:
|
||||
self._handle(logs, "test_begin")
|
||||
|
||||
def on_test_end(self, logs=None):
|
||||
if "test_end" in self._on:
|
||||
self._handle(logs, "test_end")
|
||||
|
||||
def on_predict_begin(self, logs=None):
|
||||
if "predict_begin" in self._on:
|
||||
self._handle(logs, "predict_begin")
|
||||
|
||||
def on_predict_end(self, logs=None):
|
||||
if "predict_end" in self._on:
|
||||
self._handle(logs, "predict_end")
|
||||
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
class Callback(_Callback):
|
||||
def __init__(
|
||||
self,
|
||||
metrics: Optional[Union[str, List[str], Dict[str, str]]] = None,
|
||||
on: Union[str, List[str]] = "epoch_end",
|
||||
frequency: Union[int, List[int]] = 1,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
metrics: Metrics to report. If this is a list, each item describes
|
||||
the metric key reported to Keras, and it will reported under the
|
||||
same name. If this is a dict, each key will be the name reported
|
||||
and the respective value will be the metric key reported to Keras.
|
||||
If this is None, all Keras logs will be reported.
|
||||
on: When to report metrics. Must be one of
|
||||
the Keras event hooks (less the ``on_``), e.g.
|
||||
"train_start", or "predict_end". Defaults to "epoch_end".
|
||||
frequency: Checkpoint frequency. If this is an integer `n`,
|
||||
checkpoints are saved every `n` times each hook was called. If
|
||||
this is a list, it specifies the checkpoint frequencies for each
|
||||
hook individually.
|
||||
|
||||
You can use this in both TuneSession and TrainSession.
|
||||
|
||||
Example:
|
||||
.. code-block: python
|
||||
|
||||
############# Using it in TrainSession ###############
|
||||
from ray.air.callbacks.keras import Callback
|
||||
def train_loop_per_worker():
|
||||
strategy = tf.distribute.MultiWorkerMirroredStrategy()
|
||||
with strategy.scope():
|
||||
model = build_model()
|
||||
#model.compile(...)
|
||||
model.fit(dataset_shard, callbacks=[Callback()])
|
||||
"""
|
||||
if isinstance(frequency, list):
|
||||
if not isinstance(on, list) or len(frequency) != len(on):
|
||||
raise ValueError(
|
||||
"If you pass a list for checkpoint frequencies, the `on` "
|
||||
"parameter has to be a list with the same length."
|
||||
)
|
||||
|
||||
self._frequency = frequency
|
||||
super(Callback, self).__init__(on)
|
||||
self._metrics = metrics
|
||||
self._counter = Counter()
|
||||
|
||||
def _handle(self, logs: Dict, when: str = None):
|
||||
self._counter[when] += 1
|
||||
|
||||
if isinstance(self._frequency, list):
|
||||
index = self._on.index(when)
|
||||
freq = self._frequency[index]
|
||||
else:
|
||||
freq = self._frequency
|
||||
|
||||
checkpoint = None
|
||||
if freq > 0 and self._counter[when] % freq == 0:
|
||||
self.model.save("my_model", overwrite=True)
|
||||
checkpoint = Checkpoint.from_directory("my_model")
|
||||
|
||||
if not self._metrics:
|
||||
report_dict = logs
|
||||
else:
|
||||
report_dict = {}
|
||||
for key in self._metrics:
|
||||
if isinstance(self._metrics, dict):
|
||||
metric = self._metrics[key]
|
||||
else:
|
||||
metric = key
|
||||
report_dict[key] = logs[metric]
|
||||
|
||||
session.report(report_dict, checkpoint=checkpoint)
|
258
python/ray/air/session.py
Normal file
258
python/ray/air/session.py
Normal file
|
@ -0,0 +1,258 @@
|
|||
from typing import TYPE_CHECKING, Dict, Optional, Union
|
||||
|
||||
from ray.air._internal.session import _get_session
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.train.session import _TrainSessionImpl
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.data import Dataset, DatasetPipeline
|
||||
|
||||
|
||||
def report(metrics: Dict, *, checkpoint: Optional[Checkpoint] = None) -> None:
|
||||
"""Report metrics and optionally save a checkpoint.
|
||||
|
||||
Each invocation of this method will automatically increment the underlying
|
||||
iteration number. The physical meaning of this "iteration" is defined by
|
||||
user (or more specifically the way they call ``report``).
|
||||
It does not necessarily map to one epoch.
|
||||
|
||||
This API is the canonical way to report metrics from Tune and Train, and
|
||||
replaces the legacy ``tune.report``, ``with tune.checkpoint_dir``,
|
||||
``train.report`` and ``train.save_checkpoint`` calls.
|
||||
|
||||
Note on directory checkpoints: AIR will take ownership of checkpoints passed
|
||||
to ``report()`` by moving them to a new path. The original directory will no
|
||||
longer be accessible to the caller after the report call.
|
||||
|
||||
Example:
|
||||
.. code-block: python
|
||||
|
||||
from ray.air import session
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
|
||||
######## Using it in the *per worker* train loop (TrainSession) #######
|
||||
def train_func():
|
||||
model = build_model()
|
||||
model.save("my_model", overwrite=True)
|
||||
session.report(
|
||||
metrics={"foo": "bar"},
|
||||
checkpoint=Checkpoint.from_directory(temp_dir.name)
|
||||
)
|
||||
# Air guarantees by this point, you can safely write new stuff to
|
||||
# "my_model" directory.
|
||||
|
||||
scaling_config = {"num_workers": 2}
|
||||
trainer = TensorflowTrainer(
|
||||
train_loop_per_worker=train_func, scaling_config=scaling_config
|
||||
)
|
||||
result = trainer.fit()
|
||||
# If you navigate to result.checkpoint's path, you will find the
|
||||
content of ``model.save()`` under it.
|
||||
# If you have `SyncConfig` configured, the content should also
|
||||
# show up in the corresponding cloud storage path.
|
||||
|
||||
Args:
|
||||
metrics: The metrics you want to report.
|
||||
checkpoint: The optional checkpoint you want to report.
|
||||
"""
|
||||
|
||||
_get_session().report(metrics, checkpoint=checkpoint)
|
||||
|
||||
|
||||
def get_checkpoint() -> Optional[Checkpoint]:
|
||||
"""Access the session's last checkpoint to resume from if applicable.
|
||||
|
||||
Returns:
|
||||
Checkpoint object if the session is currently being resumed.
|
||||
Otherwise, return None.
|
||||
|
||||
Example:
|
||||
.. code-block: python
|
||||
|
||||
######## Using it in the *per worker* train loop (TrainSession) ######
|
||||
from ray.air import session
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
def train_func():
|
||||
if session.get_checkpoint():
|
||||
with session.get_checkpoint().as_directory() as
|
||||
loaded_checkpoint_dir:
|
||||
import tensorflow as tf
|
||||
model = tf.keras.models.load_model(loaded_checkpoint_dir)
|
||||
else:
|
||||
model = build_model()
|
||||
|
||||
model.save("my_model", overwrite=True)
|
||||
session.report(
|
||||
metrics={"iter": 1},
|
||||
checkpoint=Checkpoint.from_directory("my_model")
|
||||
)
|
||||
|
||||
scaling_config = {"num_workers": 2}
|
||||
trainer = TensorflowTrainer(
|
||||
train_loop_per_worker=train_func, scaling_config=scaling_config
|
||||
)
|
||||
result = trainer.fit()
|
||||
|
||||
# trainer2 will pick up from the checkpoint saved by trainer1.
|
||||
trainer2 = TensorflowTrainer(
|
||||
train_loop_per_worker=train_func,
|
||||
scaling_config=scaling_config,
|
||||
# this is ultimately what is accessed through
|
||||
# ``Session.get_checkpoint()``
|
||||
resume_from_checkpoint=result.checkpoint,
|
||||
)
|
||||
result2 = trainer2.fit()
|
||||
"""
|
||||
|
||||
return _get_session().loaded_checkpoint
|
||||
|
||||
|
||||
def get_trial_name() -> str:
|
||||
"""Trial name for the corresponding trial."""
|
||||
return _get_session().trial_name
|
||||
|
||||
|
||||
def get_trial_id() -> str:
|
||||
"""Trial id for the corresponding trial."""
|
||||
return _get_session().trial_id
|
||||
|
||||
|
||||
def get_trial_resources() -> Dict[str, float]:
|
||||
"""Trial resources for the corresponding trial."""
|
||||
return _get_session().trial_resources
|
||||
|
||||
|
||||
def get_world_size() -> int:
|
||||
"""Get the current world size (i.e. total number of workers) for this run.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import time
|
||||
from ray.air import session
|
||||
|
||||
def train_loop_per_worker(config):
|
||||
assert session.get_world_size() == 4
|
||||
|
||||
train_dataset = ray.data.from_items(
|
||||
[{"x": x, "y": x + 1} for x in range(32)])
|
||||
trainer = TensorflowTrainer(train_loop_per_worker,
|
||||
scaling_config={"num_workers": 1},
|
||||
datasets={"train": train_dataset})
|
||||
trainer.fit()
|
||||
"""
|
||||
session = _get_session()
|
||||
if not isinstance(session, _TrainSessionImpl):
|
||||
raise RuntimeError(
|
||||
"`get_world_size` can only be called for TrainSession! "
|
||||
"Make sure you only use that in `train_loop_per_worker` function"
|
||||
"that is passed into `DataParallelTrainer`."
|
||||
)
|
||||
return session.world_size
|
||||
|
||||
|
||||
def get_world_rank() -> int:
|
||||
"""Get the world rank of this worker.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import time
|
||||
from ray.air import session
|
||||
|
||||
def train_loop_per_worker():
|
||||
for iter in range(100):
|
||||
time.sleep(1)
|
||||
if session.get_world_rank() == 0:
|
||||
print("Worker 0")
|
||||
|
||||
train_dataset = ray.data.from_items(
|
||||
[{"x": x, "y": x + 1} for x in range(32)])
|
||||
trainer = TensorflowTrainer(train_loop_per_worker,
|
||||
scaling_config={"num_workers": 1},
|
||||
datasets={"train": train_dataset})
|
||||
trainer.fit()
|
||||
"""
|
||||
session = _get_session()
|
||||
if not isinstance(session, _TrainSessionImpl):
|
||||
raise RuntimeError(
|
||||
"`get_world_rank` can only be called for TrainSession! "
|
||||
"Make sure you only use that in `train_loop_per_worker` function"
|
||||
"that is passed into `DataParallelTrainer`."
|
||||
)
|
||||
return session.world_rank
|
||||
|
||||
|
||||
def get_local_rank() -> int:
|
||||
"""Get the local rank of this worker (rank of the worker on its node).
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import time
|
||||
from ray.air import session
|
||||
|
||||
def train_loop_per_worker():
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_device(session.get_local_rank())
|
||||
...
|
||||
|
||||
train_dataset = ray.data.from_items(
|
||||
[{"x": x, "y": x + 1} for x in range(32)])
|
||||
trainer = TensorflowTrainer(train_loop_per_worker,
|
||||
scaling_config={"num_workers": 1},
|
||||
datasets={"train": train_dataset})
|
||||
trainer.fit()
|
||||
"""
|
||||
session = _get_session()
|
||||
if not isinstance(session, _TrainSessionImpl):
|
||||
raise RuntimeError(
|
||||
"`get_local_rank` can only be called for TrainSession! "
|
||||
"Make sure you only use that in `train_loop_per_worker` function"
|
||||
"that is passed into `DataParallelTrainer`."
|
||||
)
|
||||
return session.local_rank
|
||||
|
||||
|
||||
def get_dataset_shard(
|
||||
dataset_name: Optional[str] = None,
|
||||
) -> Optional[Union["Dataset", "DatasetPipeline"]]:
|
||||
"""Returns the Ray Dataset or DatasetPipeline shard for this worker.
|
||||
|
||||
You should call ``to_torch()`` or ``to_tf()`` on this shard to convert
|
||||
it to the appropriate framework-specific Dataset.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import ray
|
||||
from ray import train
|
||||
from ray.air import session
|
||||
|
||||
def train_loop_per_worker():
|
||||
model = Net()
|
||||
for iter in range(100):
|
||||
# Trainer will automatically handle sharding.
|
||||
data_shard = session.get_dataset_shard().to_torch()
|
||||
model.train(data_shard)
|
||||
return model
|
||||
|
||||
train_dataset = ray.data.from_items(
|
||||
[{"x": x, "y": x + 1} for x in range(32)])
|
||||
trainer = TorchTrainer(train_loop_per_worker,
|
||||
scaling_config={"num_workers": 2},
|
||||
datasets={"train": train_dataset})
|
||||
trainer.fit()
|
||||
|
||||
Args:
|
||||
dataset_name: If a Dictionary of Datasets was passed to ``Trainer``, then
|
||||
specifies which dataset shard to return.
|
||||
|
||||
Returns:
|
||||
The ``Dataset`` or ``DatasetPipeline`` shard to use for this worker.
|
||||
If no dataset is passed into Trainer, then return None.
|
||||
"""
|
||||
session = _get_session()
|
||||
if not isinstance(session, _TrainSessionImpl):
|
||||
raise RuntimeError(
|
||||
"`get_dataset_shard` can only be called for TrainSession! "
|
||||
"Make sure you only use that in `train_loop_per_worker` function"
|
||||
"that is passed into `DataParallelTrainer`."
|
||||
)
|
||||
return session.get_dataset_shard(dataset_name)
|
64
python/ray/air/tests/test_keras_callback.py
Normal file
64
python/ray/air/tests/test_keras_callback.py
Normal file
|
@ -0,0 +1,64 @@
|
|||
import os
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from ray.air import session
|
||||
from ray.air.callbacks.keras import Callback
|
||||
from ray.air.examples.tf.tensorflow_linear_dataset_example import (
|
||||
build_model,
|
||||
get_dataset,
|
||||
)
|
||||
from ray.train.constants import TRAIN_DATASET_KEY
|
||||
from ray.train.tensorflow import TensorflowTrainer, prepare_dataset_shard
|
||||
|
||||
|
||||
def train_func(config: dict):
|
||||
strategy = tf.distribute.MultiWorkerMirroredStrategy()
|
||||
with strategy.scope():
|
||||
# Model building/compiling need to be within `strategy.scope()`.
|
||||
multi_worker_model = build_model()
|
||||
multi_worker_model.compile(
|
||||
optimizer=tf.keras.optimizers.SGD(learning_rate=config.get("lr", 1e-3)),
|
||||
loss=tf.keras.losses.mean_squared_error,
|
||||
metrics=[tf.keras.metrics.mean_squared_error],
|
||||
)
|
||||
|
||||
dataset = session.get_dataset_shard("train")
|
||||
|
||||
for _ in range(config.get("epoch", 3)):
|
||||
tf_dataset = prepare_dataset_shard(
|
||||
dataset.to_tf(
|
||||
label_column="y",
|
||||
output_signature=(
|
||||
tf.TensorSpec(shape=(None, 1), dtype=tf.float32),
|
||||
tf.TensorSpec(shape=(None), dtype=tf.float32),
|
||||
),
|
||||
batch_size=32,
|
||||
)
|
||||
)
|
||||
multi_worker_model.fit(tf_dataset, callbacks=[Callback()])
|
||||
|
||||
|
||||
def test_keras_callback():
|
||||
epochs = 3
|
||||
scaling_config = {"num_workers": 2}
|
||||
config = {
|
||||
"epochs": epochs,
|
||||
}
|
||||
trainer = TensorflowTrainer(
|
||||
train_loop_per_worker=train_func,
|
||||
train_loop_config=config,
|
||||
scaling_config=scaling_config,
|
||||
datasets={TRAIN_DATASET_KEY: get_dataset()},
|
||||
)
|
||||
checkpoint = trainer.fit().checkpoint
|
||||
with checkpoint.as_directory() as ckpt_dir:
|
||||
assert os.path.exists(os.path.join(ckpt_dir, "saved_model.pb"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
sys.exit(pytest.main(["-v", "-x", __file__]))
|
|
@ -1,23 +1,29 @@
|
|||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import Callable, List, Optional, Dict, Type, Tuple, TypeVar
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar
|
||||
|
||||
import ray
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.exceptions import RayActorError
|
||||
from ray.ray_constants import env_integer
|
||||
from ray.train._internal.dataset_spec import RayDatasetSpec
|
||||
from ray.train._internal.session import (
|
||||
TrainingResult,
|
||||
TrialInfo,
|
||||
get_session,
|
||||
init_session,
|
||||
shutdown_session,
|
||||
)
|
||||
from ray.train._internal.utils import check_for_failure
|
||||
from ray.train._internal.worker_group import WorkerGroup
|
||||
from ray.train.backend import BackendConfig
|
||||
from ray.train.constants import (
|
||||
ENABLE_DETAILED_AUTOFILLED_METRICS_ENV,
|
||||
ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV,
|
||||
TRAIN_PLACEMENT_GROUP_TIMEOUT_S_ENV,
|
||||
TRAIN_ENABLE_WORKER_SPREAD_ENV,
|
||||
TRAIN_PLACEMENT_GROUP_TIMEOUT_S_ENV,
|
||||
)
|
||||
from ray.train.backend import BackendConfig
|
||||
from ray.train._internal.dataset_spec import RayDatasetSpec
|
||||
from ray.train._internal.session import TrainingResult
|
||||
from ray.train._internal.session import init_session, get_session, shutdown_session
|
||||
from ray.train._internal.utils import check_for_failure
|
||||
from ray.train._internal.worker_group import WorkerGroup
|
||||
from ray.util.placement_group import get_current_placement_group, remove_placement_group
|
||||
|
||||
T = TypeVar("T")
|
||||
|
@ -57,6 +63,8 @@ class BackendExecutor:
|
|||
def __init__(
|
||||
self,
|
||||
backend_config: BackendConfig,
|
||||
# TODO(xwjiang): Legacy Ray Train trainer clean up!
|
||||
trial_info: Optional[TrialInfo] = None,
|
||||
num_workers: int = 1,
|
||||
num_cpus_per_worker: float = 1,
|
||||
num_gpus_per_worker: float = 0,
|
||||
|
@ -76,6 +84,8 @@ class BackendExecutor:
|
|||
self._initialization_hook = None
|
||||
self._placement_group = None
|
||||
|
||||
self._trial_info = trial_info
|
||||
|
||||
self.worker_group = InactiveWorkerGroup()
|
||||
self.dataset_shards = None
|
||||
|
||||
|
@ -264,7 +274,7 @@ class BackendExecutor:
|
|||
self,
|
||||
train_func: Callable[[], T],
|
||||
dataset_spec: RayDatasetSpec,
|
||||
checkpoint: Optional[Dict] = None,
|
||||
checkpoint: Optional[Checkpoint] = None,
|
||||
) -> None:
|
||||
"""Executes a training function on all workers in a separate thread.
|
||||
|
||||
|
@ -290,6 +300,7 @@ class BackendExecutor:
|
|||
world_rank,
|
||||
local_rank,
|
||||
world_size,
|
||||
trial_info,
|
||||
checkpoint,
|
||||
dataset_shard,
|
||||
encode_data_fn,
|
||||
|
@ -300,6 +311,7 @@ class BackendExecutor:
|
|||
world_rank=world_rank,
|
||||
local_rank=local_rank,
|
||||
world_size=world_size,
|
||||
trial_info=trial_info,
|
||||
dataset_shard=dataset_shard,
|
||||
checkpoint=checkpoint,
|
||||
encode_data_fn=encode_data_fn,
|
||||
|
@ -328,6 +340,7 @@ class BackendExecutor:
|
|||
world_rank=index,
|
||||
local_rank=local_rank_map[index],
|
||||
world_size=len(self.worker_group),
|
||||
trial_info=self._trial_info,
|
||||
train_func=train_func,
|
||||
dataset_shard=self.dataset_shards[index],
|
||||
checkpoint=checkpoint,
|
||||
|
|
|
@ -1,22 +1,21 @@
|
|||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Dict, Union, Callable
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
from ray.air import Checkpoint
|
||||
from ray.train._internal.session import TrainingResult
|
||||
from ray.train._internal.utils import construct_path
|
||||
from ray.train.constants import (
|
||||
TIMESTAMP,
|
||||
TRAIN_CHECKPOINT_SUBDIR,
|
||||
TUNE_CHECKPOINT_ID,
|
||||
TUNE_INSTALLED,
|
||||
)
|
||||
from ray.train._internal.session import TrainingResult
|
||||
from ray.train._internal.utils import construct_path
|
||||
from ray.util.ml_utils.checkpoint_manager import CheckpointStorage, CheckpointStrategy
|
||||
from ray.util.ml_utils.checkpoint_manager import (
|
||||
_CheckpointManager as CommonCheckpointManager,
|
||||
_TrackedCheckpoint,
|
||||
CheckpointStrategy,
|
||||
CheckpointStorage,
|
||||
)
|
||||
from ray.util.ml_utils.checkpoint_manager import _TrackedCheckpoint
|
||||
|
||||
if TUNE_INSTALLED:
|
||||
from ray import tune
|
||||
|
@ -80,14 +79,17 @@ class CheckpointManager(CommonCheckpointManager):
|
|||
if self._checkpoint_strategy.checkpoint_score_attribute is None:
|
||||
self._checkpoint_strategy.checkpoint_score_attribute = TIMESTAMP
|
||||
|
||||
# TODO(xwjiang): Legacy Ray Train trainer clean up!
|
||||
def _load_checkpoint(
|
||||
self, checkpoint_to_load: Optional[Union[Dict, str, Path]]
|
||||
) -> Optional[Dict]:
|
||||
self, checkpoint_to_load: Optional[Union[Dict, str, Path, Checkpoint]]
|
||||
) -> Optional[Union[Dict, Checkpoint]]:
|
||||
"""Load the checkpoint dictionary from the input dict or path."""
|
||||
if checkpoint_to_load is None:
|
||||
return None
|
||||
if isinstance(checkpoint_to_load, Dict):
|
||||
return checkpoint_to_load
|
||||
if isinstance(checkpoint_to_load, Checkpoint):
|
||||
return checkpoint_to_load
|
||||
else:
|
||||
# Load checkpoint from path.
|
||||
return load_checkpoint_from_path(checkpoint_to_load)
|
||||
|
@ -198,9 +200,17 @@ class CheckpointManager(CommonCheckpointManager):
|
|||
|
||||
class TuneCheckpointManager(CheckpointManager):
|
||||
def _load_checkpoint(
|
||||
self, checkpoint_to_load: Optional[Union[Dict, str, Path]]
|
||||
) -> Optional[Dict]:
|
||||
self, checkpoint_to_load: Optional[Union[Dict, str, Path, Checkpoint]]
|
||||
) -> Optional[Union[Dict, Checkpoint]]:
|
||||
# TODO(xwjiang): Legacy Ray Train trainer clean up!
|
||||
loaded_checkpoint = super()._load_checkpoint(checkpoint_to_load)
|
||||
# New path...
|
||||
if isinstance(loaded_checkpoint, Checkpoint):
|
||||
# The new logic
|
||||
checkpoint_dict = loaded_checkpoint.to_dict()
|
||||
self._latest_checkpoint_id = checkpoint_dict[TUNE_CHECKPOINT_ID]
|
||||
return loaded_checkpoint
|
||||
# legacy path...
|
||||
if loaded_checkpoint is not None:
|
||||
# If the Tune trial is restarted, a new Trainer is instantiated.
|
||||
# However, we want the checkpoint_id to continue incrementing
|
||||
|
|
|
@ -3,29 +3,30 @@ import platform
|
|||
import queue
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum, auto
|
||||
from typing import Callable
|
||||
from typing import Optional, Dict, Type, Union
|
||||
from typing import Callable, Dict, Optional, Type, Union
|
||||
|
||||
import ray
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.data import Dataset, DatasetPipeline
|
||||
from ray.train._internal.accelerator import Accelerator
|
||||
from ray.train.constants import (
|
||||
DETAILED_AUTOFILLED_KEYS,
|
||||
TIME_THIS_ITER_S,
|
||||
PID,
|
||||
TIMESTAMP,
|
||||
TIME_TOTAL_S,
|
||||
NODE_IP,
|
||||
TRAINING_ITERATION,
|
||||
HOSTNAME,
|
||||
DATE,
|
||||
RESULT_FETCH_TIMEOUT,
|
||||
)
|
||||
from ray.train._internal.utils import PropagatingThread
|
||||
from ray.train.constants import (
|
||||
DATE,
|
||||
DETAILED_AUTOFILLED_KEYS,
|
||||
HOSTNAME,
|
||||
NODE_IP,
|
||||
PID,
|
||||
RESULT_FETCH_TIMEOUT,
|
||||
TIME_THIS_ITER_S,
|
||||
TIME_TOTAL_S,
|
||||
TIMESTAMP,
|
||||
TRAINING_ITERATION,
|
||||
)
|
||||
from ray.train.error import SessionMisuseError
|
||||
from ray.train.session import _TrainSessionImpl
|
||||
|
||||
|
||||
class TrainingResultType(Enum):
|
||||
|
@ -33,13 +34,24 @@ class TrainingResultType(Enum):
|
|||
CHECKPOINT = auto()
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrialInfo:
|
||||
"""The trial information to propagate to TrainSession."""
|
||||
|
||||
name: str
|
||||
id: str
|
||||
resources: Dict[str, float]
|
||||
logdir: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingResult:
|
||||
type: TrainingResultType
|
||||
data: Dict
|
||||
|
||||
|
||||
class Session:
|
||||
# TODO(xwjiang): This needs a better name.
|
||||
class _TrainSession:
|
||||
"""Holds information for training on each worker."""
|
||||
|
||||
def __init__(
|
||||
|
@ -48,8 +60,11 @@ class Session:
|
|||
world_rank: int,
|
||||
local_rank: int,
|
||||
world_size: int,
|
||||
# TODO(xwjiang): Legacy Ray Train trainer clean up!
|
||||
trial_info: Optional[TrialInfo] = None,
|
||||
dataset_shard: Optional[Union[Dataset, DatasetPipeline]] = None,
|
||||
checkpoint: Optional[Dict] = None,
|
||||
# TODO(xwjiang): Legacy Ray Train trainer clean up!
|
||||
checkpoint: Optional[Union[Dict, Checkpoint]] = None,
|
||||
encode_data_fn: Callable = None,
|
||||
detailed_autofilled_metrics: bool = False,
|
||||
):
|
||||
|
@ -61,7 +76,9 @@ class Session:
|
|||
self.world_rank = world_rank
|
||||
self.local_rank = local_rank
|
||||
self.world_size = world_size
|
||||
self.loaded_checkpoint = checkpoint
|
||||
self.trial_info = trial_info
|
||||
# TODO(xwjiang): Legacy Ray Train trainer clean up!
|
||||
self.loaded_checkpoint: Optional[Union[Dict, Checkpoint]] = checkpoint
|
||||
|
||||
# Function to encode checkpoint dict before sending to the driver.
|
||||
if not encode_data_fn:
|
||||
|
@ -72,6 +89,13 @@ class Session:
|
|||
encode_data_fn = noop
|
||||
self._encode_data_fn = encode_data_fn
|
||||
|
||||
# TODO(xwjiang): Legacy Ray Train trainer clean up!
|
||||
if trial_info:
|
||||
# Change the working directory to `logdir`.
|
||||
logdir = os.path.join(trial_info.logdir, f"rank_{self.world_rank}")
|
||||
os.makedirs(logdir, exist_ok=True)
|
||||
os.chdir(logdir)
|
||||
|
||||
# This lock is used to control the execution of the training thread.
|
||||
self.continue_lock = threading.Semaphore(0)
|
||||
|
||||
|
@ -184,7 +208,7 @@ class Session:
|
|||
result.update(auto_filled_metrics)
|
||||
return result
|
||||
|
||||
def report(self, **kwargs):
|
||||
def _report_legacy(self, **kwargs):
|
||||
"""Adds kwargs to the queue to be consumed by main thread."""
|
||||
if self.ignore_report:
|
||||
return
|
||||
|
@ -234,22 +258,32 @@ class Session:
|
|||
# checkpoint has been processed.
|
||||
self.continue_lock.acquire()
|
||||
|
||||
def report(self, metrics: Dict, checkpoint: Optional[Checkpoint] = None) -> None:
|
||||
# TODO(xwjiang): tons of optimizations.
|
||||
if checkpoint:
|
||||
checkpoint_dict = checkpoint.to_dict()
|
||||
self.checkpoint(**checkpoint_dict)
|
||||
self._report_legacy(**metrics)
|
||||
|
||||
_session = None
|
||||
|
||||
_session: Optional[_TrainSession] = None
|
||||
# V2 Session API
|
||||
_session_v2: Optional[_TrainSessionImpl] = None
|
||||
|
||||
|
||||
def init_session(*args, **kwargs) -> None:
|
||||
global _session
|
||||
global _session_v2
|
||||
if _session:
|
||||
raise ValueError(
|
||||
"A Train session is already in use. Do not call "
|
||||
"`init_session()` manually."
|
||||
)
|
||||
_session = Session(*args, **kwargs)
|
||||
_session = _TrainSession(*args, **kwargs)
|
||||
_session_v2 = _TrainSessionImpl(session=_session)
|
||||
|
||||
|
||||
def get_session() -> Optional[Session]:
|
||||
global _session
|
||||
def get_session() -> Optional[_TrainSession]:
|
||||
return _session
|
||||
|
||||
|
||||
|
|
|
@ -1,33 +1,22 @@
|
|||
import inspect
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
Type,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type, Union
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.air.constants import MODEL_KEY, PREPROCESSOR_KEY
|
||||
from ray.train.constants import (
|
||||
TRAIN_DATASET_KEY,
|
||||
WILDCARD_KEY,
|
||||
)
|
||||
from ray.train.trainer import BaseTrainer
|
||||
from ray.air.config import ScalingConfig, RunConfig, DatasetConfig
|
||||
from ray.train.trainer import GenDataset
|
||||
from ray.air import session
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.train._internal.dataset_spec import DataParallelIngestSpec
|
||||
from ray.air.config import DatasetConfig, RunConfig, ScalingConfig
|
||||
from ray.air.constants import MODEL_KEY, PREPROCESSOR_KEY
|
||||
from ray.train import BackendConfig, TrainingIterator
|
||||
from ray.train._internal.backend_executor import BackendExecutor
|
||||
from ray.train._internal.backend_executor import BackendExecutor, TrialInfo
|
||||
from ray.train._internal.checkpoint import TuneCheckpointManager
|
||||
from ray.train._internal.dataset_spec import DataParallelIngestSpec
|
||||
from ray.train._internal.utils import construct_train_func
|
||||
from ray.train.constants import TRAIN_DATASET_KEY, WILDCARD_KEY
|
||||
from ray.train.trainer import BaseTrainer, GenDataset
|
||||
from ray.util.annotations import DeveloperAPI
|
||||
from ray.util.ml_utils.checkpoint_manager import CheckpointStrategy, _TrackedCheckpoint
|
||||
|
||||
|
@ -335,8 +324,16 @@ class DataParallelTrainer(BaseTrainer):
|
|||
scaling_config_dataclass.additional_resources_per_worker
|
||||
)
|
||||
|
||||
trial_info = TrialInfo(
|
||||
name=session.get_trial_name(),
|
||||
id=session.get_trial_id(),
|
||||
resources=session.get_trial_resources(),
|
||||
logdir=os.getcwd(),
|
||||
)
|
||||
|
||||
backend_executor = BackendExecutor(
|
||||
backend_config=self._backend_config,
|
||||
trial_info=trial_info,
|
||||
num_workers=scaling_config_dataclass.num_workers,
|
||||
num_cpus_per_worker=scaling_config_dataclass.num_cpus_per_worker,
|
||||
num_gpus_per_worker=scaling_config_dataclass.num_gpus_per_worker,
|
||||
|
@ -351,20 +348,13 @@ class DataParallelTrainer(BaseTrainer):
|
|||
# Start the remote actors.
|
||||
backend_executor.start(initialization_hook=None)
|
||||
|
||||
if self.resume_from_checkpoint:
|
||||
resume_checkpoint_dict = self.resume_from_checkpoint.to_dict()
|
||||
else:
|
||||
resume_checkpoint_dict = None
|
||||
|
||||
# TODO(amog): Have TrainingIterator also accept a checkpoint ObjectRef instead
|
||||
# of just a Dict.
|
||||
training_iterator = TrainingIterator(
|
||||
backend_executor=backend_executor,
|
||||
backend_config=self._backend_config,
|
||||
train_func=train_loop_per_worker,
|
||||
dataset_spec=self._ingest_spec,
|
||||
checkpoint_manager=checkpoint_manager,
|
||||
checkpoint=resume_checkpoint_dict,
|
||||
checkpoint=self.resume_from_checkpoint,
|
||||
checkpoint_strategy=None,
|
||||
)
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@ from torch.utils.data import Dataset as TorchDataset
|
|||
from transformers.trainer import TRAINING_ARGS_NAME, WEIGHTS_NAME
|
||||
|
||||
from ray import train
|
||||
from ray.air import session
|
||||
from ray.air._internal.checkpointing import (
|
||||
load_preprocessor_from_dir,
|
||||
save_preprocessor_to_dir,
|
||||
|
@ -517,12 +518,14 @@ def _huggingface_train_loop_per_worker(config):
|
|||
|
||||
trainer.add_callback(TrainReportCallback)
|
||||
|
||||
checkpoint = train.load_checkpoint()
|
||||
checkpoint = session.get_checkpoint()
|
||||
checkpoint_path = None
|
||||
remove_checkpoint_path = False
|
||||
if checkpoint:
|
||||
source_ip = checkpoint[NODE_IP_KEY]
|
||||
source_path = checkpoint[CHECKPOINT_PATH_ON_NODE_KEY]
|
||||
assert isinstance(checkpoint, Checkpoint)
|
||||
checkpoint_dict = checkpoint.to_dict()
|
||||
source_ip = checkpoint_dict[NODE_IP_KEY]
|
||||
source_path = checkpoint_dict[CHECKPOINT_PATH_ON_NODE_KEY]
|
||||
target_ip = get_node_ip_address()
|
||||
if source_ip == target_ip:
|
||||
checkpoint_path = source_path
|
||||
|
|
79
python/ray/train/session.py
Normal file
79
python/ray/train/session.py
Normal file
|
@ -0,0 +1,79 @@
|
|||
import warnings
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Union
|
||||
|
||||
from ray.air._internal.session import Session
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# avoid circular import
|
||||
from ray.data import Dataset, DatasetPipeline
|
||||
from ray.train._internal.session import _TrainSession
|
||||
|
||||
|
||||
class _TrainSessionImpl(Session):
|
||||
"""Session client that "per worker train loop" can interact with.
|
||||
|
||||
Notice that each worker will automatically switch to its working
|
||||
directory on entering the train loop. This is to ensure that
|
||||
each worker can safely write to a local directory without racing
|
||||
and overwriting each other."""
|
||||
|
||||
def __init__(self, session: "_TrainSession"):
|
||||
self._session = session
|
||||
|
||||
def report(self, metrics: Dict, *, checkpoint: Optional[Checkpoint] = None) -> None:
|
||||
self._session.report(metrics, checkpoint)
|
||||
|
||||
@property
|
||||
def loaded_checkpoint(self) -> Optional[Checkpoint]:
|
||||
ckpt = self._session.loaded_checkpoint
|
||||
if ckpt:
|
||||
# The new API should only interact with Checkpoint object.
|
||||
assert isinstance(ckpt, Checkpoint)
|
||||
return ckpt
|
||||
|
||||
@property
|
||||
def trial_name(self) -> str:
|
||||
return self._session.trial_info.name
|
||||
|
||||
@property
|
||||
def trial_id(self) -> str:
|
||||
return self._session.trial_info.id
|
||||
|
||||
@property
|
||||
def trial_resources(self) -> Dict[str, float]:
|
||||
return self._session.trial_info.resources
|
||||
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return self._session.world_size
|
||||
|
||||
@property
|
||||
def world_rank(self) -> int:
|
||||
return self._session.world_rank
|
||||
|
||||
@property
|
||||
def local_rank(self) -> int:
|
||||
return self._session.local_rank
|
||||
|
||||
def get_dataset_shard(
|
||||
self,
|
||||
dataset_name: Optional[str] = None,
|
||||
) -> Optional[Union["Dataset", "DatasetPipeline"]]:
|
||||
shard = self._session.dataset_shard
|
||||
if shard is None:
|
||||
warnings.warn(
|
||||
"No dataset passed in. Returning None. Make sure to "
|
||||
"pass in a Ray Dataset to Trainer.run to use this "
|
||||
"function."
|
||||
)
|
||||
elif isinstance(shard, dict):
|
||||
if not dataset_name:
|
||||
raise RuntimeError(
|
||||
"Multiple datasets were passed into ``Trainer``, "
|
||||
"but no ``dataset_name`` is passed into "
|
||||
"``get_dataset_shard``. Please specify which "
|
||||
"dataset shard to retrieve."
|
||||
)
|
||||
return shard.get(dataset_name)
|
||||
return shard
|
|
@ -1,11 +1,11 @@
|
|||
import pytest
|
||||
|
||||
import ray
|
||||
from ray import train, tune
|
||||
from ray import tune
|
||||
from ray.air import session
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.train.constants import PREPROCESSOR_KEY
|
||||
|
||||
from ray.data.preprocessor import Preprocessor
|
||||
from ray.train.constants import PREPROCESSOR_KEY
|
||||
from ray.train.data_parallel_trainer import DataParallelTrainer
|
||||
from ray.tune.tune_config import TuneConfig
|
||||
from ray.tune.tuner import Tuner
|
||||
|
@ -24,7 +24,7 @@ scale_config = {"num_workers": 2}
|
|||
|
||||
def test_fit_train(ray_start_4_cpus):
|
||||
def train_func():
|
||||
train.report(loss=1)
|
||||
session.report({"loss": 1})
|
||||
|
||||
trainer = DataParallelTrainer(
|
||||
train_loop_per_worker=train_func, scaling_config=scale_config
|
||||
|
@ -35,7 +35,7 @@ def test_fit_train(ray_start_4_cpus):
|
|||
def test_scaling_config(ray_start_4_cpus):
|
||||
def train_func():
|
||||
assert ray.available_resources()["CPU"] == 1
|
||||
train.report(loss=1)
|
||||
session.report({"loss": 1})
|
||||
|
||||
assert ray.available_resources()["CPU"] == 4
|
||||
trainer = DataParallelTrainer(
|
||||
|
@ -46,7 +46,7 @@ def test_scaling_config(ray_start_4_cpus):
|
|||
|
||||
def test_fit_train_config(ray_start_4_cpus):
|
||||
def train_func(config):
|
||||
train.report(loss=config["x"])
|
||||
session.report({"loss": config["x"]})
|
||||
|
||||
trainer = DataParallelTrainer(
|
||||
train_loop_per_worker=train_func,
|
||||
|
@ -65,10 +65,10 @@ def test_datasets(ray_start_4_cpus):
|
|||
|
||||
def get_dataset():
|
||||
# Train dataset should be sharded.
|
||||
train_dataset = train.get_dataset_shard("train")
|
||||
train_dataset = session.get_dataset_shard("train")
|
||||
assert train_dataset.count() == num_train_data / scale_config["num_workers"]
|
||||
# All other datasets should not be sharded.
|
||||
val_dataset = train.get_dataset_shard("val")
|
||||
val_dataset = session.get_dataset_shard("val")
|
||||
assert val_dataset.count() == num_val_data
|
||||
|
||||
trainer = DataParallelTrainer(
|
||||
|
@ -82,7 +82,7 @@ def test_datasets(ray_start_4_cpus):
|
|||
def test_checkpoint(ray_start_4_cpus):
|
||||
def train_func():
|
||||
for i in range(3):
|
||||
train.save_checkpoint(model=i)
|
||||
session.report({"epoch": i}, checkpoint=Checkpoint.from_dict({"model": i}))
|
||||
|
||||
trainer = DataParallelTrainer(
|
||||
train_loop_per_worker=train_func, scaling_config=scale_config
|
||||
|
@ -99,7 +99,7 @@ def test_preprocessor_in_checkpoint(ray_start_4_cpus):
|
|||
|
||||
def train_func():
|
||||
for i in range(3):
|
||||
train.save_checkpoint(model=i)
|
||||
session.report({"epoch": i}, checkpoint=Checkpoint.from_dict({"model": i}))
|
||||
|
||||
trainer = DataParallelTrainer(
|
||||
train_loop_per_worker=train_func,
|
||||
|
@ -113,13 +113,13 @@ def test_preprocessor_in_checkpoint(ray_start_4_cpus):
|
|||
|
||||
def test_resume_from_checkpoint(ray_start_4_cpus, tmpdir):
|
||||
def train_func():
|
||||
checkpoint = train.load_checkpoint()
|
||||
checkpoint = session.get_checkpoint()
|
||||
if checkpoint:
|
||||
epoch = checkpoint["epoch"]
|
||||
epoch = checkpoint.to_dict()["epoch"]
|
||||
else:
|
||||
epoch = 0
|
||||
for i in range(epoch, epoch + 2):
|
||||
train.save_checkpoint(epoch=i)
|
||||
session.report({"epoch": i}, checkpoint=Checkpoint.from_dict({"epoch": i}))
|
||||
|
||||
trainer = DataParallelTrainer(
|
||||
train_loop_per_worker=train_func, scaling_config=scale_config
|
||||
|
@ -152,7 +152,7 @@ def test_invalid_train_loop(ray_start_4_cpus):
|
|||
|
||||
def test_tune(ray_start_4_cpus):
|
||||
def train_func(config):
|
||||
train.report(loss=config["x"])
|
||||
session.report({"loss": config["x"]})
|
||||
|
||||
trainer = DataParallelTrainer(
|
||||
train_loop_per_worker=train_func,
|
||||
|
@ -173,7 +173,8 @@ def test_tune(ray_start_4_cpus):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
sys.exit(pytest.main(["-v", "-x", __file__]))
|
||||
|
|
|
@ -1,8 +1,12 @@
|
|||
import os
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import ray
|
||||
from ray import train
|
||||
from ray.air import session
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.air.examples.tf.tensorflow_linear_dataset_example import get_dataset
|
||||
from ray.air.examples.tf.tensorflow_linear_dataset_example import (
|
||||
train_func as tensorflow_linear_train_func,
|
||||
|
@ -34,6 +38,8 @@ def build_model():
|
|||
|
||||
@pytest.mark.parametrize("num_workers", [1, 2])
|
||||
def test_tensorflow_linear(ray_start_4_cpus, num_workers):
|
||||
"""Also tests air Keras callback."""
|
||||
|
||||
def train_func(config):
|
||||
result = tensorflow_linear_train_func(config)
|
||||
assert len(result) == epochs
|
||||
|
@ -83,6 +89,39 @@ def test_tensorflow_e2e(ray_start_4_cpus):
|
|||
assert predictions.count() == 3
|
||||
|
||||
|
||||
def test_report_and_load_using_ml_session(ray_start_4_cpus):
|
||||
def train_func():
|
||||
if session.get_checkpoint():
|
||||
with session.get_checkpoint().as_directory() as checkpoint_dir:
|
||||
import tensorflow as tf
|
||||
|
||||
model = tf.keras.models.load_model(checkpoint_dir)
|
||||
else:
|
||||
model = build_model()
|
||||
|
||||
model.save("my_model", overwrite=True)
|
||||
session.report(
|
||||
metrics={"iter": 1}, checkpoint=Checkpoint.from_directory("my_model")
|
||||
)
|
||||
|
||||
scaling_config = {"num_workers": 2}
|
||||
trainer = TensorflowTrainer(
|
||||
train_loop_per_worker=train_func, scaling_config=scaling_config
|
||||
)
|
||||
result = trainer.fit()
|
||||
|
||||
trainer2 = TensorflowTrainer(
|
||||
train_loop_per_worker=train_func,
|
||||
scaling_config=scaling_config,
|
||||
resume_from_checkpoint=result.checkpoint,
|
||||
)
|
||||
result = trainer2.fit()
|
||||
checkpoint = result.checkpoint
|
||||
with checkpoint.as_directory() as ckpt_dir:
|
||||
assert os.path.exists(os.path.join(ckpt_dir, "saved_model.pb"))
|
||||
assert result.metrics["iter"] == 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
|
|
|
@ -1,11 +1,8 @@
|
|||
from typing import TYPE_CHECKING
|
||||
from typing import Optional, Dict, Union
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Union
|
||||
|
||||
from ray.train._internal.session import get_session
|
||||
from ray.train.constants import SESSION_MISUSE_LOG_ONCE_KEY
|
||||
from ray.train._internal.session import (
|
||||
get_session,
|
||||
)
|
||||
from ray.util import PublicAPI, log_once
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -118,7 +115,7 @@ def report(**kwargs) -> None:
|
|||
if session is None:
|
||||
_warn_session_misuse(report.__name__)
|
||||
return
|
||||
session.report(**kwargs)
|
||||
session._report_legacy(**kwargs)
|
||||
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
|
|
|
@ -1,23 +1,14 @@
|
|||
import copy
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Union, Callable, List, TypeVar, Optional, Any, Dict, Type
|
||||
import warnings
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
|
||||
|
||||
import ray
|
||||
from ray.actor import ActorHandle
|
||||
from ray.train.backend import (
|
||||
BackendConfig,
|
||||
)
|
||||
from ray.train.callbacks.callback import TrainingCallback
|
||||
from ray.train._internal.dataset_spec import RayDataset, RayDatasetSpec
|
||||
from ray.train._internal.session import TrainingResultType
|
||||
from ray.train._internal.utils import (
|
||||
construct_train_func,
|
||||
ActorWrapper,
|
||||
)
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.train._internal.backend_executor import (
|
||||
BackendExecutor,
|
||||
InactiveWorkerGroupError,
|
||||
|
@ -25,35 +16,37 @@ from ray.train._internal.backend_executor import (
|
|||
TrainingWorkerError,
|
||||
)
|
||||
from ray.train._internal.checkpoint import (
|
||||
TuneCheckpointManager,
|
||||
CheckpointManager,
|
||||
TuneCheckpointManager,
|
||||
load_checkpoint_from_path,
|
||||
)
|
||||
from ray.train.constants import (
|
||||
TUNE_INSTALLED,
|
||||
DEFAULT_RESULTS_DIR,
|
||||
ENABLE_DETAILED_AUTOFILLED_METRICS_ENV,
|
||||
ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV,
|
||||
TRAIN_PLACEMENT_GROUP_TIMEOUT_S_ENV,
|
||||
TRAIN_ENABLE_WORKER_SPREAD_ENV,
|
||||
)
|
||||
from ray.train._internal.dataset_spec import RayDataset, RayDatasetSpec
|
||||
from ray.train._internal.session import TrainingResultType
|
||||
|
||||
# Ray Train should be usable even if Tune is not installed.
|
||||
from ray.train._internal.utils import construct_path
|
||||
from ray.train._internal.utils import ActorWrapper, construct_path, construct_train_func
|
||||
from ray.train._internal.worker_group import WorkerGroup
|
||||
from ray.util.annotations import DeveloperAPI, Deprecated
|
||||
from ray.util.ml_utils.checkpoint_manager import CheckpointStrategy
|
||||
|
||||
from ray.train.backend import BackendConfig
|
||||
from ray.train.base_trainer import ( # noqa: F401
|
||||
BaseTrainer,
|
||||
GenDataset,
|
||||
TrainingFailedError,
|
||||
)
|
||||
from ray.train.callbacks.callback import TrainingCallback
|
||||
from ray.train.constants import (
|
||||
DEFAULT_RESULTS_DIR,
|
||||
ENABLE_DETAILED_AUTOFILLED_METRICS_ENV,
|
||||
ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV,
|
||||
TRAIN_ENABLE_WORKER_SPREAD_ENV,
|
||||
TRAIN_PLACEMENT_GROUP_TIMEOUT_S_ENV,
|
||||
TUNE_INSTALLED,
|
||||
)
|
||||
from ray.util.annotations import Deprecated, DeveloperAPI
|
||||
from ray.util.ml_utils.checkpoint_manager import CheckpointStrategy
|
||||
|
||||
if TUNE_INSTALLED:
|
||||
from ray import tune
|
||||
from ray.tune import Trainable
|
||||
from ray.tune import PlacementGroupFactory
|
||||
from ray.tune import PlacementGroupFactory, Trainable
|
||||
from ray.tune.function_runner import wrap_function
|
||||
else:
|
||||
tune = PlacementGroupFactory = Trainable = object
|
||||
|
@ -676,7 +669,7 @@ class TrainingIterator:
|
|||
train_func: Union[Callable[[], T], Callable[[Dict[str, Any]], T]],
|
||||
dataset_spec: RayDatasetSpec,
|
||||
checkpoint_manager: CheckpointManager,
|
||||
checkpoint: Optional[Union[Dict, str, Path]],
|
||||
checkpoint: Optional[Union[Dict, str, Path, Checkpoint]],
|
||||
checkpoint_strategy: Optional[CheckpointStrategy],
|
||||
run_dir: Optional[Path] = None,
|
||||
):
|
||||
|
@ -715,12 +708,12 @@ class TrainingIterator:
|
|||
run_dir=run_dir,
|
||||
latest_checkpoint_id=latest_checkpoint_id,
|
||||
)
|
||||
checkpoint_dict = self._checkpoint_manager._load_checkpoint(checkpoint)
|
||||
checkpoint = self._checkpoint_manager._load_checkpoint(checkpoint)
|
||||
self._run_with_error_handling(
|
||||
lambda: self._backend_executor.start_training(
|
||||
train_func=train_func,
|
||||
dataset_spec=dataset_spec,
|
||||
checkpoint=checkpoint_dict,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -1,34 +1,34 @@
|
|||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import inspect
|
||||
import shutil
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from functools import partial
|
||||
from numbers import Number
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from ray.util.annotations import DeveloperAPI
|
||||
from six.moves import queue
|
||||
|
||||
from ray.util.debug import log_once
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.tune import TuneError, session
|
||||
from ray.tune.trainable import Trainable, TrainableUtil
|
||||
from ray.tune.result import (
|
||||
DEFAULT_METRIC,
|
||||
TIME_THIS_ITER_S,
|
||||
RESULT_DUPLICATE,
|
||||
SHOULD_CHECKPOINT,
|
||||
TIME_THIS_ITER_S,
|
||||
)
|
||||
from ray.tune.trainable import Trainable, TrainableUtil
|
||||
from ray.tune.utils import (
|
||||
detect_checkpoint_function,
|
||||
detect_config_single,
|
||||
detect_reporter,
|
||||
)
|
||||
from ray.tune.utils.trainable import with_parameters # noqa: F401
|
||||
from ray.util.annotations import DeveloperAPI
|
||||
from ray.util.debug import log_once
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -123,8 +123,6 @@ class FuncCheckpointUtil:
|
|||
|
||||
|
||||
class _StatusReporter:
|
||||
"""Object passed into your function that you can report status through."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
result_queue,
|
||||
|
@ -145,6 +143,10 @@ class _StatusReporter:
|
|||
self._last_checkpoint = None
|
||||
self._fresh_checkpoint = False
|
||||
self._trial_resources = trial_resources
|
||||
# Also used as a marker of whether new `report()` API is being used,
|
||||
# in which case, `_iter` will be incremented from 0 every time `report`
|
||||
# is called.
|
||||
self._iter = None
|
||||
|
||||
def reset(self, trial_name=None, trial_id=None, logdir=None, trial_resources=None):
|
||||
self._trial_name = trial_name
|
||||
|
@ -153,6 +155,7 @@ class _StatusReporter:
|
|||
self._last_checkpoint = None
|
||||
self._fresh_checkpoint = False
|
||||
self._trial_resources = trial_resources
|
||||
self._iter = None
|
||||
|
||||
def __call__(self, _metric=None, **kwargs):
|
||||
"""Report updated training status.
|
||||
|
@ -168,7 +171,7 @@ class _StatusReporter:
|
|||
"""
|
||||
|
||||
assert self._last_report_time is not None, (
|
||||
"StatusReporter._start() must be called before the first "
|
||||
"_StatusReporter._start() must be called before the first "
|
||||
"report __call__ is made to ensure correct runtime metrics."
|
||||
)
|
||||
|
||||
|
@ -229,6 +232,26 @@ class _StatusReporter:
|
|||
def _start(self):
|
||||
self._last_report_time = time.time()
|
||||
|
||||
def report(self, metrics: Dict, *, checkpoint: Optional[Checkpoint] = None) -> None:
|
||||
# TODO(xwjiang): Tons of optimizations.
|
||||
if not self._iter:
|
||||
self._iter = 0
|
||||
if checkpoint:
|
||||
checkpoint_dir = self.make_checkpoint_dir(step=self._iter)
|
||||
self.set_checkpoint(checkpoint_dir)
|
||||
checkpoint.to_directory(checkpoint_dir)
|
||||
# TODO(krfricke): Remove this once support is added in Checkpoint.
|
||||
open(os.path.join(checkpoint_dir, ".is_checkpoint"), "a").close()
|
||||
self.__call__(**metrics)
|
||||
self._iter += 1
|
||||
|
||||
@property
|
||||
def loaded_checkpoint(self) -> Optional[Checkpoint]:
|
||||
if self._last_checkpoint:
|
||||
assert isinstance(self._last_checkpoint, str)
|
||||
return Checkpoint.from_directory(self._last_checkpoint)
|
||||
return None
|
||||
|
||||
@property
|
||||
def logdir(self):
|
||||
return self._logdir
|
||||
|
@ -484,7 +507,7 @@ class FunctionRunner(Trainable):
|
|||
obj = TrainableUtil.checkpoint_to_object(checkpoint_path)
|
||||
return obj
|
||||
|
||||
def load_checkpoint(self, checkpoint):
|
||||
def load_checkpoint(self, checkpoint: str):
|
||||
# This should be removed once Trainables are refactored.
|
||||
if "tune_checkpoint_path" in checkpoint:
|
||||
del checkpoint["tune_checkpoint_path"]
|
||||
|
|
|
@ -1,113 +1,9 @@
|
|||
from collections import Counter
|
||||
from typing import Dict, List, Union, Optional
|
||||
|
||||
from tensorflow.keras.callbacks import Callback
|
||||
from ray import tune
|
||||
|
||||
import os
|
||||
from collections import Counter
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
|
||||
class TuneCallback(Callback):
|
||||
"""Base class for Tune's Keras callbacks."""
|
||||
|
||||
_allowed = [
|
||||
"batch_begin",
|
||||
"batch_end",
|
||||
"epoch_begin",
|
||||
"epoch_end",
|
||||
"train_batch_begin",
|
||||
"train_batch_end",
|
||||
"test_batch_begin",
|
||||
"test_batch_end",
|
||||
"predict_batch_begin",
|
||||
"predict_batch_end",
|
||||
"train_begin",
|
||||
"train_end",
|
||||
"test_begin",
|
||||
"test_end",
|
||||
"predict_begin",
|
||||
"predict_end",
|
||||
]
|
||||
|
||||
def __init__(self, on: Union[str, List[str]] = "validation_end"):
|
||||
super(TuneCallback, self).__init__()
|
||||
|
||||
if not isinstance(on, list):
|
||||
on = [on]
|
||||
if any(w not in self._allowed for w in on):
|
||||
raise ValueError(
|
||||
"Invalid trigger time selected: {}. Must be one of {}".format(
|
||||
on, self._allowed
|
||||
)
|
||||
)
|
||||
self._on = on
|
||||
|
||||
def _handle(self, logs: Dict, when: str):
|
||||
raise NotImplementedError
|
||||
|
||||
def on_batch_begin(self, batch, logs=None):
|
||||
if "batch_begin" in self._on:
|
||||
self._handle(logs, "batch_begin")
|
||||
|
||||
def on_batch_end(self, batch, logs=None):
|
||||
if "batch_end" in self._on:
|
||||
self._handle(logs, "batch_end")
|
||||
|
||||
def on_epoch_begin(self, epoch, logs=None):
|
||||
if "epoch_begin" in self._on:
|
||||
self._handle(logs, "epoch_begin")
|
||||
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
if "epoch_end" in self._on:
|
||||
self._handle(logs, "epoch_end")
|
||||
|
||||
def on_train_batch_begin(self, batch, logs=None):
|
||||
if "train_batch_begin" in self._on:
|
||||
self._handle(logs, "train_batch_begin")
|
||||
|
||||
def on_train_batch_end(self, batch, logs=None):
|
||||
if "train_batch_end" in self._on:
|
||||
self._handle(logs, "train_batch_end")
|
||||
|
||||
def on_test_batch_begin(self, batch, logs=None):
|
||||
if "test_batch_begin" in self._on:
|
||||
self._handle(logs, "test_batch_begin")
|
||||
|
||||
def on_test_batch_end(self, batch, logs=None):
|
||||
if "test_batch_end" in self._on:
|
||||
self._handle(logs, "test_batch_end")
|
||||
|
||||
def on_predict_batch_begin(self, batch, logs=None):
|
||||
if "predict_batch_begin" in self._on:
|
||||
self._handle(logs, "predict_batch_begin")
|
||||
|
||||
def on_predict_batch_end(self, batch, logs=None):
|
||||
if "predict_batch_end" in self._on:
|
||||
self._handle(logs, "predict_batch_end")
|
||||
|
||||
def on_train_begin(self, logs=None):
|
||||
if "train_begin" in self._on:
|
||||
self._handle(logs, "train_begin")
|
||||
|
||||
def on_train_end(self, logs=None):
|
||||
if "train_end" in self._on:
|
||||
self._handle(logs, "train_end")
|
||||
|
||||
def on_test_begin(self, logs=None):
|
||||
if "test_begin" in self._on:
|
||||
self._handle(logs, "test_begin")
|
||||
|
||||
def on_test_end(self, logs=None):
|
||||
if "test_end" in self._on:
|
||||
self._handle(logs, "test_end")
|
||||
|
||||
def on_predict_begin(self, logs=None):
|
||||
if "predict_begin" in self._on:
|
||||
self._handle(logs, "predict_begin")
|
||||
|
||||
def on_predict_end(self, logs=None):
|
||||
if "predict_end" in self._on:
|
||||
self._handle(logs, "predict_end")
|
||||
from ray import tune
|
||||
from ray.air.callbacks.keras import _Callback as TuneCallback
|
||||
|
||||
|
||||
class TuneReportCallback(TuneCallback):
|
||||
|
|
|
@ -1,23 +1,54 @@
|
|||
from contextlib import contextmanager
|
||||
import inspect
|
||||
import os
|
||||
import logging
|
||||
import os
|
||||
import traceback
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, Optional, Set
|
||||
|
||||
import ray
|
||||
from ray.air._internal.session import Session
|
||||
from ray.air.checkpoint import Checkpoint
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.function_runner import _StatusReporter
|
||||
from ray.util.annotations import DeveloperAPI, PublicAPI
|
||||
from ray.util.debug import log_once
|
||||
from ray.util.annotations import PublicAPI, DeveloperAPI
|
||||
from ray.util.placement_group import _valid_resource_shape
|
||||
from ray.util.scheduling_strategies import (
|
||||
SchedulingStrategyT,
|
||||
PlacementGroupSchedulingStrategy,
|
||||
SchedulingStrategyT,
|
||||
)
|
||||
from ray.tune.error import TuneError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_session = None
|
||||
_session: Optional[_StatusReporter] = None
|
||||
# V2 Session API.
|
||||
_session_v2: Optional["_TuneSessionImpl"] = None
|
||||
|
||||
|
||||
class _TuneSessionImpl(Session):
|
||||
"""Session client that function trainable can interact with."""
|
||||
|
||||
def __init__(self, status_reporter: _StatusReporter):
|
||||
self._status_reporter = status_reporter
|
||||
|
||||
def report(self, metrics: Dict, *, checkpoint: Optional[Checkpoint] = None) -> None:
|
||||
self._status_reporter.report(metrics, checkpoint=checkpoint)
|
||||
|
||||
@property
|
||||
def loaded_checkpoint(self) -> Optional[Checkpoint]:
|
||||
return self._status_reporter.loaded_checkpoint
|
||||
|
||||
@property
|
||||
def trial_name(self) -> str:
|
||||
return self._status_reporter.trial_name
|
||||
|
||||
@property
|
||||
def trial_id(self) -> str:
|
||||
return self._status_reporter.trial_id
|
||||
|
||||
@property
|
||||
def trial_resources(self) -> Dict[str, float]:
|
||||
return self._status_reporter.trial_resources.required_resources
|
||||
|
||||
|
||||
@PublicAPI
|
||||
|
@ -50,6 +81,7 @@ def get_session():
|
|||
def init(reporter, ignore_reinit_error=True):
|
||||
"""Initializes the global trial context for this process."""
|
||||
global _session
|
||||
global _session_v2
|
||||
|
||||
if _session is not None:
|
||||
# TODO(ng): would be nice to stack crawl at creation time to report
|
||||
|
@ -83,6 +115,7 @@ def init(reporter, ignore_reinit_error=True):
|
|||
remote_function._task_launch_hook = tune_task_and_actor_launch_hook
|
||||
|
||||
_session = reporter
|
||||
_session_v2 = _TuneSessionImpl(status_reporter=reporter)
|
||||
|
||||
|
||||
# Cache of resource dicts that have been checked by the launch hook already.
|
||||
|
@ -183,6 +216,11 @@ def report(_metric=None, **kwargs):
|
|||
"""
|
||||
_session = get_session()
|
||||
if _session:
|
||||
if _session._iter:
|
||||
raise ValueError(
|
||||
"It is not allowed to mix `tune.report` with `session.report`."
|
||||
)
|
||||
|
||||
return _session(_metric, **kwargs)
|
||||
|
||||
|
||||
|
@ -242,6 +280,11 @@ def checkpoint_dir(step: int):
|
|||
raise ValueError("checkpoint_dir(step) must be provided - got None.")
|
||||
|
||||
if _session:
|
||||
if _session._iter:
|
||||
raise ValueError(
|
||||
"It is not allowed to mix `with tune.checkpoint_dir` "
|
||||
"with `session.report`."
|
||||
)
|
||||
_checkpoint_dir = _session.make_checkpoint_dir(step=step)
|
||||
else:
|
||||
_checkpoint_dir = os.path.abspath("./")
|
||||
|
|
Loading…
Add table
Reference in a new issue