implement sampling of vector valued functions

This commit is contained in:
hiro98 2020-04-17 14:28:01 +02:00
parent ad236db35e
commit fe8633fb66

View file

@ -3,7 +3,9 @@ Simple monte carlo integration implementation.
Author: Valentin Boettcher <hiro@protagon.space>
"""
import numpy as np
from scipy.optimize import minimize_scalar, root
import inspect
import functools
from scipy.optimize import minimize_scalar, root, shgo
from dataclasses import dataclass
@ -106,6 +108,51 @@ def find_upper_bound(f, interval, **kwargs):
raise RuntimeError("Could not find an upper bound.")
def _negate(f):
"""A helper that multiplies the given function with -1."""
@functools.wraps(f)
def negated(*args, **kwargs):
return -f(*args, **kwargs)
return negated
def sample_unweighted_vector(
f, interval, seed=None, upper_bound=None, report_efficiency=False, num=None
):
dimension = len(interval)
interval = np.array([_process_interval(i) for i in interval])
if not upper_bound:
result = shgo(_negate(f), bounds=interval)
if not result.success:
raise RuntimeError("Could not find an upper bound.")
upper_bound = -result.fun + 0.1
def allocate_random_chunk():
return np.random.uniform(
[*interval[:, 0], 0], [*interval[:, 1], 1], [1, 1 + dimension],
)
total_points = 0
total_accepted = 0
while True:
points = allocate_random_chunk()
if report_efficiency:
total_points += 1
arg = points[:, 0:-1][0]
if f(arg) > points[:, -1] * upper_bound:
if report_efficiency:
total_accepted += 1
yield (arg, total_accepted / total_points,) if report_efficiency else arg
return
def sample_unweighted(
f,
interval,
@ -476,17 +523,29 @@ def sample_stratified(
def sample_unweighted_array(
num, *args, increment_borders=None, report_efficiency=False, **kwargs
num, f, interval, *args, increment_borders=None, report_efficiency=False, **kwargs
):
"""Sample `num` elements from a distribution. The rest of the
arguments is analogous to `sample_unweighted`.
"""
sample_arr = np.empty(num)
interval = np.array(interval)
vectorized = len(interval.shape) > 1
sample_arr = np.empty((num, interval.shape[0]) if vectorized else num)
samples = None
if len(interval.shape) > 1:
samples = sample_unweighted_vector(
f, interval, *args, report_efficiency=report_efficiency, **kwargs
)
else:
if "chunk_size" not in kwargs:
kwargs["chunk_size"] = num * 10
samples = (
sample_unweighted(*args, report_efficiency=report_efficiency, **kwargs)
sample_unweighted(
f, interval, *args, report_efficiency=report_efficiency, **kwargs
)
if increment_borders is None
else sample_stratified(
*args,