ray/rllib/execution/segment_tree.py
Balaji Veeramani 7f1bacc7dc
[CI] Format Python code with Black (#21975)
See #21316 and #21311 for the motivation behind these changes.
2022-01-29 18:41:57 -08:00

212 lines
7.7 KiB
Python

import operator
from typing import Any, Optional
class SegmentTree:
"""A Segment Tree data structure.
https://en.wikipedia.org/wiki/Segment_tree
Can be used as regular array, but with two important differences:
a) Setting an item's value is slightly slower. It is O(lg capacity),
instead of O(1).
b) Offers efficient `reduce` operation which reduces the tree's values
over some specified contiguous subsequence of items in the array.
Operation could be e.g. min/max/sum.
The data is stored in a list, where the length is 2 * capacity.
The second half of the list stores the actual values for each index, so if
capacity=8, values are stored at indices 8 to 15. The first half of the
array contains the reduced-values of the different (binary divided)
segments, e.g. (capacity=4):
0=not used
1=reduced-value over all elements (array indices 4 to 7).
2=reduced-value over array indices (4 and 5).
3=reduced-value over array indices (6 and 7).
4-7: values of the tree.
NOTE that the values of the tree are accessed by indices starting at 0, so
`tree[0]` accesses `internal_array[4]` in the above example.
"""
def __init__(
self, capacity: int, operation: Any, neutral_element: Optional[Any] = None
):
"""Initializes a Segment Tree object.
Args:
capacity (int): Total size of the array - must be a power of two.
operation (operation): Lambda obj, obj -> obj
The operation for combining elements (eg. sum, max).
Must be a mathematical group together with the set of
possible values for array elements.
neutral_element (Optional[obj]): The neutral element for
`operation`. Use None for automatically finding a value:
max: float("-inf"), min: float("inf"), sum: 0.0.
"""
assert (
capacity > 0 and capacity & (capacity - 1) == 0
), "Capacity must be positive and a power of 2!"
self.capacity = capacity
if neutral_element is None:
neutral_element = (
0.0
if operation is operator.add
else float("-inf")
if operation is max
else float("inf")
)
self.neutral_element = neutral_element
self.value = [self.neutral_element for _ in range(2 * capacity)]
self.operation = operation
def reduce(self, start: int = 0, end: Optional[int] = None) -> Any:
"""Applies `self.operation` to subsequence of our values.
Subsequence is contiguous, includes `start` and excludes `end`.
self.operation(
arr[start], operation(arr[start+1], operation(... arr[end])))
Args:
start (int): Start index to apply reduction to.
end (Optional[int]): End index to apply reduction to (excluded).
Returns:
any: The result of reducing self.operation over the specified
range of `self._value` elements.
"""
if end is None:
end = self.capacity
elif end < 0:
end += self.capacity
# Init result with neutral element.
result = self.neutral_element
# Map start/end to our actual index space (second half of array).
start += self.capacity
end += self.capacity
# Example:
# internal-array (first half=sums, second half=actual values):
# 0 1 2 3 | 4 5 6 7
# - 6 1 5 | 1 0 2 3
# tree.sum(0, 3) = 3
# internally: start=4, end=7 -> sum values 1 0 2 = 3.
# Iterate over tree starting in the actual-values (second half)
# section.
# 1) start=4 is even -> do nothing.
# 2) end=7 is odd -> end-- -> end=6 -> add value to result: result=2
# 3) int-divide start and end by 2: start=2, end=3
# 4) start still smaller end -> iterate once more.
# 5) start=2 is even -> do nothing.
# 6) end=3 is odd -> end-- -> end=2 -> add value to result: result=1
# NOTE: This adds the sum of indices 4 and 5 to the result.
# Iterate as long as start != end.
while start < end:
# If start is odd: Add its value to result and move start to
# next even value.
if start & 1:
result = self.operation(result, self.value[start])
start += 1
# If end is odd: Move end to previous even value, then add its
# value to result. NOTE: This takes care of excluding `end` in any
# situation.
if end & 1:
end -= 1
result = self.operation(result, self.value[end])
# Divide both start and end by 2 to make them "jump" into the
# next upper level reduce-index space.
start //= 2
end //= 2
# Then repeat till start == end.
return result
def __setitem__(self, idx: int, val: float) -> None:
"""
Inserts/overwrites a value in/into the tree.
Args:
idx (int): The index to insert to. Must be in [0, `self.capacity`[
val (float): The value to insert.
"""
assert 0 <= idx < self.capacity, f"idx={idx} capacity={self.capacity}"
# Index of the leaf to insert into (always insert in "second half"
# of the tree, the first half is reserved for already calculated
# reduction-values).
idx += self.capacity
self.value[idx] = val
# Recalculate all affected reduction values (in "first half" of tree).
idx = idx >> 1 # Divide by 2 (faster than division).
while idx >= 1:
update_idx = 2 * idx # calculate only once
# Update the reduction value at the correct "first half" idx.
self.value[idx] = self.operation(
self.value[update_idx], self.value[update_idx + 1]
)
idx = idx >> 1 # Divide by 2 (faster than division).
def __getitem__(self, idx: int) -> Any:
assert 0 <= idx < self.capacity
return self.value[idx + self.capacity]
def get_state(self):
return self.value
def set_state(self, state):
assert len(state) == self.capacity * 2
self.value = state
class SumSegmentTree(SegmentTree):
"""A SegmentTree with the reduction `operation`=operator.add."""
def __init__(self, capacity: int):
super(SumSegmentTree, self).__init__(capacity=capacity, operation=operator.add)
def sum(self, start: int = 0, end: Optional[Any] = None) -> Any:
"""Returns the sum over a sub-segment of the tree."""
return self.reduce(start, end)
def find_prefixsum_idx(self, prefixsum: float) -> int:
"""Finds highest i, for which: sum(arr[0]+..+arr[i - i]) <= prefixsum.
Args:
prefixsum (float): `prefixsum` upper bound in above constraint.
Returns:
int: Largest possible index (i) satisfying above constraint.
"""
assert 0 <= prefixsum <= self.sum() + 1e-5
# Global sum node.
idx = 1
# While non-leaf (first half of tree).
while idx < self.capacity:
update_idx = 2 * idx
if self.value[update_idx] > prefixsum:
idx = update_idx
else:
prefixsum -= self.value[update_idx]
idx = update_idx + 1
return idx - self.capacity
class MinSegmentTree(SegmentTree):
def __init__(self, capacity: int):
super(MinSegmentTree, self).__init__(capacity=capacity, operation=min)
def min(self, start: int = 0, end: Optional[Any] = None) -> Any:
"""Returns min(arr[start], ..., arr[end])"""
return self.reduce(start, end)