[tune] Fix Trial Serialization (#3743)

This commit is contained in:
Richard Liaw 2019-01-10 19:26:10 -08:00 committed by GitHub
parent 597abb24ea
commit 574f0b73bc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 28 additions and 2 deletions

View file

@ -1772,6 +1772,30 @@ class TrialRunnerTest(unittest.TestCase):
runner2.step()
shutil.rmtree(tmpdir)
def testCheckpointWithFunction(self):
ray.init()
trial = Trial(
"__fake",
config={
"callbacks": {
"on_episode_start": tune.function(lambda i: i),
}
},
checkpoint_freq=1)
tmpdir = tempfile.mkdtemp()
runner = TrialRunner(
BasicVariantGenerator(), metadata_checkpoint_dir=tmpdir)
runner.add_trial(trial)
for i in range(5):
runner.step()
# force checkpoint
runner.checkpoint()
runner2 = TrialRunner.restore(tmpdir)
new_trial = runner2.get_trials()[0]
self.assertTrue("callbacks" in new_trial.config)
self.assertTrue("on_episode_start" in new_trial.config["callbacks"])
shutil.rmtree(tmpdir)
class SearchAlgorithmTest(unittest.TestCase):
def testNestedSuggestion(self):

View file

@ -409,7 +409,8 @@ class Trial(object):
"_checkpoint": self._checkpoint,
"config": self.config,
"custom_loggers": self.custom_loggers,
"sync_function": self.sync_function
"sync_function": self.sync_function,
"last_result": self.last_result
}
for key, value in pickle_data.items():
@ -430,7 +431,8 @@ class Trial(object):
logger_started = state.pop("__logger_started__")
state["resources"] = json_to_resources(state["resources"])
for key in [
"_checkpoint", "config", "custom_loggers", "sync_function"
"_checkpoint", "config", "custom_loggers", "sync_function",
"last_result"
]:
state[key] = cloudpickle.loads(hex_to_binary(state[key]))