diff --git a/rllib/BUILD b/rllib/BUILD index f114184d8..fa22c2549 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -852,15 +852,26 @@ py_test( # Tag: optimizers # -------------------------------------------------------------------- -# This has bugs: See PR https://github.com/ray-project/ray/pull/7534 -# which fixes these and re-adds this test. +py_test( + name = "test_optimizers", + tags = ["optimizers"], + size = "large", + srcs = ["optimizers/tests/test_optimizers.py"] +) -# py_test( -# name = "test_segment_tree", -# tags = ["optimizers"], -# size = "small", -# srcs = ["optimizers/tests/test_segment_tree.py"] -# ) +py_test( + name = "test_segment_tree", + tags = ["optimizers"], + size = "small", + srcs = ["optimizers/tests/test_segment_tree.py"] +) + +py_test( + name = "test_prioritized_replay_buffer", + tags = ["optimizers"], + size = "small", + srcs = ["optimizers/tests/test_prioritized_replay_buffer.py"] +) # -------------------------------------------------------------------- # Policies @@ -1044,13 +1055,6 @@ py_test( srcs = ["tests/test_nested_spaces.py"] ) -py_test( - name = "tests/test_optimizers", - tags = ["tests_dir", "tests_dir_O"], - size = "large", - srcs = ["tests/test_optimizers.py"] -) - py_test( name = "tests/test_exec_api", tags = ["tests_dir", "tests_dir_E"], diff --git a/rllib/optimizers/replay_buffer.py b/rllib/optimizers/replay_buffer.py index 2f59c1d49..63a514794 100644 --- a/rllib/optimizers/replay_buffer.py +++ b/rllib/optimizers/replay_buffer.py @@ -227,7 +227,7 @@ class PrioritizedReplayBuffer(ReplayBuffer): Array of shape (batch_size,) and dtype np.int32 idexes in buffer of sampled experiences """ - assert beta > 0 + assert beta >= 0.0 self._num_sampled += batch_size idxes = self._sample_proportional(batch_size) diff --git a/rllib/optimizers/segment_tree.py b/rllib/optimizers/segment_tree.py index ac7e06481..e436f3a5a 100644 --- a/rllib/optimizers/segment_tree.py +++ b/rllib/optimizers/segment_tree.py @@ -2,140 +2,195 @@ import operator class SegmentTree: - def __init__(self, capacity, operation, neutral_element): - """Build a Segment Tree data structure. + """A Segment Tree data structure. - https://en.wikipedia.org/wiki/Segment_tree + https://en.wikipedia.org/wiki/Segment_tree - Can be used as regular array, but with two - important differences: + Can be used as regular array, but with two important differences: - a) setting item's value is slightly slower. - It is O(lg capacity) instead of O(1). - b) user has access to an efficient `reduce` - operation which reduces `operation` over - a contiguous subsequence of items in the - array. + a) Setting an item's value is slightly slower. It is O(lg capacity), + instead of O(1). + b) Offers efficient `reduce` operation which reduces the tree's values + over some specified contiguous subsequence of items in the array. + Operation could be e.g. min/max/sum. - Paramters - --------- - capacity: int - Total size of the array - must be a power of two. - operation: lambda obj, obj -> obj - and operation for combining elements (eg. sum, max) - must for a mathematical group together with the set of - possible values for array elements. - neutral_element: obj - neutral element for the operation above. eg. float('-inf') - for max and 0 for sum. + The data is stored in a list, where the length is 2 * capacity. + The second half of the list stores the actual values for each index, so if + capacity=8, values are stored at indices 8 to 15. The first half of the + array contains the reduced-values of the different (binary divided) + segments, e.g. (capacity=4): + 0=not used + 1=reduced-value over all elements (array indices 4 to 7). + 2=reduced-value over array indices (4 and 5). + 3=reduced-value over array indices (6 and 7). + 4-7: values of the tree. + NOTE that the values of the tree are accessed by indices starting at 0, so + `tree[0]` accesses `internal_array[4]` in the above example. + """ + + def __init__(self, capacity, operation, neutral_element=None): + """Initializes a Segment Tree object. + + Args: + capacity (int): Total size of the array - must be a power of two. + operation (operation): Lambda obj, obj -> obj + The operation for combining elements (eg. sum, max). + Must be a mathematical group together with the set of + possible values for array elements. + neutral_element (Optional[obj]): The neutral element for + `operation`. Use None for automatically finding a value: + max: float("-inf"), min: float("inf"), sum: 0.0. """ assert capacity > 0 and capacity & (capacity - 1) == 0, \ - "capacity must be positive and a power of 2." - self._capacity = capacity - self._value = [neutral_element for _ in range(2 * capacity)] - self._operation = operation - - def _reduce_helper(self, start, end, node, node_start, node_end): - if start == node_start and end == node_end: - return self._value[node] - mid = (node_start + node_end) // 2 - if end <= mid: - return self._reduce_helper(start, end, 2 * node, node_start, mid) - else: - if mid + 1 <= start: - return self._reduce_helper(start, end, 2 * node + 1, mid + 1, - node_end) - else: - return self._operation( - self._reduce_helper(start, mid, 2 * node, node_start, mid), - self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1, - node_end)) + "Capacity must be positive and a power of 2!" + self.capacity = capacity + if neutral_element is None: + neutral_element = 0.0 if operation is operator.add else \ + float("-inf") if operation is max else float("inf") + self.neutral_element = neutral_element + self.value = [self.neutral_element for _ in range(2 * capacity)] + self.operation = operation def reduce(self, start=0, end=None): - """Returns result of applying `self.operation` - to a contiguous subsequence of the array. + """Applies `self.operation` to subsequence of our values. + + Subsequence is contiguous, includes `start` and excludes `end`. self.operation( arr[start], operation(arr[start+1], operation(... arr[end]))) - Parameters - ---------- - start: int - beginning of the subsequence - end: int - end of the subsequences + Args: + start (int): Start index to apply reduction to. + end (Optional[int]): End index to apply reduction to (excluded). - Returns - ------- - reduced: obj - result of reducing self.operation over the specified range of array - elements. + Returns: + any: The result of reducing self.operation over the specified + range of `self._value` elements. """ if end is None: - end = self._capacity - 1 - if end < 0: - end += self._capacity - return self._reduce_helper(start, end, 1, 0, self._capacity - 1) + end = self.capacity + elif end < 0: + end += self.capacity + + # Init result with neutral element. + result = self.neutral_element + # Map start/end to our actual index space (second half of array). + start += self.capacity + end += self.capacity + + # Example: + # internal-array (first half=sums, second half=actual values): + # 0 1 2 3 | 4 5 6 7 + # - 6 1 5 | 1 0 2 3 + + # tree.sum(0, 3) = 3 + # internally: start=4, end=7 -> sum values 1 0 2 = 3. + + # Iterate over tree starting in the actual-values (second half) + # section. + # 1) start=4 is even -> do nothing. + # 2) end=7 is odd -> end-- -> end=6 -> add value to result: result=2 + # 3) int-divide start and end by 2: start=2, end=3 + # 4) start still smaller end -> iterate once more. + # 5) start=2 is even -> do nothing. + # 6) end=3 is odd -> end-- -> end=2 -> add value to result: result=1 + # NOTE: This adds the sum of indices 4 and 5 to the result. + + # Iterate as long as start != end. + while start < end: + + # If start is odd: Add its value to result and move start to + # next even value. + if start & 1: + result = self.operation(result, self.value[start]) + start += 1 + + # If end is odd: Move end to previous even value, then add its + # value to result. NOTE: This takes care of excluding `end` in any + # situation. + if end & 1: + end -= 1 + result = self.operation(result, self.value[end]) + + # Divide both start and end by 2 to make them "jump" into the + # next upper level reduce-index space. + start //= 2 + end //= 2 + + # Then repeat till start == end. + + return result def __setitem__(self, idx, val): - # index of the leaf - idx += self._capacity - self._value[idx] = val - idx //= 2 + """ + Inserts/overwrites a value in/into the tree. + + Args: + idx (int): The index to insert to. Must be in [0, `self.capacity`[ + val (float): The value to insert. + """ + assert 0 <= idx < self.capacity + + # Index of the leaf to insert into (always insert in "second half" + # of the tree, the first half is reserved for already calculated + # reduction-values). + idx += self.capacity + self.value[idx] = val + + # Recalculate all affected reduction values (in "first half" of tree). + idx = idx >> 1 # Divide by 2 (faster than division). while idx >= 1: - self._value[idx] = self._operation(self._value[2 * idx], - self._value[2 * idx + 1]) - idx //= 2 + update_idx = 2 * idx # calculate only once + # Update the reduction value at the correct "first half" idx. + self.value[idx] = self.operation(self.value[update_idx], + self.value[update_idx + 1]) + idx = idx >> 1 # Divide by 2 (faster than division). def __getitem__(self, idx): - assert 0 <= idx < self._capacity - return self._value[self._capacity + idx] + assert 0 <= idx < self.capacity + return self.value[idx + self.capacity] class SumSegmentTree(SegmentTree): + """A SegmentTree with the reduction `operation`=operator.add.""" + def __init__(self, capacity): super(SumSegmentTree, self).__init__( - capacity=capacity, operation=operator.add, neutral_element=0.0) + capacity=capacity, operation=operator.add) def sum(self, start=0, end=None): - """Returns arr[start] + ... + arr[end]""" - return super(SumSegmentTree, self).reduce(start, end) + """Returns the sum over a sub-segment of the tree.""" + return self.reduce(start, end) def find_prefixsum_idx(self, prefixsum): - """Find the highest index `i` in the array such that - sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum + """Finds highest i, for which: sum(arr[0]+..+arr[i - i]) <= prefixsum. - if array values are probabilities, this function - allows to sample indexes according to the discrete - probability efficiently. + Args: + prefixsum (float): `prefixsum` upper bound in above constraint. - Parameters - ---------- - perfixsum: float - upperbound on the sum of array prefix - - Returns - ------- - idx: int - highest index satisfying the prefixsum constraint + Returns: + int: Largest possible index (i) satisfying above constraint. """ assert 0 <= prefixsum <= self.sum() + 1e-5 + # Global sum node. idx = 1 - while idx < self._capacity: # while non-leaf - if self._value[2 * idx] > prefixsum: - idx = 2 * idx + + # While non-leaf (first half of tree). + while idx < self.capacity: + update_idx = 2 * idx + if self.value[update_idx] > prefixsum: + idx = update_idx else: - prefixsum -= self._value[2 * idx] - idx = 2 * idx + 1 - return idx - self._capacity + prefixsum -= self.value[update_idx] + idx = update_idx + 1 + return idx - self.capacity class MinSegmentTree(SegmentTree): def __init__(self, capacity): - super(MinSegmentTree, self).__init__( - capacity=capacity, operation=min, neutral_element=float("inf")) + super(MinSegmentTree, self).__init__(capacity=capacity, operation=min) def min(self, start=0, end=None): """Returns min(arr[start], ..., arr[end])""" - return super(MinSegmentTree, self).reduce(start, end) + return self.reduce(start, end) diff --git a/rllib/optimizers/tests/old_segment_tree.py b/rllib/optimizers/tests/old_segment_tree.py new file mode 100644 index 000000000..41123330e --- /dev/null +++ b/rllib/optimizers/tests/old_segment_tree.py @@ -0,0 +1,143 @@ +import operator + + +class OldSegmentTree(object): + def __init__(self, capacity, operation, neutral_element): + """Build a Segment Tree data structure. + + https://en.wikipedia.org/wiki/Segment_tree + + Can be used as regular array, but with two + important differences: + + a) setting item's value is slightly slower. + It is O(lg capacity) instead of O(1). + b) user has access to an efficient `reduce` + operation which reduces `operation` over + a contiguous subsequence of items in the + array. + + Paramters + --------- + capacity: int + Total size of the array - must be a power of two. + operation: lambda obj, obj -> obj + and operation for combining elements (eg. sum, max) + must for a mathematical group together with the set of + possible values for array elements. + neutral_element: obj + neutral element for the operation above. eg. float('-inf') + for max and 0 for sum. + """ + + assert capacity > 0 and capacity & (capacity - 1) == 0, \ + "capacity must be positive and a power of 2." + self._capacity = capacity + self._value = [neutral_element for _ in range(2 * capacity)] + self._operation = operation + + def _reduce_helper(self, start, end, node, node_start, node_end): + if start == node_start and end == node_end: + return self._value[node] + mid = (node_start + node_end) // 2 + if end <= mid: + return self._reduce_helper(start, end, 2 * node, node_start, mid) + else: + if mid + 1 <= start: + return self._reduce_helper(start, end, 2 * node + 1, mid + 1, + node_end) + else: + return self._operation( + self._reduce_helper(start, mid, 2 * node, node_start, mid), + self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1, + node_end)) + + def reduce(self, start=0, end=None): + """Returns result of applying `self.operation` + to a contiguous subsequence of the array. + + self.operation( + arr[start], operation(arr[start+1], operation(... arr[end]))) + + Parameters + ---------- + start: int + beginning of the subsequence + end: int + end of the subsequences + + Returns + ------- + reduced: obj + result of reducing self.operation over the specified range of array + elements. + """ + if end is None: + end = self._capacity + if end < 0: + end += self._capacity + end -= 1 + return self._reduce_helper(start, end, 1, 0, self._capacity - 1) + + def __setitem__(self, idx, val): + # index of the leaf + idx += self._capacity + self._value[idx] = val + idx //= 2 + while idx >= 1: + self._value[idx] = self._operation(self._value[2 * idx], + self._value[2 * idx + 1]) + idx //= 2 + + def __getitem__(self, idx): + assert 0 <= idx < self._capacity + return self._value[self._capacity + idx] + + +class OldSumSegmentTree(OldSegmentTree): + def __init__(self, capacity): + super(OldSumSegmentTree, self).__init__( + capacity=capacity, operation=operator.add, neutral_element=0.0) + + def sum(self, start=0, end=None): + """Returns arr[start] + ... + arr[end]""" + return super(OldSumSegmentTree, self).reduce(start, end) + + def find_prefixsum_idx(self, prefixsum): + """Find the highest index `i` in the array such that + sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum + + if array values are probabilities, this function + allows to sample indexes according to the discrete + probability efficiently. + + Parameters + ---------- + perfixsum: float + upperbound on the sum of array prefix + + Returns + ------- + idx: int + highest index satisfying the prefixsum constraint + """ + assert 0 <= prefixsum <= self.sum() + 1e-5 + idx = 1 + while idx < self._capacity: # while non-leaf + if self._value[2 * idx] > prefixsum: + idx = 2 * idx + else: + prefixsum -= self._value[2 * idx] + idx = 2 * idx + 1 + return idx - self._capacity + + +class OldMinSegmentTree(OldSegmentTree): + def __init__(self, capacity): + super(OldMinSegmentTree, self).__init__( + capacity=capacity, operation=min, neutral_element=float("inf")) + + def min(self, start=0, end=None): + """Returns min(arr[start], ..., arr[end])""" + + return super(OldMinSegmentTree, self).reduce(start, end) diff --git a/rllib/tests/test_optimizers.py b/rllib/optimizers/tests/test_optimizers.py similarity index 100% rename from rllib/tests/test_optimizers.py rename to rllib/optimizers/tests/test_optimizers.py diff --git a/rllib/optimizers/tests/test_prioritized_replay_buffer.py b/rllib/optimizers/tests/test_prioritized_replay_buffer.py new file mode 100644 index 000000000..cb1a2b1a8 --- /dev/null +++ b/rllib/optimizers/tests/test_prioritized_replay_buffer.py @@ -0,0 +1,180 @@ +from collections import Counter +import numpy as np +import unittest + +from ray.rllib.optimizers.replay_buffer import PrioritizedReplayBuffer +from ray.rllib.utils.test_utils import check + + +class TestPrioritizedReplayBuffer(unittest.TestCase): + """ + Tests insertion and (weighted) sampling of the PrioritizedReplayBuffer. + """ + + capacity = 10 + alpha = 1.0 + beta = 1.0 + max_priority = 1.0 + + def _generate_data(self): + return ( + np.random.random((4, )), # obs_t + np.random.choice([0, 1]), # action + np.random.rand(), # reward + np.random.random((4, )), # obs_tp1 + np.random.choice([False, True]), # done + ) + + def test_add(self): + memory = PrioritizedReplayBuffer( + size=2, + alpha=self.alpha, + ) + + # Assert indices 0 before insert. + self.assertEqual(len(memory), 0) + self.assertEqual(memory._next_idx, 0) + + # Insert single record. + data = self._generate_data() + memory.add(*data, weight=0.5) + self.assertTrue(len(memory) == 1) + self.assertTrue(memory._next_idx == 1) + + # Insert single record. + data = self._generate_data() + memory.add(*data, weight=0.1) + self.assertTrue(len(memory) == 2) + self.assertTrue(memory._next_idx == 0) + + # Insert over capacity. + data = self._generate_data() + memory.add(*data, weight=1.0) + self.assertTrue(len(memory) == 2) + self.assertTrue(memory._next_idx == 1) + + def test_update_priorities(self): + memory = PrioritizedReplayBuffer(size=self.capacity, alpha=self.alpha) + + # Insert n samples. + num_records = 5 + for i in range(num_records): + data = self._generate_data() + memory.add(*data, weight=1.0) + self.assertTrue(len(memory) == i + 1) + self.assertTrue(memory._next_idx == i + 1) + + # Fetch records, their indices and weights. + _, _, _, _, _, weights, indices = \ + memory.sample(3, beta=self.beta) + check(weights, np.ones(shape=(3, ))) + self.assertEqual(3, len(indices)) + self.assertTrue(len(memory) == num_records) + self.assertTrue(memory._next_idx == num_records) + + # Update weight of indices 0, 2, 3, 4 to very small. + memory.update_priorities( + np.array([0, 2, 3, 4]), np.array([0.01, 0.01, 0.01, 0.01])) + # Expect to sample almost only index 1 + # (which still has a weight of 1.0). + for _ in range(10): + _, _, _, _, _, weights, indices = memory.sample( + 1000, beta=self.beta) + self.assertTrue(970 < np.sum(indices) < 1100) + + # Update weight of indices 0 and 1 to >> 0.01. + # Expect to sample 0 and 1 equally (and some 2s, 3s, and 4s). + for _ in range(10): + rand = np.random.random() + 0.2 + memory.update_priorities(np.array([0, 1]), np.array([rand, rand])) + _, _, _, _, _, _, indices = memory.sample(1000, beta=self.beta) + # Expect biased to higher values due to some 2s, 3s, and 4s. + # print(np.sum(indices)) + self.assertTrue(400 < np.sum(indices) < 800) + + # Update weights to be 1:2. + # Expect to sample double as often index 1 over index 0 + # plus very few times indices 2, 3, or 4. + for _ in range(10): + rand = np.random.random() + 0.2 + memory.update_priorities( + np.array([0, 1]), np.array([rand, rand * 2])) + _, _, _, _, _, _, indices = memory.sample(1000, beta=self.beta) + # print(np.sum(indices)) + self.assertTrue(600 < np.sum(indices) < 850) + + # Update weights to be 1:4. + # Expect to sample quadruple as often index 1 over index 0 + # plus very few times indices 2, 3, or 4. + for _ in range(10): + rand = np.random.random() + 0.2 + memory.update_priorities( + np.array([0, 1]), np.array([rand, rand * 4])) + _, _, _, _, _, _, indices = memory.sample(1000, beta=self.beta) + # print(np.sum(indices)) + self.assertTrue(750 < np.sum(indices) < 950) + + # Update weights to be 1:9. + # Expect to sample 9 times as often index 1 over index 0. + # plus very few times indices 2, 3, or 4. + for _ in range(10): + rand = np.random.random() + 0.2 + memory.update_priorities( + np.array([0, 1]), np.array([rand, rand * 9])) + _, _, _, _, _, _, indices = memory.sample(1000, beta=self.beta) + # print(np.sum(indices)) + self.assertTrue(850 < np.sum(indices) < 1100) + + # Insert n more samples. + num_records = 5 + for i in range(num_records): + data = self._generate_data() + memory.add(*data, weight=1.0) + self.assertTrue(len(memory) == i + 6) + self.assertTrue(memory._next_idx == (i + 6) % self.capacity) + + # Update all weights to be 1.0 to 10.0 and sample a >100 batch. + memory.update_priorities( + np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), + np.array([0.001, 0.1, 2., 8., 16., 32., 64., 128., 256., 512.])) + counts = Counter() + for _ in range(10): + _, _, _, _, _, _, indices = memory.sample( + np.random.randint(100, 600), beta=self.beta) + for i in indices: + counts[i] += 1 + print(counts) + # Expect an approximately correct distribution of indices. + self.assertTrue( + counts[9] >= counts[8] >= counts[7] >= counts[6] >= counts[5] >= + counts[4] >= counts[3] >= counts[2] >= counts[1] >= counts[0]) + + def test_alpha_parameter(self): + # Test sampling from a PR with a very small alpha (should behave just + # like a regular ReplayBuffer). + memory = PrioritizedReplayBuffer(size=self.capacity, alpha=0.01) + + # Insert n samples. + num_records = 5 + for i in range(num_records): + data = self._generate_data() + memory.add(*data, weight=np.random.rand()) + self.assertTrue(len(memory) == i + 1) + self.assertTrue(memory._next_idx == i + 1) + + # Fetch records, their indices and weights. + _, _, _, _, _, weights, indices = \ + memory.sample(1000, beta=self.beta) + counts = Counter() + for i in indices: + counts[i] += 1 + print(counts) + # Expect an approximately uniform distribution of indices. + for i in counts.values(): + self.assertTrue(100 < i < 300) + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/optimizers/tests/test_segment_tree.py b/rllib/optimizers/tests/test_segment_tree.py index 731f25655..1f70adc0e 100644 --- a/rllib/optimizers/tests/test_segment_tree.py +++ b/rllib/optimizers/tests/test_segment_tree.py @@ -1,4 +1,5 @@ import numpy as np +import timeit import unittest from ray.rllib.optimizers.segment_tree import SumSegmentTree, MinSegmentTree @@ -17,6 +18,7 @@ class TestSegmentTree(unittest.TestCase): assert np.isclose(tree.sum(2, 3), 1.0) assert np.isclose(tree.sum(2, -1), 1.0) assert np.isclose(tree.sum(2, 4), 4.0) + assert np.isclose(tree.sum(2), 4.0) def test_tree_set_overlap(self): tree = SumSegmentTree(4) @@ -28,6 +30,7 @@ class TestSegmentTree(unittest.TestCase): assert np.isclose(tree.sum(2, 3), 3.0) assert np.isclose(tree.sum(2, -1), 3.0) assert np.isclose(tree.sum(2, 4), 3.0) + assert np.isclose(tree.sum(2), 3.0) assert np.isclose(tree.sum(1, 2), 0.0) def test_prefixsum_idx(self): @@ -92,6 +95,37 @@ class TestSegmentTree(unittest.TestCase): assert np.isclose(tree.min(2, -1), 4.0) assert np.isclose(tree.min(3, 4), 3.0) + def test_microbenchmark_vs_old_version(self): + """ + Results from March 2020 (capacity=1048576): + + New tree: + 0.049599366000000256s + results = timeit.timeit("tree.sum(5, 60000)", + setup="from ray.rllib.optimizers.segment_tree import + SumSegmentTree; tree = SumSegmentTree({})".format(capacity), + number=10000) + + Old tree: + 0.13390400999999974s + results = timeit.timeit("tree.sum(5, 60000)", + setup="from ray.rllib.optimizers.tests.old_segment_tree import + OldSumSegmentTree; tree = OldSumSegmentTree({})".format(capacity), + number=10000) + """ + capacity = 2**20 + new = timeit.timeit( + "tree.sum(5, 60000)", + setup="from ray.rllib.optimizers.segment_tree import " + "SumSegmentTree; tree = SumSegmentTree({})".format(capacity), + number=10000) + old = timeit.timeit( + "tree.sum(5, 60000)", + setup="from ray.rllib.optimizers.tests.old_segment_tree import " + "OldSumSegmentTree; tree = OldSumSegmentTree({})".format(capacity), + number=10000) + self.assertGreater(old, new) + if __name__ == "__main__": import pytest