mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[tune] fix pbt flakey test (#12418)
This commit is contained in:
parent
f6a5b733d5
commit
323941c745
2 changed files with 14 additions and 6 deletions
|
@ -1,5 +1,6 @@
|
|||
# coding: utf-8
|
||||
import heapq
|
||||
import gc
|
||||
import logging
|
||||
|
||||
from ray.tune.result import TRAINING_ITERATION
|
||||
|
@ -104,6 +105,13 @@ class CheckpointManager:
|
|||
def newest_memory_checkpoint(self):
|
||||
return self._newest_memory_checkpoint
|
||||
|
||||
def replace_newest_memory_checkpoint(self, new_checkpoint):
|
||||
# Forcibly remove the memory checkpoint
|
||||
del self._newest_memory_checkpoint
|
||||
# Apparently avoids memory leaks on k8s/k3s/pods
|
||||
gc.collect()
|
||||
self._newest_memory_checkpoint = new_checkpoint
|
||||
|
||||
def on_checkpoint(self, checkpoint):
|
||||
"""Starts tracking checkpoint metadata on checkpoint.
|
||||
|
||||
|
@ -115,9 +123,7 @@ class CheckpointManager:
|
|||
checkpoint (Checkpoint): Trial state checkpoint.
|
||||
"""
|
||||
if checkpoint.storage == Checkpoint.MEMORY:
|
||||
# Forcibly remove the memory checkpoint
|
||||
del self._newest_memory_checkpoint
|
||||
self._newest_memory_checkpoint = checkpoint
|
||||
self.replace_newest_memory_checkpoint(checkpoint)
|
||||
return
|
||||
|
||||
old_checkpoint = self.newest_persistent_checkpoint
|
||||
|
|
|
@ -12,6 +12,8 @@ from ray.tune import Trainable
|
|||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
||||
from ray.tune.schedulers import PopulationBasedTraining
|
||||
|
||||
MB = 1024**2
|
||||
|
||||
|
||||
class MockParam(object):
|
||||
def __init__(self, params):
|
||||
|
@ -26,7 +28,7 @@ class MockParam(object):
|
|||
|
||||
class PopulationBasedTrainingMemoryTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
ray.init(num_cpus=1)
|
||||
ray.init(num_cpus=1, object_store_memory=100 * MB)
|
||||
|
||||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
|
@ -36,7 +38,7 @@ class PopulationBasedTrainingMemoryTest(unittest.TestCase):
|
|||
def setup(self, config):
|
||||
# Make sure this is large enough so ray uses object store
|
||||
# instead of in-process store.
|
||||
self.large_object = random.getrandbits(int(10e7))
|
||||
self.large_object = random.getrandbits(int(10e6))
|
||||
self.iter = 0
|
||||
self.a = config["a"]
|
||||
|
||||
|
@ -58,7 +60,7 @@ class PopulationBasedTrainingMemoryTest(unittest.TestCase):
|
|||
class CustomExecutor(RayTrialExecutor):
|
||||
def save(self, *args, **kwargs):
|
||||
checkpoint = super(CustomExecutor, self).save(*args, **kwargs)
|
||||
assert len(ray.objects()) <= 10
|
||||
assert len(ray.objects()) <= 12
|
||||
return checkpoint
|
||||
|
||||
param_a = MockParam([1, -1])
|
||||
|
|
Loading…
Add table
Reference in a new issue