mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00

Co-authored-by: xwjiang2010 <xwjiang2010@gmail.com> Co-authored-by: Kai Fricke <kai@anyscale.com> Co-authored-by: xwjiang2010 <87673679+xwjiang2010@users.noreply.github.com> Co-authored-by: Eric Liang <ekhliang@gmail.com>
155 lines
3.6 KiB
Python
155 lines
3.6 KiB
Python
# flake8: noqa
|
|
# isort: skip_file
|
|
|
|
# __session_report_start__
|
|
from ray.air import session, ScalingConfig
|
|
from ray.train.data_parallel_trainer import DataParallelTrainer
|
|
|
|
|
|
def train_fn(config):
|
|
for i in range(10):
|
|
session.report({"step": i})
|
|
|
|
|
|
trainer = DataParallelTrainer(
|
|
train_loop_per_worker=train_fn, scaling_config=ScalingConfig(num_workers=1)
|
|
)
|
|
trainer.fit()
|
|
|
|
# __session_report_end__
|
|
|
|
|
|
# __session_data_info_start__
|
|
import ray.data
|
|
from ray.air import session, ScalingConfig
|
|
from ray.train.data_parallel_trainer import DataParallelTrainer
|
|
|
|
|
|
def train_fn(config):
|
|
dataset_shard = session.get_dataset_shard("train")
|
|
|
|
session.report(
|
|
{
|
|
# Global world size
|
|
"world_size": session.get_world_size(),
|
|
# Global worker rank on the cluster
|
|
"world_rank": session.get_world_rank(),
|
|
# Local worker rank on the current machine
|
|
"local_rank": session.get_local_rank(),
|
|
# Data
|
|
"data_shard": dataset_shard.to_pandas().to_numpy().tolist(),
|
|
}
|
|
)
|
|
|
|
|
|
trainer = DataParallelTrainer(
|
|
train_loop_per_worker=train_fn,
|
|
scaling_config=ScalingConfig(num_workers=2),
|
|
datasets={"train": ray.data.from_items([1, 2, 3, 4])},
|
|
)
|
|
trainer.fit()
|
|
# __session_data_info_end__
|
|
|
|
|
|
# __session_checkpoint_start__
|
|
from ray.air import session, ScalingConfig, Checkpoint
|
|
from ray.train.data_parallel_trainer import DataParallelTrainer
|
|
|
|
|
|
def train_fn(config):
|
|
checkpoint = session.get_checkpoint()
|
|
|
|
if checkpoint:
|
|
state = checkpoint.to_dict()
|
|
else:
|
|
state = {"step": 0}
|
|
|
|
for i in range(state["step"], 10):
|
|
state["step"] += 1
|
|
session.report(
|
|
metrics={"step": state["step"]}, checkpoint=Checkpoint.from_dict(state)
|
|
)
|
|
|
|
|
|
trainer = DataParallelTrainer(
|
|
train_loop_per_worker=train_fn,
|
|
scaling_config=ScalingConfig(num_workers=1),
|
|
resume_from_checkpoint=Checkpoint.from_dict({"step": 4}),
|
|
)
|
|
trainer.fit()
|
|
|
|
# __session_checkpoint_end__
|
|
|
|
|
|
# __scaling_config_start__
|
|
from ray.air import ScalingConfig
|
|
|
|
scaling_config = ScalingConfig(
|
|
# Number of distributed workers.
|
|
num_workers=2,
|
|
# Turn on/off GPU.
|
|
use_gpu=True,
|
|
# Specify resources used for trainer.
|
|
trainer_resources={"CPU": 1},
|
|
# Try to schedule workers on different nodes.
|
|
placement_strategy="SPREAD",
|
|
)
|
|
# __scaling_config_end__
|
|
|
|
# __run_config_start__
|
|
from ray.air import RunConfig
|
|
|
|
run_config = RunConfig(
|
|
# Name of the training run (directory name).
|
|
name="my_train_run",
|
|
# Directory to store results in (will be local_dir/name).
|
|
local_dir="~/ray_results",
|
|
# Low training verbosity.
|
|
verbose=1,
|
|
)
|
|
# __run_config_end__
|
|
|
|
# __failure_config_start__
|
|
from ray.air import RunConfig, FailureConfig
|
|
|
|
run_config = RunConfig(
|
|
failure_config=FailureConfig(
|
|
# Tries to recover a run up to this many times.
|
|
max_failures=2
|
|
)
|
|
)
|
|
# __failure_config_end__
|
|
|
|
# __sync_config_start__
|
|
from ray.air import RunConfig
|
|
from ray.tune import SyncConfig
|
|
|
|
run_config = RunConfig(
|
|
sync_config=SyncConfig(
|
|
# This will store checkpoints on S3.
|
|
upload_dir="s3://remote-bucket/location"
|
|
)
|
|
)
|
|
# __sync_config_end__
|
|
|
|
# __checkpoint_config_start__
|
|
from ray.air import RunConfig, CheckpointConfig
|
|
|
|
run_config = RunConfig(
|
|
checkpoint_config=CheckpointConfig(
|
|
# Only keep this many checkpoints.
|
|
num_to_keep=2
|
|
)
|
|
)
|
|
# __checkpoint_config_end__
|
|
|
|
|
|
# __results_start__
|
|
result = trainer.fit()
|
|
|
|
# Print metrics
|
|
print("Observed metrics:", result.metrics)
|
|
|
|
checkpoint_data = result.checkpoint.to_dict()
|
|
print("Checkpoint data:", checkpoint_data["step"])
|
|
# __results_end__
|