mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[tune] Fix analysis without registered trainable (#21475)
This PR fixes issues with loading ExperimentAnalysis from path or pickle if the trainable used in the trials is not registered. Chiefly, it ensures that the stub attribute set in load_trials_from_experiment_checkpoint doesn't get overridden by the state of the loaded trial, and that when pickling, all trials in ExperimentAnalysis are turned into stubs if they aren't already. A test has also been added.
This commit is contained in:
parent
08b8f3065b
commit
850eb88cde
3 changed files with 92 additions and 2 deletions
|
@ -2,6 +2,7 @@ import json
|
|||
import logging
|
||||
import os
|
||||
import warnings
|
||||
import traceback
|
||||
from numbers import Number
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
|
@ -703,13 +704,13 @@ class ExperimentAnalysis:
|
|||
try:
|
||||
self.trials += load_trials_from_experiment_checkpoint(
|
||||
experiment_state, stub=True)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"Could not load trials from experiment checkpoint. "
|
||||
f"This means your experiment checkpoint is likely "
|
||||
f"faulty or incomplete, and you won't have access "
|
||||
f"to all analysis methods. "
|
||||
f"Observed error: {e}")
|
||||
f"Observed error:\n{traceback.format_exc()}")
|
||||
|
||||
if not _trial_paths:
|
||||
raise TuneError("No trials found.")
|
||||
|
@ -760,6 +761,23 @@ class ExperimentAnalysis:
|
|||
|
||||
return rows
|
||||
|
||||
def __getstate__(self) -> Dict[str, Any]:
|
||||
"""Ensure that trials are marked as stubs when pickling,
|
||||
so that they can be loaded later without the trainable
|
||||
being registered.
|
||||
"""
|
||||
state = self.__dict__.copy()
|
||||
|
||||
def make_stub_if_needed(trial: Trial) -> Trial:
|
||||
if trial.stub:
|
||||
return trial
|
||||
trial_copy = Trial(trial.trainable_name, stub=True)
|
||||
trial_copy.__setstate__(trial.__getstate__())
|
||||
return trial_copy
|
||||
|
||||
state["trials"] = [make_stub_if_needed(t) for t in state["trials"]]
|
||||
return state
|
||||
|
||||
|
||||
@Deprecated
|
||||
class Analysis(ExperimentAnalysis):
|
||||
|
|
|
@ -3,11 +3,14 @@ import shutil
|
|||
import tempfile
|
||||
import random
|
||||
import os
|
||||
import pickle
|
||||
import pandas as pd
|
||||
from numpy import nan
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.tune import ExperimentAnalysis
|
||||
import ray.tune.registry
|
||||
from ray.tune.utils.mock_trainable import MyTrainableClass
|
||||
|
||||
|
||||
|
@ -300,6 +303,71 @@ class ExperimentAnalysisPropertySuite(unittest.TestCase):
|
|||
self.assertEqual(var, 1)
|
||||
|
||||
|
||||
class ExperimentAnalysisStubSuite(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.test_dir = tempfile.mkdtemp()
|
||||
self.test_name = "analysis_exp"
|
||||
self.num_samples = 2
|
||||
self.metric = "episode_reward_mean"
|
||||
self.test_path = os.path.join(self.test_dir, self.test_name)
|
||||
self.run_test_exp()
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.test_dir, ignore_errors=True)
|
||||
ray.shutdown()
|
||||
|
||||
def run_test_exp(self):
|
||||
def training_function(config, checkpoint_dir=None):
|
||||
tune.report(episode_reward_mean=config["alpha"])
|
||||
|
||||
return tune.run(
|
||||
training_function,
|
||||
name=self.test_name,
|
||||
local_dir=self.test_dir,
|
||||
stop={"training_iteration": 1},
|
||||
num_samples=self.num_samples,
|
||||
config={
|
||||
"alpha": tune.sample_from(
|
||||
lambda spec: 10 + int(90 * random.random())),
|
||||
})
|
||||
|
||||
def testPickling(self):
|
||||
analysis = self.run_test_exp()
|
||||
pickle_path = os.path.join(self.test_dir, "analysis.pickle")
|
||||
with open(pickle_path, "wb") as f:
|
||||
pickle.dump(analysis, f)
|
||||
|
||||
self.assertTrue(
|
||||
analysis.get_best_trial(metric=self.metric, mode="max"))
|
||||
|
||||
ray.shutdown()
|
||||
ray.tune.registry._global_registry = ray.tune.registry._Registry(
|
||||
prefix="global")
|
||||
|
||||
with open(pickle_path, "rb") as f:
|
||||
analysis = pickle.load(f)
|
||||
|
||||
self.assertTrue(
|
||||
analysis.get_best_trial(metric=self.metric, mode="max"))
|
||||
|
||||
def testFromPath(self):
|
||||
self.run_test_exp()
|
||||
analysis = ExperimentAnalysis(self.test_path)
|
||||
|
||||
self.assertTrue(
|
||||
analysis.get_best_trial(metric=self.metric, mode="max"))
|
||||
|
||||
ray.shutdown()
|
||||
ray.tune.registry._global_registry = ray.tune.registry._Registry(
|
||||
prefix="global")
|
||||
|
||||
analysis = ExperimentAnalysis(self.test_path)
|
||||
|
||||
# This will be None if validate_trainable during loading fails
|
||||
self.assertTrue(
|
||||
analysis.get_best_trial(metric=self.metric, mode="max"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
|
|
|
@ -777,7 +777,11 @@ class Trial:
|
|||
for key in self._nonjson_fields:
|
||||
state[key] = cloudpickle.loads(hex_to_binary(state[key]))
|
||||
|
||||
# Ensure that stub doesn't get overriden
|
||||
stub = state.pop("stub", True)
|
||||
self.__dict__.update(state)
|
||||
self.stub = stub or getattr(self, "stub", False)
|
||||
|
||||
if not self.stub:
|
||||
validate_trainable(self.trainable_name)
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue