[tune] fix pbt flakey test (#12418)

This commit is contained in:
Richard Liaw 2020-11-25 16:58:37 -08:00 committed by GitHub
parent f6a5b733d5
commit 323941c745
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 6 deletions

View file

@ -1,5 +1,6 @@
# coding: utf-8
import heapq
import gc
import logging
from ray.tune.result import TRAINING_ITERATION
@ -104,6 +105,13 @@ class CheckpointManager:
def newest_memory_checkpoint(self):
return self._newest_memory_checkpoint
def replace_newest_memory_checkpoint(self, new_checkpoint):
# Forcibly remove the memory checkpoint
del self._newest_memory_checkpoint
# Apparently avoids memory leaks on k8s/k3s/pods
gc.collect()
self._newest_memory_checkpoint = new_checkpoint
def on_checkpoint(self, checkpoint):
"""Starts tracking checkpoint metadata on checkpoint.
@ -115,9 +123,7 @@ class CheckpointManager:
checkpoint (Checkpoint): Trial state checkpoint.
"""
if checkpoint.storage == Checkpoint.MEMORY:
# Forcibly remove the memory checkpoint
del self._newest_memory_checkpoint
self._newest_memory_checkpoint = checkpoint
self.replace_newest_memory_checkpoint(checkpoint)
return
old_checkpoint = self.newest_persistent_checkpoint

View file

@ -12,6 +12,8 @@ from ray.tune import Trainable
from ray.tune.ray_trial_executor import RayTrialExecutor
from ray.tune.schedulers import PopulationBasedTraining
MB = 1024**2
class MockParam(object):
def __init__(self, params):
@ -26,7 +28,7 @@ class MockParam(object):
class PopulationBasedTrainingMemoryTest(unittest.TestCase):
def setUp(self):
ray.init(num_cpus=1)
ray.init(num_cpus=1, object_store_memory=100 * MB)
def tearDown(self):
ray.shutdown()
@ -36,7 +38,7 @@ class PopulationBasedTrainingMemoryTest(unittest.TestCase):
def setup(self, config):
# Make sure this is large enough so ray uses object store
# instead of in-process store.
self.large_object = random.getrandbits(int(10e7))
self.large_object = random.getrandbits(int(10e6))
self.iter = 0
self.a = config["a"]
@ -58,7 +60,7 @@ class PopulationBasedTrainingMemoryTest(unittest.TestCase):
class CustomExecutor(RayTrialExecutor):
def save(self, *args, **kwargs):
checkpoint = super(CustomExecutor, self).save(*args, **kwargs)
assert len(ray.objects()) <= 10
assert len(ray.objects()) <= 12
return checkpoint
param_a = MockParam([1, -1])