diff --git a/python/ray/dag/py_obj_scanner.py b/python/ray/dag/py_obj_scanner.py index 6ecbe9b4e..20798b844 100644 --- a/python/ray/dag/py_obj_scanner.py +++ b/python/ray/dag/py_obj_scanner.py @@ -106,5 +106,10 @@ class _PyObjScanner(ray.cloudpickle.CloudPickler, Generic[SourceType, Transforme def _replace_index(self, i: int) -> SourceType: return self._replace_table[self._found[i]] + def clear(self): + """Clear the scanner from the _instances""" + if id(self) in _instances: + del _instances[id(self)] + def __del__(self): - del _instances[id(self)] + self.clear() diff --git a/python/ray/dag/tests/test_py_obj_scanner.py b/python/ray/dag/tests/test_py_obj_scanner.py index 00c96252e..e70bf0b22 100644 --- a/python/ray/dag/tests/test_py_obj_scanner.py +++ b/python/ray/dag/tests/test_py_obj_scanner.py @@ -1,4 +1,5 @@ -from ray.dag.py_obj_scanner import _PyObjScanner +from ray.dag.py_obj_scanner import _PyObjScanner, _instances +import pytest class Source: @@ -31,3 +32,35 @@ def test_not_serializing_objects(): replaced = scanner.replace_nodes({obj: 1 for obj in found}) assert replaced == [not_serializable, {"key": 1}] + + +def test_scanner_clear(): + """Test scanner clear to make the scanner GCable""" + prev_len = len(_instances) + + def call_find_nodes(): + scanner = _PyObjScanner(source_type=Source) + my_objs = [Source(), [Source(), {"key": Source()}]] + scanner.find_nodes(my_objs) + scanner.clear() + assert id(scanner) not in _instances + + call_find_nodes() + assert prev_len == len(_instances) + + def call_find_and_replace_nodes(): + scanner = _PyObjScanner(source_type=Source) + my_objs = [Source(), [Source(), {"key": Source()}]] + found = scanner.find_nodes(my_objs) + scanner.replace_nodes({obj: 1 for obj in found}) + scanner.clear() + assert id(scanner) not in _instances + + call_find_and_replace_nodes() + assert prev_len == len(_instances) + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/serve/_private/router.py b/python/ray/serve/_private/router.py index afa360c1e..f0a1b1d06 100644 --- a/python/ray/serve/_private/router.py +++ b/python/ray/serve/_private/router.py @@ -56,6 +56,9 @@ class Query: replacement_table = dict(zip(tasks, resolved)) self.args, self.kwargs = scanner.replace_nodes(replacement_table) + # Make the scanner GCable to avoid memory leak + scanner.clear() + class ReplicaSet: """Data structure representing a set of replica actor handles"""