mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[tune] Fix Trial Serialization (#3743)
This commit is contained in:
parent
597abb24ea
commit
574f0b73bc
2 changed files with 28 additions and 2 deletions
|
@ -1772,6 +1772,30 @@ class TrialRunnerTest(unittest.TestCase):
|
|||
runner2.step()
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
def testCheckpointWithFunction(self):
|
||||
ray.init()
|
||||
trial = Trial(
|
||||
"__fake",
|
||||
config={
|
||||
"callbacks": {
|
||||
"on_episode_start": tune.function(lambda i: i),
|
||||
}
|
||||
},
|
||||
checkpoint_freq=1)
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
runner = TrialRunner(
|
||||
BasicVariantGenerator(), metadata_checkpoint_dir=tmpdir)
|
||||
runner.add_trial(trial)
|
||||
for i in range(5):
|
||||
runner.step()
|
||||
# force checkpoint
|
||||
runner.checkpoint()
|
||||
runner2 = TrialRunner.restore(tmpdir)
|
||||
new_trial = runner2.get_trials()[0]
|
||||
self.assertTrue("callbacks" in new_trial.config)
|
||||
self.assertTrue("on_episode_start" in new_trial.config["callbacks"])
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
|
||||
class SearchAlgorithmTest(unittest.TestCase):
|
||||
def testNestedSuggestion(self):
|
||||
|
|
|
@ -409,7 +409,8 @@ class Trial(object):
|
|||
"_checkpoint": self._checkpoint,
|
||||
"config": self.config,
|
||||
"custom_loggers": self.custom_loggers,
|
||||
"sync_function": self.sync_function
|
||||
"sync_function": self.sync_function,
|
||||
"last_result": self.last_result
|
||||
}
|
||||
|
||||
for key, value in pickle_data.items():
|
||||
|
@ -430,7 +431,8 @@ class Trial(object):
|
|||
logger_started = state.pop("__logger_started__")
|
||||
state["resources"] = json_to_resources(state["resources"])
|
||||
for key in [
|
||||
"_checkpoint", "config", "custom_loggers", "sync_function"
|
||||
"_checkpoint", "config", "custom_loggers", "sync_function",
|
||||
"last_result"
|
||||
]:
|
||||
state[key] = cloudpickle.loads(hex_to_binary(state[key]))
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue