diff --git a/baukit/runningstats.py b/baukit/runningstats.py index 5cc575c..478609f 100644 --- a/baukit/runningstats.py +++ b/baukit/runningstats.py @@ -16,6 +16,7 @@ Built-in runningstats objects include: TopK - topk() returns (values, indexes). Bincount - bincount() histograms nonnegative integer data. IoU - intersection(), union(), iou() tally binary co-occurrences. + Reservoir - sample() returns uniform size-k sample of values. History - history() returns concatenation of data. CrossCovariance - covariance between two signals, without self-covariance. CrossIoU - iou between two signals, without self-IoU. @@ -1286,6 +1287,88 @@ class TopK: ) +class Reservoir: + """ + A class for collecting a uniform random sample of all provided data. + """ + + def __init__(self, k=100, state=None): + if state is not None: + return super().__init__(state) + self.k = k + self.count = 0 + # This version flattens all data internally to 2-d tensors, + # to avoid crashes with the current pytorch topk implementation. + # The data is puffed back out to arbitrary tensor shapes on ouput. + self.top_data = None + self.top_score = None + self.next = 0 + + def add(self, data, index=None): + """ + Adds a batch of data to be considered for the running top k. + The zeroth dimension enumerates the observations. All other + dimensions enumerate different features. + """ + if self.top_score is None: + # Allocation: allocate a buffer of size 5*k, at least 10, for each. + bk = max(10, 5 * self.k) + self.top_score = torch.zeros(bk) + self.top_data = torch.zeros((bk,) + data.shape[1:], + dtype=data.dtype, device=data.device) + size = data.shape[0] + sk = min(size, self.k) + scores = torch.randn(size) + if len(self.top_score) < self.next + sk: + # Compression: if full, keep topk only. + self.sample() + # Pick: copy the top sk of the next batch into the buffer. + # Currently strided topk is slow. So we clone after transpose. + # TODO: remove the clone() if it becomes faster. + ts, ti = scores.topk(sk, sorted=False) + self.top_score[self.next : self.next + sk] = ts + self.top_data[self.next : self.next + sk] = data[ti] + self.next += sk + self.count += size + + def size(self): + return self.count + + def sample(self): + """ + Returns top k data items and indexes in each dimension, + with channels in the first dimension and k in the last dimension. + """ + k = min(self.k, self.next) + if k < self.next: + td, ti = self.top_score[:self.next].topk(k, sorted=False) + self.top_score[:k] = td + self.top_data[:k] = self.top_data[ti].clone() + self.next = k + return self.top_data[:k] + + def to_(self, device): + if self.top_data is not None: + self.top_data = self.top_data.to(device) + + def state_dict(self): + return dict( + constructor=self.__module__ + "." + self.__class__.__name__ + "()", + k=self.k, + count=self.count, + top_data=self.top_data.cpu().detach().numpy(), + top_score=self.top_score.cpu().detach().numpy(), + next=self.next, + ) + + def load_state_dict(self, state): + self.k = int(state["k"]) + self.count = int(state["count"]) + self.top_data = torch.from_numpy(state["top_data"]) + self.top_score = torch.from_numpy(state["top_score"]) + self.next = int(state["next"]) + + class History(Stat): """ Accumulates the concatenation of all the added data. @@ -1729,6 +1812,7 @@ def _unit_test(): s=SecondMoment(), t=TopK(), i=IoU(), + r=Reservoir(), ) # Feed data in little batches i = 0 @@ -1753,6 +1837,7 @@ def _unit_test(): s=SecondMoment(), t=TopK(), i=IoU(), + r=Reservoir(), state=saved, ) # saved = unbox_numpy_null(numpy.load(f'{testdir}/saved.npz')) @@ -1765,6 +1850,7 @@ def _unit_test(): assert all(abs(alldata.mean(0) - cs2.m.mean()) / alldata.mean() < 1e-5) assert all(abs(alldata.mean(0) - cs2.v.mean()) / alldata.mean() < 1e-5) assert all(abs(alldata.mean(0) - cs2.c.mean()) / alldata.mean() < 1e-5) + assert all(abs(alldata.mean(0) - cs2.r.sample().mean(0)) / alldata.mean() < 1e-2) # print(abs(alldata.var(0) - cs2.v.variance()) / alldata.var(0)) assert all(abs(alldata.var(0) - cs2.v.variance()) / alldata.var(0) < 1e-3) assert all(abs(alldata.var(0) - cs2.c.variance()) / alldata.var(0) < 1e-2) @@ -1790,6 +1876,7 @@ def _unit_test(): s=SecondMoment(), t=TopK(), i=IoU(), + r=Reservoir(), ) cs.load(f"{testdir}/saved.npz") assert not cs.qc.device.type == "cuda" @@ -1800,6 +1887,7 @@ def _unit_test(): assert all(abs(alldata.mean(0) - cs.m.mean()) / alldata.mean() < 1e-5) assert all(abs(alldata.mean(0) - cs.v.mean()) / alldata.mean() < 1e-5) assert all(abs(alldata.mean(0) - cs.c.mean()) / alldata.mean() < 1e-5) + assert all(abs(alldata.mean(0) - cs.r.sample().mean(0)) / alldata.mean() < 1e-2) # print(abs(alldata.var(0) - cs.v.variance()) / alldata.var(0)) assert all(abs(alldata.var(0) - cs.v.variance()) / alldata.var(0) < 1e-3) assert all(abs(alldata.var(0) - cs.c.variance()) / alldata.var(0) < 1e-2)