mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 20:53:14 +08:00
[RLlib] Fix bugs and speed up SegmentTree
This commit is contained in:
+19
-15
@@ -852,15 +852,26 @@ py_test(
|
||||
# Tag: optimizers
|
||||
# --------------------------------------------------------------------
|
||||
|
||||
# This has bugs: See PR https://github.com/ray-project/ray/pull/7534
|
||||
# which fixes these and re-adds this test.
|
||||
py_test(
|
||||
name = "test_optimizers",
|
||||
tags = ["optimizers"],
|
||||
size = "large",
|
||||
srcs = ["optimizers/tests/test_optimizers.py"]
|
||||
)
|
||||
|
||||
# py_test(
|
||||
# name = "test_segment_tree",
|
||||
# tags = ["optimizers"],
|
||||
# size = "small",
|
||||
# srcs = ["optimizers/tests/test_segment_tree.py"]
|
||||
# )
|
||||
py_test(
|
||||
name = "test_segment_tree",
|
||||
tags = ["optimizers"],
|
||||
size = "small",
|
||||
srcs = ["optimizers/tests/test_segment_tree.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_prioritized_replay_buffer",
|
||||
tags = ["optimizers"],
|
||||
size = "small",
|
||||
srcs = ["optimizers/tests/test_prioritized_replay_buffer.py"]
|
||||
)
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
# Policies
|
||||
@@ -1044,13 +1055,6 @@ py_test(
|
||||
srcs = ["tests/test_nested_spaces.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tests/test_optimizers",
|
||||
tags = ["tests_dir", "tests_dir_O"],
|
||||
size = "large",
|
||||
srcs = ["tests/test_optimizers.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tests/test_exec_api",
|
||||
tags = ["tests_dir", "tests_dir_E"],
|
||||
|
||||
@@ -227,7 +227,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
Array of shape (batch_size,) and dtype np.int32
|
||||
idexes in buffer of sampled experiences
|
||||
"""
|
||||
assert beta > 0
|
||||
assert beta >= 0.0
|
||||
self._num_sampled += batch_size
|
||||
|
||||
idxes = self._sample_proportional(batch_size)
|
||||
|
||||
@@ -2,140 +2,195 @@ import operator
|
||||
|
||||
|
||||
class SegmentTree:
|
||||
def __init__(self, capacity, operation, neutral_element):
|
||||
"""Build a Segment Tree data structure.
|
||||
"""A Segment Tree data structure.
|
||||
|
||||
https://en.wikipedia.org/wiki/Segment_tree
|
||||
https://en.wikipedia.org/wiki/Segment_tree
|
||||
|
||||
Can be used as regular array, but with two
|
||||
important differences:
|
||||
Can be used as regular array, but with two important differences:
|
||||
|
||||
a) setting item's value is slightly slower.
|
||||
It is O(lg capacity) instead of O(1).
|
||||
b) user has access to an efficient `reduce`
|
||||
operation which reduces `operation` over
|
||||
a contiguous subsequence of items in the
|
||||
array.
|
||||
a) Setting an item's value is slightly slower. It is O(lg capacity),
|
||||
instead of O(1).
|
||||
b) Offers efficient `reduce` operation which reduces the tree's values
|
||||
over some specified contiguous subsequence of items in the array.
|
||||
Operation could be e.g. min/max/sum.
|
||||
|
||||
Paramters
|
||||
---------
|
||||
capacity: int
|
||||
Total size of the array - must be a power of two.
|
||||
operation: lambda obj, obj -> obj
|
||||
and operation for combining elements (eg. sum, max)
|
||||
must for a mathematical group together with the set of
|
||||
possible values for array elements.
|
||||
neutral_element: obj
|
||||
neutral element for the operation above. eg. float('-inf')
|
||||
for max and 0 for sum.
|
||||
The data is stored in a list, where the length is 2 * capacity.
|
||||
The second half of the list stores the actual values for each index, so if
|
||||
capacity=8, values are stored at indices 8 to 15. The first half of the
|
||||
array contains the reduced-values of the different (binary divided)
|
||||
segments, e.g. (capacity=4):
|
||||
0=not used
|
||||
1=reduced-value over all elements (array indices 4 to 7).
|
||||
2=reduced-value over array indices (4 and 5).
|
||||
3=reduced-value over array indices (6 and 7).
|
||||
4-7: values of the tree.
|
||||
NOTE that the values of the tree are accessed by indices starting at 0, so
|
||||
`tree[0]` accesses `internal_array[4]` in the above example.
|
||||
"""
|
||||
|
||||
def __init__(self, capacity, operation, neutral_element=None):
|
||||
"""Initializes a Segment Tree object.
|
||||
|
||||
Args:
|
||||
capacity (int): Total size of the array - must be a power of two.
|
||||
operation (operation): Lambda obj, obj -> obj
|
||||
The operation for combining elements (eg. sum, max).
|
||||
Must be a mathematical group together with the set of
|
||||
possible values for array elements.
|
||||
neutral_element (Optional[obj]): The neutral element for
|
||||
`operation`. Use None for automatically finding a value:
|
||||
max: float("-inf"), min: float("inf"), sum: 0.0.
|
||||
"""
|
||||
|
||||
assert capacity > 0 and capacity & (capacity - 1) == 0, \
|
||||
"capacity must be positive and a power of 2."
|
||||
self._capacity = capacity
|
||||
self._value = [neutral_element for _ in range(2 * capacity)]
|
||||
self._operation = operation
|
||||
|
||||
def _reduce_helper(self, start, end, node, node_start, node_end):
|
||||
if start == node_start and end == node_end:
|
||||
return self._value[node]
|
||||
mid = (node_start + node_end) // 2
|
||||
if end <= mid:
|
||||
return self._reduce_helper(start, end, 2 * node, node_start, mid)
|
||||
else:
|
||||
if mid + 1 <= start:
|
||||
return self._reduce_helper(start, end, 2 * node + 1, mid + 1,
|
||||
node_end)
|
||||
else:
|
||||
return self._operation(
|
||||
self._reduce_helper(start, mid, 2 * node, node_start, mid),
|
||||
self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1,
|
||||
node_end))
|
||||
"Capacity must be positive and a power of 2!"
|
||||
self.capacity = capacity
|
||||
if neutral_element is None:
|
||||
neutral_element = 0.0 if operation is operator.add else \
|
||||
float("-inf") if operation is max else float("inf")
|
||||
self.neutral_element = neutral_element
|
||||
self.value = [self.neutral_element for _ in range(2 * capacity)]
|
||||
self.operation = operation
|
||||
|
||||
def reduce(self, start=0, end=None):
|
||||
"""Returns result of applying `self.operation`
|
||||
to a contiguous subsequence of the array.
|
||||
"""Applies `self.operation` to subsequence of our values.
|
||||
|
||||
Subsequence is contiguous, includes `start` and excludes `end`.
|
||||
|
||||
self.operation(
|
||||
arr[start], operation(arr[start+1], operation(... arr[end])))
|
||||
|
||||
Parameters
|
||||
----------
|
||||
start: int
|
||||
beginning of the subsequence
|
||||
end: int
|
||||
end of the subsequences
|
||||
Args:
|
||||
start (int): Start index to apply reduction to.
|
||||
end (Optional[int]): End index to apply reduction to (excluded).
|
||||
|
||||
Returns
|
||||
-------
|
||||
reduced: obj
|
||||
result of reducing self.operation over the specified range of array
|
||||
elements.
|
||||
Returns:
|
||||
any: The result of reducing self.operation over the specified
|
||||
range of `self._value` elements.
|
||||
"""
|
||||
if end is None:
|
||||
end = self._capacity - 1
|
||||
if end < 0:
|
||||
end += self._capacity
|
||||
return self._reduce_helper(start, end, 1, 0, self._capacity - 1)
|
||||
end = self.capacity
|
||||
elif end < 0:
|
||||
end += self.capacity
|
||||
|
||||
# Init result with neutral element.
|
||||
result = self.neutral_element
|
||||
# Map start/end to our actual index space (second half of array).
|
||||
start += self.capacity
|
||||
end += self.capacity
|
||||
|
||||
# Example:
|
||||
# internal-array (first half=sums, second half=actual values):
|
||||
# 0 1 2 3 | 4 5 6 7
|
||||
# - 6 1 5 | 1 0 2 3
|
||||
|
||||
# tree.sum(0, 3) = 3
|
||||
# internally: start=4, end=7 -> sum values 1 0 2 = 3.
|
||||
|
||||
# Iterate over tree starting in the actual-values (second half)
|
||||
# section.
|
||||
# 1) start=4 is even -> do nothing.
|
||||
# 2) end=7 is odd -> end-- -> end=6 -> add value to result: result=2
|
||||
# 3) int-divide start and end by 2: start=2, end=3
|
||||
# 4) start still smaller end -> iterate once more.
|
||||
# 5) start=2 is even -> do nothing.
|
||||
# 6) end=3 is odd -> end-- -> end=2 -> add value to result: result=1
|
||||
# NOTE: This adds the sum of indices 4 and 5 to the result.
|
||||
|
||||
# Iterate as long as start != end.
|
||||
while start < end:
|
||||
|
||||
# If start is odd: Add its value to result and move start to
|
||||
# next even value.
|
||||
if start & 1:
|
||||
result = self.operation(result, self.value[start])
|
||||
start += 1
|
||||
|
||||
# If end is odd: Move end to previous even value, then add its
|
||||
# value to result. NOTE: This takes care of excluding `end` in any
|
||||
# situation.
|
||||
if end & 1:
|
||||
end -= 1
|
||||
result = self.operation(result, self.value[end])
|
||||
|
||||
# Divide both start and end by 2 to make them "jump" into the
|
||||
# next upper level reduce-index space.
|
||||
start //= 2
|
||||
end //= 2
|
||||
|
||||
# Then repeat till start == end.
|
||||
|
||||
return result
|
||||
|
||||
def __setitem__(self, idx, val):
|
||||
# index of the leaf
|
||||
idx += self._capacity
|
||||
self._value[idx] = val
|
||||
idx //= 2
|
||||
"""
|
||||
Inserts/overwrites a value in/into the tree.
|
||||
|
||||
Args:
|
||||
idx (int): The index to insert to. Must be in [0, `self.capacity`[
|
||||
val (float): The value to insert.
|
||||
"""
|
||||
assert 0 <= idx < self.capacity
|
||||
|
||||
# Index of the leaf to insert into (always insert in "second half"
|
||||
# of the tree, the first half is reserved for already calculated
|
||||
# reduction-values).
|
||||
idx += self.capacity
|
||||
self.value[idx] = val
|
||||
|
||||
# Recalculate all affected reduction values (in "first half" of tree).
|
||||
idx = idx >> 1 # Divide by 2 (faster than division).
|
||||
while idx >= 1:
|
||||
self._value[idx] = self._operation(self._value[2 * idx],
|
||||
self._value[2 * idx + 1])
|
||||
idx //= 2
|
||||
update_idx = 2 * idx # calculate only once
|
||||
# Update the reduction value at the correct "first half" idx.
|
||||
self.value[idx] = self.operation(self.value[update_idx],
|
||||
self.value[update_idx + 1])
|
||||
idx = idx >> 1 # Divide by 2 (faster than division).
|
||||
|
||||
def __getitem__(self, idx):
|
||||
assert 0 <= idx < self._capacity
|
||||
return self._value[self._capacity + idx]
|
||||
assert 0 <= idx < self.capacity
|
||||
return self.value[idx + self.capacity]
|
||||
|
||||
|
||||
class SumSegmentTree(SegmentTree):
|
||||
"""A SegmentTree with the reduction `operation`=operator.add."""
|
||||
|
||||
def __init__(self, capacity):
|
||||
super(SumSegmentTree, self).__init__(
|
||||
capacity=capacity, operation=operator.add, neutral_element=0.0)
|
||||
capacity=capacity, operation=operator.add)
|
||||
|
||||
def sum(self, start=0, end=None):
|
||||
"""Returns arr[start] + ... + arr[end]"""
|
||||
return super(SumSegmentTree, self).reduce(start, end)
|
||||
"""Returns the sum over a sub-segment of the tree."""
|
||||
return self.reduce(start, end)
|
||||
|
||||
def find_prefixsum_idx(self, prefixsum):
|
||||
"""Find the highest index `i` in the array such that
|
||||
sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum
|
||||
"""Finds highest i, for which: sum(arr[0]+..+arr[i - i]) <= prefixsum.
|
||||
|
||||
if array values are probabilities, this function
|
||||
allows to sample indexes according to the discrete
|
||||
probability efficiently.
|
||||
Args:
|
||||
prefixsum (float): `prefixsum` upper bound in above constraint.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
perfixsum: float
|
||||
upperbound on the sum of array prefix
|
||||
|
||||
Returns
|
||||
-------
|
||||
idx: int
|
||||
highest index satisfying the prefixsum constraint
|
||||
Returns:
|
||||
int: Largest possible index (i) satisfying above constraint.
|
||||
"""
|
||||
assert 0 <= prefixsum <= self.sum() + 1e-5
|
||||
# Global sum node.
|
||||
idx = 1
|
||||
while idx < self._capacity: # while non-leaf
|
||||
if self._value[2 * idx] > prefixsum:
|
||||
idx = 2 * idx
|
||||
|
||||
# While non-leaf (first half of tree).
|
||||
while idx < self.capacity:
|
||||
update_idx = 2 * idx
|
||||
if self.value[update_idx] > prefixsum:
|
||||
idx = update_idx
|
||||
else:
|
||||
prefixsum -= self._value[2 * idx]
|
||||
idx = 2 * idx + 1
|
||||
return idx - self._capacity
|
||||
prefixsum -= self.value[update_idx]
|
||||
idx = update_idx + 1
|
||||
return idx - self.capacity
|
||||
|
||||
|
||||
class MinSegmentTree(SegmentTree):
|
||||
def __init__(self, capacity):
|
||||
super(MinSegmentTree, self).__init__(
|
||||
capacity=capacity, operation=min, neutral_element=float("inf"))
|
||||
super(MinSegmentTree, self).__init__(capacity=capacity, operation=min)
|
||||
|
||||
def min(self, start=0, end=None):
|
||||
"""Returns min(arr[start], ..., arr[end])"""
|
||||
return super(MinSegmentTree, self).reduce(start, end)
|
||||
return self.reduce(start, end)
|
||||
|
||||
@@ -0,0 +1,143 @@
|
||||
import operator
|
||||
|
||||
|
||||
class OldSegmentTree(object):
|
||||
def __init__(self, capacity, operation, neutral_element):
|
||||
"""Build a Segment Tree data structure.
|
||||
|
||||
https://en.wikipedia.org/wiki/Segment_tree
|
||||
|
||||
Can be used as regular array, but with two
|
||||
important differences:
|
||||
|
||||
a) setting item's value is slightly slower.
|
||||
It is O(lg capacity) instead of O(1).
|
||||
b) user has access to an efficient `reduce`
|
||||
operation which reduces `operation` over
|
||||
a contiguous subsequence of items in the
|
||||
array.
|
||||
|
||||
Paramters
|
||||
---------
|
||||
capacity: int
|
||||
Total size of the array - must be a power of two.
|
||||
operation: lambda obj, obj -> obj
|
||||
and operation for combining elements (eg. sum, max)
|
||||
must for a mathematical group together with the set of
|
||||
possible values for array elements.
|
||||
neutral_element: obj
|
||||
neutral element for the operation above. eg. float('-inf')
|
||||
for max and 0 for sum.
|
||||
"""
|
||||
|
||||
assert capacity > 0 and capacity & (capacity - 1) == 0, \
|
||||
"capacity must be positive and a power of 2."
|
||||
self._capacity = capacity
|
||||
self._value = [neutral_element for _ in range(2 * capacity)]
|
||||
self._operation = operation
|
||||
|
||||
def _reduce_helper(self, start, end, node, node_start, node_end):
|
||||
if start == node_start and end == node_end:
|
||||
return self._value[node]
|
||||
mid = (node_start + node_end) // 2
|
||||
if end <= mid:
|
||||
return self._reduce_helper(start, end, 2 * node, node_start, mid)
|
||||
else:
|
||||
if mid + 1 <= start:
|
||||
return self._reduce_helper(start, end, 2 * node + 1, mid + 1,
|
||||
node_end)
|
||||
else:
|
||||
return self._operation(
|
||||
self._reduce_helper(start, mid, 2 * node, node_start, mid),
|
||||
self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1,
|
||||
node_end))
|
||||
|
||||
def reduce(self, start=0, end=None):
|
||||
"""Returns result of applying `self.operation`
|
||||
to a contiguous subsequence of the array.
|
||||
|
||||
self.operation(
|
||||
arr[start], operation(arr[start+1], operation(... arr[end])))
|
||||
|
||||
Parameters
|
||||
----------
|
||||
start: int
|
||||
beginning of the subsequence
|
||||
end: int
|
||||
end of the subsequences
|
||||
|
||||
Returns
|
||||
-------
|
||||
reduced: obj
|
||||
result of reducing self.operation over the specified range of array
|
||||
elements.
|
||||
"""
|
||||
if end is None:
|
||||
end = self._capacity
|
||||
if end < 0:
|
||||
end += self._capacity
|
||||
end -= 1
|
||||
return self._reduce_helper(start, end, 1, 0, self._capacity - 1)
|
||||
|
||||
def __setitem__(self, idx, val):
|
||||
# index of the leaf
|
||||
idx += self._capacity
|
||||
self._value[idx] = val
|
||||
idx //= 2
|
||||
while idx >= 1:
|
||||
self._value[idx] = self._operation(self._value[2 * idx],
|
||||
self._value[2 * idx + 1])
|
||||
idx //= 2
|
||||
|
||||
def __getitem__(self, idx):
|
||||
assert 0 <= idx < self._capacity
|
||||
return self._value[self._capacity + idx]
|
||||
|
||||
|
||||
class OldSumSegmentTree(OldSegmentTree):
|
||||
def __init__(self, capacity):
|
||||
super(OldSumSegmentTree, self).__init__(
|
||||
capacity=capacity, operation=operator.add, neutral_element=0.0)
|
||||
|
||||
def sum(self, start=0, end=None):
|
||||
"""Returns arr[start] + ... + arr[end]"""
|
||||
return super(OldSumSegmentTree, self).reduce(start, end)
|
||||
|
||||
def find_prefixsum_idx(self, prefixsum):
|
||||
"""Find the highest index `i` in the array such that
|
||||
sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum
|
||||
|
||||
if array values are probabilities, this function
|
||||
allows to sample indexes according to the discrete
|
||||
probability efficiently.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
perfixsum: float
|
||||
upperbound on the sum of array prefix
|
||||
|
||||
Returns
|
||||
-------
|
||||
idx: int
|
||||
highest index satisfying the prefixsum constraint
|
||||
"""
|
||||
assert 0 <= prefixsum <= self.sum() + 1e-5
|
||||
idx = 1
|
||||
while idx < self._capacity: # while non-leaf
|
||||
if self._value[2 * idx] > prefixsum:
|
||||
idx = 2 * idx
|
||||
else:
|
||||
prefixsum -= self._value[2 * idx]
|
||||
idx = 2 * idx + 1
|
||||
return idx - self._capacity
|
||||
|
||||
|
||||
class OldMinSegmentTree(OldSegmentTree):
|
||||
def __init__(self, capacity):
|
||||
super(OldMinSegmentTree, self).__init__(
|
||||
capacity=capacity, operation=min, neutral_element=float("inf"))
|
||||
|
||||
def min(self, start=0, end=None):
|
||||
"""Returns min(arr[start], ..., arr[end])"""
|
||||
|
||||
return super(OldMinSegmentTree, self).reduce(start, end)
|
||||
@@ -0,0 +1,180 @@
|
||||
from collections import Counter
|
||||
import numpy as np
|
||||
import unittest
|
||||
|
||||
from ray.rllib.optimizers.replay_buffer import PrioritizedReplayBuffer
|
||||
from ray.rllib.utils.test_utils import check
|
||||
|
||||
|
||||
class TestPrioritizedReplayBuffer(unittest.TestCase):
|
||||
"""
|
||||
Tests insertion and (weighted) sampling of the PrioritizedReplayBuffer.
|
||||
"""
|
||||
|
||||
capacity = 10
|
||||
alpha = 1.0
|
||||
beta = 1.0
|
||||
max_priority = 1.0
|
||||
|
||||
def _generate_data(self):
|
||||
return (
|
||||
np.random.random((4, )), # obs_t
|
||||
np.random.choice([0, 1]), # action
|
||||
np.random.rand(), # reward
|
||||
np.random.random((4, )), # obs_tp1
|
||||
np.random.choice([False, True]), # done
|
||||
)
|
||||
|
||||
def test_add(self):
|
||||
memory = PrioritizedReplayBuffer(
|
||||
size=2,
|
||||
alpha=self.alpha,
|
||||
)
|
||||
|
||||
# Assert indices 0 before insert.
|
||||
self.assertEqual(len(memory), 0)
|
||||
self.assertEqual(memory._next_idx, 0)
|
||||
|
||||
# Insert single record.
|
||||
data = self._generate_data()
|
||||
memory.add(*data, weight=0.5)
|
||||
self.assertTrue(len(memory) == 1)
|
||||
self.assertTrue(memory._next_idx == 1)
|
||||
|
||||
# Insert single record.
|
||||
data = self._generate_data()
|
||||
memory.add(*data, weight=0.1)
|
||||
self.assertTrue(len(memory) == 2)
|
||||
self.assertTrue(memory._next_idx == 0)
|
||||
|
||||
# Insert over capacity.
|
||||
data = self._generate_data()
|
||||
memory.add(*data, weight=1.0)
|
||||
self.assertTrue(len(memory) == 2)
|
||||
self.assertTrue(memory._next_idx == 1)
|
||||
|
||||
def test_update_priorities(self):
|
||||
memory = PrioritizedReplayBuffer(size=self.capacity, alpha=self.alpha)
|
||||
|
||||
# Insert n samples.
|
||||
num_records = 5
|
||||
for i in range(num_records):
|
||||
data = self._generate_data()
|
||||
memory.add(*data, weight=1.0)
|
||||
self.assertTrue(len(memory) == i + 1)
|
||||
self.assertTrue(memory._next_idx == i + 1)
|
||||
|
||||
# Fetch records, their indices and weights.
|
||||
_, _, _, _, _, weights, indices = \
|
||||
memory.sample(3, beta=self.beta)
|
||||
check(weights, np.ones(shape=(3, )))
|
||||
self.assertEqual(3, len(indices))
|
||||
self.assertTrue(len(memory) == num_records)
|
||||
self.assertTrue(memory._next_idx == num_records)
|
||||
|
||||
# Update weight of indices 0, 2, 3, 4 to very small.
|
||||
memory.update_priorities(
|
||||
np.array([0, 2, 3, 4]), np.array([0.01, 0.01, 0.01, 0.01]))
|
||||
# Expect to sample almost only index 1
|
||||
# (which still has a weight of 1.0).
|
||||
for _ in range(10):
|
||||
_, _, _, _, _, weights, indices = memory.sample(
|
||||
1000, beta=self.beta)
|
||||
self.assertTrue(970 < np.sum(indices) < 1100)
|
||||
|
||||
# Update weight of indices 0 and 1 to >> 0.01.
|
||||
# Expect to sample 0 and 1 equally (and some 2s, 3s, and 4s).
|
||||
for _ in range(10):
|
||||
rand = np.random.random() + 0.2
|
||||
memory.update_priorities(np.array([0, 1]), np.array([rand, rand]))
|
||||
_, _, _, _, _, _, indices = memory.sample(1000, beta=self.beta)
|
||||
# Expect biased to higher values due to some 2s, 3s, and 4s.
|
||||
# print(np.sum(indices))
|
||||
self.assertTrue(400 < np.sum(indices) < 800)
|
||||
|
||||
# Update weights to be 1:2.
|
||||
# Expect to sample double as often index 1 over index 0
|
||||
# plus very few times indices 2, 3, or 4.
|
||||
for _ in range(10):
|
||||
rand = np.random.random() + 0.2
|
||||
memory.update_priorities(
|
||||
np.array([0, 1]), np.array([rand, rand * 2]))
|
||||
_, _, _, _, _, _, indices = memory.sample(1000, beta=self.beta)
|
||||
# print(np.sum(indices))
|
||||
self.assertTrue(600 < np.sum(indices) < 850)
|
||||
|
||||
# Update weights to be 1:4.
|
||||
# Expect to sample quadruple as often index 1 over index 0
|
||||
# plus very few times indices 2, 3, or 4.
|
||||
for _ in range(10):
|
||||
rand = np.random.random() + 0.2
|
||||
memory.update_priorities(
|
||||
np.array([0, 1]), np.array([rand, rand * 4]))
|
||||
_, _, _, _, _, _, indices = memory.sample(1000, beta=self.beta)
|
||||
# print(np.sum(indices))
|
||||
self.assertTrue(750 < np.sum(indices) < 950)
|
||||
|
||||
# Update weights to be 1:9.
|
||||
# Expect to sample 9 times as often index 1 over index 0.
|
||||
# plus very few times indices 2, 3, or 4.
|
||||
for _ in range(10):
|
||||
rand = np.random.random() + 0.2
|
||||
memory.update_priorities(
|
||||
np.array([0, 1]), np.array([rand, rand * 9]))
|
||||
_, _, _, _, _, _, indices = memory.sample(1000, beta=self.beta)
|
||||
# print(np.sum(indices))
|
||||
self.assertTrue(850 < np.sum(indices) < 1100)
|
||||
|
||||
# Insert n more samples.
|
||||
num_records = 5
|
||||
for i in range(num_records):
|
||||
data = self._generate_data()
|
||||
memory.add(*data, weight=1.0)
|
||||
self.assertTrue(len(memory) == i + 6)
|
||||
self.assertTrue(memory._next_idx == (i + 6) % self.capacity)
|
||||
|
||||
# Update all weights to be 1.0 to 10.0 and sample a >100 batch.
|
||||
memory.update_priorities(
|
||||
np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
|
||||
np.array([0.001, 0.1, 2., 8., 16., 32., 64., 128., 256., 512.]))
|
||||
counts = Counter()
|
||||
for _ in range(10):
|
||||
_, _, _, _, _, _, indices = memory.sample(
|
||||
np.random.randint(100, 600), beta=self.beta)
|
||||
for i in indices:
|
||||
counts[i] += 1
|
||||
print(counts)
|
||||
# Expect an approximately correct distribution of indices.
|
||||
self.assertTrue(
|
||||
counts[9] >= counts[8] >= counts[7] >= counts[6] >= counts[5] >=
|
||||
counts[4] >= counts[3] >= counts[2] >= counts[1] >= counts[0])
|
||||
|
||||
def test_alpha_parameter(self):
|
||||
# Test sampling from a PR with a very small alpha (should behave just
|
||||
# like a regular ReplayBuffer).
|
||||
memory = PrioritizedReplayBuffer(size=self.capacity, alpha=0.01)
|
||||
|
||||
# Insert n samples.
|
||||
num_records = 5
|
||||
for i in range(num_records):
|
||||
data = self._generate_data()
|
||||
memory.add(*data, weight=np.random.rand())
|
||||
self.assertTrue(len(memory) == i + 1)
|
||||
self.assertTrue(memory._next_idx == i + 1)
|
||||
|
||||
# Fetch records, their indices and weights.
|
||||
_, _, _, _, _, weights, indices = \
|
||||
memory.sample(1000, beta=self.beta)
|
||||
counts = Counter()
|
||||
for i in indices:
|
||||
counts[i] += 1
|
||||
print(counts)
|
||||
# Expect an approximately uniform distribution of indices.
|
||||
for i in counts.values():
|
||||
self.assertTrue(100 < i < 300)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
@@ -1,4 +1,5 @@
|
||||
import numpy as np
|
||||
import timeit
|
||||
import unittest
|
||||
|
||||
from ray.rllib.optimizers.segment_tree import SumSegmentTree, MinSegmentTree
|
||||
@@ -17,6 +18,7 @@ class TestSegmentTree(unittest.TestCase):
|
||||
assert np.isclose(tree.sum(2, 3), 1.0)
|
||||
assert np.isclose(tree.sum(2, -1), 1.0)
|
||||
assert np.isclose(tree.sum(2, 4), 4.0)
|
||||
assert np.isclose(tree.sum(2), 4.0)
|
||||
|
||||
def test_tree_set_overlap(self):
|
||||
tree = SumSegmentTree(4)
|
||||
@@ -28,6 +30,7 @@ class TestSegmentTree(unittest.TestCase):
|
||||
assert np.isclose(tree.sum(2, 3), 3.0)
|
||||
assert np.isclose(tree.sum(2, -1), 3.0)
|
||||
assert np.isclose(tree.sum(2, 4), 3.0)
|
||||
assert np.isclose(tree.sum(2), 3.0)
|
||||
assert np.isclose(tree.sum(1, 2), 0.0)
|
||||
|
||||
def test_prefixsum_idx(self):
|
||||
@@ -92,6 +95,37 @@ class TestSegmentTree(unittest.TestCase):
|
||||
assert np.isclose(tree.min(2, -1), 4.0)
|
||||
assert np.isclose(tree.min(3, 4), 3.0)
|
||||
|
||||
def test_microbenchmark_vs_old_version(self):
|
||||
"""
|
||||
Results from March 2020 (capacity=1048576):
|
||||
|
||||
New tree:
|
||||
0.049599366000000256s
|
||||
results = timeit.timeit("tree.sum(5, 60000)",
|
||||
setup="from ray.rllib.optimizers.segment_tree import
|
||||
SumSegmentTree; tree = SumSegmentTree({})".format(capacity),
|
||||
number=10000)
|
||||
|
||||
Old tree:
|
||||
0.13390400999999974s
|
||||
results = timeit.timeit("tree.sum(5, 60000)",
|
||||
setup="from ray.rllib.optimizers.tests.old_segment_tree import
|
||||
OldSumSegmentTree; tree = OldSumSegmentTree({})".format(capacity),
|
||||
number=10000)
|
||||
"""
|
||||
capacity = 2**20
|
||||
new = timeit.timeit(
|
||||
"tree.sum(5, 60000)",
|
||||
setup="from ray.rllib.optimizers.segment_tree import "
|
||||
"SumSegmentTree; tree = SumSegmentTree({})".format(capacity),
|
||||
number=10000)
|
||||
old = timeit.timeit(
|
||||
"tree.sum(5, 60000)",
|
||||
setup="from ray.rllib.optimizers.tests.old_segment_tree import "
|
||||
"OldSumSegmentTree; tree = OldSumSegmentTree({})".format(capacity),
|
||||
number=10000)
|
||||
self.assertGreater(old, new)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
|
||||
Reference in New Issue
Block a user