[tune] Fix analysis without registered trainable (#21475)

This PR fixes issues with loading ExperimentAnalysis from path or pickle if the trainable used in the trials is not registered. Chiefly, it ensures that the stub attribute set in load_trials_from_experiment_checkpoint doesn't get overridden by the state of the loaded trial, and that when pickling, all trials in ExperimentAnalysis are turned into stubs if they aren't already. A test has also been added.
This commit is contained in:
Antoni Baum 2022-01-24 17:27:08 +01:00 committed by GitHub
parent 08b8f3065b
commit 850eb88cde
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 92 additions and 2 deletions

View file

@ -2,6 +2,7 @@ import json
import logging
import os
import warnings
import traceback
from numbers import Number
from typing import Any, Dict, List, Optional, Tuple
@ -703,13 +704,13 @@ class ExperimentAnalysis:
try:
self.trials += load_trials_from_experiment_checkpoint(
experiment_state, stub=True)
except Exception as e:
except Exception:
logger.warning(
f"Could not load trials from experiment checkpoint. "
f"This means your experiment checkpoint is likely "
f"faulty or incomplete, and you won't have access "
f"to all analysis methods. "
f"Observed error: {e}")
f"Observed error:\n{traceback.format_exc()}")
if not _trial_paths:
raise TuneError("No trials found.")
@ -760,6 +761,23 @@ class ExperimentAnalysis:
return rows
def __getstate__(self) -> Dict[str, Any]:
"""Ensure that trials are marked as stubs when pickling,
so that they can be loaded later without the trainable
being registered.
"""
state = self.__dict__.copy()
def make_stub_if_needed(trial: Trial) -> Trial:
if trial.stub:
return trial
trial_copy = Trial(trial.trainable_name, stub=True)
trial_copy.__setstate__(trial.__getstate__())
return trial_copy
state["trials"] = [make_stub_if_needed(t) for t in state["trials"]]
return state
@Deprecated
class Analysis(ExperimentAnalysis):

View file

@ -3,11 +3,14 @@ import shutil
import tempfile
import random
import os
import pickle
import pandas as pd
from numpy import nan
import ray
from ray import tune
from ray.tune import ExperimentAnalysis
import ray.tune.registry
from ray.tune.utils.mock_trainable import MyTrainableClass
@ -300,6 +303,71 @@ class ExperimentAnalysisPropertySuite(unittest.TestCase):
self.assertEqual(var, 1)
class ExperimentAnalysisStubSuite(unittest.TestCase):
def setUp(self):
self.test_dir = tempfile.mkdtemp()
self.test_name = "analysis_exp"
self.num_samples = 2
self.metric = "episode_reward_mean"
self.test_path = os.path.join(self.test_dir, self.test_name)
self.run_test_exp()
def tearDown(self):
shutil.rmtree(self.test_dir, ignore_errors=True)
ray.shutdown()
def run_test_exp(self):
def training_function(config, checkpoint_dir=None):
tune.report(episode_reward_mean=config["alpha"])
return tune.run(
training_function,
name=self.test_name,
local_dir=self.test_dir,
stop={"training_iteration": 1},
num_samples=self.num_samples,
config={
"alpha": tune.sample_from(
lambda spec: 10 + int(90 * random.random())),
})
def testPickling(self):
analysis = self.run_test_exp()
pickle_path = os.path.join(self.test_dir, "analysis.pickle")
with open(pickle_path, "wb") as f:
pickle.dump(analysis, f)
self.assertTrue(
analysis.get_best_trial(metric=self.metric, mode="max"))
ray.shutdown()
ray.tune.registry._global_registry = ray.tune.registry._Registry(
prefix="global")
with open(pickle_path, "rb") as f:
analysis = pickle.load(f)
self.assertTrue(
analysis.get_best_trial(metric=self.metric, mode="max"))
def testFromPath(self):
self.run_test_exp()
analysis = ExperimentAnalysis(self.test_path)
self.assertTrue(
analysis.get_best_trial(metric=self.metric, mode="max"))
ray.shutdown()
ray.tune.registry._global_registry = ray.tune.registry._Registry(
prefix="global")
analysis = ExperimentAnalysis(self.test_path)
# This will be None if validate_trainable during loading fails
self.assertTrue(
analysis.get_best_trial(metric=self.metric, mode="max"))
if __name__ == "__main__":
import pytest
import sys

View file

@ -777,7 +777,11 @@ class Trial:
for key in self._nonjson_fields:
state[key] = cloudpickle.loads(hex_to_binary(state[key]))
# Ensure that stub doesn't get overriden
stub = state.pop("stub", True)
self.__dict__.update(state)
self.stub = stub or getattr(self, "stub", False)
if not self.stub:
validate_trainable(self.trainable_name)