fix tracking :)

This commit is contained in:
Valentin Boettcher 2022-11-30 13:26:37 -05:00
parent e8a39d5b03
commit 9f816d8a1b
No known key found for this signature in database
GPG key ID: E034E12B7AF56ACE

View file

@ -651,14 +651,14 @@ def integrate_array(
class WelfordAggregator:
__slots__ = ["n", "mean", "_m_2", "_tracker"]
def __init__(self, first_value: np.ndarray, track=False):
def __init__(self, first_value: np.ndarray, i: Optional[int] = None):
self.n = 1
self.mean = first_value
self._m_2 = np.zeros_like(first_value)
self._tracker: Optional[SortedList] = None
if track:
self._tracker = SortedList()
if i is not None:
self._tracker = SortedList([i])
def update(self, new_value: np.ndarray, i: Optional[int] = None):
if self._tracker is not None:
@ -669,8 +669,6 @@ class WelfordAggregator:
return
self._tracker.add(i)
with open("/tmp/out.txt", "a") as f:
f.write(f"adding {i}\n")
self.n += 1
delta = new_value - self.mean
@ -761,13 +759,13 @@ def _ensemble_remote_function(function, chunk: tuple, index: int):
def get_online_data_path(save: str):
return Path("results") / Path(f"online_{save}.npy")
return Path("results") / Path(f"online_{save}.pickle")
def load_online_cache(save: str):
path = get_online_data_path(save)
with path.open("rb") as agg_file:
with portalocker.Lock(path, "rb") as agg_file:
aggregate = pickle.load(agg_file)
return aggregate.ensemble_value
@ -796,7 +794,7 @@ def ensemble_mean_online(
if result is None:
raise RuntimeError("No cache and no result.")
aggregate = WelfordAggregator(result, i is not None)
aggregate = WelfordAggregator(result, i)
with portalocker.Lock(path, "wb") as agg_file:
pickle.dump(aggregate, agg_file)