mirror of
https://github.com/vale981/SecondaryValue
synced 2025-03-04 16:41:38 -05:00
implement vectorization
This commit is contained in:
parent
f789e2b4d6
commit
c13a93b3fe
1 changed files with 72 additions and 17 deletions
|
@ -33,6 +33,9 @@ class SecondaryValue:
|
|||
self._parsed = sympify(self._expr, _clash) if isinstance(self._expr, str) \
|
||||
else self._expr
|
||||
|
||||
self._parsed_lambda = sympy.lambdify(self._parsed.free_symbols,
|
||||
self._expr)
|
||||
|
||||
self._symbols = {symbol.__str__() \
|
||||
for symbol in self._parsed.free_symbols}
|
||||
|
||||
|
@ -85,7 +88,7 @@ class SecondaryValue:
|
|||
|
||||
return kwargs
|
||||
|
||||
def _calculate(self, values, derivs, max_uncertainties, errors):
|
||||
def _calculate(self, values, derivs, error):
|
||||
"""Calculates a value from the expression by substituting
|
||||
variables by the values of the given keyword arguments. If an
|
||||
argument is specified as a tuplpe of (value, error) the
|
||||
|
@ -97,15 +100,13 @@ class SecondaryValue:
|
|||
a tuple the beforementioned as first element
|
||||
"""
|
||||
|
||||
# ugly, but works for now
|
||||
terms = [np.array([(derivs[var](**values) * err[i]) \
|
||||
for var, err in errors.items() \
|
||||
if len(err) > i and err[i] > 0],
|
||||
dtype=self._dtype) for i in range(1, max_uncertainties)]
|
||||
term = np.array([(derivs[var](**values) * err) \
|
||||
for var, err in error.items() \
|
||||
if err > 0], dtype=self._dtype)
|
||||
|
||||
terms = np.array([np.sqrt(t.dot(t)) for t in terms], dtype=self._dtype)
|
||||
term = np.sqrt(term.dot(term))
|
||||
|
||||
return np.insert(terms, 0, self._parsed.subs(values))
|
||||
return self._dtype(term)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""Calculates a value from the expression by substituting
|
||||
|
@ -137,24 +138,64 @@ class SecondaryValue:
|
|||
if isinstance(val, Iterable)] or [0])
|
||||
|
||||
# filter out the error values
|
||||
errors = {var: val for var, val in kwargs.items() \
|
||||
if isinstance(val, Iterable) and len(val) > 1}
|
||||
errors = [{var: val[i] for var, val in kwargs.items() \
|
||||
if isinstance(val, Iterable) and len(val) > i} \
|
||||
for i in range(1, max_uncertainties)]
|
||||
|
||||
if not errors:
|
||||
return self._dtype(self._parsed.subs(kwargs))
|
||||
|
||||
values = {var: (val[0] if isinstance(val, Iterable) else val) \
|
||||
for var, val in kwargs.items()}
|
||||
|
||||
# do the actual calulation
|
||||
terms = []
|
||||
scalar_values, vector_values = filter_out_vecotrized(values)
|
||||
value = 0
|
||||
value_length = length = max([len(elem) \
|
||||
for elem in vector_values.values()] or [0])
|
||||
if vector_values:
|
||||
value = np.empty(value_length)
|
||||
for i in range(0, value_length):
|
||||
current_values = {**scalar_values,
|
||||
**{key: val[i] \
|
||||
for key, val in vector_values.items()}}
|
||||
value[i] = self._parsed_lambda(**current_values)
|
||||
else:
|
||||
value = self._parsed_lambda(**values)
|
||||
|
||||
if not errors:
|
||||
return value
|
||||
|
||||
# get them cached
|
||||
derivs = self._get_derivatives(*list(errors.keys()))
|
||||
derivs = self._get_derivatives(*list(errors[0].keys()))
|
||||
|
||||
terms = self._calculate(values, derivs, max_uncertainties, errors)
|
||||
for error in errors:
|
||||
scalar_errors, vector_errors = filter_out_vecotrized(error)
|
||||
length = max([len(elem) for elem in (list(vector_values.values())
|
||||
+ list(vector_errors.values()))] or [0])
|
||||
if length == 0:
|
||||
terms.append(self._calculate(values,
|
||||
derivs, error))
|
||||
else:
|
||||
tmp = np.empty(length, dtype=self._dtype)
|
||||
for i in range(0, length):
|
||||
current_values = {**scalar_values,
|
||||
**{key: val[i] \
|
||||
for key, val in vector_values.items()}}
|
||||
|
||||
current_errors = {**scalar_errors,
|
||||
**{key: val[i] \
|
||||
for key, val in vector_errors.items()}}
|
||||
|
||||
tmp[i] = self._calculate(current_values,
|
||||
derivs, current_errors)
|
||||
terms.append(tmp)
|
||||
|
||||
|
||||
result = np.array([self._dtype(value)] + terms, dtype=self._dtype)
|
||||
if dep_values:
|
||||
return terms, dep_values
|
||||
return result, dep_values
|
||||
|
||||
return terms
|
||||
return result
|
||||
|
||||
|
||||
@lru_cache(maxsize=32)
|
||||
|
@ -166,7 +207,8 @@ class SecondaryValue:
|
|||
for var in args:
|
||||
if var not in self._derivatives:
|
||||
self._derivatives[var] = \
|
||||
sympy.lambdify(args, diff(self._parsed, var))
|
||||
sympy.lambdify(self._parsed.free_symbols
|
||||
, diff(self._parsed, var))
|
||||
|
||||
return {var: self._derivatives[var] for var in args}
|
||||
|
||||
|
@ -193,3 +235,16 @@ class SecondaryValue:
|
|||
return {symbol: self._deps[symbol].get_symbols() \
|
||||
if symbol in self._deps else {} \
|
||||
for symbol in self._symbols}
|
||||
|
||||
|
||||
def filter_out_vecotrized(dictionary):
|
||||
scalar = dict()
|
||||
vector = dict()
|
||||
|
||||
for key, value in dictionary.items():
|
||||
if isinstance(value, Iterable):
|
||||
vector[key] = value
|
||||
else:
|
||||
scalar[key] = value
|
||||
|
||||
return scalar, vector
|
||||
|
|
Loading…
Add table
Reference in a new issue