mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[Train] MLflow start run under correct experiment (#23662)
Start Mlflow run under correct mlflow experiment
This commit is contained in:
parent
56a5007b84
commit
dd28d45261
4 changed files with 37 additions and 3 deletions
|
@ -25,3 +25,15 @@ result = trainer.run(
|
|||
print("Run directory:", trainer.latest_run_dir)
|
||||
|
||||
trainer.shutdown()
|
||||
|
||||
# How to visualize the logs
|
||||
|
||||
# Navigate to the run directory of the trainer.
|
||||
# For example `cd /home/ray_results/train_2021-09-01_12-00-00/run_001`
|
||||
# $ cd <TRAINER_RUN_DIR>
|
||||
#
|
||||
# # View the MLflow UI.
|
||||
# $ mlflow ui
|
||||
#
|
||||
# # View the tensorboard UI.
|
||||
# $ tensorboard --logdir .
|
||||
|
|
|
@ -226,7 +226,8 @@ def test_mlflow(ray_start_4_cpus, tmp_path):
|
|||
client = MlflowClient(
|
||||
tracking_uri=callback.mlflow_util._mlflow.get_tracking_uri())
|
||||
|
||||
all_runs = callback.mlflow_util._mlflow.search_runs(experiment_ids=["0"])
|
||||
experiment_id = client.get_experiment_by_name("test_exp").experiment_id
|
||||
all_runs = callback.mlflow_util._mlflow.search_runs(experiment_ids=[experiment_id])
|
||||
assert len(all_runs) == 1
|
||||
# all_runs is a pandas dataframe.
|
||||
all_runs = all_runs.to_dict(orient="records")
|
||||
|
|
|
@ -161,6 +161,7 @@ class MLflowLoggerUtil:
|
|||
from mlflow.utils.mlflow_tags import MLFLOW_RUN_NAME
|
||||
|
||||
client = self._get_client()
|
||||
tags = tags or {}
|
||||
tags[MLFLOW_RUN_NAME] = run_name
|
||||
run = client.create_run(experiment_id=self.experiment_id, tags=tags)
|
||||
|
||||
|
@ -177,7 +178,9 @@ class MLflowLoggerUtil:
|
|||
if active_run:
|
||||
return active_run
|
||||
|
||||
return self._mlflow.start_run(run_name=run_name, tags=tags)
|
||||
return self._mlflow.start_run(
|
||||
run_name=run_name, experiment_id=self.experiment_id, tags=tags
|
||||
)
|
||||
|
||||
def _run_exists(self, run_id: str) -> bool:
|
||||
"""Check if run with the provided id exists."""
|
||||
|
|
|
@ -36,6 +36,22 @@ class MLflowTest(unittest.TestCase):
|
|||
experiment_name="existing_experiment")
|
||||
assert self.mlflow_util.experiment_id == "0"
|
||||
|
||||
def test_run_started_with_correct_experiment(self):
|
||||
experiment_name = "my_experiment_name"
|
||||
# Make sure run is started under the correct experiment.
|
||||
self.mlflow_util.setup_mlflow(
|
||||
tracking_uri=self.tracking_uri, experiment_name=experiment_name
|
||||
)
|
||||
run = self.mlflow_util.start_run(set_active=True)
|
||||
assert (
|
||||
run.info.experiment_id
|
||||
== self.mlflow_util._mlflow.get_experiment_by_name(
|
||||
experiment_name
|
||||
).experiment_id
|
||||
)
|
||||
|
||||
self.mlflow_util.end_run()
|
||||
|
||||
def test_experiment_name_env_var(self):
|
||||
os.environ["MLFLOW_EXPERIMENT_NAME"] = "existing_experiment"
|
||||
self.mlflow_util.setup_mlflow(tracking_uri=self.tracking_uri)
|
||||
|
@ -75,10 +91,12 @@ class MLflowTest(unittest.TestCase):
|
|||
params2 = {"b": "b"}
|
||||
self.mlflow_util.start_run(set_active=True)
|
||||
self.mlflow_util.log_params(params_to_log=params2, run_id=run_id)
|
||||
assert self.mlflow_util._mlflow.get_run(run_id=run_id).data.params == {
|
||||
run = self.mlflow_util._mlflow.get_run(run_id=run_id)
|
||||
assert run.data.params == {
|
||||
**params,
|
||||
**params2
|
||||
}
|
||||
|
||||
self.mlflow_util.end_run()
|
||||
|
||||
def test_log_metrics(self):
|
||||
|
|
Loading…
Add table
Reference in a new issue