diff --git a/examples/pool_map.py b/examples/pool_map.py index 8be0597..257373e 100644 --- a/examples/pool_map.py +++ b/examples/pool_map.py @@ -14,23 +14,26 @@ sys.path.insert(0, dirname(dirname(abspath(__file__)))) import jobmanager -def func(x): +def func(x, y, z): """Example function with only one argument""" - time.sleep(x[0]/10) - return np.sum(x) + time.sleep(x/10) + return np.sum([x, y, z]) + +def wrapper(data): + return func(*data) # Create list of parameters a = list() for i in range(10): - a.append([i,2.34]) + a.append([i, 2.34, 9]) # mp.Pool example: p_mp = mp.Pool() -res_mp = p_mp.map(func, a) +res_mp = p_mp.map(wrapper, a) # equivalent to mp.Pool() but with progress bar: p_jm = jobmanager.decorators.Pool() -res_jm = p_jm.map(func, a) +res_jm = p_jm.map(wrapper, a) assert res_mp == res_jm print("result: ", res_jm) diff --git a/jobmanager/decorators.py b/jobmanager/decorators.py index 075f43b..b24bcaa 100644 --- a/jobmanager/decorators.py +++ b/jobmanager/decorators.py @@ -436,6 +436,7 @@ def decorate_module_ProgressBar(module, decorator=ProgressBar, **kwargs): decorator = ProgressBar kwargs.pop("override_count") + vdict = module.__dict__ for key in list(vdict.keys()): if hasattr(vdict[key], "__call__"): @@ -456,13 +457,13 @@ def decorate_module_ProgressBar(module, decorator=ProgressBar, **kwargs): print("Jobmanager wrapped {}.{}".format( module.__name__, key)) - elif vdict[key] == mp.Pool: + # Decorate Pool + if vdict[key] == mp.Pool: # replace mp.Pool - setattr(module, vdict[key], Pool) + setattr(module, key, Pool) elif isinstance(vdict[key], ModuleType): # replace mp.Pool in submodules subdict = vdict[key].__dict__ for skey in list(subdict.keys()): - if subdict[skey] == mp.pool.Pool: - setattr(vdict[key], subdict[skey], Pool) - \ No newline at end of file + if subdict[skey] == mp.Pool: + setattr(vdict[key], skey, Pool)