mirror of
https://github.com/vale981/ray
synced 2025-03-07 02:51:39 -05:00
[tune] make tests faster + fix flaky test (#10264)
This commit is contained in:
parent
9e63f7ccc3
commit
58891551d3
5 changed files with 201 additions and 167 deletions
|
@ -87,7 +87,7 @@ py_test(
|
|||
|
||||
py_test(
|
||||
name = "test_function_api",
|
||||
size = "medium",
|
||||
size = "small",
|
||||
srcs = ["tests/test_function_api.py"],
|
||||
deps = [":tune_lib"],
|
||||
tags = ["exclusive"],
|
||||
|
@ -415,7 +415,7 @@ py_test(
|
|||
|
||||
py_test(
|
||||
name = "hyperband_function_example",
|
||||
size = "medium",
|
||||
size = "small",
|
||||
srcs = ["examples/hyperband_function_example.py"],
|
||||
deps = [":tune_lib"],
|
||||
tags = ["exclusive", "example"],
|
||||
|
@ -433,7 +433,7 @@ py_test(
|
|||
|
||||
py_test(
|
||||
name = "lightgbm_example",
|
||||
size = "medium",
|
||||
size = "small",
|
||||
srcs = ["examples/lightgbm_example.py"],
|
||||
deps = [":tune_lib"],
|
||||
tags = ["exclusive", "example"]
|
||||
|
@ -441,7 +441,7 @@ py_test(
|
|||
|
||||
py_test(
|
||||
name = "logging_example",
|
||||
size = "medium",
|
||||
size = "small",
|
||||
srcs = ["examples/logging_example.py"],
|
||||
deps = [":tune_lib"],
|
||||
tags = ["exclusive", "example"],
|
||||
|
@ -504,7 +504,7 @@ py_test(
|
|||
|
||||
py_test(
|
||||
name = "nevergrad_example",
|
||||
size = "medium",
|
||||
size = "small",
|
||||
srcs = ["examples/nevergrad_example.py"],
|
||||
deps = [":tune_lib"],
|
||||
tags = ["exclusive", "example"],
|
||||
|
@ -513,7 +513,7 @@ py_test(
|
|||
|
||||
py_test(
|
||||
name = "optuna_example",
|
||||
size = "medium",
|
||||
size = "small",
|
||||
srcs = ["examples/optuna_example.py"],
|
||||
deps = [":tune_lib"],
|
||||
tags = ["exclusive", "example"],
|
||||
|
@ -531,7 +531,7 @@ py_test(
|
|||
|
||||
py_test(
|
||||
name = "pbt_convnet_example",
|
||||
size = "medium",
|
||||
size = "small",
|
||||
srcs = ["examples/pbt_convnet_example.py"],
|
||||
deps = [":tune_lib"],
|
||||
tags = ["exclusive", "example"],
|
||||
|
@ -540,7 +540,7 @@ py_test(
|
|||
|
||||
py_test(
|
||||
name = "pbt_convnet_function_example",
|
||||
size = "medium",
|
||||
size = "small",
|
||||
srcs = ["examples/pbt_convnet_function_example.py"],
|
||||
deps = [":tune_lib"],
|
||||
tags = ["exclusive", "example"],
|
||||
|
@ -567,7 +567,7 @@ py_test(
|
|||
|
||||
py_test(
|
||||
name = "pbt_example",
|
||||
size = "medium",
|
||||
size = "small",
|
||||
srcs = ["examples/pbt_example.py"],
|
||||
deps = [":tune_lib"],
|
||||
tags = ["exclusive", "example"],
|
||||
|
|
|
@ -647,63 +647,6 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
|||
self.assertTrue(
|
||||
all(t.last_result.get("hello") == 123 for t in new_trials))
|
||||
|
||||
def testErrorReturn(self):
|
||||
def train(config, reporter):
|
||||
raise Exception("uh oh")
|
||||
|
||||
register_trainable("f1", train)
|
||||
|
||||
def f():
|
||||
run_experiments({
|
||||
"foo": {
|
||||
"run": "f1",
|
||||
}
|
||||
})
|
||||
|
||||
self.assertRaises(TuneError, f)
|
||||
|
||||
def testSuccess(self):
|
||||
def train(config, reporter):
|
||||
for i in range(100):
|
||||
reporter(timesteps_total=i)
|
||||
|
||||
register_trainable("f1", train)
|
||||
[trial] = run_experiments({
|
||||
"foo": {
|
||||
"run": "f1",
|
||||
}
|
||||
})
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 99)
|
||||
|
||||
def testNoRaiseFlag(self):
|
||||
def train(config, reporter):
|
||||
raise Exception()
|
||||
|
||||
register_trainable("f1", train)
|
||||
|
||||
[trial] = run_experiments(
|
||||
{
|
||||
"foo": {
|
||||
"run": "f1",
|
||||
}
|
||||
}, raise_on_failed_trial=False)
|
||||
self.assertEqual(trial.status, Trial.ERROR)
|
||||
|
||||
def testReportInfinity(self):
|
||||
def train(config, reporter):
|
||||
for _ in range(100):
|
||||
reporter(mean_accuracy=float("inf"))
|
||||
|
||||
register_trainable("f1", train)
|
||||
[trial] = run_experiments({
|
||||
"foo": {
|
||||
"run": "f1",
|
||||
}
|
||||
})
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertEqual(trial.last_result["mean_accuracy"], float("inf"))
|
||||
|
||||
def testTrialInfoAccess(self):
|
||||
class TestTrainable(Trainable):
|
||||
def step(self):
|
||||
|
@ -742,64 +685,16 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
|||
return result
|
||||
|
||||
def cleanup(self):
|
||||
time.sleep(2)
|
||||
time.sleep(0.3)
|
||||
open(os.path.join(self.logdir, "marker"), "a").close()
|
||||
return 1
|
||||
|
||||
analysis = tune.run(
|
||||
TestTrainable, num_samples=10, stop={TRAINING_ITERATION: 1})
|
||||
ray.shutdown()
|
||||
for trial in analysis.trials:
|
||||
path = os.path.join(trial.logdir, "marker")
|
||||
assert os.path.exists(path)
|
||||
|
||||
def testNestedResults(self):
|
||||
def create_result(i):
|
||||
return {"test": {"1": {"2": {"3": i, "4": False}}}}
|
||||
|
||||
flattened_keys = list(flatten_dict(create_result(0)))
|
||||
|
||||
class _MockScheduler(FIFOScheduler):
|
||||
results = []
|
||||
|
||||
def on_trial_result(self, trial_runner, trial, result):
|
||||
self.results += [result]
|
||||
return TrialScheduler.CONTINUE
|
||||
|
||||
def on_trial_complete(self, trial_runner, trial, result):
|
||||
self.complete_result = result
|
||||
|
||||
def train(config, reporter):
|
||||
for i in range(100):
|
||||
reporter(**create_result(i))
|
||||
|
||||
algo = _MockSuggestionAlgorithm()
|
||||
scheduler = _MockScheduler()
|
||||
[trial] = tune.run(
|
||||
train,
|
||||
scheduler=scheduler,
|
||||
search_alg=algo,
|
||||
stop={
|
||||
"test/1/2/3": 20
|
||||
}).trials
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertEqual(trial.last_result["test"]["1"]["2"]["3"], 20)
|
||||
self.assertEqual(trial.last_result["test"]["1"]["2"]["4"], False)
|
||||
self.assertEqual(trial.last_result[TRAINING_ITERATION], 21)
|
||||
self.assertEqual(len(scheduler.results), 20)
|
||||
self.assertTrue(
|
||||
all(
|
||||
set(result) >= set(flattened_keys)
|
||||
for result in scheduler.results))
|
||||
self.assertTrue(set(scheduler.complete_result) >= set(flattened_keys))
|
||||
self.assertEqual(len(algo.results), 20)
|
||||
self.assertTrue(
|
||||
all(set(result) >= set(flattened_keys) for result in algo.results))
|
||||
with self.assertRaises(TuneError):
|
||||
[trial] = tune.run(train, stop={"1/2/3": 20})
|
||||
with self.assertRaises(TuneError):
|
||||
[trial] = tune.run(train, stop={"test": 1}).trials
|
||||
|
||||
def testReportTimeStep(self):
|
||||
# Test that no timestep count are logged if never the Trainable never
|
||||
# returns any.
|
||||
|
@ -994,28 +889,6 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
|||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertTrue(trial.has_checkpoint())
|
||||
|
||||
def testIterationCounter(self):
|
||||
def train(config, reporter):
|
||||
for i in range(100):
|
||||
reporter(itr=i, timesteps_this_iter=1)
|
||||
|
||||
register_trainable("exp", train)
|
||||
config = {
|
||||
"my_exp": {
|
||||
"run": "exp",
|
||||
"config": {
|
||||
"iterations": 100,
|
||||
},
|
||||
"stop": {
|
||||
"timesteps_total": 100
|
||||
},
|
||||
}
|
||||
}
|
||||
[trial] = run_experiments(config)
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertEqual(trial.last_result[TRAINING_ITERATION], 100)
|
||||
self.assertEqual(trial.last_result["itr"], 99)
|
||||
|
||||
def testBackwardsCompat(self):
|
||||
class TestTrain(Trainable):
|
||||
def _setup(self, config):
|
||||
|
@ -1263,6 +1136,150 @@ class ShimCreationTest(unittest.TestCase):
|
|||
assert type(shim_searcher_hyperopt) is type(real_searcher_hyperopt)
|
||||
|
||||
|
||||
class ApiTestFast(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
ray.init(
|
||||
num_cpus=4, num_gpus=0, local_mode=True, include_dashboard=False)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
ray.shutdown()
|
||||
_register_all()
|
||||
|
||||
def setUp(self):
|
||||
self.tmpdir = tempfile.mkdtemp()
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdir)
|
||||
|
||||
def testNestedResults(self):
|
||||
def create_result(i):
|
||||
return {"test": {"1": {"2": {"3": i, "4": False}}}}
|
||||
|
||||
flattened_keys = list(flatten_dict(create_result(0)))
|
||||
|
||||
class _MockScheduler(FIFOScheduler):
|
||||
results = []
|
||||
|
||||
def on_trial_result(self, trial_runner, trial, result):
|
||||
self.results += [result]
|
||||
return TrialScheduler.CONTINUE
|
||||
|
||||
def on_trial_complete(self, trial_runner, trial, result):
|
||||
self.complete_result = result
|
||||
|
||||
def train(config, reporter):
|
||||
for i in range(100):
|
||||
reporter(**create_result(i))
|
||||
|
||||
algo = _MockSuggestionAlgorithm()
|
||||
scheduler = _MockScheduler()
|
||||
[trial] = tune.run(
|
||||
train,
|
||||
scheduler=scheduler,
|
||||
search_alg=algo,
|
||||
stop={
|
||||
"test/1/2/3": 20
|
||||
}).trials
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertEqual(trial.last_result["test"]["1"]["2"]["3"], 20)
|
||||
self.assertEqual(trial.last_result["test"]["1"]["2"]["4"], False)
|
||||
self.assertEqual(trial.last_result[TRAINING_ITERATION], 21)
|
||||
self.assertEqual(len(scheduler.results), 20)
|
||||
self.assertTrue(
|
||||
all(
|
||||
set(result) >= set(flattened_keys)
|
||||
for result in scheduler.results))
|
||||
self.assertTrue(set(scheduler.complete_result) >= set(flattened_keys))
|
||||
self.assertEqual(len(algo.results), 20)
|
||||
self.assertTrue(
|
||||
all(set(result) >= set(flattened_keys) for result in algo.results))
|
||||
with self.assertRaises(TuneError):
|
||||
[trial] = tune.run(train, stop={"1/2/3": 20})
|
||||
with self.assertRaises(TuneError):
|
||||
[trial] = tune.run(train, stop={"test": 1}).trials
|
||||
|
||||
def testIterationCounter(self):
|
||||
def train(config, reporter):
|
||||
for i in range(100):
|
||||
reporter(itr=i, timesteps_this_iter=1)
|
||||
|
||||
register_trainable("exp", train)
|
||||
config = {
|
||||
"my_exp": {
|
||||
"run": "exp",
|
||||
"config": {
|
||||
"iterations": 100,
|
||||
},
|
||||
"stop": {
|
||||
"timesteps_total": 100
|
||||
},
|
||||
}
|
||||
}
|
||||
[trial] = run_experiments(config)
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertEqual(trial.last_result[TRAINING_ITERATION], 100)
|
||||
self.assertEqual(trial.last_result["itr"], 99)
|
||||
|
||||
def testErrorReturn(self):
|
||||
def train(config, reporter):
|
||||
raise Exception("uh oh")
|
||||
|
||||
register_trainable("f1", train)
|
||||
|
||||
def f():
|
||||
run_experiments({
|
||||
"foo": {
|
||||
"run": "f1",
|
||||
}
|
||||
})
|
||||
|
||||
self.assertRaises(TuneError, f)
|
||||
|
||||
def testSuccess(self):
|
||||
def train(config, reporter):
|
||||
for i in range(100):
|
||||
reporter(timesteps_total=i)
|
||||
|
||||
register_trainable("f1", train)
|
||||
[trial] = run_experiments({
|
||||
"foo": {
|
||||
"run": "f1",
|
||||
}
|
||||
})
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 99)
|
||||
|
||||
def testNoRaiseFlag(self):
|
||||
def train(config, reporter):
|
||||
raise Exception()
|
||||
|
||||
register_trainable("f1", train)
|
||||
|
||||
[trial] = run_experiments(
|
||||
{
|
||||
"foo": {
|
||||
"run": "f1",
|
||||
}
|
||||
}, raise_on_failed_trial=False)
|
||||
self.assertEqual(trial.status, Trial.ERROR)
|
||||
|
||||
def testReportInfinity(self):
|
||||
def train(config, reporter):
|
||||
for _ in range(100):
|
||||
reporter(mean_accuracy=float("inf"))
|
||||
|
||||
register_trainable("f1", train)
|
||||
[trial] = run_experiments({
|
||||
"foo": {
|
||||
"run": "f1",
|
||||
}
|
||||
})
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertEqual(trial.last_result["mean_accuracy"], float("inf"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
|
|
|
@ -12,8 +12,16 @@ from ray.tune.examples.async_hyperband_example import MyTrainableClass
|
|||
|
||||
|
||||
class ExperimentAnalysisSuite(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
ray.init(
|
||||
num_cpus=4, num_gpus=0, local_mode=True, include_dashboard=False)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
ray.shutdown()
|
||||
|
||||
def setUp(self):
|
||||
ray.init(local_mode=False)
|
||||
self.test_dir = tempfile.mkdtemp()
|
||||
self.test_name = "analysis_exp"
|
||||
self.num_samples = 10
|
||||
|
@ -23,7 +31,6 @@ class ExperimentAnalysisSuite(unittest.TestCase):
|
|||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.test_dir, ignore_errors=True)
|
||||
ray.shutdown()
|
||||
|
||||
def run_test_exp(self):
|
||||
self.ea = tune.run(
|
||||
|
|
|
@ -15,6 +15,14 @@ from ray.tune.examples.async_hyperband_example import MyTrainableClass
|
|||
|
||||
|
||||
class ExperimentAnalysisInMemorySuite(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
ray.init(local_mode=False, num_cpus=1)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
ray.shutdown()
|
||||
|
||||
def setUp(self):
|
||||
class MockTrainable(Trainable):
|
||||
scores_dict = {
|
||||
|
@ -42,11 +50,9 @@ class ExperimentAnalysisInMemorySuite(unittest.TestCase):
|
|||
|
||||
self.MockTrainable = MockTrainable
|
||||
self.test_dir = tempfile.mkdtemp()
|
||||
ray.init(local_mode=False, num_cpus=1)
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.test_dir, ignore_errors=True)
|
||||
ray.shutdown()
|
||||
|
||||
def testInit(self):
|
||||
experiment_checkpoint_path = os.path.join(self.test_dir,
|
||||
|
@ -123,8 +129,15 @@ class ExperimentAnalysisInMemorySuite(unittest.TestCase):
|
|||
|
||||
|
||||
class AnalysisSuite(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
ray.init(local_mode=True, include_dashboard=False)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
ray.shutdown()
|
||||
|
||||
def setUp(self):
|
||||
ray.init(local_mode=True)
|
||||
self.test_dir = tempfile.mkdtemp()
|
||||
self.num_samples = 10
|
||||
self.metric = "episode_reward_mean"
|
||||
|
@ -145,7 +158,6 @@ class AnalysisSuite(unittest.TestCase):
|
|||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.test_dir, ignore_errors=True)
|
||||
ray.shutdown()
|
||||
|
||||
def testDataframe(self):
|
||||
analysis = Analysis(self.test_dir)
|
||||
|
|
|
@ -446,7 +446,7 @@ class TrialRunnerTest3(unittest.TestCase):
|
|||
runner.step() # Start trial
|
||||
runner.step() # Process result, dispatch save
|
||||
runner.step() # Process save
|
||||
self.assertEquals(trials[0].status, Trial.TERMINATED)
|
||||
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
||||
|
||||
trials += [
|
||||
Trial(
|
||||
|
@ -461,7 +461,7 @@ class TrialRunnerTest3(unittest.TestCase):
|
|||
runner.step() # Process result, dispatch save
|
||||
runner.step() # Process save
|
||||
runner.step() # Error
|
||||
self.assertEquals(trials[1].status, Trial.ERROR)
|
||||
self.assertEqual(trials[1].status, Trial.ERROR)
|
||||
|
||||
trials += [
|
||||
Trial(
|
||||
|
@ -472,8 +472,8 @@ class TrialRunnerTest3(unittest.TestCase):
|
|||
]
|
||||
runner.add_trial(trials[2])
|
||||
runner.step() # Start trial
|
||||
self.assertEquals(len(runner.trial_executor.get_checkpoints()), 3)
|
||||
self.assertEquals(trials[2].status, Trial.RUNNING)
|
||||
self.assertEqual(len(runner.trial_executor.get_checkpoints()), 3)
|
||||
self.assertEqual(trials[2].status, Trial.RUNNING)
|
||||
|
||||
runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=self.tmpdir)
|
||||
for tid in ["trial_terminate", "trial_fail"]:
|
||||
|
@ -529,7 +529,7 @@ class TrialRunnerTest3(unittest.TestCase):
|
|||
|
||||
runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=self.tmpdir)
|
||||
new_trials = runner2.get_trials()
|
||||
self.assertEquals(len(new_trials), 3)
|
||||
self.assertEqual(len(new_trials), 3)
|
||||
self.assertTrue(
|
||||
runner2.get_trial("non_checkpoint").status == Trial.TERMINATED)
|
||||
self.assertTrue(
|
||||
|
@ -575,15 +575,15 @@ class TrialRunnerTest3(unittest.TestCase):
|
|||
runner.step()
|
||||
# force checkpoint
|
||||
runner.checkpoint()
|
||||
self.assertEquals(count_checkpoints(tmpdir), 1)
|
||||
self.assertEqual(count_checkpoints(tmpdir), 1)
|
||||
|
||||
runner2 = TrialRunner(resume="LOCAL", local_checkpoint_dir=tmpdir)
|
||||
for _ in range(5):
|
||||
runner2.step()
|
||||
self.assertEquals(count_checkpoints(tmpdir), 2)
|
||||
self.assertEqual(count_checkpoints(tmpdir), 2)
|
||||
|
||||
runner2.checkpoint()
|
||||
self.assertEquals(count_checkpoints(tmpdir), 2)
|
||||
self.assertEqual(count_checkpoints(tmpdir), 2)
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
def testUserCheckpoint(self):
|
||||
|
@ -612,7 +612,13 @@ class TrialRunnerTest3(unittest.TestCase):
|
|||
|
||||
|
||||
class SearchAlgorithmTest(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
ray.init(
|
||||
num_cpus=4, num_gpus=0, local_mode=True, include_dashboard=False)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
ray.shutdown()
|
||||
_register_all()
|
||||
|
||||
|
@ -629,8 +635,6 @@ class SearchAlgorithmTest(unittest.TestCase):
|
|||
self.assertTrue("d=4" in trial.experiment_tag)
|
||||
|
||||
def _test_repeater(self, num_samples, repeat):
|
||||
ray.init(num_cpus=4)
|
||||
|
||||
class TestSuggestion(Searcher):
|
||||
index = 0
|
||||
|
||||
|
@ -660,25 +664,23 @@ class SearchAlgorithmTest(unittest.TestCase):
|
|||
|
||||
def testRepeat1(self):
|
||||
trials = self._test_repeater(num_samples=2, repeat=1)
|
||||
self.assertEquals(len(trials), 2)
|
||||
self.assertEqual(len(trials), 2)
|
||||
parameter_set = {t.evaluated_params["test_variable"] for t in trials}
|
||||
self.assertEquals(len(parameter_set), 2)
|
||||
self.assertEqual(len(parameter_set), 2)
|
||||
|
||||
def testRepeat4(self):
|
||||
trials = self._test_repeater(num_samples=12, repeat=4)
|
||||
self.assertEquals(len(trials), 12)
|
||||
self.assertEqual(len(trials), 12)
|
||||
parameter_set = {t.evaluated_params["test_variable"] for t in trials}
|
||||
self.assertEquals(len(parameter_set), 3)
|
||||
self.assertEqual(len(parameter_set), 3)
|
||||
|
||||
def testOddRepeat(self):
|
||||
trials = self._test_repeater(num_samples=11, repeat=5)
|
||||
self.assertEquals(len(trials), 11)
|
||||
self.assertEqual(len(trials), 11)
|
||||
parameter_set = {t.evaluated_params["test_variable"] for t in trials}
|
||||
self.assertEquals(len(parameter_set), 3)
|
||||
self.assertEqual(len(parameter_set), 3)
|
||||
|
||||
def testSetGetRepeater(self):
|
||||
ray.init(num_cpus=4)
|
||||
|
||||
class TestSuggestion(Searcher):
|
||||
def __init__(self, index):
|
||||
self.index = index
|
||||
|
@ -729,8 +731,6 @@ class SearchAlgorithmTest(unittest.TestCase):
|
|||
assert new_repeater.searcher.returned_result[-1] == {"result": 3}
|
||||
|
||||
def testSetGetLimiter(self):
|
||||
ray.init(num_cpus=4)
|
||||
|
||||
class TestSuggestion(Searcher):
|
||||
def __init__(self, index):
|
||||
self.index = index
|
||||
|
@ -761,8 +761,6 @@ class SearchAlgorithmTest(unittest.TestCase):
|
|||
assert limiter2.suggest("test_3")["score"] == 3
|
||||
|
||||
def testBatchLimiter(self):
|
||||
ray.init(num_cpus=4)
|
||||
|
||||
class TestSuggestion(Searcher):
|
||||
def __init__(self, index):
|
||||
self.index = index
|
||||
|
@ -841,7 +839,7 @@ class ResourcesTest(unittest.TestCase):
|
|||
original = Resources(1, 0, 0, 1, custom_resources={"a": 1, "b": 2})
|
||||
jsoned = resources_to_json(original)
|
||||
new_resource = json_to_resources(jsoned)
|
||||
self.assertEquals(original, new_resource)
|
||||
self.assertEqual(original, new_resource)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Add table
Reference in a new issue