mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00
[tune] schedulers: Add test for context finalization (#11889)
This commit is contained in:
parent
a09e49ee94
commit
287aba6dc3
1 changed files with 56 additions and 1 deletions
|
@ -2,6 +2,7 @@ import os
|
|||
import json
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import sys
|
||||
import tempfile
|
||||
|
@ -12,7 +13,8 @@ import ray
|
|||
from ray import tune
|
||||
from ray.tune import Trainable
|
||||
from ray.tune.result import TRAINING_ITERATION
|
||||
from ray.tune.schedulers import (HyperBandScheduler, AsyncHyperBandScheduler,
|
||||
from ray.tune.schedulers import (FIFOScheduler, HyperBandScheduler,
|
||||
AsyncHyperBandScheduler,
|
||||
PopulationBasedTraining, MedianStoppingRule,
|
||||
TrialScheduler, HyperBandForBOHB)
|
||||
|
||||
|
@ -1696,6 +1698,59 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
|||
pbt._exploit(runner.trial_executor, trials[1], trials[2])
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
def testContextExit(self):
|
||||
vals = [5, 1]
|
||||
|
||||
class MockContext:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.active = False
|
||||
|
||||
def __enter__(self):
|
||||
print("Set up resource.", self.config)
|
||||
with open("status.txt", "wt") as fp:
|
||||
fp.write("Activate\n")
|
||||
self.active = True
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
print("Clean up resource.", self.config)
|
||||
with open("status.txt", "at") as fp:
|
||||
fp.write("Cleanup\n")
|
||||
self.active = False
|
||||
|
||||
def train(config):
|
||||
with MockContext(config):
|
||||
for i in range(10):
|
||||
tune.report(metric=i + config["x"])
|
||||
|
||||
class MockScheduler(FIFOScheduler):
|
||||
def on_trial_result(self, trial_runner, trial, result):
|
||||
return TrialScheduler.STOP
|
||||
|
||||
scheduler = MockScheduler()
|
||||
|
||||
out = tune.run(
|
||||
train, config={"x": tune.grid_search(vals)}, scheduler=scheduler)
|
||||
|
||||
ever_active = set()
|
||||
active = set()
|
||||
for trial in out.trials:
|
||||
with open(os.path.join(trial.logdir, "status.txt"), "rt") as fp:
|
||||
status = fp.read()
|
||||
print(f"Status for trial {trial}: {status}")
|
||||
if "Activate" in status:
|
||||
ever_active.add(trial)
|
||||
active.add(trial)
|
||||
if "Cleanup" in status:
|
||||
active.remove(trial)
|
||||
|
||||
print(f"Ever active: {ever_active}")
|
||||
print(f"Still active: {active}")
|
||||
|
||||
self.assertEqual(len(ever_active), len(vals))
|
||||
self.assertEqual(len(active), 0)
|
||||
|
||||
|
||||
class E2EPopulationBasedTestingSuite(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
|
Loading…
Add table
Reference in a new issue