Add Resevoir.

This commit is contained in:
David Bau
2022-03-23 23:12:45 -04:00
parent e89b3c88e3
commit a22269a72d
+88
View File
@@ -16,6 +16,7 @@ Built-in runningstats objects include:
TopK - topk() returns (values, indexes).
Bincount - bincount() histograms nonnegative integer data.
IoU - intersection(), union(), iou() tally binary co-occurrences.
Reservoir - sample() returns uniform size-k sample of values.
History - history() returns concatenation of data.
CrossCovariance - covariance between two signals, without self-covariance.
CrossIoU - iou between two signals, without self-IoU.
@@ -1286,6 +1287,88 @@ class TopK:
)
class Reservoir:
"""
A class for collecting a uniform random sample of all provided data.
"""
def __init__(self, k=100, state=None):
if state is not None:
return super().__init__(state)
self.k = k
self.count = 0
# This version flattens all data internally to 2-d tensors,
# to avoid crashes with the current pytorch topk implementation.
# The data is puffed back out to arbitrary tensor shapes on ouput.
self.top_data = None
self.top_score = None
self.next = 0
def add(self, data, index=None):
"""
Adds a batch of data to be considered for the running top k.
The zeroth dimension enumerates the observations. All other
dimensions enumerate different features.
"""
if self.top_score is None:
# Allocation: allocate a buffer of size 5*k, at least 10, for each.
bk = max(10, 5 * self.k)
self.top_score = torch.zeros(bk)
self.top_data = torch.zeros((bk,) + data.shape[1:],
dtype=data.dtype, device=data.device)
size = data.shape[0]
sk = min(size, self.k)
scores = torch.randn(size)
if len(self.top_score) < self.next + sk:
# Compression: if full, keep topk only.
self.sample()
# Pick: copy the top sk of the next batch into the buffer.
# Currently strided topk is slow. So we clone after transpose.
# TODO: remove the clone() if it becomes faster.
ts, ti = scores.topk(sk, sorted=False)
self.top_score[self.next : self.next + sk] = ts
self.top_data[self.next : self.next + sk] = data[ti]
self.next += sk
self.count += size
def size(self):
return self.count
def sample(self):
"""
Returns top k data items and indexes in each dimension,
with channels in the first dimension and k in the last dimension.
"""
k = min(self.k, self.next)
if k < self.next:
td, ti = self.top_score[:self.next].topk(k, sorted=False)
self.top_score[:k] = td
self.top_data[:k] = self.top_data[ti].clone()
self.next = k
return self.top_data[:k]
def to_(self, device):
if self.top_data is not None:
self.top_data = self.top_data.to(device)
def state_dict(self):
return dict(
constructor=self.__module__ + "." + self.__class__.__name__ + "()",
k=self.k,
count=self.count,
top_data=self.top_data.cpu().detach().numpy(),
top_score=self.top_score.cpu().detach().numpy(),
next=self.next,
)
def load_state_dict(self, state):
self.k = int(state["k"])
self.count = int(state["count"])
self.top_data = torch.from_numpy(state["top_data"])
self.top_score = torch.from_numpy(state["top_score"])
self.next = int(state["next"])
class History(Stat):
"""
Accumulates the concatenation of all the added data.
@@ -1729,6 +1812,7 @@ def _unit_test():
s=SecondMoment(),
t=TopK(),
i=IoU(),
r=Reservoir(),
)
# Feed data in little batches
i = 0
@@ -1753,6 +1837,7 @@ def _unit_test():
s=SecondMoment(),
t=TopK(),
i=IoU(),
r=Reservoir(),
state=saved,
)
# saved = unbox_numpy_null(numpy.load(f'{testdir}/saved.npz'))
@@ -1765,6 +1850,7 @@ def _unit_test():
assert all(abs(alldata.mean(0) - cs2.m.mean()) / alldata.mean() < 1e-5)
assert all(abs(alldata.mean(0) - cs2.v.mean()) / alldata.mean() < 1e-5)
assert all(abs(alldata.mean(0) - cs2.c.mean()) / alldata.mean() < 1e-5)
assert all(abs(alldata.mean(0) - cs2.r.sample().mean(0)) / alldata.mean() < 1e-2)
# print(abs(alldata.var(0) - cs2.v.variance()) / alldata.var(0))
assert all(abs(alldata.var(0) - cs2.v.variance()) / alldata.var(0) < 1e-3)
assert all(abs(alldata.var(0) - cs2.c.variance()) / alldata.var(0) < 1e-2)
@@ -1790,6 +1876,7 @@ def _unit_test():
s=SecondMoment(),
t=TopK(),
i=IoU(),
r=Reservoir(),
)
cs.load(f"{testdir}/saved.npz")
assert not cs.qc.device.type == "cuda"
@@ -1800,6 +1887,7 @@ def _unit_test():
assert all(abs(alldata.mean(0) - cs.m.mean()) / alldata.mean() < 1e-5)
assert all(abs(alldata.mean(0) - cs.v.mean()) / alldata.mean() < 1e-5)
assert all(abs(alldata.mean(0) - cs.c.mean()) / alldata.mean() < 1e-5)
assert all(abs(alldata.mean(0) - cs.r.sample().mean(0)) / alldata.mean() < 1e-2)
# print(abs(alldata.var(0) - cs.v.variance()) / alldata.var(0))
assert all(abs(alldata.var(0) - cs.v.variance()) / alldata.var(0) < 1e-3)
assert all(abs(alldata.var(0) - cs.c.variance()) / alldata.var(0) < 1e-2)