[RLlib] Fix bugs and speed up SegmentTree

This commit is contained in:
Sven Mika
2020-03-13 09:03:07 +01:00
committed by GitHub
parent 6022eb53c4
commit 552cfb37ea
7 changed files with 526 additions and 110 deletions
+19 -15
View File
@@ -852,15 +852,26 @@ py_test(
# Tag: optimizers
# --------------------------------------------------------------------
# This has bugs: See PR https://github.com/ray-project/ray/pull/7534
# which fixes these and re-adds this test.
py_test(
name = "test_optimizers",
tags = ["optimizers"],
size = "large",
srcs = ["optimizers/tests/test_optimizers.py"]
)
# py_test(
# name = "test_segment_tree",
# tags = ["optimizers"],
# size = "small",
# srcs = ["optimizers/tests/test_segment_tree.py"]
# )
py_test(
name = "test_segment_tree",
tags = ["optimizers"],
size = "small",
srcs = ["optimizers/tests/test_segment_tree.py"]
)
py_test(
name = "test_prioritized_replay_buffer",
tags = ["optimizers"],
size = "small",
srcs = ["optimizers/tests/test_prioritized_replay_buffer.py"]
)
# --------------------------------------------------------------------
# Policies
@@ -1044,13 +1055,6 @@ py_test(
srcs = ["tests/test_nested_spaces.py"]
)
py_test(
name = "tests/test_optimizers",
tags = ["tests_dir", "tests_dir_O"],
size = "large",
srcs = ["tests/test_optimizers.py"]
)
py_test(
name = "tests/test_exec_api",
tags = ["tests_dir", "tests_dir_E"],
+1 -1
View File
@@ -227,7 +227,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
Array of shape (batch_size,) and dtype np.int32
idexes in buffer of sampled experiences
"""
assert beta > 0
assert beta >= 0.0
self._num_sampled += batch_size
idxes = self._sample_proportional(batch_size)
+149 -94
View File
@@ -2,140 +2,195 @@ import operator
class SegmentTree:
def __init__(self, capacity, operation, neutral_element):
"""Build a Segment Tree data structure.
"""A Segment Tree data structure.
https://en.wikipedia.org/wiki/Segment_tree
https://en.wikipedia.org/wiki/Segment_tree
Can be used as regular array, but with two
important differences:
Can be used as regular array, but with two important differences:
a) setting item's value is slightly slower.
It is O(lg capacity) instead of O(1).
b) user has access to an efficient `reduce`
operation which reduces `operation` over
a contiguous subsequence of items in the
array.
a) Setting an item's value is slightly slower. It is O(lg capacity),
instead of O(1).
b) Offers efficient `reduce` operation which reduces the tree's values
over some specified contiguous subsequence of items in the array.
Operation could be e.g. min/max/sum.
Paramters
---------
capacity: int
Total size of the array - must be a power of two.
operation: lambda obj, obj -> obj
and operation for combining elements (eg. sum, max)
must for a mathematical group together with the set of
possible values for array elements.
neutral_element: obj
neutral element for the operation above. eg. float('-inf')
for max and 0 for sum.
The data is stored in a list, where the length is 2 * capacity.
The second half of the list stores the actual values for each index, so if
capacity=8, values are stored at indices 8 to 15. The first half of the
array contains the reduced-values of the different (binary divided)
segments, e.g. (capacity=4):
0=not used
1=reduced-value over all elements (array indices 4 to 7).
2=reduced-value over array indices (4 and 5).
3=reduced-value over array indices (6 and 7).
4-7: values of the tree.
NOTE that the values of the tree are accessed by indices starting at 0, so
`tree[0]` accesses `internal_array[4]` in the above example.
"""
def __init__(self, capacity, operation, neutral_element=None):
"""Initializes a Segment Tree object.
Args:
capacity (int): Total size of the array - must be a power of two.
operation (operation): Lambda obj, obj -> obj
The operation for combining elements (eg. sum, max).
Must be a mathematical group together with the set of
possible values for array elements.
neutral_element (Optional[obj]): The neutral element for
`operation`. Use None for automatically finding a value:
max: float("-inf"), min: float("inf"), sum: 0.0.
"""
assert capacity > 0 and capacity & (capacity - 1) == 0, \
"capacity must be positive and a power of 2."
self._capacity = capacity
self._value = [neutral_element for _ in range(2 * capacity)]
self._operation = operation
def _reduce_helper(self, start, end, node, node_start, node_end):
if start == node_start and end == node_end:
return self._value[node]
mid = (node_start + node_end) // 2
if end <= mid:
return self._reduce_helper(start, end, 2 * node, node_start, mid)
else:
if mid + 1 <= start:
return self._reduce_helper(start, end, 2 * node + 1, mid + 1,
node_end)
else:
return self._operation(
self._reduce_helper(start, mid, 2 * node, node_start, mid),
self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1,
node_end))
"Capacity must be positive and a power of 2!"
self.capacity = capacity
if neutral_element is None:
neutral_element = 0.0 if operation is operator.add else \
float("-inf") if operation is max else float("inf")
self.neutral_element = neutral_element
self.value = [self.neutral_element for _ in range(2 * capacity)]
self.operation = operation
def reduce(self, start=0, end=None):
"""Returns result of applying `self.operation`
to a contiguous subsequence of the array.
"""Applies `self.operation` to subsequence of our values.
Subsequence is contiguous, includes `start` and excludes `end`.
self.operation(
arr[start], operation(arr[start+1], operation(... arr[end])))
Parameters
----------
start: int
beginning of the subsequence
end: int
end of the subsequences
Args:
start (int): Start index to apply reduction to.
end (Optional[int]): End index to apply reduction to (excluded).
Returns
-------
reduced: obj
result of reducing self.operation over the specified range of array
elements.
Returns:
any: The result of reducing self.operation over the specified
range of `self._value` elements.
"""
if end is None:
end = self._capacity - 1
if end < 0:
end += self._capacity
return self._reduce_helper(start, end, 1, 0, self._capacity - 1)
end = self.capacity
elif end < 0:
end += self.capacity
# Init result with neutral element.
result = self.neutral_element
# Map start/end to our actual index space (second half of array).
start += self.capacity
end += self.capacity
# Example:
# internal-array (first half=sums, second half=actual values):
# 0 1 2 3 | 4 5 6 7
# - 6 1 5 | 1 0 2 3
# tree.sum(0, 3) = 3
# internally: start=4, end=7 -> sum values 1 0 2 = 3.
# Iterate over tree starting in the actual-values (second half)
# section.
# 1) start=4 is even -> do nothing.
# 2) end=7 is odd -> end-- -> end=6 -> add value to result: result=2
# 3) int-divide start and end by 2: start=2, end=3
# 4) start still smaller end -> iterate once more.
# 5) start=2 is even -> do nothing.
# 6) end=3 is odd -> end-- -> end=2 -> add value to result: result=1
# NOTE: This adds the sum of indices 4 and 5 to the result.
# Iterate as long as start != end.
while start < end:
# If start is odd: Add its value to result and move start to
# next even value.
if start & 1:
result = self.operation(result, self.value[start])
start += 1
# If end is odd: Move end to previous even value, then add its
# value to result. NOTE: This takes care of excluding `end` in any
# situation.
if end & 1:
end -= 1
result = self.operation(result, self.value[end])
# Divide both start and end by 2 to make them "jump" into the
# next upper level reduce-index space.
start //= 2
end //= 2
# Then repeat till start == end.
return result
def __setitem__(self, idx, val):
# index of the leaf
idx += self._capacity
self._value[idx] = val
idx //= 2
"""
Inserts/overwrites a value in/into the tree.
Args:
idx (int): The index to insert to. Must be in [0, `self.capacity`[
val (float): The value to insert.
"""
assert 0 <= idx < self.capacity
# Index of the leaf to insert into (always insert in "second half"
# of the tree, the first half is reserved for already calculated
# reduction-values).
idx += self.capacity
self.value[idx] = val
# Recalculate all affected reduction values (in "first half" of tree).
idx = idx >> 1 # Divide by 2 (faster than division).
while idx >= 1:
self._value[idx] = self._operation(self._value[2 * idx],
self._value[2 * idx + 1])
idx //= 2
update_idx = 2 * idx # calculate only once
# Update the reduction value at the correct "first half" idx.
self.value[idx] = self.operation(self.value[update_idx],
self.value[update_idx + 1])
idx = idx >> 1 # Divide by 2 (faster than division).
def __getitem__(self, idx):
assert 0 <= idx < self._capacity
return self._value[self._capacity + idx]
assert 0 <= idx < self.capacity
return self.value[idx + self.capacity]
class SumSegmentTree(SegmentTree):
"""A SegmentTree with the reduction `operation`=operator.add."""
def __init__(self, capacity):
super(SumSegmentTree, self).__init__(
capacity=capacity, operation=operator.add, neutral_element=0.0)
capacity=capacity, operation=operator.add)
def sum(self, start=0, end=None):
"""Returns arr[start] + ... + arr[end]"""
return super(SumSegmentTree, self).reduce(start, end)
"""Returns the sum over a sub-segment of the tree."""
return self.reduce(start, end)
def find_prefixsum_idx(self, prefixsum):
"""Find the highest index `i` in the array such that
sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum
"""Finds highest i, for which: sum(arr[0]+..+arr[i - i]) <= prefixsum.
if array values are probabilities, this function
allows to sample indexes according to the discrete
probability efficiently.
Args:
prefixsum (float): `prefixsum` upper bound in above constraint.
Parameters
----------
perfixsum: float
upperbound on the sum of array prefix
Returns
-------
idx: int
highest index satisfying the prefixsum constraint
Returns:
int: Largest possible index (i) satisfying above constraint.
"""
assert 0 <= prefixsum <= self.sum() + 1e-5
# Global sum node.
idx = 1
while idx < self._capacity: # while non-leaf
if self._value[2 * idx] > prefixsum:
idx = 2 * idx
# While non-leaf (first half of tree).
while idx < self.capacity:
update_idx = 2 * idx
if self.value[update_idx] > prefixsum:
idx = update_idx
else:
prefixsum -= self._value[2 * idx]
idx = 2 * idx + 1
return idx - self._capacity
prefixsum -= self.value[update_idx]
idx = update_idx + 1
return idx - self.capacity
class MinSegmentTree(SegmentTree):
def __init__(self, capacity):
super(MinSegmentTree, self).__init__(
capacity=capacity, operation=min, neutral_element=float("inf"))
super(MinSegmentTree, self).__init__(capacity=capacity, operation=min)
def min(self, start=0, end=None):
"""Returns min(arr[start], ..., arr[end])"""
return super(MinSegmentTree, self).reduce(start, end)
return self.reduce(start, end)
+143
View File
@@ -0,0 +1,143 @@
import operator
class OldSegmentTree(object):
def __init__(self, capacity, operation, neutral_element):
"""Build a Segment Tree data structure.
https://en.wikipedia.org/wiki/Segment_tree
Can be used as regular array, but with two
important differences:
a) setting item's value is slightly slower.
It is O(lg capacity) instead of O(1).
b) user has access to an efficient `reduce`
operation which reduces `operation` over
a contiguous subsequence of items in the
array.
Paramters
---------
capacity: int
Total size of the array - must be a power of two.
operation: lambda obj, obj -> obj
and operation for combining elements (eg. sum, max)
must for a mathematical group together with the set of
possible values for array elements.
neutral_element: obj
neutral element for the operation above. eg. float('-inf')
for max and 0 for sum.
"""
assert capacity > 0 and capacity & (capacity - 1) == 0, \
"capacity must be positive and a power of 2."
self._capacity = capacity
self._value = [neutral_element for _ in range(2 * capacity)]
self._operation = operation
def _reduce_helper(self, start, end, node, node_start, node_end):
if start == node_start and end == node_end:
return self._value[node]
mid = (node_start + node_end) // 2
if end <= mid:
return self._reduce_helper(start, end, 2 * node, node_start, mid)
else:
if mid + 1 <= start:
return self._reduce_helper(start, end, 2 * node + 1, mid + 1,
node_end)
else:
return self._operation(
self._reduce_helper(start, mid, 2 * node, node_start, mid),
self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1,
node_end))
def reduce(self, start=0, end=None):
"""Returns result of applying `self.operation`
to a contiguous subsequence of the array.
self.operation(
arr[start], operation(arr[start+1], operation(... arr[end])))
Parameters
----------
start: int
beginning of the subsequence
end: int
end of the subsequences
Returns
-------
reduced: obj
result of reducing self.operation over the specified range of array
elements.
"""
if end is None:
end = self._capacity
if end < 0:
end += self._capacity
end -= 1
return self._reduce_helper(start, end, 1, 0, self._capacity - 1)
def __setitem__(self, idx, val):
# index of the leaf
idx += self._capacity
self._value[idx] = val
idx //= 2
while idx >= 1:
self._value[idx] = self._operation(self._value[2 * idx],
self._value[2 * idx + 1])
idx //= 2
def __getitem__(self, idx):
assert 0 <= idx < self._capacity
return self._value[self._capacity + idx]
class OldSumSegmentTree(OldSegmentTree):
def __init__(self, capacity):
super(OldSumSegmentTree, self).__init__(
capacity=capacity, operation=operator.add, neutral_element=0.0)
def sum(self, start=0, end=None):
"""Returns arr[start] + ... + arr[end]"""
return super(OldSumSegmentTree, self).reduce(start, end)
def find_prefixsum_idx(self, prefixsum):
"""Find the highest index `i` in the array such that
sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum
if array values are probabilities, this function
allows to sample indexes according to the discrete
probability efficiently.
Parameters
----------
perfixsum: float
upperbound on the sum of array prefix
Returns
-------
idx: int
highest index satisfying the prefixsum constraint
"""
assert 0 <= prefixsum <= self.sum() + 1e-5
idx = 1
while idx < self._capacity: # while non-leaf
if self._value[2 * idx] > prefixsum:
idx = 2 * idx
else:
prefixsum -= self._value[2 * idx]
idx = 2 * idx + 1
return idx - self._capacity
class OldMinSegmentTree(OldSegmentTree):
def __init__(self, capacity):
super(OldMinSegmentTree, self).__init__(
capacity=capacity, operation=min, neutral_element=float("inf"))
def min(self, start=0, end=None):
"""Returns min(arr[start], ..., arr[end])"""
return super(OldMinSegmentTree, self).reduce(start, end)
@@ -0,0 +1,180 @@
from collections import Counter
import numpy as np
import unittest
from ray.rllib.optimizers.replay_buffer import PrioritizedReplayBuffer
from ray.rllib.utils.test_utils import check
class TestPrioritizedReplayBuffer(unittest.TestCase):
"""
Tests insertion and (weighted) sampling of the PrioritizedReplayBuffer.
"""
capacity = 10
alpha = 1.0
beta = 1.0
max_priority = 1.0
def _generate_data(self):
return (
np.random.random((4, )), # obs_t
np.random.choice([0, 1]), # action
np.random.rand(), # reward
np.random.random((4, )), # obs_tp1
np.random.choice([False, True]), # done
)
def test_add(self):
memory = PrioritizedReplayBuffer(
size=2,
alpha=self.alpha,
)
# Assert indices 0 before insert.
self.assertEqual(len(memory), 0)
self.assertEqual(memory._next_idx, 0)
# Insert single record.
data = self._generate_data()
memory.add(*data, weight=0.5)
self.assertTrue(len(memory) == 1)
self.assertTrue(memory._next_idx == 1)
# Insert single record.
data = self._generate_data()
memory.add(*data, weight=0.1)
self.assertTrue(len(memory) == 2)
self.assertTrue(memory._next_idx == 0)
# Insert over capacity.
data = self._generate_data()
memory.add(*data, weight=1.0)
self.assertTrue(len(memory) == 2)
self.assertTrue(memory._next_idx == 1)
def test_update_priorities(self):
memory = PrioritizedReplayBuffer(size=self.capacity, alpha=self.alpha)
# Insert n samples.
num_records = 5
for i in range(num_records):
data = self._generate_data()
memory.add(*data, weight=1.0)
self.assertTrue(len(memory) == i + 1)
self.assertTrue(memory._next_idx == i + 1)
# Fetch records, their indices and weights.
_, _, _, _, _, weights, indices = \
memory.sample(3, beta=self.beta)
check(weights, np.ones(shape=(3, )))
self.assertEqual(3, len(indices))
self.assertTrue(len(memory) == num_records)
self.assertTrue(memory._next_idx == num_records)
# Update weight of indices 0, 2, 3, 4 to very small.
memory.update_priorities(
np.array([0, 2, 3, 4]), np.array([0.01, 0.01, 0.01, 0.01]))
# Expect to sample almost only index 1
# (which still has a weight of 1.0).
for _ in range(10):
_, _, _, _, _, weights, indices = memory.sample(
1000, beta=self.beta)
self.assertTrue(970 < np.sum(indices) < 1100)
# Update weight of indices 0 and 1 to >> 0.01.
# Expect to sample 0 and 1 equally (and some 2s, 3s, and 4s).
for _ in range(10):
rand = np.random.random() + 0.2
memory.update_priorities(np.array([0, 1]), np.array([rand, rand]))
_, _, _, _, _, _, indices = memory.sample(1000, beta=self.beta)
# Expect biased to higher values due to some 2s, 3s, and 4s.
# print(np.sum(indices))
self.assertTrue(400 < np.sum(indices) < 800)
# Update weights to be 1:2.
# Expect to sample double as often index 1 over index 0
# plus very few times indices 2, 3, or 4.
for _ in range(10):
rand = np.random.random() + 0.2
memory.update_priorities(
np.array([0, 1]), np.array([rand, rand * 2]))
_, _, _, _, _, _, indices = memory.sample(1000, beta=self.beta)
# print(np.sum(indices))
self.assertTrue(600 < np.sum(indices) < 850)
# Update weights to be 1:4.
# Expect to sample quadruple as often index 1 over index 0
# plus very few times indices 2, 3, or 4.
for _ in range(10):
rand = np.random.random() + 0.2
memory.update_priorities(
np.array([0, 1]), np.array([rand, rand * 4]))
_, _, _, _, _, _, indices = memory.sample(1000, beta=self.beta)
# print(np.sum(indices))
self.assertTrue(750 < np.sum(indices) < 950)
# Update weights to be 1:9.
# Expect to sample 9 times as often index 1 over index 0.
# plus very few times indices 2, 3, or 4.
for _ in range(10):
rand = np.random.random() + 0.2
memory.update_priorities(
np.array([0, 1]), np.array([rand, rand * 9]))
_, _, _, _, _, _, indices = memory.sample(1000, beta=self.beta)
# print(np.sum(indices))
self.assertTrue(850 < np.sum(indices) < 1100)
# Insert n more samples.
num_records = 5
for i in range(num_records):
data = self._generate_data()
memory.add(*data, weight=1.0)
self.assertTrue(len(memory) == i + 6)
self.assertTrue(memory._next_idx == (i + 6) % self.capacity)
# Update all weights to be 1.0 to 10.0 and sample a >100 batch.
memory.update_priorities(
np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
np.array([0.001, 0.1, 2., 8., 16., 32., 64., 128., 256., 512.]))
counts = Counter()
for _ in range(10):
_, _, _, _, _, _, indices = memory.sample(
np.random.randint(100, 600), beta=self.beta)
for i in indices:
counts[i] += 1
print(counts)
# Expect an approximately correct distribution of indices.
self.assertTrue(
counts[9] >= counts[8] >= counts[7] >= counts[6] >= counts[5] >=
counts[4] >= counts[3] >= counts[2] >= counts[1] >= counts[0])
def test_alpha_parameter(self):
# Test sampling from a PR with a very small alpha (should behave just
# like a regular ReplayBuffer).
memory = PrioritizedReplayBuffer(size=self.capacity, alpha=0.01)
# Insert n samples.
num_records = 5
for i in range(num_records):
data = self._generate_data()
memory.add(*data, weight=np.random.rand())
self.assertTrue(len(memory) == i + 1)
self.assertTrue(memory._next_idx == i + 1)
# Fetch records, their indices and weights.
_, _, _, _, _, weights, indices = \
memory.sample(1000, beta=self.beta)
counts = Counter()
for i in indices:
counts[i] += 1
print(counts)
# Expect an approximately uniform distribution of indices.
for i in counts.values():
self.assertTrue(100 < i < 300)
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))
@@ -1,4 +1,5 @@
import numpy as np
import timeit
import unittest
from ray.rllib.optimizers.segment_tree import SumSegmentTree, MinSegmentTree
@@ -17,6 +18,7 @@ class TestSegmentTree(unittest.TestCase):
assert np.isclose(tree.sum(2, 3), 1.0)
assert np.isclose(tree.sum(2, -1), 1.0)
assert np.isclose(tree.sum(2, 4), 4.0)
assert np.isclose(tree.sum(2), 4.0)
def test_tree_set_overlap(self):
tree = SumSegmentTree(4)
@@ -28,6 +30,7 @@ class TestSegmentTree(unittest.TestCase):
assert np.isclose(tree.sum(2, 3), 3.0)
assert np.isclose(tree.sum(2, -1), 3.0)
assert np.isclose(tree.sum(2, 4), 3.0)
assert np.isclose(tree.sum(2), 3.0)
assert np.isclose(tree.sum(1, 2), 0.0)
def test_prefixsum_idx(self):
@@ -92,6 +95,37 @@ class TestSegmentTree(unittest.TestCase):
assert np.isclose(tree.min(2, -1), 4.0)
assert np.isclose(tree.min(3, 4), 3.0)
def test_microbenchmark_vs_old_version(self):
"""
Results from March 2020 (capacity=1048576):
New tree:
0.049599366000000256s
results = timeit.timeit("tree.sum(5, 60000)",
setup="from ray.rllib.optimizers.segment_tree import
SumSegmentTree; tree = SumSegmentTree({})".format(capacity),
number=10000)
Old tree:
0.13390400999999974s
results = timeit.timeit("tree.sum(5, 60000)",
setup="from ray.rllib.optimizers.tests.old_segment_tree import
OldSumSegmentTree; tree = OldSumSegmentTree({})".format(capacity),
number=10000)
"""
capacity = 2**20
new = timeit.timeit(
"tree.sum(5, 60000)",
setup="from ray.rllib.optimizers.segment_tree import "
"SumSegmentTree; tree = SumSegmentTree({})".format(capacity),
number=10000)
old = timeit.timeit(
"tree.sum(5, 60000)",
setup="from ray.rllib.optimizers.tests.old_segment_tree import "
"OldSumSegmentTree; tree = OldSumSegmentTree({})".format(capacity),
number=10000)
self.assertGreater(old, new)
if __name__ == "__main__":
import pytest