[tune] schedulers: Add test for context finalization (#11889)

This commit is contained in:
Kai Fricke 2020-11-09 20:37:05 +01:00 committed by GitHub
parent a09e49ee94
commit 287aba6dc3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -2,6 +2,7 @@ import os
import json
import random
import unittest
import numpy as np
import sys
import tempfile
@ -12,7 +13,8 @@ import ray
from ray import tune
from ray.tune import Trainable
from ray.tune.result import TRAINING_ITERATION
from ray.tune.schedulers import (HyperBandScheduler, AsyncHyperBandScheduler,
from ray.tune.schedulers import (FIFOScheduler, HyperBandScheduler,
AsyncHyperBandScheduler,
PopulationBasedTraining, MedianStoppingRule,
TrialScheduler, HyperBandForBOHB)
@ -1696,6 +1698,59 @@ class PopulationBasedTestingSuite(unittest.TestCase):
pbt._exploit(runner.trial_executor, trials[1], trials[2])
shutil.rmtree(tmpdir)
def testContextExit(self):
vals = [5, 1]
class MockContext:
def __init__(self, config):
self.config = config
self.active = False
def __enter__(self):
print("Set up resource.", self.config)
with open("status.txt", "wt") as fp:
fp.write("Activate\n")
self.active = True
return self
def __exit__(self, type, value, traceback):
print("Clean up resource.", self.config)
with open("status.txt", "at") as fp:
fp.write("Cleanup\n")
self.active = False
def train(config):
with MockContext(config):
for i in range(10):
tune.report(metric=i + config["x"])
class MockScheduler(FIFOScheduler):
def on_trial_result(self, trial_runner, trial, result):
return TrialScheduler.STOP
scheduler = MockScheduler()
out = tune.run(
train, config={"x": tune.grid_search(vals)}, scheduler=scheduler)
ever_active = set()
active = set()
for trial in out.trials:
with open(os.path.join(trial.logdir, "status.txt"), "rt") as fp:
status = fp.read()
print(f"Status for trial {trial}: {status}")
if "Activate" in status:
ever_active.add(trial)
active.add(trial)
if "Cleanup" in status:
active.remove(trial)
print(f"Ever active: {ever_active}")
print(f"Still active: {active}")
self.assertEqual(len(ever_active), len(vals))
self.assertEqual(len(active), 0)
class E2EPopulationBasedTestingSuite(unittest.TestCase):
def setUp(self):