diff --git a/python/ray/tune/analysis/experiment_analysis.py b/python/ray/tune/analysis/experiment_analysis.py index b78f87164..5476fa506 100644 --- a/python/ray/tune/analysis/experiment_analysis.py +++ b/python/ray/tune/analysis/experiment_analysis.py @@ -2,6 +2,7 @@ import json import logging import os import warnings +import traceback from numbers import Number from typing import Any, Dict, List, Optional, Tuple @@ -703,13 +704,13 @@ class ExperimentAnalysis: try: self.trials += load_trials_from_experiment_checkpoint( experiment_state, stub=True) - except Exception as e: + except Exception: logger.warning( f"Could not load trials from experiment checkpoint. " f"This means your experiment checkpoint is likely " f"faulty or incomplete, and you won't have access " f"to all analysis methods. " - f"Observed error: {e}") + f"Observed error:\n{traceback.format_exc()}") if not _trial_paths: raise TuneError("No trials found.") @@ -760,6 +761,23 @@ class ExperimentAnalysis: return rows + def __getstate__(self) -> Dict[str, Any]: + """Ensure that trials are marked as stubs when pickling, + so that they can be loaded later without the trainable + being registered. + """ + state = self.__dict__.copy() + + def make_stub_if_needed(trial: Trial) -> Trial: + if trial.stub: + return trial + trial_copy = Trial(trial.trainable_name, stub=True) + trial_copy.__setstate__(trial.__getstate__()) + return trial_copy + + state["trials"] = [make_stub_if_needed(t) for t in state["trials"]] + return state + @Deprecated class Analysis(ExperimentAnalysis): diff --git a/python/ray/tune/tests/test_experiment_analysis.py b/python/ray/tune/tests/test_experiment_analysis.py index 591deae1b..c953fd50f 100644 --- a/python/ray/tune/tests/test_experiment_analysis.py +++ b/python/ray/tune/tests/test_experiment_analysis.py @@ -3,11 +3,14 @@ import shutil import tempfile import random import os +import pickle import pandas as pd from numpy import nan import ray from ray import tune +from ray.tune import ExperimentAnalysis +import ray.tune.registry from ray.tune.utils.mock_trainable import MyTrainableClass @@ -300,6 +303,71 @@ class ExperimentAnalysisPropertySuite(unittest.TestCase): self.assertEqual(var, 1) +class ExperimentAnalysisStubSuite(unittest.TestCase): + def setUp(self): + self.test_dir = tempfile.mkdtemp() + self.test_name = "analysis_exp" + self.num_samples = 2 + self.metric = "episode_reward_mean" + self.test_path = os.path.join(self.test_dir, self.test_name) + self.run_test_exp() + + def tearDown(self): + shutil.rmtree(self.test_dir, ignore_errors=True) + ray.shutdown() + + def run_test_exp(self): + def training_function(config, checkpoint_dir=None): + tune.report(episode_reward_mean=config["alpha"]) + + return tune.run( + training_function, + name=self.test_name, + local_dir=self.test_dir, + stop={"training_iteration": 1}, + num_samples=self.num_samples, + config={ + "alpha": tune.sample_from( + lambda spec: 10 + int(90 * random.random())), + }) + + def testPickling(self): + analysis = self.run_test_exp() + pickle_path = os.path.join(self.test_dir, "analysis.pickle") + with open(pickle_path, "wb") as f: + pickle.dump(analysis, f) + + self.assertTrue( + analysis.get_best_trial(metric=self.metric, mode="max")) + + ray.shutdown() + ray.tune.registry._global_registry = ray.tune.registry._Registry( + prefix="global") + + with open(pickle_path, "rb") as f: + analysis = pickle.load(f) + + self.assertTrue( + analysis.get_best_trial(metric=self.metric, mode="max")) + + def testFromPath(self): + self.run_test_exp() + analysis = ExperimentAnalysis(self.test_path) + + self.assertTrue( + analysis.get_best_trial(metric=self.metric, mode="max")) + + ray.shutdown() + ray.tune.registry._global_registry = ray.tune.registry._Registry( + prefix="global") + + analysis = ExperimentAnalysis(self.test_path) + + # This will be None if validate_trainable during loading fails + self.assertTrue( + analysis.get_best_trial(metric=self.metric, mode="max")) + + if __name__ == "__main__": import pytest import sys diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index 41558cef1..69b9b78a1 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -777,7 +777,11 @@ class Trial: for key in self._nonjson_fields: state[key] = cloudpickle.loads(hex_to_binary(state[key])) + # Ensure that stub doesn't get overriden + stub = state.pop("stub", True) self.__dict__.update(state) + self.stub = stub or getattr(self, "stub", False) + if not self.stub: validate_trainable(self.trainable_name)