mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
207 lines
7.7 KiB
Python
207 lines
7.7 KiB
Python
import operator
|
|
from typing import Any, Optional
|
|
|
|
|
|
class SegmentTree:
|
|
"""A Segment Tree data structure.
|
|
|
|
https://en.wikipedia.org/wiki/Segment_tree
|
|
|
|
Can be used as regular array, but with two important differences:
|
|
|
|
a) Setting an item's value is slightly slower. It is O(lg capacity),
|
|
instead of O(1).
|
|
b) Offers efficient `reduce` operation which reduces the tree's values
|
|
over some specified contiguous subsequence of items in the array.
|
|
Operation could be e.g. min/max/sum.
|
|
|
|
The data is stored in a list, where the length is 2 * capacity.
|
|
The second half of the list stores the actual values for each index, so if
|
|
capacity=8, values are stored at indices 8 to 15. The first half of the
|
|
array contains the reduced-values of the different (binary divided)
|
|
segments, e.g. (capacity=4):
|
|
0=not used
|
|
1=reduced-value over all elements (array indices 4 to 7).
|
|
2=reduced-value over array indices (4 and 5).
|
|
3=reduced-value over array indices (6 and 7).
|
|
4-7: values of the tree.
|
|
NOTE that the values of the tree are accessed by indices starting at 0, so
|
|
`tree[0]` accesses `internal_array[4]` in the above example.
|
|
"""
|
|
|
|
def __init__(self,
|
|
capacity: int,
|
|
operation: Any,
|
|
neutral_element: Optional[Any] = None):
|
|
"""Initializes a Segment Tree object.
|
|
|
|
Args:
|
|
capacity (int): Total size of the array - must be a power of two.
|
|
operation (operation): Lambda obj, obj -> obj
|
|
The operation for combining elements (eg. sum, max).
|
|
Must be a mathematical group together with the set of
|
|
possible values for array elements.
|
|
neutral_element (Optional[obj]): The neutral element for
|
|
`operation`. Use None for automatically finding a value:
|
|
max: float("-inf"), min: float("inf"), sum: 0.0.
|
|
"""
|
|
|
|
assert capacity > 0 and capacity & (capacity - 1) == 0, \
|
|
"Capacity must be positive and a power of 2!"
|
|
self.capacity = capacity
|
|
if neutral_element is None:
|
|
neutral_element = 0.0 if operation is operator.add else \
|
|
float("-inf") if operation is max else float("inf")
|
|
self.neutral_element = neutral_element
|
|
self.value = [self.neutral_element for _ in range(2 * capacity)]
|
|
self.operation = operation
|
|
|
|
def reduce(self, start: int = 0, end: Optional[int] = None) -> Any:
|
|
"""Applies `self.operation` to subsequence of our values.
|
|
|
|
Subsequence is contiguous, includes `start` and excludes `end`.
|
|
|
|
self.operation(
|
|
arr[start], operation(arr[start+1], operation(... arr[end])))
|
|
|
|
Args:
|
|
start (int): Start index to apply reduction to.
|
|
end (Optional[int]): End index to apply reduction to (excluded).
|
|
|
|
Returns:
|
|
any: The result of reducing self.operation over the specified
|
|
range of `self._value` elements.
|
|
"""
|
|
if end is None:
|
|
end = self.capacity
|
|
elif end < 0:
|
|
end += self.capacity
|
|
|
|
# Init result with neutral element.
|
|
result = self.neutral_element
|
|
# Map start/end to our actual index space (second half of array).
|
|
start += self.capacity
|
|
end += self.capacity
|
|
|
|
# Example:
|
|
# internal-array (first half=sums, second half=actual values):
|
|
# 0 1 2 3 | 4 5 6 7
|
|
# - 6 1 5 | 1 0 2 3
|
|
|
|
# tree.sum(0, 3) = 3
|
|
# internally: start=4, end=7 -> sum values 1 0 2 = 3.
|
|
|
|
# Iterate over tree starting in the actual-values (second half)
|
|
# section.
|
|
# 1) start=4 is even -> do nothing.
|
|
# 2) end=7 is odd -> end-- -> end=6 -> add value to result: result=2
|
|
# 3) int-divide start and end by 2: start=2, end=3
|
|
# 4) start still smaller end -> iterate once more.
|
|
# 5) start=2 is even -> do nothing.
|
|
# 6) end=3 is odd -> end-- -> end=2 -> add value to result: result=1
|
|
# NOTE: This adds the sum of indices 4 and 5 to the result.
|
|
|
|
# Iterate as long as start != end.
|
|
while start < end:
|
|
|
|
# If start is odd: Add its value to result and move start to
|
|
# next even value.
|
|
if start & 1:
|
|
result = self.operation(result, self.value[start])
|
|
start += 1
|
|
|
|
# If end is odd: Move end to previous even value, then add its
|
|
# value to result. NOTE: This takes care of excluding `end` in any
|
|
# situation.
|
|
if end & 1:
|
|
end -= 1
|
|
result = self.operation(result, self.value[end])
|
|
|
|
# Divide both start and end by 2 to make them "jump" into the
|
|
# next upper level reduce-index space.
|
|
start //= 2
|
|
end //= 2
|
|
|
|
# Then repeat till start == end.
|
|
|
|
return result
|
|
|
|
def __setitem__(self, idx: int, val: float) -> None:
|
|
"""
|
|
Inserts/overwrites a value in/into the tree.
|
|
|
|
Args:
|
|
idx (int): The index to insert to. Must be in [0, `self.capacity`[
|
|
val (float): The value to insert.
|
|
"""
|
|
assert 0 <= idx < self.capacity, f"idx={idx} capacity={self.capacity}"
|
|
|
|
# Index of the leaf to insert into (always insert in "second half"
|
|
# of the tree, the first half is reserved for already calculated
|
|
# reduction-values).
|
|
idx += self.capacity
|
|
self.value[idx] = val
|
|
|
|
# Recalculate all affected reduction values (in "first half" of tree).
|
|
idx = idx >> 1 # Divide by 2 (faster than division).
|
|
while idx >= 1:
|
|
update_idx = 2 * idx # calculate only once
|
|
# Update the reduction value at the correct "first half" idx.
|
|
self.value[idx] = self.operation(self.value[update_idx],
|
|
self.value[update_idx + 1])
|
|
idx = idx >> 1 # Divide by 2 (faster than division).
|
|
|
|
def __getitem__(self, idx: int) -> Any:
|
|
assert 0 <= idx < self.capacity
|
|
return self.value[idx + self.capacity]
|
|
|
|
def get_state(self):
|
|
return self.value
|
|
|
|
def set_state(self, state):
|
|
assert len(state) == self.capacity * 2
|
|
self.value = state
|
|
|
|
|
|
class SumSegmentTree(SegmentTree):
|
|
"""A SegmentTree with the reduction `operation`=operator.add."""
|
|
|
|
def __init__(self, capacity: int):
|
|
super(SumSegmentTree, self).__init__(
|
|
capacity=capacity, operation=operator.add)
|
|
|
|
def sum(self, start: int = 0, end: Optional[Any] = None) -> Any:
|
|
"""Returns the sum over a sub-segment of the tree."""
|
|
return self.reduce(start, end)
|
|
|
|
def find_prefixsum_idx(self, prefixsum: float) -> int:
|
|
"""Finds highest i, for which: sum(arr[0]+..+arr[i - i]) <= prefixsum.
|
|
|
|
Args:
|
|
prefixsum (float): `prefixsum` upper bound in above constraint.
|
|
|
|
Returns:
|
|
int: Largest possible index (i) satisfying above constraint.
|
|
"""
|
|
assert 0 <= prefixsum <= self.sum() + 1e-5
|
|
# Global sum node.
|
|
idx = 1
|
|
|
|
# While non-leaf (first half of tree).
|
|
while idx < self.capacity:
|
|
update_idx = 2 * idx
|
|
if self.value[update_idx] > prefixsum:
|
|
idx = update_idx
|
|
else:
|
|
prefixsum -= self.value[update_idx]
|
|
idx = update_idx + 1
|
|
return idx - self.capacity
|
|
|
|
|
|
class MinSegmentTree(SegmentTree):
|
|
def __init__(self, capacity: int):
|
|
super(MinSegmentTree, self).__init__(capacity=capacity, operation=min)
|
|
|
|
def min(self, start: int = 0, end: Optional[Any] = None) -> Any:
|
|
"""Returns min(arr[start], ..., arr[end])"""
|
|
return self.reduce(start, end)
|