diff --git a/README.md b/README.md
new file mode 100644
index 0000000..027add9
--- /dev/null
+++ b/README.md
@@ -0,0 +1 @@
+Some utilities useful for prototyping with pytorch.
diff --git a/torchkit/labwidget.py b/torchkit/labwidget.py
new file mode 100644
index 0000000..116e023
--- /dev/null
+++ b/torchkit/labwidget.py
@@ -0,0 +1,678 @@
+"""
+labwidget by David Bau.
+
+Base class for a lightweight javascript notebook widget framework
+that is portable across Google colab and Jupyter notebooks.
+No use of requirejs: the design uses all inline javascript.
+
+Defines Model, Widget, Trigger, and Property, which set up data binding
+using the communication channels available in either google colab
+environment or jupyter notebook.
+
+This module also defines Label, Textbox, Range, Choice, and Div
+widgets; the code for these are good examples of usage of Widget,
+Trigger, and Property objects.
+
+User interaction should update the javascript model using
+model.set('propname', value); this will propagate to the python
+model and notify any registered python listeners.
+
+TODO: Support jupyterlab also.
+"""
+
+import json, html, re
+from inspect import signature
+
+class Model(object):
+ '''
+ Abstract base class that supports data binding. Within __init__,
+ a model subclass defines databound events and properties using:
+
+ self.evtname = Trigger()
+ self.propname = Property(initval)
+
+ Any Trigger or Property member can be watched by registering a
+ listener with `model.on('propname', callback)`.
+
+ An event can be triggered by `model.evtname.trigger(value)`.
+ A property can be read with `model.propname`, and can be set by
+ `model.propname = value`; this also triggers notifications.
+ In both these cases, any registered listeners will be called
+ with the given value.
+ '''
+ def on(self, name, cb):
+ '''
+ Registers a listener for named events and properties.
+ A space-separated list of names can be provided as `name`.
+ '''
+ for n in name.split():
+ self.prop(n).on(cb)
+ return self
+
+ def off(self, name, cb):
+ '''
+ Unregisters a listener for named events and properties.
+ A space-separated list of names can be provided as `name`.
+ '''
+ for n in name.split():
+ self.prop(n).off(cb)
+ return self
+
+ def prop(self, name):
+ '''
+ Returns the underlying Trigger or Property object for a
+ property, rather than its held value.
+ '''
+ curvalue = super().__getattribute__(name)
+ if not isinstance(curvalue, Trigger):
+ raise AttributeError('%s not a property or trigger but %s'
+ % (name, str(type(curvalue))))
+ return curvalue
+
+ def _initprop_(self, name, value):
+ '''
+ To be overridden in base classes. Handles initialization of
+ a new Trigger or Property member.
+ '''
+ return
+
+ def __setattr__(self, name, value):
+ '''
+ When a member is an Trigger or Property, then assignment notation
+ is delegated to the Trigger or Property so that notifications
+ and reparenting can be handled. That is, `model.name = value`
+ turns into `prop(name).set(value)`.
+ '''
+ if hasattr(self, name):
+ curvalue = super().__getattribute__(name)
+ if isinstance(curvalue, Trigger):
+ # Delegte "set" to the underlying Property.
+ curvalue.set(value)
+ else:
+ super().__setattr__(name, value)
+ else:
+ super().__setattr__(name, value)
+ if isinstance(value, Trigger):
+ self._initprop_(name, value)
+
+ def __getattribute__(self, name):
+ '''
+ When a member is a Property, then property getter
+ notation is delegated to the peoperty object.
+ '''
+ curvalue = super().__getattribute__(name)
+ if isinstance(curvalue, Property):
+ return curvalue.value
+ return curvalue
+
+class Widget(Model):
+ '''
+ Base class for an HTML widget that uses a Javascript model object
+ to syncrhonize HTML view state with the backend Python model state.
+ Each widget subclass overrides widget_js to provide Javascript code
+ that defines the widget's behavior. This javascript will be wrapped
+ in an immediately-invoked function and included in the widget's HTML
+ representation (_repr_html_) when the widget is viewed.
+
+ A widget's javascript is provided with two local variables:
+
+ element - the widget's root HTML element. By default this is
+ a
but can be overridden in widget_html.
+ model - the object representing the data model for the widget.
+ within javascript.
+
+ The model object provides the following javascript API:
+
+ model.get('propname') obtains a current property value.
+ model.set('propname', 'value') requests a change in value.
+ model.on('propname', callback) listens for property changes.
+ model.trigger('evtname', value) triggers an event.
+
+ Note that model.set just requests a change but does not change the
+ value immediately: model.get will not reflect the change until the
+ python backend has handled it and notified the javascript of the new
+ value, which will trigger any callbacks previously registered using
+ .on('propname', callback). Thus Widget impelements a V-shaped
+ notification protocol:
+
+ User entry -> | -> User-visible feedback
+ js model.set -> | -> js.model.on callback
+ python prop.trigger -> | -> python prop.notify
+ python prop.handle
+ '''
+
+ def __init__(self):
+ # In the jupyter case, there can be some delay between js injection
+ # and comm creation, so we need to queue some initial messages.
+ if WIDGET_ENV == 'jupyter':
+ self._comms = []
+ self._queue = []
+ # Each call to _repr_html_ creates a unique view instance.
+ self._viewcount = 0
+ # Python notification is handled by Property objects.
+ def handle_remote_set(name, value):
+ self.prop(name).trigger(value)
+ self._recv_from_js_(handle_remote_set)
+
+ def widget_js(self):
+ '''
+ Override to define the javascript logic for the widget. Should
+ render the initial view based on the current model state (if not
+ already rendered using widget_html) and set up listeners to keep
+ the model and the view synchornized.
+ '''
+ return ''
+
+ def widget_html(self):
+ '''
+ Override to define the initial HTML view of the widget. Should
+ define an element with id given by view_id().
+ '''
+ return f''
+
+ def view_id(self):
+ '''
+ Returns an HTML element id for the view currently being rendered.
+ Note that each time _repr_html_ is called, this id will change.
+ '''
+ return f"_{id(self)}_{self._viewcount}"
+
+ def _repr_html_(self):
+ '''
+ Returns the HTML code for the widget.
+ '''
+ self._viewcount += 1
+ json_data = json.dumps({
+ k: v.value for k, v in vars(self).items()
+ if isinstance(v, Property)})
+ json_data = re.sub('', '<\\/', json_data)
+ return f"""
+ {self.widget_html()}
+
+ """
+
+ def _initprop_(self, name, value):
+ if not hasattr(self, '_viewcount'):
+ raise ValueError('base Model __init__ must be called')
+ def notify_js(value):
+ self._send_to_js_(id(self), name, value)
+ if isinstance(value, Trigger):
+ value.on(notify_js)
+
+ def _send_to_js_(self, *args):
+ if self._viewcount > 0:
+ if WIDGET_ENV == 'colab':
+ colab_output.eval_js(f"""
+ (window.send_{id(self)} = window.send_{id(self)} ||
+ new BroadcastChannel("channel_{id(self)}")
+ ).postMessage({json.dumps(args)});
+ """, ignore_result=True)
+ elif WIDGET_ENV == 'jupyter':
+ if not self._comms:
+ self._queue.append(args)
+ return
+ for comm in self._comms:
+ comm.send(args)
+
+ def _recv_from_js_(self, fn):
+ if WIDGET_ENV == 'colab':
+ colab_output.register_callback(f"invoke_{id(self)}", fn)
+ elif WIDGET_ENV == 'jupyter':
+ def handle_comm(msg):
+ fn(*(msg['content']['data']))
+ # TODO: handle closing also.
+ def handle_close(close_msg):
+ comm_id = close_msg['content']['comm_id']
+ self._comms = [c for c in self._comms if c.comm_id != comm_id]
+ def open_comm(comm, open_msg):
+ self._comms.append(comm)
+ comm.on_msg(handle_comm)
+ comm.on_close(handle_close)
+ comm.send('ok')
+ if self._queue:
+ for args in self._queue:
+ comm.send(args)
+ self._queue.clear()
+ if open_msg['content']['data']:
+ handle_comm(open_msg)
+ cname = "comm_" + str(id(self))
+ COMM_MANAGER.register_target(cname, open_comm)
+
+class Trigger(object):
+ """
+ Trigger is the base class for Property and other data-bound
+ field objects. Trigger holds a list of listeners that need to
+ be notified about the event.
+
+ Multple Trigger objects can be tied (typically a parent Model can
+ have Triggers that are triggered by children models). To support
+ this, each Trigger can have a parent.
+
+ Trigger objects provide a notification protocol where view
+ interactions trigger events at a leaf that are sent up to the
+ root Trigger to be handled. By default, the root handler accepts
+ events by notifying all listeners and children in the tree.
+ """
+ def __init__(self):
+ self._listeners = []
+ self.parent = None
+ def handle(self, value):
+ '''
+ Method to override; called at the root when an event has been
+ triggered, and on a child when the parent has notified. By
+ default notifies all listeners.
+ '''
+ self.notify(value)
+ def trigger(self, value=None):
+ '''
+ Triggers an event to be handled by the root. By default, the root
+ handler will accept the event so all the listeners will be notified.
+ '''
+ if self.parent is not None:
+ self.parent.trigger(value)
+ else:
+ self.handle(value)
+ def set(self, value):
+ '''
+ Sets the parent Trigger. Child Triggers trigger events by
+ triggering parents, and in turn they handle notifications
+ that come from parents.
+ '''
+ if self.parent is not None:
+ self.parent.off(self.handle)
+ self.parent = None
+ if isinstance(value, Trigger):
+ ancestor = value.parent
+ while ancestor is not None:
+ if ancestor == self:
+ raise ValueError('bound properties should not form a loop')
+ ancestor = ancestor.parent
+ self.parent = value
+ self.parent.on(self.handle)
+ elif not isinstance(self, Property):
+ raise ValueError('only properties can be set to a value')
+ def notify(self, value=None):
+ '''
+ Notifies listeners and children. If a listener accepts an argument,
+ the value will be passed as a single argument.
+ '''
+ for cb in self._listeners:
+ if len(signature(cb).parameters) == 0:
+ cb() # no-parameter callback.
+ else:
+ cb(value)
+ # TODO: consider adding a two-parameter callback form
+ # where a detailed event object can be added.
+ def on(self, cb):
+ '''
+ Registers a listener. Calling multiple times registers
+ multiple listeners.
+ '''
+ self._listeners.append(cb)
+ def off(self, cb):
+ '''
+ Unregisters a listener.
+ '''
+ self._listeners = [c for c in self._listeners if c != cb]
+
+class Property(Trigger):
+ """
+ A Property is just an Trigger that remembers its last value.
+ """
+ def __init__(self, value=None):
+ '''
+ Can be initialized with a starting value.
+ '''
+ super().__init__()
+ self.set(value)
+ def handle(self, value):
+ '''
+ The default handling for a Property is to store the value,
+ then notify listeners. This method can be overridden,
+ for example to validate values.
+ '''
+ self.value = value
+ self.notify(value)
+ def set(self, value):
+ '''
+ When a Property value is set to an ordinary value, it
+ triggers an event which causes a notification to be
+ sent to update all linked Properties. A Property set
+ to another Property becomes a child of the value.
+ '''
+ # Handle setting a parent Property
+ if isinstance(value, Property):
+ super().set(value)
+ self.handle(value.value)
+ elif isinstance(value, Trigger):
+ raise ValueError('Cannot set a Property to an Trigger')
+ else:
+ self.trigger(value)
+
+
+##########################################################################
+## Specific widgets
+##########################################################################
+
+class Button(Widget):
+ def __init__(self, label='button'):
+ super().__init__()
+ self.click = Trigger()
+ self.label = Property(label)
+ def widget_js(self):
+ return '''
+ element.addEventListener('click', (e) => {
+ model.trigger('click');
+ })
+ model.on('label', (v) => {
+ element.value = v;
+ })
+ '''
+ def widget_html(self):
+ return f'''
+
+ '''
+
+class Label(Widget):
+ def __init__(self, value=''):
+ super().__init__()
+ # databinding is defined using Property objects.
+ self.value = Property(value)
+
+ def widget_js(self):
+ # Both "model" and "element" objects are defined within the scope
+ # where the js is run. "element" looks for the element with id
+ # self.view_id(); if widget_html is overridden, this id should be used.
+ return '''
+ model.on('value', (value) => {
+ element.innerText = model.get('value');
+ });
+ '''
+ def widget_html(self):
+ return f'''
+
+ '''
+
+class Textbox(Widget):
+ def __init__(self, value='', size=20):
+ super().__init__()
+ # databinding is defined using Property objects.
+ self.value = Property(value)
+ self.size = Property(size)
+
+ def widget_js(self):
+ # Both "model" and "element" objects are defined within the scope
+ # where the js is run. "element" looks for the element with id
+ # self.view_id(); if widget_html is overridden, this id should be used.
+ return '''
+ element.value = model.get('value');
+ element.size = model.get('size');
+ element.addEventListener('keydown', (e) => {
+ if (e.code == 'Enter') {
+ model.set('value', element.value);
+ }
+ });
+ model.on('value', (value) => {
+ element.value = model.get('value');
+ });
+ model.on('size', (value) => {
+ element.size = model.get('size');
+ });
+ '''
+ def widget_html(self):
+ return f'''
+
+ '''
+
+class Range(Widget):
+ def __init__(self, value=50, min=0, max=100):
+ super().__init__()
+ # databinding is defined using Property objects.
+ self.value = Property(value)
+ self.min = Property(min)
+ self.max = Property(max)
+
+ def widget_js(self):
+ # Note that the 'input' event would enable during-drag feedback,
+ # but this is pretty slow on google colab.
+ return '''
+ element.addEventListener('change', (e) => {
+ model.set('value', element.value);
+ });
+ model.on('value', (value) => {
+ if (!element.matches(':active')) {
+ element.value = value;
+ }
+ })
+ '''
+ def widget_html(self):
+ return f'''
+
+ '''
+
+class Choice(Widget):
+ def __init__(self, choices=None, selection=None, horizontal=False):
+ super().__init__()
+ if choices is None:
+ choices = []
+ self.choices = Property(choices)
+ self.horizontal = Property(horizontal)
+ self.selection = Property(selection)
+ def widget_js(self):
+ # Note that the 'input' event would enable during-drag feedback,
+ # but this is pretty slow on google colab.
+ return '''
+ function esc(unsafe) {
+ return unsafe.replace(/&/g, "&").replace(//g, ">").replace(/"/g, """);
+ }
+ function render() {
+ console.log('rendering');
+ var lines = model.get('choices').map((c) => {
+ return ''
+ });
+ element.innerHTML = lines.join(model.get('horizontal')?' ':' ');
+ }
+ model.on('choices horizontal', render);
+ model.on('selection', (selection) => {
+ [...element.querySelectorAll('input')].forEach((e) => {
+ e.checked = (e.value == selection);
+ })
+ });
+ element.addEventListener('change', (e) => {
+ model.set('selection', element.choice.value);
+ });
+ '''
+ def widget_html(self):
+ radios = [
+ f""""""
+ for value in self.choices ]
+ sep = " " if self.horizontal else " "
+ return f''
+
+
+class Div(Widget):
+ def __init__(self, innerHTML='', style=None, data=None):
+ super().__init__()
+ style = {} if style is None else style
+ data = {} if data is None else data
+ # TODO: convert non-string innerHTML objects to html; unify
+ # with the show() library.
+ self.innerHTML = Property(innerHTML)
+ self.style = Property(style)
+ self.click = Trigger()
+
+ def print(self, text, replace=False):
+ newHTML = '
%s
' % html.escape(str(text));
+ if replace:
+ self.innerHTML = newHTML
+ else:
+ self.innerHTML += newHTML
+
+ def widget_js(self):
+ # Note that if we want innerHTML to support script execution,
+ # we need to do it explicitly, like this.
+ return '''
+ function updater(attr) {
+ return (val) => { for (k in val) { element[attr][k] = val[k]; } }
+ }
+ model.on('style', updater('style'));
+ updater('style')();
+ model.on('data', updater('dataset'));
+ updater('dataset')();
+ model.on('innerHTML', (innerHTML) => {
+ element.innerHTML = innerHTML;
+ Array.from(element.querySelectorAll("script")).forEach(old=>{
+ const newScript = document.createElement("script");
+ Array.from(old.attributes).forEach(attr =>
+ newScript.setAttribute(attr.name, attr.value));
+ newScript.appendChild(document.createTextNode(old.innerHTML));
+ old.parentNode.replaceChild(newScript, old);
+ });
+ });
+ '''
+ def widget_html(self):
+ return f'''
+
{self.innerHTML}
+ '''
+
+class ClickDiv(Div):
+ '''
+ A Div that triggers click events when anything inside them is clicked.
+ If a clicked element contains a data-click value, then that value is
+ sent as the click event value.
+ '''
+ def __init__(self, innerHTML='', style=None, data=None):
+ super().__init__(innertHTML, style, data)
+ self.click = Trigger()
+
+ def widget_js(self):
+ return super().widget_js() + '''
+ element.addEventListener('click', (ev) => {
+ var target = ev.target;
+ while (target && target != element && !target.dataset.click) {
+ target = target.parentElement;
+ }
+ var value = target.dataset.click;
+ model.trigger('click', value);
+ });
+ '''
+
+##########################################################################
+## Implementation Details
+##########################################################################
+
+WIDGET_ENV = None
+if WIDGET_ENV is None:
+ try:
+ from google.colab import output as colab_output
+ WIDGET_ENV = 'colab'
+ except:
+ pass
+if WIDGET_ENV is None:
+ try:
+ from ipykernel.comm import Comm as jupyter_comm
+ COMM_MANAGER = get_ipython().kernel.comm_manager
+ WIDGET_ENV = 'jupyter'
+ except:
+ pass
+
+SEND_RECV_JS = """
+function recvFromPython(obj_id, fn) {
+ var recvname = "recv_" + obj_id;
+ if (window[recvname] === undefined) {
+ window[recvname] = new BroadcastChannel("channel_" + obj_id);
+ }
+ window[recvname].addEventListener("message", (ev) => {
+ if (ev.data == 'ok') {
+ window[recvname].ok = true;
+ return;
+ }
+ fn.apply(null, ev.data.slice(1));
+ });
+}
+function sendToPython(obj_id, ...args) {
+ google.colab.kernel.invokeFunction('invoke_' + obj_id, args, {})
+}
+""" if WIDGET_ENV == 'colab' else """
+function getChan(obj_id) {
+ var cname = "comm_" + obj_id;
+ if (!window[cname]) { window[cname] = []; }
+ var chan = window[cname];
+ if (!chan.comm && Jupyter.notebook.kernel) {
+ chan.comm = Jupyter.notebook.kernel.comm_manager.new_comm(cname, {});
+ chan.comm.on_msg((ev) => {
+ if (chan.retry) { clearInterval(chan.retry); chan.retry = null; }
+ if (ev.content.data == 'ok') { return; }
+ var args = ev.content.data.slice(1);
+ for (fn of chan) { fn.apply(null, args); }
+ });
+ chan.retry = setInterval(() => { chan.comm.open(); }, 2000);
+ }
+ return chan;
+}
+function recvFromPython(obj_id, fn) {
+ getChan(obj_id).push(fn);
+}
+function sendToPython(obj_id, ...args) {
+ var comm = getChan(obj_id).comm;
+ if (comm) { comm.send(args); }
+}
+"""
+
+
+WIDGET_MODEL_JS = SEND_RECV_JS + """
+class Model {
+ constructor(obj_id, init) {
+ this._id = obj_id;
+ this._listeners = {};
+ this._data = Object.assign({}, init)
+ recvFromPython(this._id, (name, value) => {
+ this._data[name] = value;
+ if (this._listeners.hasOwnProperty(name)) {
+ this._listeners[name].forEach((fn) => { fn(value); });
+ }
+ })
+ }
+ trigger(name, value) {
+ sendToPython(this._id, name, value);
+ }
+ get(name) {
+ return this._data[name];
+ }
+ set(name, value) {
+ this.trigger(name, value);
+ }
+ on(name, fn) {
+ name.split(/\s+/).forEach((n) => {
+ if (!this._listeners.hasOwnProperty(n)) {
+ this._listeners[n] = [];
+ }
+ this._listeners[n].push(fn);
+ });
+ }
+ off(name, fn) {
+ name.split(/\s+/).forEach((n) => {
+ if (!fn) {
+ delete this._listeners[n];
+ } else if (this._listeners.hasOwnProperty(n)) {
+ this._listeners[n] = this._listeners[n].filter(
+ (e) => { return e !== fn; });
+ }
+ });
+ }
+}
+"""
diff --git a/torchkit/nethook.py b/torchkit/nethook.py
new file mode 100644
index 0000000..12177ca
--- /dev/null
+++ b/torchkit/nethook.py
@@ -0,0 +1,451 @@
+"""
+Utilities for instrumenting a torch model.
+
+Trace will hook one layer at a time.
+TraceDict will hook multiple layers at once.
+subsequence slices intervals from Sequential modules.
+get_module, replace_module, get_parameter resolve dotted names.
+set_requires_grad recursively sets requires_grad in module parameters.
+"""
+
+import contextlib
+import copy
+import inspect
+from collections import OrderedDict
+
+import torch
+
+
+class Trace(contextlib.AbstractContextManager):
+ """
+ To retain the output of the named layer during the computation of
+ the given network:
+
+ with Trace(net, 'layer.name') as ret:
+ _ = net(inp)
+ representation = ret.output
+
+ A layer module can be passed directly without a layer name, and
+ its output will be retained. By default, a direct reference to
+ the output object is returned, but options can control this:
+
+ clone=True - retains a copy of the output, which can be
+ useful if you want to see the output before it might
+ be modified by the network in-place later.
+ detach=True - retains a detached reference or copy. (By
+ default the value would be left attached to the graph.)
+ retain_grad=True - request gradient to be retained on the
+ output. After backward(), ret.output.grad is populated.
+
+ retain_input=True - also retains the input.
+ retain_output=False - can disable retaining the output.
+ edit_output=fn - calls the function to modify the output
+ of the layer before passing it the rest of the model.
+ fn can optionally accept (output, layer) arguments
+ for the original output and the layer name.
+ stop=True - throws a StopForward exception after the layer
+ is run, which allows running just a portion of a model.
+ """
+
+ def __init__(
+ self,
+ module,
+ layer=None,
+ retain_output=True,
+ retain_input=False,
+ clone=False,
+ detach=False,
+ retain_grad=False,
+ edit_output=None,
+ stop=False,
+ ):
+ """
+ Method to replace a forward method with a closure that
+ intercepts the call, and tracks the hook so that it can be reverted.
+ """
+ retainer = self
+ self.layer = layer
+ if layer is not None:
+ module = get_module(module, layer)
+
+ def retain_hook(m, inputs, output):
+ if retain_input:
+ retainer.input = recursive_copy(
+ inputs[0] if len(inputs) == 1 else inputs,
+ clone=clone,
+ detach=detach,
+ retain_grad=False,
+ ) # retain_grad applies to output only.
+ if edit_output:
+ output = invoke_with_optional_args(
+ edit_output, output=output, layer=self.layer
+ )
+ if retain_output:
+ retainer.output = recursive_copy(
+ output, clone=clone, detach=detach, retain_grad=retain_grad
+ )
+ # When retain_grad is set, also insert a trivial
+ # copy operation. That allows in-place operations
+ # to follow without error.
+ if retain_grad:
+ output = recursive_copy(retainer.output, clone=True, detach=False)
+ if stop:
+ raise StopForward()
+ return output
+
+ self.registered_hook = module.register_forward_hook(retain_hook)
+ self.stop = stop
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, type, value, traceback):
+ self.close()
+ if self.stop and issubclass(type, StopForward):
+ return True
+
+ def close(self):
+ self.registered_hook.remove()
+
+
+class TraceDict(OrderedDict, contextlib.AbstractContextManager):
+ """
+ To retain the output of multiple named layers during the computation
+ of the given network:
+
+ with TraceDict(net, ['layer1.name1', 'layer2.name2']) as ret:
+ _ = net(inp)
+ representation = ret['layer1.name1'].output
+
+ If edit_output is provided, it should be a function that takes
+ two arguments: output, and the layer name; and then it returns the
+ modified output.
+
+ Other arguments are the same as Trace. If stop is True, then the
+ execution of the network will be stopped after the last layer
+ listed (even if it would not have been the last to be executed).
+ """
+
+ def __init__(
+ self,
+ module,
+ layers=None,
+ retain_output=True,
+ retain_input=False,
+ clone=False,
+ detach=False,
+ retain_grad=False,
+ edit_output=None,
+ stop=False,
+ ):
+ self.stop = stop
+
+ def flag_last_unseen(it):
+ try:
+ it = iter(it)
+ prev = next(it)
+ seen = set([prev])
+ except StopIteration:
+ return
+ for item in it:
+ if item not in seen:
+ yield False, prev
+ seen.add(item)
+ prev = item
+ yield True, prev
+
+ for is_last, layer in flag_last_unseen(layers):
+ self[layer] = Trace(
+ module=module,
+ layer=layer,
+ retain_output=retain_output,
+ retain_input=retain_input,
+ clone=clone,
+ detach=detach,
+ retain_grad=retain_grad,
+ edit_output=edit_output,
+ stop=stop and is_last,
+ )
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, type, value, traceback):
+ self.close()
+ if self.stop and issubclass(type, StopForward):
+ return True
+
+ def close(self):
+ for layer, trace in reversed(self.items()):
+ trace.close()
+
+
+class StopForward(Exception):
+ """
+ If the only output needed from running a network is the retained
+ submodule then Trace(submodule, stop=True) will stop execution
+ immediately after the retained submodule by raising the StopForward()
+ exception. When Trace is used as context manager, it catches that
+ exception and can be used as follows:
+
+ with Trace(net, layername, stop=True) as tr:
+ net(inp) # Only runs the network up to layername
+ print(tr.output)
+ """
+
+ pass
+
+
+def recursive_copy(x, clone=None, detach=None, retain_grad=None):
+ """
+ Copies a reference to a tensor, or an object that contains tensors,
+ optionally detaching and cloning the tensor(s). If retain_grad is
+ true, the original tensors are marked to have grads retained.
+ """
+ if not clone and not detach and not retain_grad:
+ return x
+ if isinstance(x, torch.Tensor):
+ if retain_grad:
+ if not x.requires_grad:
+ x.requires_grad = True
+ x.retain_grad()
+ elif detach:
+ x = x.detach()
+ if clone:
+ x = x.clone()
+ return x
+ # Only dicts, lists, and tuples (and subclasses) can be copied.
+ if isinstance(x, dict):
+ return type(x)({k: recursive_copy(v) for k, v in x.items()})
+ elif isinstance(x, (list, tuple)):
+ return type(x)([recursive_copy(v) for v in x])
+ else:
+ assert False, f"Unknown type {type(x)} cannot be broken into tensors."
+
+
+def subsequence(
+ sequential,
+ first_layer=None,
+ last_layer=None,
+ after_layer=None,
+ upto_layer=None,
+ single_layer=None,
+ share_weights=False,
+):
+ """
+ Creates a subsequence of a pytorch Sequential model, copying over
+ modules together with parameters for the subsequence. Only
+ modules from first_layer to last_layer (inclusive) are included,
+ or modules between after_layer and upto_layer (exclusive).
+ Handles descent into dotted layer names as long as all references
+ are within nested Sequential models.
+
+ If share_weights is True, then references the original modules
+ and their parameters without copying them. Otherwise, by default,
+ makes a separate brand-new copy.
+ """
+ assert (single_layer is None) or (
+ first_layer is last_layer is after_layer is upto_layer is None
+ )
+ if single_layer is not None:
+ first_layer = single_layer
+ last_layer = single_layer
+ first, last, after, upto = [
+ None if d is None else d.split(".")
+ for d in [first_layer, last_layer, after_layer, upto_layer]
+ ]
+ return hierarchical_subsequence(
+ sequential,
+ first=first,
+ last=last,
+ after=after,
+ upto=upto,
+ share_weights=share_weights,
+ )
+
+
+def hierarchical_subsequence(
+ sequential, first, last, after, upto, share_weights=False, depth=0
+):
+ """
+ Recursive helper for subsequence() to support descent into dotted
+ layer names. In this helper, first, last, after, and upto are
+ arrays of names resulting from splitting on dots. Can only
+ descend into nested Sequentials.
+ """
+ assert (last is None) or (upto is None)
+ assert (first is None) or (after is None)
+ if first is last is after is upto is None:
+ return sequential if share_weights else copy.deepcopy(sequential)
+ assert isinstance(sequential, torch.nn.Sequential), (
+ ".".join((first or last or after or upto)[:depth] or "arg") + " not Sequential"
+ )
+ including_children = (first is None) and (after is None)
+ included_children = OrderedDict()
+ # A = current level short name of A.
+ # AN = full name for recursive descent if not innermost.
+ (F, FN), (L, LN), (A, AN), (U, UN) = [
+ (d[depth], (None if len(d) == depth + 1 else d))
+ if d is not None
+ else (None, None)
+ for d in [first, last, after, upto]
+ ]
+ for name, layer in sequential._modules.items():
+ if name == F:
+ first = None
+ including_children = True
+ if name == A and AN is not None: # just like F if not a leaf.
+ after = None
+ including_children = True
+ if name == U and UN is None:
+ upto = None
+ including_children = False
+ if including_children:
+ # AR = full name for recursive descent if name matches.
+ FR, LR, AR, UR = [
+ n if n is None or n[depth] == name else None for n in [FN, LN, AN, UN]
+ ]
+ chosen = hierarchical_subsequence(
+ layer,
+ first=FR,
+ last=LR,
+ after=AR,
+ upto=UR,
+ share_weights=share_weights,
+ depth=depth + 1,
+ )
+ if chosen is not None:
+ included_children[name] = chosen
+ if name == L:
+ last = None
+ including_children = False
+ if name == U and UN is not None: # just like L if not a leaf.
+ upto = None
+ including_children = False
+ if name == A and AN is None:
+ after = None
+ including_children = True
+ for name in [first, last, after, upto]:
+ if name is not None:
+ raise ValueError("Layer %s not found" % ".".join(name))
+ # Omit empty subsequences except at the outermost level,
+ # where we should not return None.
+ if not len(included_children) and depth > 0:
+ return None
+ result = torch.nn.Sequential(included_children)
+ result.training = sequential.training
+ return result
+
+
+def set_requires_grad(requires_grad, *models):
+ """
+ Sets requires_grad true or false for all parameters within the
+ models passed.
+ """
+ for model in models:
+ if isinstance(model, torch.nn.Module):
+ for param in model.parameters():
+ param.requires_grad = requires_grad
+ elif isinstance(model, (torch.nn.Parameter, torch.Tensor)):
+ model.requires_grad = requires_grad
+ else:
+ assert False, "unknown type %r" % type(model)
+
+
+def get_module(model, name):
+ """
+ Finds the named module within the given model.
+ """
+ for n, m in model.named_modules():
+ if n == name:
+ return m
+ raise LookupError(name)
+
+
+def get_parameter(model, name):
+ """
+ Finds the named parameter within the given model.
+ """
+ for n, p in model.named_parameters():
+ if n == name:
+ return p
+ raise LookupError(name)
+
+
+def replace_module(model, name, new_module):
+ """
+ Replaces the named module within the given model.
+ """
+ if "." in name:
+ parent_name, attr_name = name.rsplit(".", 1)
+ model = get_module(model, parent_name)
+ # original_module = getattr(model, attr_name)
+ setattr(model, attr_name, new_module)
+
+
+def invoke_with_optional_args(fn, *args, **kwargs):
+ """
+ Invokes a function with only the arguments that it
+ is written to accept, giving priority to arguments
+ that match by-name, using the following rules.
+ (1) arguments with matching names are passed by name.
+ (2) remaining non-name-matched args are passed by order.
+ (3) extra caller arguments that the function cannot
+ accept are not passed.
+ (4) extra required function arguments that the caller
+ cannot provide cause a TypeError to be raised.
+ Ordinary python calling conventions are helpful for
+ supporting a function that might be revised to accept
+ extra arguments in a newer version, without requiring the
+ caller to pass those new arguments. This function helps
+ support function callers that might be revised to supply
+ extra arguments, without requiring the callee to accept
+ those new arguments.
+ """
+ argspec = inspect.getfullargspec(fn)
+ pass_args = []
+ used_kw = set()
+ unmatched_pos = []
+ used_pos = 0
+ defaulted_pos = len(argspec.args) - (
+ 0 if not argspec.defaults else len(argspec.defaults)
+ )
+ # Pass positional args that match name first, then by position.
+ for i, n in enumerate(argspec.args):
+ if n in kwargs:
+ pass_args.append(kwargs[n])
+ used_kw.add(n)
+ elif used_pos < len(args):
+ pass_args.append(args[used_pos])
+ used_pos += 1
+ else:
+ unmatched_pos.append(len(pass_args))
+ pass_args.append(
+ None if i < defaulted_pos else argspec.defaults[i - defaulted_pos]
+ )
+ # Fill unmatched positional args with unmatched keyword args in order.
+ if len(unmatched_pos):
+ for k, v in kwargs.items():
+ if k in used_kw or k in argspec.kwonlyargs:
+ continue
+ pass_args[unmatched_pos[0]] = v
+ used_kw.add(k)
+ unmatched_pos = unmatched_pos[1:]
+ if len(unmatched_pos) == 0:
+ break
+ else:
+ if unmatched_pos[0] < defaulted_pos:
+ unpassed = ", ".join(
+ argspec.args[u] for u in unmatched_pos if u < defaulted_pos
+ )
+ raise TypeError(f"{fn.__name__}() cannot be passed {unpassed}.")
+ # Pass remaining kw args if they can be accepted.
+ pass_kw = {
+ k: v
+ for k, v in kwargs.items()
+ if k not in used_kw and (k in argspec.kwonlyargs or argspec.varargs is not None)
+ }
+ # Pass remaining positional args if they can be accepted.
+ if argspec.varargs is not None:
+ pass_args += list(args[used_pos:])
+ return fn(*pass_args, **pass_kw)
diff --git a/torchkit/paintwidget.py b/torchkit/paintwidget.py
new file mode 100644
index 0000000..cef3744
--- /dev/null
+++ b/torchkit/paintwidget.py
@@ -0,0 +1,147 @@
+from .labwidget import Widget, Property
+import html
+
+class PaintWidget(Widget):
+ def __init__(self,
+ width=256, height=256,
+ image='', mask='', brushsize=10.0, oneshot=False, disabled=False):
+ super().__init__()
+ self.mask = Property(mask)
+ self.image = Property(image)
+ self.brushsize = Property(brushsize)
+ self.erase = Property(False)
+ self.oneshot = Property(oneshot)
+ self.disabled = Property(disabled)
+ self.width = Property(width)
+ self.height = Property(height)
+
+ def widget_js(self):
+ return f'''
+ {PAINT_WIDGET_JS}
+ var pw = new PaintWidget(element, model);
+ '''
+ def widget_html(self):
+ v = self.view_id()
+ return f'''
+
+
+ '''
+
+PAINT_WIDGET_JS = """
+class PaintWidget {
+ constructor(el, model) {
+ this.el = el;
+ this.model = model;
+ this.size_changed();
+ this.model.on('mask', this.mask_changed.bind(this));
+ this.model.on('image', this.image_changed.bind(this));
+ this.model.on('width', this.size_changed.bind(this));
+ this.model.on('height', this.size_changed.bind(this));
+ }
+ mouse_stroke(first_event) {
+ var self = this;
+ if (self.model.get('disabled')) { return; }
+ if (self.model.get('oneshot')) {
+ var canvas = self.mask_canvas;
+ var ctx = canvas.getContext('2d');
+ ctx.clearRect(0, 0, canvas.width, canvas.height);
+ }
+ function track_mouse(evt) {
+ if (evt.type == 'keydown' || self.model.get('disabled')) {
+ if (self.model.get('disabled') || evt.key === "Escape") {
+ window.removeEventListener('mousemove', track_mouse);
+ window.removeEventListener('mouseup', track_mouse);
+ window.removeEventListener('keydown', track_mouse, true);
+ self.mask_changed();
+ }
+ return;
+ }
+ if (evt.type == 'mouseup' ||
+ (typeof evt.buttons != 'undefined' && evt.buttons == 0)) {
+ window.removeEventListener('mousemove', track_mouse);
+ window.removeEventListener('mouseup', track_mouse);
+ window.removeEventListener('keydown', track_mouse, true);
+ self.model.set('mask', self.mask_canvas.toDataURL());
+ return;
+ }
+ var p = self.cursor_position();
+ self.fill_circle(p.x, p.y,
+ self.model.get('brushsize'),
+ self.model.get('erase'));
+ }
+ this.mask_canvas.focus();
+ window.addEventListener('mousemove', track_mouse);
+ window.addEventListener('mouseup', track_mouse);
+ window.addEventListener('keydown', track_mouse, true);
+ track_mouse(first_event);
+ }
+ mask_changed(val) {
+ this.draw_data_url(this.mask_canvas, this.model.get('mask'));
+ }
+ image_changed() {
+ this.draw_data_url(this.image_canvas, this.model.get('image'));
+ }
+ size_changed() {
+ this.mask_canvas = document.createElement('canvas');
+ this.image_canvas = document.createElement('canvas');
+ this.mask_canvas.className = "paintmask";
+ this.image_canvas.className = "paintimage";
+ for (var attr of ['width', 'height']) {
+ this.mask_canvas[attr] = this.model.get(attr);
+ this.image_canvas[attr] = this.model.get(attr);
+ }
+
+ this.el.innerHTML = '';
+ this.el.appendChild(this.image_canvas);
+ this.el.appendChild(this.mask_canvas);
+ this.mask_canvas.addEventListener('mousedown',
+ this.mouse_stroke.bind(this));
+ this.mask_changed();
+ this.image_changed();
+ }
+
+ cursor_position(evt) {
+ const rect = this.mask_canvas.getBoundingClientRect();
+ const x = event.clientX - rect.left;
+ const y = event.clientY - rect.top;
+ return {x: x, y: y};
+ }
+
+ fill_circle(x, y, r, erase, blur) {
+ var ctx = this.mask_canvas.getContext('2d');
+ ctx.save();
+ if (blur) {
+ ctx.filter = 'blur(' + blur + 'px)';
+ }
+ ctx.globalCompositeOperation = (
+ erase ? "destination-out" : 'source-over');
+ ctx.fillStyle = '#fff';
+ ctx.beginPath();
+ ctx.arc(x, y, r, 0, 2 * Math.PI);
+ ctx.fill();
+ ctx.restore()
+ }
+
+ draw_data_url(canvas, durl) {
+ var ctx = canvas.getContext('2d');
+ var img = new Image;
+ canvas.pendingImg = img;
+ function imgdone() {
+ if (canvas.pendingImg == img) {
+ ctx.clearRect(0, 0, canvas.width, canvas.height);
+ ctx.drawImage(img, 0, 0);
+ canvas.pendingImg = null;
+ }
+ }
+ img.addEventListener('load', imgdone);
+ img.addEventListener('error', imgdone);
+ img.src = durl;
+ }
+}
+"""
diff --git a/torchkit/pbar.py b/torchkit/pbar.py
new file mode 100644
index 0000000..79264ee
--- /dev/null
+++ b/torchkit/pbar.py
@@ -0,0 +1,212 @@
+'''
+Utilities for showing progress bars, controlling default verbosity, etc.
+'''
+
+# If the tqdm package is not available, then do not show progress bars;
+# just connect print_progress to print.
+import sys
+import types
+import builtins
+try:
+ from tqdm import tqdm
+ try:
+ from tqdm.notebook import tqdm as tqdm_nb
+ except:
+ from tqdm import tqdm_notebook as tqdm_nb
+except:
+ tqdm = None
+
+default_verbosity = True
+next_description = None
+python_print = builtins.print
+
+
+def post(**kwargs):
+ '''
+ When within a progress loop, pbar.post(k=str) will display
+ the given k=str status on the right-hand-side of the progress
+ status bar. If not within a visible progress bar, does nothing.
+ '''
+ innermost = innermost_tqdm()
+ if innermost is not None:
+ innermost.set_postfix(**kwargs)
+
+
+def desc(desc):
+ '''
+ When within a progress loop, pbar.desc(str) changes the
+ left-hand-side description of the loop toe the given description.
+ '''
+ innermost = innermost_tqdm()
+ if innermost is not None:
+ innermost.set_description(str(desc))
+
+
+def descnext(desc):
+ '''
+ Called before starting a progress loop, pbar.descnext(str)
+ sets the description text that will be used in the following loop.
+ '''
+ global next_description
+ if not default_verbosity or tqdm is None:
+ return
+ next_description = desc
+
+
+def print(*args):
+ '''
+ When within a progress loop, will print above the progress loop.
+ '''
+ global next_description
+ next_description = None
+ if default_verbosity:
+ msg = ' '.join(str(s) for s in args)
+ if tqdm is None:
+ python_print(msg)
+ else:
+ tqdm.write(msg)
+
+
+def tqdm_terminal(it, *args, **kwargs):
+ '''
+ Some settings for tqdm that make it run better in resizable terminals.
+ '''
+ return tqdm(it, *args, dynamic_ncols=True, ascii=True,
+ leave=(innermost_tqdm() is not None), **kwargs)
+
+
+def in_notebook():
+ '''
+ True if running inside a Jupyter notebook.
+ '''
+ # From https://stackoverflow.com/a/39662359/265298
+ try:
+ shell = get_ipython().__class__.__name__
+ if shell == 'ZMQInteractiveShell':
+ return True # Jupyter notebook or qtconsole
+ elif shell == 'TerminalInteractiveShell':
+ return False # Terminal running IPython
+ else:
+ return False # Other type (?)
+ except NameError:
+ return False # Probably standard Python interpreter
+
+
+def innermost_tqdm():
+ '''
+ Returns the innermost active tqdm progress loop on the stack.
+ '''
+ if hasattr(tqdm, '_instances') and len(tqdm._instances) > 0:
+ return max(tqdm._instances, key=lambda x: x.pos)
+ else:
+ return None
+
+
+def reporthook(*args, **kwargs):
+ '''
+ For use with urllib.request.urlretrieve.
+
+ with pbar.reporthook() as hook:
+ urllib.request.urlretrieve(url, filename, reporthook=hook)
+ '''
+ kwargs2 = dict(unit_scale=True, miniters=1)
+ kwargs2.update(kwargs)
+ bar = __call__(None, *args, **kwargs2)
+
+ class ReportHook(object):
+ def __init__(self, t):
+ self.t = t
+
+ def __call__(self, b=1, bsize=1, tsize=None):
+ if hasattr(self.t, 'total'):
+ if tsize is not None:
+ self.t.total = tsize
+ if hasattr(self.t, 'update'):
+ self.t.update(b * bsize - self.t.n)
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, *exc):
+ if hasattr(self.t, '__exit__'):
+ self.t.__exit__(*exc)
+ return ReportHook(bar)
+
+
+def __call__(x, *args, **kwargs):
+ '''
+ Invokes a progress function that can wrap iterators to print
+ progress messages, if verbose is True.
+
+ If verbose is False or tqdm is unavailable, then a quiet
+ non-printing identity function is used.
+
+ verbose can also be set to a spefific progress function rather
+ than True, and that function will be used.
+ '''
+ global default_verbosity, next_description
+ if not default_verbosity or tqdm is None:
+ return x
+ if default_verbosity == True:
+ fn = tqdm_nb if in_notebook() else tqdm_terminal
+ else:
+ fn = default_verbosity
+ if next_description is not None:
+ kwargs = dict(kwargs)
+ kwargs['desc'] = next_description
+ next_description = None
+ return fn(x, *args, **kwargs)
+
+
+class VerboseContextManager():
+ def __init__(self, v, entered=False):
+ self.v, self.entered, self.saved = v, False, []
+ if entered:
+ self.__enter__()
+ self.entered = True
+
+ def __enter__(self):
+ global default_verbosity
+ if self.entered:
+ self.entered = False
+ else:
+ self.saved.append(default_verbosity)
+ default_verbosity = self.v
+ return self
+
+ def __exit__(self, exc_type, exc_value, exc_traceback):
+ global default_verbosity
+ default_verbosity = self.saved.pop()
+
+ def __call__(self, v=True):
+ '''
+ Calling the context manager makes a new context that is
+ pre-entered, so it works as both a plain function and as a
+ factory for a context manager.
+ '''
+ new_v = v if self.v else not v
+ cm = VerboseContextManager(new_v, entered=True)
+ default_verbosity = new_v
+ return cm
+
+
+# Use as either "with pbar.verbose:" or "pbar.verbose(False)", or also
+# "with pbar.verbose(False):"
+verbose = VerboseContextManager(True)
+
+# Use as either "with @pbar.quiet" or "pbar.quiet(True)". or also
+# "with pbar.quiet(True):"
+quiet = VerboseContextManager(False)
+
+
+class CallableModule(types.ModuleType):
+ def __init__(self):
+ # or super().__init__(__name__) for Python 3
+ types.ModuleType.__init__(self, __name__)
+ self.__dict__.update(sys.modules[__name__].__dict__)
+
+ def __call__(self, x, *args, **kwargs):
+ return __call__(x, *args, **kwargs)
+
+
+sys.modules[__name__] = CallableModule()
diff --git a/torchkit/pidfile.py b/torchkit/pidfile.py
new file mode 100644
index 0000000..3224b4d
--- /dev/null
+++ b/torchkit/pidfile.py
@@ -0,0 +1,129 @@
+'''
+Utility for simple distribution of work on multiple processes, by
+making sure only one process is working on a job at once.
+'''
+
+import os
+import errno
+import socket
+import atexit
+import time
+import sys
+
+
+def reserve_dir(*args):
+ '''
+ Convenience function to get exclusive access to an unfinished
+ experiment directory. Exits the program if the directory is
+ already done or busy (using exit_of_job_done). Otherwise,
+ returns a function creates filenames within that directory.
+ '''
+ directory = os.path.join(*args)
+ exit_if_job_done(directory)
+
+ def dirfn(*fn):
+ return os.path.join(directory, *fn)
+ dirfn.dir = directory
+
+ def done():
+ mark_job_done(directory)
+ dirfn.done = done
+ print('Working in %s' % directory)
+ return dirfn
+
+
+# Old function name.
+exclusive_dirfn = reserve_dir
+
+
+def exit_if_job_done(directory, redo=False, force=False, verbose=True):
+ if pidfile_taken(os.path.join(directory, 'lockfile.pid'),
+ force=force, verbose=verbose):
+ sys.exit(0)
+ donefile = os.path.join(directory, 'done.txt')
+ if os.path.isfile(donefile):
+ with open(donefile) as f:
+ msg = f.read()
+ if redo or force:
+ if verbose:
+ print('Removing %s %s' % (donefile, msg))
+ os.remove(donefile)
+ else:
+ if verbose:
+ print('%s %s' % (donefile, msg))
+ sys.exit(0)
+
+
+def mark_job_done(directory):
+ with open(os.path.join(directory, 'done.txt'), 'w') as f:
+ f.write('done by %d@%s %s at %s' %
+ (os.getpid(), socket.gethostname(),
+ os.getenv('STY', ''),
+ time.strftime('%c')))
+
+
+def pidfile_taken(path, verbose=False, force=False):
+ '''
+ Usage. To grab an exclusive lock for the remaining duration of the
+ current process (and exit if another process already has the lock),
+ do this:
+
+ if pidfile_taken('job_423/lockfile.pid', verbose=True):
+ sys.exit(0)
+
+ To do a batch of jobs, just run a script that does them all on
+ each available machine, sharing a network filesystem. When each
+ job grabs a lock, then this will automatically distribute the
+ jobs so that each one is done just once on one machine.
+ '''
+
+ # Try to create the file exclusively and write my pid into it.
+ try:
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ fd = os.open(path, os.O_CREAT | os.O_EXCL | os.O_RDWR)
+ except OSError as e:
+ if e.errno == errno.EEXIST:
+ # If we cannot because there was a race, yield the conflicter.
+ conflicter = 'race'
+ try:
+ with open(path, 'r') as lockfile:
+ conflicter = lockfile.read().strip() or 'empty'
+ except:
+ pass
+ # Force is for manual one-time use, for deleting stale lockfiles.
+ if force:
+ if verbose:
+ print('Removing %s from %s' % (path, conflicter))
+ os.remove(path)
+ return pidfile_taken(path, verbose=verbose, force=False)
+ if verbose:
+ print('%s held by %s' % (path, conflicter))
+ return conflicter
+ else:
+ # Other problems get an exception.
+ raise
+ # Register to delete this file on exit.
+ lockfile = os.fdopen(fd, 'r+')
+ atexit.register(delete_pidfile, lockfile, path)
+ # Write my pid into the open file.
+ lockfile.write('%d@%s %s\n' % (os.getpid(), socket.gethostname(),
+ os.getenv('STY', '')))
+ lockfile.flush()
+ os.fsync(lockfile)
+ # Return 'None' to say there was not a conflict.
+ return None
+
+
+def delete_pidfile(lockfile, path):
+ '''
+ Runs at exit after pidfile_taken succeeds.
+ '''
+ if lockfile is not None:
+ try:
+ lockfile.close()
+ except:
+ pass
+ try:
+ os.unlink(path)
+ except:
+ pass
diff --git a/torchkit/runningstats.py b/torchkit/runningstats.py
new file mode 100644
index 0000000..5cc575c
--- /dev/null
+++ b/torchkit/runningstats.py
@@ -0,0 +1,1871 @@
+"""
+To use a runningstats object,
+
+ 1. Create the the desired stat object, e.g., `m = Mean()`
+ 2. Feed it batches via the add method, e.g., `m.add(batch)`
+ 3. Repeat step 2 any number of times.
+ 4. Read out the statistic of interest, e.g., `m.mean()`
+
+Built-in runningstats objects include:
+
+ Mean - produces mean().
+ Variance - mean() and variance() and stdev().
+ Covariance - mean(), covariance(), correlation(), variance(), stdev().
+ SecondMoment - moment() is the non-mean-centered covariance, E[x x^T].
+ Quantile - quantile(), min(), max(), median(), mean(), variance(), stdev().
+ TopK - topk() returns (values, indexes).
+ Bincount - bincount() histograms nonnegative integer data.
+ IoU - intersection(), union(), iou() tally binary co-occurrences.
+ History - history() returns concatenation of data.
+ CrossCovariance - covariance between two signals, without self-covariance.
+ CrossIoU - iou between two signals, without self-IoU.
+ CombinedStat - aggregates any set of stats.
+
+Add more running stats by subclassing the Stat class.
+
+These statistics are vectorized along dim>=1, so stat.add()
+should supply a two-dimensional input where the zeroth
+dimension is the batch/sampling dimension and the first
+dimension is the feature dimension.
+
+The data type and device used matches the data passed to add();
+for example, for higher-precision covariances, convert to double
+before calling add().
+
+It is common to want to compute and remember a statistic sampled
+over a Dataset, computed in batches, possibly caching the computed
+statistic in a file. The tally(stat, dataset, cache) handles
+this pattern. It takes a statistic, a dataset, and a cache filename
+and sets up a data loader that can be run (or not, if cached) to
+compute the statistic, adopting the convention that cached stats are
+saved to and loaded from numpy npz files.
+"""
+
+import math
+import os
+import random
+import struct
+
+import numpy
+import torch
+from torch.utils.data.sampler import Sampler
+
+
+def tally(stat, dataset, cache=None, quiet=False, **kwargs):
+ """
+ To use tally, write code like the following.
+
+ stat = Mean()
+ ds = MyDataset()
+ for batch in tally(stat, ds, cache='mymean.npz', batch_size=50):
+ stat.add(batch)
+ mean = stat.mean()
+
+ The first argument should be the Stat being computed. After the
+ loader is exhausted, tally will bring this stat to the cpu and
+ cache it (if a cache is specified).
+
+ The dataset can be a torch Dataset or a plain Tensor, or it can
+ be a callable that returns one of those.
+
+ Details on caching via the cache= argument:
+
+ If the given filename cannot be loaded, tally will leave the
+ statistic object empty and set up a DataLoader object so that
+ the loop can be run. After the last iteration of the loop, the
+ completed statistic will be moved to the cpu device and also
+ saved in the cache file.
+
+ If the cached statistic can be loaded from the given file, tally
+ will not set up the data loader and instead will return a fully
+ loaded statistic object (on the cpu device) and an empty list as
+ the loader.
+
+ The `with cache_load_enabled(False):` context manager can
+ be used to disable loading from the cache.
+
+ If needed, a DataLoader will be created to wrap the dataset:
+
+ Keyword arguments of tally are passed to the DataLoader,
+ so batch_size, num_workers, pin_memory, etc. can be specified.
+
+ Subsampling is supported via sample_size= and random_sample=:
+
+ If sample_size=N is specified, rather than loading the whole
+ dataset, only the first N items are sampled. If additionally
+ random_sample=S is specified, the pseudorandom seed S will be
+ used to select a fixed psedorandom sample of size N to sample.
+ """
+ assert isinstance(stat, Stat)
+ args = {}
+ for k in ["sample_size"]:
+ if k in kwargs:
+ args[k] = kwargs[k]
+ cached_state = load_cached_state(cache, args, quiet=quiet)
+ if cached_state is not None:
+ stat.load_state_dict(cached_state)
+
+ def empty_loader():
+ return
+ yield
+
+ return empty_loader()
+ loader = make_loader(dataset, **kwargs)
+
+ def wrapped_loader():
+ yield from loader
+ stat.to_(device="cpu")
+ if cache is not None:
+ save_cached_state(cache, stat, args)
+
+ return wrapped_loader()
+
+
+class cache_load_enabled:
+ """
+ When used as a context manager, cache_load_enabled(False) will prevent
+ tally from loading cached statsitics, forcing them to be recomputed.
+ """
+
+ def __init__(self, enabled=True):
+ self.prev = False
+ self.enabled = enabled
+
+ def __enter__(self):
+ global global_load_cache_enabled
+ self.prev = global_load_cache_enabled
+ global_load_cache_enabled = self.enabled
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ global global_load_cache_enabled
+ global_load_cache_enabled = self.prev
+
+
+class Stat:
+ """
+ Abstract base class for a running pytorch statistic.
+ """
+
+ def __init__(self, state):
+ """
+ By convention, all Stat subclasses can be initialized by passing
+ state=; and then they will initialize by calling load_state_dict.
+ """
+ self.load_state_dict(resolve_state_dict(state))
+
+ def add(self, x, *args, **kwargs):
+ """
+ Observes a batch of samples to be incorporated into the statistic.
+ Dimension 0 should be the batch dimension, and dimension 1 should
+ be the feature dimension of the pytorch tensor x.
+ """
+ pass
+
+ def load_state_dict(self, d):
+ """
+ Loads this Stat from a dictionary of numpy arrays as saved
+ by state_dict.
+ """
+ pass
+
+ def state_dict(self):
+ """
+ Saves this Stat as a dictionary of numpy arrays that can be
+ stored in an npz or reloaded later using load_state_dict.
+ """
+ return {}
+
+ def save(self, filename):
+ """
+ Saves this stat as an npz file containing the state_dict.
+ """
+ save_cached_state(filename, self, {})
+
+ def load(self, filename):
+ """
+ Loads this stat from an npz file containing a saved state_dict.
+ """
+ self.load_state_dict(load_cached_state(filename, {}, quiet=True, throw=True))
+
+ def to_(self, device):
+ """
+ Moves this Stat to the given device.
+ """
+ pass
+
+ def cpu_(self):
+ """
+ Moves this Stat to the cpu device.
+ """
+ self.to_("cpu")
+
+ def cuda_(self):
+ """
+ Moves this Stat to the default cuda device.
+ """
+ self.to_("cuda")
+
+ def _normalize_add_shape(self, x, attr="data_shape"):
+ """
+ Flattens input data to 2d.
+ """
+ if not torch.is_tensor(x):
+ x = torch.tensor(x)
+ if len(x.shape) < 1:
+ x = x.view(-1)
+ data_shape = getattr(self, attr, None)
+ if data_shape is None:
+ data_shape = x.shape[1:]
+ setattr(self, attr, data_shape)
+ else:
+ assert x.shape[1:] == data_shape
+ return x.view(x.shape[0], int(numpy.prod(data_shape)))
+
+ def _restore_result_shape(self, x, attr="data_shape"):
+ """
+ Restores output data to input data shape.
+ """
+ data_shape = getattr(self, attr, None)
+ if data_shape is None:
+ return x
+ return x.view(data_shape * len(x.shape))
+
+
+class Mean(Stat):
+ """
+ Running mean.
+ """
+
+ def __init__(self, state=None):
+ if state is not None:
+ return super().__init__(state)
+ self.count = 0
+ self.batchcount = 0
+ self._mean = None
+ self.data_shape = None
+
+ def add(self, a):
+ a = self._normalize_add_shape(a)
+ if len(a) == 0:
+ return
+ batch_count = a.shape[0]
+ batch_mean = a.sum(0) / batch_count
+ self.batchcount += 1
+ # Initial batch.
+ if self._mean is None:
+ self.count = batch_count
+ self._mean = batch_mean
+ return
+ # Update a batch using Chan-style update for numerical stability.
+ self.count += batch_count
+ new_frac = float(batch_count) / self.count
+ # Update the mean according to the batch deviation from the old mean.
+ delta = batch_mean.sub_(self._mean).mul_(new_frac)
+ self._mean.add_(delta)
+
+ def size(self):
+ return self.count
+
+ def mean(self):
+ return self._restore_result_shape(self._mean)
+
+ def to_(self, device):
+ if self._mean is not None:
+ self._mean = self._mean.to(device)
+
+ def load_state_dict(self, state):
+ self.count = state["count"]
+ self.batchcount = state["batchcount"]
+ self._mean = torch.from_numpy(state["mean"])
+ self.data_shape = (
+ None if state["data_shape"] is None else tuple(state["data_shape"])
+ )
+
+ def state_dict(self):
+ return dict(
+ constructor=self.__module__ + "." + self.__class__.__name__ + "()",
+ count=self.count,
+ data_shape=self.data_shape and tuple(self.data_shape),
+ batchcount=self.batchcount,
+ mean=self._mean.cpu().numpy(),
+ )
+
+
+class Variance(Stat):
+ """
+ Running computation of mean and variance. Use this when you just need
+ basic stats without covariance.
+ """
+
+ def __init__(self, state=None):
+ if state is not None:
+ return super().__init__(state)
+ self.count = 0
+ self.batchcount = 0
+ self._mean = None
+ self.v_cmom2 = None
+ self.data_shape = None
+
+ def add(self, a):
+ a = self._normalize_add_shape(a)
+ if len(a) == 0:
+ return
+ batch_count = a.shape[0]
+ batch_mean = a.sum(0) / batch_count
+ centered = a - batch_mean
+ self.batchcount += 1
+ # Initial batch.
+ if self._mean is None:
+ self.count = batch_count
+ self._mean = batch_mean
+ self.v_cmom2 = centered.pow(2).sum(0)
+ return
+ # Update a batch using Chan-style update for numerical stability.
+ oldcount = self.count
+ self.count += batch_count
+ new_frac = float(batch_count) / self.count
+ # Update the mean according to the batch deviation from the old mean.
+ delta = batch_mean.sub_(self._mean).mul_(new_frac)
+ self._mean.add_(delta)
+ # Update the variance using the batch deviation
+ self.v_cmom2.add_(centered.pow(2).sum(0))
+ self.v_cmom2.add_(delta.pow_(2).mul_(new_frac * oldcount))
+
+ def size(self):
+ return self.count
+
+ def mean(self):
+ return self._restore_result_shape(self._mean)
+
+ def variance(self, unbiased=True):
+ return self._restore_result_shape(
+ self.v_cmom2 / (self.count - (1 if unbiased else 0))
+ )
+
+ def stdev(self, unbiased=True):
+ return self.variance(unbiased=unbiased).sqrt()
+
+ def to_(self, device):
+ if self._mean is not None:
+ self._mean = self._mean.to(device)
+ if self.v_cmom2 is not None:
+ self.v_cmom2 = self.v_cmom2.to(device)
+
+ def load_state_dict(self, state):
+ self.count = state["count"]
+ self.batchcount = state["batchcount"]
+ self._mean = torch.from_numpy(state["mean"])
+ self.v_cmom2 = torch.from_numpy(state["cmom2"])
+ self.data_shape = (
+ None if state["data_shape"] is None else tuple(state["data_shape"])
+ )
+
+ def state_dict(self):
+ return dict(
+ constructor=self.__module__ + "." + self.__class__.__name__ + "()",
+ count=self.count,
+ data_shape=self.data_shape and tuple(self.data_shape),
+ batchcount=self.batchcount,
+ mean=self._mean.cpu().numpy(),
+ cmom2=self.v_cmom2.cpu().numpy(),
+ )
+
+
+class Covariance(Stat):
+ """
+ Running computation. Use this when the entire covariance matrix is needed,
+ and when the whole covariance matrix fits in the GPU.
+
+ Chan-style numerically stable update of mean and full covariance matrix.
+ Chan, Golub. LeVeque. 1983. http://www.jstor.org/stable/2683386
+ """
+
+ def __init__(self, state=None):
+ if state is not None:
+ return super().__init__(state)
+ self.count = 0
+ self._mean = None
+ self.cmom2 = None
+ self.data_shape = None
+
+ def add(self, a):
+ a = self._normalize_add_shape(a)
+ if len(a) == 0:
+ return
+ batch_count = a.shape[0]
+ # Initial batch.
+ if self._mean is None:
+ self.count = batch_count
+ self._mean = a.sum(0) / batch_count
+ centered = a - self._mean
+ self.cmom2 = centered.t().mm(centered)
+ return
+ # Update a batch using Chan-style update for numerical stability.
+ self.count += batch_count
+ # Update the mean according to the batch deviation from the old mean.
+ delta = a - self._mean
+ self._mean.add_(delta.sum(0) / self.count)
+ delta2 = a - self._mean
+ # Update the variance using the batch deviation
+ self.cmom2.addmm_(mat1=delta.t(), mat2=delta2)
+
+ def to_(self, device):
+ if self._mean is not None:
+ self._mean = self._mean.to(device)
+ if self.cmom2 is not None:
+ self.cmom2 = self.cmom2.to(device)
+
+ def mean(self):
+ return self._restore_result_shape(self._mean)
+
+ def covariance(self, unbiased=True):
+ return self._restore_result_shape(
+ self.cmom2 / (self.count - (1 if unbiased else 0))
+ )
+
+ def correlation(self, unbiased=True):
+ cov = self.cmom2 / (self.count - (1 if unbiased else 0))
+ rstdev = cov.diag().sqrt().reciprocal()
+ return self._restore_result_shape(rstdev[:, None] * cov * rstdev[None, :])
+
+ def variance(self, unbiased=True):
+ return self._restore_result_shape(
+ self.cmom2.diag() / (self.count - (1 if unbiased else 0))
+ )
+
+ def stdev(self, unbiased=True):
+ return self.variance(unbiased=unbiased).sqrt()
+
+ def state_dict(self):
+ return dict(
+ constructor=self.__module__ + "." + self.__class__.__name__ + "()",
+ count=self.count,
+ data_shape=self.data_shape and tuple(self.data_shape),
+ mean=self._mean.cpu().numpy(),
+ cmom2=self.cmom2.cpu().numpy(),
+ )
+
+ def load_state_dict(self, state):
+ self.count = state["count"]
+ self._mean = torch.from_numpy(state["mean"])
+ self.cmom2 = torch.from_numpy(state["cmom2"])
+ self.data_shape = (
+ None if state["data_shape"] is None else tuple(state["data_shape"])
+ )
+
+
+class SecondMoment(Stat):
+ """
+ Running computation. Use this when the entire non-centered 2nd-moment
+ 'covariance-like' matrix is needed, and when the whole matrix fits
+ in the GPU.
+ """
+
+ def __init__(self, split_batch=True, state=None):
+ if state is not None:
+ return super().__init__(state)
+ self.count = 0
+ self.mom2 = None
+ self.split_batch = split_batch
+
+ def add(self, a):
+ a = self._normalize_add_shape(a)
+ if len(a) == 0:
+ return
+ # Initial batch reveals the shape of the data.
+ if self.count == 0:
+ self.mom2 = a.new(a.shape[1], a.shape[1]).zero_()
+ batch_count = a.shape[0]
+ # Update the covariance using the batch deviation
+ self.count += batch_count
+ self.mom2 += a.t().mm(a)
+
+ def to_(self, device):
+ if self.mom2 is not None:
+ self.mom2 = self.mom2.to(device)
+
+ def moment(self):
+ return self.mom2 / self.count
+
+ def state_dict(self):
+ return dict(
+ constructor=self.__module__ + "." + self.__class__.__name__ + "()",
+ count=self.count,
+ mom2=self.mom2.cpu().numpy(),
+ )
+
+ def load_state_dict(self, state):
+ self.count = int(state["count"])
+ self.mom2 = torch.from_numpy(state["mom2"])
+
+
+class Bincount(Stat):
+ """
+ Running bincount. The counted array should be an integer type with
+ non-negative integers.
+ """
+
+ def __init__(self, state=None):
+ if state is not None:
+ return super().__init__(state)
+ self.count = 0
+ self._bincount = None
+
+ def add(self, a, size=None):
+ a = a.view(-1)
+ bincount = a.bincount()
+ if self._bincount is None:
+ self._bincount = bincount
+ elif len(self._bincount) < len(bincount):
+ bincount[: len(self._bincount)] += self._bincount
+ self._bincount = bincount
+ else:
+ self._bincount[: len(bincount)] += bincount
+ if size is None:
+ self.count += len(a)
+ else:
+ self.count += size
+
+ def to_(self, device):
+ self._bincount = self._bincount.to(device)
+
+ def size(self):
+ return self.count
+
+ def bincount(self):
+ return self._bincount
+
+ def state_dict(self):
+ return dict(
+ constructor=self.__module__ + "." + self.__class__.__name__ + "()",
+ count=self.count,
+ bincount=self._bincount.cpu().numpy(),
+ )
+
+ def load_state_dict(self, dic):
+ self.count = int(dic["count"])
+ self._bincount = torch.from_numpy(dic["bincount"])
+
+
+class CrossCovariance(Stat):
+ """
+ Covariance. Use this when an off-diagonal block of the covariance
+ matrix is needed (e.g., when the whole covariance matrix does
+ not fit in the GPU, this could use a quarter of the memory).
+
+ Chan-style numerically stable update of mean and full covariance matrix.
+ Chan, Golub. LeVeque. 1983. http://www.jstor.org/stable/2683386
+ """
+
+ def __init__(self, split_batch=True, state=None):
+ if state is not None:
+ return super().__init__(state)
+ self.count = 0
+ self._mean = None
+ self.cmom2 = None
+ self.v_cmom2 = None
+ self.split_batch = split_batch
+
+ def add(self, a, b):
+ if len(a.shape) == 1:
+ a = a[None, :]
+ b = b[None, :]
+ assert a.shape[0] == b.shape[0]
+ if len(a.shape) > 2:
+ a, b = [
+ d.view(d.shape[0], d.shape[1], -1)
+ .permute(0, 2, 1)
+ .reshape(-1, d.shape[1])
+ for d in [a, b]
+ ]
+ batch_count = a.shape[0]
+ # Initial batch.
+ if self._mean is None:
+ self.count = batch_count
+ self._mean = [d.sum(0) / batch_count for d in [a, b]]
+ centered = [d - bm for d, bm in zip([a, b], self._mean)]
+ self.v_cmom2 = [c.pow(2).sum(0) for c in centered]
+ self.cmom2 = centered[0].t().mm(centered[1])
+ return
+ # Update a batch using Chan-style update for numerical stability.
+ self.count += batch_count
+ # Update the mean according to the batch deviation from the old mean.
+ delta = [(d - bm) for d, bm in zip([a, b], self._mean)]
+ for m, d in zip(self._mean, delta):
+ m.add_(d.sum(0) / self.count)
+ delta2 = [(d - bm) for d, bm in zip([a, b], self._mean)]
+ # Update the cross-covariance using the batch deviation
+ self.cmom2.addmm_(mat1=delta[0].t(), mat2=delta2[1])
+ # Update the variance using the batch deviation
+ for vc2, d, d2 in zip(self.v_cmom2, delta, delta2):
+ vc2.add_((d * d2).sum(0))
+
+ def mean(self):
+ return self._mean
+
+ def variance(self, unbiased=True):
+ return [vc2 / (self.count - (1 if unbiased else 0)) for vc2 in self.v_cmom2]
+
+ def stdev(self, unbiased=True):
+ return [v.sqrt() for v in self.variance(unbiased=unbiased)]
+
+ def covariance(self, unbiased=True):
+ return self.cmom2 / (self.count - (1 if unbiased else 0))
+
+ def correlation(self):
+ covariance = self.covariance(unbiased=False)
+ rstdev = [s.reciprocal() for s in self.stdev(unbiased=False)]
+ cor = rstdev[0][:, None] * covariance * rstdev[1][None, :]
+ # Remove NaNs
+ cor[torch.isnan(cor)] = 0
+ return cor
+
+ def to_(self, device):
+ self._mean = [m.to(device) for m in self._mean]
+ self.v_cmom2 = [vcs.to(device) for vcs in self.v_cmom2]
+ self.cmom2 = self.cmom2.to(device)
+
+ def state_dict(self):
+ return dict(
+ constructor=self.__module__ + "." + self.__class__.__name__ + "()",
+ count=self.count,
+ mean_a=self._mean[0].cpu().numpy(),
+ mean_b=self._mean[1].cpu().numpy(),
+ cmom2_a=self.v_cmom2[0].cpu().numpy(),
+ cmom2_b=self.v_cmom2[1].cpu().numpy(),
+ cmom2=self.cmom2.cpu().numpy(),
+ )
+
+ def load_state_dict(self, state):
+ self.count = int(state["count"])
+ self._mean = [torch.from_numpy(state[f"mean_{k}"]) for k in "ab"]
+ self.v_cmom2 = [torch.from_numpy(state[f"cmom2_{k}"]) for k in "ab"]
+ self.cmom2 = torch.from_numpy(state["cmom2"])
+
+
+def _float_from_bool(a):
+ """
+ Since pytorch only supports matrix multiplication on float,
+ IoU computations are done using floating point types.
+
+ This function binarizes the input (positive to True and
+ nonpositive to False), and converts from bool to float.
+ If the data is already a floating-point type, it leaves
+ it keeps the same type; otherwise it uses float.
+ """
+ if a.dtype == torch.bool:
+ return a.float()
+ if a.dtype.is_floating_point:
+ return a.sign().clamp_(0)
+ return (a > 0).float()
+
+
+class IoU(Stat):
+ """
+ Running computation of intersections and unions of all features.
+ """
+
+ def __init__(self, state=None):
+ if state is not None:
+ return super().__init__(state)
+ self.count = 0
+ self._intersection = None
+
+ def add(self, a):
+ assert len(a.shape) == 2
+ a = _float_from_bool(a)
+ if self._intersection is None:
+ self._intersection = torch.mm(a.t(), a)
+ else:
+ self._intersection.addmm_(a.t(), a)
+ self.count += len(a)
+
+ def size(self):
+ return self.count
+
+ def intersection(self):
+ return self._intersection
+
+ def union(self):
+ total = self._intersection.diagonal(0)
+ return total[:, None] + total[None, :] - self._intersection
+
+ def iou(self):
+ return self.intersection() / (self.union() + 1e-20)
+
+ def to_(self, _device):
+ self._intersection = self._intersection.to(_device)
+
+ def state_dict(self):
+ return dict(
+ constructor=self.__module__ + "." + self.__class__.__name__ + "()",
+ count=self.count,
+ intersection=self._intersection.cpu().numpy(),
+ )
+
+ def load_state_dict(self, state):
+ self.count = int(state["count"])
+ self._intersection = torch.tensor(state["intersection"])
+
+
+class CrossIoU(Stat):
+ """
+ Running computation of intersections and unions of two binary vectors.
+ """
+
+ def __init__(self, state=None):
+ if state is not None:
+ return super().__init__(state)
+ self.count = 0
+ self._intersection = None
+ self.total_a = None
+ self.total_b = None
+
+ def add(self, a, b):
+ assert len(a.shape) == 2 and len(b.shape) == 2
+ assert len(a) == len(b), f"{len(a)} vs {len(b)}"
+ a = _float_from_bool(a) # CUDA only supports mm on float...
+ b = _float_from_bool(b) # otherwise we would use integers.
+ intersection = torch.mm(a.t(), b)
+ asum = a.sum(0)
+ bsum = b.sum(0)
+ if self._intersection is None:
+ self._intersection = intersection
+ self.total_a = asum
+ self.total_b = bsum
+ else:
+ self._intersection += intersection
+ self.total_a += asum
+ self.total_b += bsum
+ self.count += len(a)
+
+ def size(self):
+ return self.count
+
+ def intersection(self):
+ return self._intersection
+
+ def union(self):
+ return self.total_a[:, None] + self.total_b[None, :] - self._intersection
+
+ def iou(self):
+ return self.intersection() / (self.union() + 1e-20)
+
+ def to_(self, _device):
+ self.total_a = self.total_a.to(_device)
+ self.total_b = self.total_b.to(_device)
+ self._intersection = self._intersection.to(_device)
+
+ def state_dict(self):
+ return dict(
+ constructor=self.__module__ + "." + self.__class__.__name__ + "()",
+ count=self.count,
+ total_a=self.total_a.cpu().numpy(),
+ total_b=self.total_b.cpu().numpy(),
+ intersection=self._intersection.cpu().numpy(),
+ )
+
+ def load_state_dict(self, state):
+ self.count = int(state["count"])
+ self.total_a = torch.tensor(state["total_a"])
+ self.total_b = torch.tensor(state["total_b"])
+ self._intersection = torch.tensor(state["intersection"])
+
+
+class Quantile(Stat):
+ """
+ Streaming randomized quantile computation for torch.
+
+ Add any amount of data repeatedly via add(data). At any time,
+ quantile estimates be read out using quantile(q).
+
+ Implemented as a sorted sample that retains at least r samples
+ (by default r = 3072); the number of retained samples will grow to
+ a finite ceiling as the data is accumulated. Accuracy scales according
+ to r: the default is to set resolution to be accurate to better than about
+ 0.1%, while limiting storage to about 50,000 samples.
+
+ Good for computing quantiles of huge data without using much memory.
+ Works well on arbitrary data with probability near 1.
+
+ Based on the optimal KLL quantile algorithm by Karnin, Lang, and Liberty
+ from FOCS 2016. http://ieee-focs.org/FOCS-2016-Papers/3933a071.pdf
+ """
+
+ def __init__(self, r=3 * 1024, buffersize=None, seed=None, state=None):
+ if state is not None:
+ return super().__init__(state)
+ self.depth = None
+ self.dtype = None
+ self.device = None
+ resolution = r * 2 # sample array is at least half full before discard
+ self.resolution = resolution
+ # Default buffersize: 128 samples (and smaller than resolution).
+ if buffersize is None:
+ buffersize = min(128, (resolution + 7) // 8)
+ self.buffersize = buffersize
+ self.samplerate = 1.0
+ self.data = None
+ self.firstfree = [0]
+ self.randbits = torch.ByteTensor(resolution)
+ self.currentbit = len(self.randbits) - 1
+ self.extremes = None
+ self.count = 0
+ self.batchcount = 0
+
+ def size(self):
+ return self.count
+
+ def _lazy_init(self, incoming):
+ self.depth = incoming.shape[1]
+ self.dtype = incoming.dtype
+ self.device = incoming.device
+ self.data = [
+ torch.zeros(
+ self.depth, self.resolution, dtype=self.dtype, device=self.device
+ )
+ ]
+ self.extremes = torch.zeros(self.depth, 2, dtype=self.dtype, device=self.device)
+ self.extremes[:, 0] = float("inf")
+ self.extremes[:, -1] = -float("inf")
+
+ def to_(self, device):
+ """Switches internal storage to specified device."""
+ if device != self.device:
+ old_data = self.data
+ old_extremes = self.extremes
+ self.data = [d.to(device) for d in self.data]
+ self.extremes = self.extremes.to(device)
+ self.device = self.extremes.device
+ del old_data
+ del old_extremes
+
+ def add(self, incoming):
+ if self.depth is None:
+ self._lazy_init(incoming)
+ assert len(incoming.shape) == 2
+ assert incoming.shape[1] == self.depth, (incoming.shape[1], self.depth)
+ self.count += incoming.shape[0]
+ self.batchcount += 1
+ # Convert to a flat torch array.
+ if self.samplerate >= 1.0:
+ self._add_every(incoming)
+ return
+ # If we are sampling, then subsample a large chunk at a time.
+ self._scan_extremes(incoming)
+ chunksize = int(math.ceil(self.buffersize / self.samplerate))
+ for index in range(0, len(incoming), chunksize):
+ batch = incoming[index : index + chunksize]
+ sample = sample_portion(batch, self.samplerate)
+ if len(sample):
+ self._add_every(sample)
+
+ def _add_every(self, incoming):
+ supplied = len(incoming)
+ index = 0
+ while index < supplied:
+ ff = self.firstfree[0]
+ available = self.data[0].shape[1] - ff
+ if available == 0:
+ if not self._shift():
+ # If we shifted by subsampling, then subsample.
+ incoming = incoming[index:]
+ if self.samplerate >= 0.5:
+ # First time sampling - the data source is very large.
+ self._scan_extremes(incoming)
+ incoming = sample_portion(incoming, self.samplerate)
+ index = 0
+ supplied = len(incoming)
+ ff = self.firstfree[0]
+ available = self.data[0].shape[1] - ff
+ copycount = min(available, supplied - index)
+ self.data[0][:, ff : ff + copycount] = torch.t(
+ incoming[index : index + copycount, :]
+ )
+ self.firstfree[0] += copycount
+ index += copycount
+
+ def _shift(self):
+ index = 0
+ # If remaining space at the current layer is less than half prev
+ # buffer size (rounding up), then we need to shift it up to ensure
+ # enough space for future shifting.
+ while self.data[index].shape[1] - self.firstfree[index] < (
+ -(-self.data[index - 1].shape[1] // 2) if index else 1
+ ):
+ if index + 1 >= len(self.data):
+ return self._expand()
+ data = self.data[index][:, 0 : self.firstfree[index]]
+ data = data.sort()[0]
+ if index == 0 and self.samplerate >= 1.0:
+ self._update_extremes(data[:, 0], data[:, -1])
+ offset = self._randbit()
+ position = self.firstfree[index + 1]
+ subset = data[:, offset::2]
+ self.data[index + 1][:, position : position + subset.shape[1]] = subset
+ self.firstfree[index] = 0
+ self.firstfree[index + 1] += subset.shape[1]
+ index += 1
+ return True
+
+ def _scan_extremes(self, incoming):
+ # When sampling, we need to scan every item still to get extremes
+ self._update_extremes(
+ torch.min(incoming, dim=0)[0], torch.max(incoming, dim=0)[0]
+ )
+
+ def _update_extremes(self, minr, maxr):
+ self.extremes[:, 0] = torch.min(
+ torch.stack([self.extremes[:, 0], minr]), dim=0
+ )[0]
+ self.extremes[:, -1] = torch.max(
+ torch.stack([self.extremes[:, -1], maxr]), dim=0
+ )[0]
+
+ def _randbit(self):
+ self.currentbit += 1
+ if self.currentbit >= len(self.randbits):
+ self.randbits.random_(to=2)
+ self.currentbit = 0
+ return self.randbits[self.currentbit]
+
+ def state_dict(self):
+ state = dict(
+ constructor=self.__module__ + "." + self.__class__.__name__ + "()",
+ resolution=self.resolution,
+ depth=self.depth,
+ buffersize=self.buffersize,
+ samplerate=self.samplerate,
+ sizes=numpy.array([d.shape[1] for d in self.data]),
+ extremes=self.extremes.cpu().detach().numpy(),
+ size=self.count,
+ batchcount=self.batchcount,
+ )
+ for i, (d, f) in enumerate(zip(self.data, self.firstfree)):
+ state[f"data.{i}"] = d.cpu().detach().numpy()[:, :f].T
+ return state
+
+ def load_state_dict(self, state):
+ self.resolution = int(state["resolution"])
+ self.randbits = torch.ByteTensor(self.resolution)
+ self.currentbit = len(self.randbits) - 1
+ self.depth = int(state["depth"])
+ self.buffersize = int(state["buffersize"])
+ self.samplerate = float(state["samplerate"])
+ firstfree = []
+ buffers = []
+ for i, s in enumerate(state["sizes"]):
+ d = state[f"data.{i}"]
+ firstfree.append(d.shape[0])
+ buf = numpy.zeros((d.shape[1], s), dtype=d.dtype)
+ buf[:, : d.shape[0]] = d.T
+ buffers.append(torch.from_numpy(buf))
+ self.firstfree = firstfree
+ self.data = buffers
+ self.extremes = torch.from_numpy((state["extremes"]))
+ self.count = int(state["size"])
+ self.batchcount = int(state.get("batchcount", 0))
+ self.dtype = self.extremes.dtype
+ self.device = self.extremes.device
+
+ def min(self):
+ return self.minmax()[0]
+
+ def max(self):
+ return self.minmax()[-1]
+
+ def minmax(self):
+ if self.firstfree[0]:
+ self._scan_extremes(self.data[0][:, : self.firstfree[0]].t())
+ return self.extremes.clone()
+
+ def median(self):
+ return self.quantiles(0.5)
+
+ def mean(self):
+ return self.integrate(lambda x: x) / self.count
+
+ def variance(self, unbiased=True):
+ mean = self.mean()[:, None]
+ return self.integrate(lambda x: (x - mean).pow(2)) / (
+ self.count - (1 if unbiased else 0)
+ )
+
+ def stdev(self, unbiased=True):
+ return self.variance(unbiased=unbiased).sqrt()
+
+ def _expand(self):
+ cap = self._next_capacity()
+ if cap > 0:
+ # First, make a new layer of the proper capacity.
+ self.data.insert(
+ 0, torch.zeros(self.depth, cap, dtype=self.dtype, device=self.device)
+ )
+ self.firstfree.insert(0, 0)
+ else:
+ # Unless we're so big we are just subsampling.
+ assert self.firstfree[0] == 0
+ self.samplerate *= 0.5
+ for index in range(1, len(self.data)):
+ # Scan for existing data that needs to be moved down a level.
+ amount = self.firstfree[index]
+ if amount == 0:
+ continue
+ position = self.firstfree[index - 1]
+ # Move data down if it would leave enough empty space there
+ # This is the key invariant: enough empty space to fit half
+ # of the previous level's buffer size (rounding up)
+ if self.data[index - 1].shape[1] - (amount + position) >= (
+ -(-self.data[index - 2].shape[1] // 2) if (index - 1) else 1
+ ):
+ self.data[index - 1][:, position : position + amount] = self.data[
+ index
+ ][:, :amount]
+ self.firstfree[index - 1] += amount
+ self.firstfree[index] = 0
+ else:
+ # Scrunch the data if it would not.
+ data = self.data[index][:, :amount]
+ data = data.sort()[0]
+ if index == 1:
+ self._update_extremes(data[:, 0], data[:, -1])
+ offset = self._randbit()
+ scrunched = data[:, offset::2]
+ self.data[index][:, : scrunched.shape[1]] = scrunched
+ self.firstfree[index] = scrunched.shape[1]
+ return cap > 0
+
+ def _next_capacity(self):
+ cap = int(math.ceil(self.resolution * (0.67 ** len(self.data))))
+ if cap < 2:
+ return 0
+ # Round up to the nearest multiple of 8 for better GPU alignment.
+ cap = -8 * (-cap // 8)
+ return max(self.buffersize, cap)
+
+ def _weighted_summary(self, sort=True):
+ if self.firstfree[0]:
+ self._scan_extremes(self.data[0][:, : self.firstfree[0]].t())
+ size = sum(self.firstfree)
+ weights = torch.FloatTensor(size) # Floating point
+ summary = torch.zeros(self.depth, size, dtype=self.dtype, device=self.device)
+ index = 0
+ for level, ff in enumerate(self.firstfree):
+ if ff == 0:
+ continue
+ summary[:, index : index + ff] = self.data[level][:, :ff]
+ weights[index : index + ff] = 2.0**level
+ index += ff
+ assert index == summary.shape[1]
+ if sort:
+ summary, order = torch.sort(summary, dim=-1)
+ weights = weights[order.view(-1).cpu()].view(order.shape)
+ summary = torch.cat(
+ [self.extremes[:, :1], summary, self.extremes[:, 1:]], dim=-1
+ )
+ weights = torch.cat(
+ [
+ torch.zeros(weights.shape[0], 1),
+ weights,
+ torch.zeros(weights.shape[0], 1),
+ ],
+ dim=-1,
+ )
+ return (summary, weights)
+
+ def quantiles(self, quantiles):
+ if not hasattr(quantiles, "cpu"):
+ quantiles = torch.tensor(quantiles)
+ qshape = quantiles.shape
+ if self.count == 0:
+ return torch.full((self.depth,) + qshape, torch.nan)
+ summary, weights = self._weighted_summary()
+ cumweights = torch.cumsum(weights, dim=-1) - weights / 2
+ cumweights /= torch.sum(weights, dim=-1, keepdim=True)
+ result = torch.zeros(
+ self.depth, quantiles.numel(), dtype=self.dtype, device=self.device
+ )
+ # numpy is needed for interpolation
+ nq = quantiles.view(-1).cpu().detach().numpy()
+ ncw = cumweights.cpu().detach().numpy()
+ nsm = summary.cpu().detach().numpy()
+ for d in range(self.depth):
+ result[d] = torch.tensor(
+ numpy.interp(nq, ncw[d], nsm[d]), dtype=self.dtype, device=self.device
+ )
+ return result.view((self.depth,) + qshape)
+
+ def integrate(self, fun):
+ result = []
+ for level, ff in enumerate(self.firstfree):
+ if ff == 0:
+ continue
+ result.append(
+ torch.sum(fun(self.data[level][:, :ff]) * (2.0**level), dim=-1)
+ )
+ if len(result) == 0:
+ return None
+ return torch.stack(result).sum(dim=0) / self.samplerate
+
+ def readout(self, count=1001):
+ return self.quantiles(torch.linspace(0.0, 1.0, count))
+
+ def normalize(self, data):
+ """
+ Given input data as taken from the training distirbution,
+ normalizes every channel to reflect quantile values,
+ uniformly distributed, within [0, 1].
+ """
+ assert self.count > 0
+ assert data.shape[0] == self.depth
+ summary, weights = self._weighted_summary()
+ cumweights = torch.cumsum(weights, dim=-1) - weights / 2
+ cumweights /= torch.sum(weights, dim=-1, keepdim=True)
+ result = torch.zeros_like(data).float()
+ # numpy is needed for interpolation
+ ndata = data.cpu().numpy().reshape((data.shape[0], -1))
+ ncw = cumweights.cpu().numpy()
+ nsm = summary.cpu().numpy()
+ for d in range(self.depth):
+ normed = torch.tensor(
+ numpy.interp(ndata[d], nsm[d], ncw[d]),
+ dtype=torch.float,
+ device=data.device,
+ ).clamp_(0.0, 1.0)
+ if len(data.shape) > 1:
+ normed = normed.view(*(data.shape[1:]))
+ result[d] = normed
+ return result
+
+
+def sample_portion(vec, p=0.5):
+ """
+ Subsamples a fraction (given by p) of the given batch. Used by
+ Quantile when the data gets very very large.
+ """
+ bits = torch.bernoulli(
+ torch.zeros(vec.shape[0], dtype=torch.uint8, device=vec.device), p
+ )
+ return vec[bits]
+
+
+class TopK:
+ """
+ A class to keep a running tally of the the top k values (and indexes)
+ of any number of torch feature components. Will work on the GPU if
+ the data is on the GPU. Tracks largest by default, but tracks smallest
+ if largest=False is passed.
+
+ This version flattens all arrays to avoid crashes.
+ """
+
+ def __init__(self, k=100, largest=True, state=None):
+ if state is not None:
+ return super().__init__(state)
+ self.k = k
+ self.count = 0
+ # This version flattens all data internally to 2-d tensors,
+ # to avoid crashes with the current pytorch topk implementation.
+ # The data is puffed back out to arbitrary tensor shapes on ouput.
+ self.data_shape = None
+ self.top_data = None
+ self.top_index = None
+ self.next = 0
+ self.linear_index = 0
+ self.perm = None
+ self.largest = largest
+
+ def add(self, data, index=None):
+ """
+ Adds a batch of data to be considered for the running top k.
+ The zeroth dimension enumerates the observations. All other
+ dimensions enumerate different features.
+ """
+ if self.top_data is None:
+ # Allocation: allocate a buffer of size 5*k, at least 10, for each.
+ self.data_shape = data.shape[1:]
+ feature_size = int(numpy.prod(self.data_shape))
+ self.top_data = torch.zeros(
+ feature_size, max(10, self.k * 5), out=data.new()
+ )
+ self.top_index = self.top_data.clone().long()
+ self.linear_index = (
+ 0
+ if len(data.shape) == 1
+ else torch.arange(feature_size, out=self.top_index.new()).mul_(
+ self.top_data.shape[-1]
+ )[:, None]
+ )
+ size = data.shape[0]
+ sk = min(size, self.k)
+ if self.top_data.shape[-1] < self.next + sk:
+ # Compression: if full, keep topk only.
+ self.top_data[:, : self.k], self.top_index[:, : self.k] = self.topk(
+ sorted=False, flat=True
+ )
+ self.next = self.k
+ # Pick: copy the top sk of the next batch into the buffer.
+ # Currently strided topk is slow. So we clone after transpose.
+ # TODO: remove the clone() if it becomes faster.
+ cdata = data.reshape(size, numpy.prod(data.shape[1:])).t().clone()
+ td, ti = cdata.topk(sk, sorted=False, largest=self.largest)
+ self.top_data[:, self.next : self.next + sk] = td
+ if index is not None:
+ ti = index[ti]
+ else:
+ ti = ti + self.count
+ self.top_index[:, self.next : self.next + sk] = ti
+ self.next += sk
+ self.count += size
+
+ def size(self):
+ return self.count
+
+ def topk(self, sorted=True, flat=False):
+ """
+ Returns top k data items and indexes in each dimension,
+ with channels in the first dimension and k in the last dimension.
+ """
+ k = min(self.k, self.next)
+ # bti are top indexes relative to buffer array.
+ td, bti = self.top_data[:, : self.next].topk(
+ k, sorted=sorted, largest=self.largest
+ )
+ # we want to report top indexes globally, which is ti.
+ ti = self.top_index.view(-1)[(bti + self.linear_index).view(-1)].view(
+ *bti.shape
+ )
+ if flat:
+ return td, ti
+ else:
+ return (
+ td.view(*(self.data_shape + (-1,))),
+ ti.view(*(self.data_shape + (-1,))),
+ )
+
+ def to_(self, device):
+ if self.top_data is not None:
+ self.top_data = self.top_data.to(device)
+ if self.top_index is not None:
+ self.top_index = self.top_index.to(device)
+ if isinstance(self.linear_index, torch.Tensor):
+ self.linear_index = self.linear_index.to(device)
+
+ def state_dict(self):
+ return dict(
+ constructor=self.__module__ + "." + self.__class__.__name__ + "()",
+ k=self.k,
+ count=self.count,
+ largest=self.largest,
+ data_shape=self.data_shape and tuple(self.data_shape),
+ top_data=self.top_data.cpu().detach().numpy(),
+ top_index=self.top_index.cpu().detach().numpy(),
+ next=self.next,
+ linear_index=(
+ self.linear_index.cpu().numpy()
+ if isinstance(self.linear_index, torch.Tensor)
+ else self.linear_index
+ ),
+ perm=self.perm,
+ )
+
+ def load_state_dict(self, state):
+ self.k = int(state["k"])
+ self.count = int(state["count"])
+ self.largest = bool(state.get("largest", True))
+ self.data_shape = (
+ None if state["data_shape"] is None else tuple(state["data_shape"])
+ )
+ self.top_data = torch.from_numpy(state["top_data"])
+ self.top_index = torch.from_numpy(state["top_index"])
+ self.next = int(state["next"])
+ self.linear_index = (
+ torch.from_numpy(state["linear_index"])
+ if len(state["linear_index"].shape) > 0
+ else int(state["linear_index"])
+ )
+
+
+class History(Stat):
+ """
+ Accumulates the concatenation of all the added data.
+ """
+
+ def __init__(self, data=None, state=None):
+ if state is not None:
+ return super().__init__(state)
+ self._data = data
+ self._added = []
+
+ def _cat_added(self):
+ if len(self._added):
+ self._data = torch.cat(
+ ([self._data] if self._data is not None else []) + self._added
+ )
+ self._added = []
+
+ def add(self, d):
+ self._added.append(d)
+ if len(self._added) > 100:
+ self._cat_added()
+
+ def history(self):
+ self._cat_added()
+ return self._data
+
+ def load_state_dict(self, state):
+ data = state["data"]
+ self._data = None if data is None else torch.from_numpy(data)
+ self._added = []
+
+ def state_dict(self):
+ self._cat_added()
+ return dict(
+ constructor=self.__module__ + "." + self.__class__.__name__ + "()",
+ data=None if self._data is None else self._data.cpu().numpy(),
+ )
+
+ def to_(self, device):
+ """Switches internal storage to specified device."""
+ self._cat_added()
+ if self._data is not None:
+ self._data = self._data.to(device)
+
+
+class CombinedStat(Stat):
+ """
+ A Stat that bundles together multiple Stat objects.
+ Convenient for loading and saving a state_dict made up of a
+ hierarchy of stats, and for use with the tally() function.
+ Example:
+
+ cs = CombinedStat(m=Mean(), q=Quantile())
+ for [b] in tally(cs, MyDataSet(), cache=fn, batch_size=100):
+ cs.add(b)
+ print(cs.m.mean())
+ print(cs.q.median())
+ """
+
+ def __init__(self, state=None, **kwargs):
+ self._objs = kwargs
+ if state is not None:
+ return super().__init__(state)
+
+ def __getattr__(self, k):
+ if k in self._objs:
+ return self._objs[k]
+ raise AttributeError()
+
+ def add(self, d, *args, **kwargs):
+ for obj in self._objs.values():
+ obj.add(d, *args, **kwargs)
+
+ def load_state_dict(self, state):
+ for prefix, obj in self._objs.items():
+ obj.load_state_dict(pull_key_prefix(prefix, state))
+
+ def state_dict(self):
+ result = {}
+ for prefix, obj in self._objs.items():
+ result.update(push_key_prefix(prefix, obj.state_dict()))
+ return result
+
+ def to_(self, device):
+ """Switches internal storage to specified device."""
+ for v in self._objs.values():
+ v.to_(device)
+
+
+def push_key_prefix(prefix, d):
+ """
+ Returns a dict with the same values as d, but where each key
+ adds the prefix, followed by a dot.
+ """
+ return {prefix + "." + k: v for k, v in d.items()}
+
+
+def pull_key_prefix(prefix, d):
+ """
+ Returns a filtered dict of all the items of d that start with
+ the given key prefix, plus a dot, with that prefix removed.
+ """
+ pd = prefix + "."
+ lpd = len(pd)
+ return {k[lpd:]: v for k, v in d.items() if k.startswith(pd)}
+
+
+# We wish to be able to save None (null) values in numpy npz files,
+# yet do so without setting the unsecure 'allow_pickle' flag. To do
+# that, we will encode null as a special kind of IEEE 754 NaN value.
+# Inspired by https://github.com/zuiderkwast/nanbox/blob/master/nanbox.h
+# we follow the same Nanboxing scheme used in JavaScriptCore
+# (search for JSCJSValue.h#L435), which encodes null values in NaN
+# as the NaN value with hex pattern 0xfff8000000000002.
+
+null_numpy_value = numpy.array(
+ struct.unpack(">d", struct.pack(">Q", 0xFFF8000000000002))[0], dtype=numpy.float64
+)
+
+
+def is_null_numpy_value(v):
+ """
+ True if v is a 64-bit float numpy scalar NaN matching null_numpy_value.
+ """
+ return (
+ isinstance(v, numpy.ndarray)
+ and numpy.ndim(v) == 0
+ and v.dtype == numpy.float64
+ and numpy.isnan(v)
+ and 0xFFF8000000000002 == struct.unpack(">Q", struct.pack(">d", v))[0]
+ )
+
+
+def box_numpy_null(d):
+ """
+ Replaces None with null_numpy_value, leaving non-None values unchanged.
+ Recursively descends into a dictionary replacing None values.
+ """
+ try:
+ return {k: box_numpy_null(v) for k, v in d.items()}
+ except Exception:
+ return null_numpy_value if d is None else d
+
+
+def unbox_numpy_null(d):
+ """
+ Reverses box_numpy_null, replacing null_numpy_value with None.
+ Recursively descends into a dictionary replacing None values.
+ """
+ try:
+ return {k: unbox_numpy_null(v) for k, v in d.items()}
+ except Exception:
+ return None if is_null_numpy_value(d) else d
+
+
+def resolve_state_dict(s):
+ """
+ Resolves a state, which can be a filename or a dict-like object.
+ """
+ if isinstance(s, str):
+ return unbox_numpy_null(numpy.load(s))
+ return s
+
+
+global_load_cache_enabled = True
+
+
+def load_cached_state(cachefile, args, quiet=False, throw=False):
+ """
+ Resolves a state, which can be a filename or a dict-like object.
+ """
+ if not global_load_cache_enabled or cachefile is None:
+ return None
+ try:
+ if isinstance(cachefile, dict):
+ dat = cachefile
+ cachefile = "state" # for printed messages
+ else:
+ dat = unbox_numpy_null(numpy.load(cachefile))
+ for a, v in args.items():
+ if a not in dat or dat[a] != v:
+ if not quiet:
+ print("%s %s changed from %s to %s" % (cachefile, a, dat[a], v))
+ return None
+ except (FileNotFoundError, ValueError) as e:
+ if throw:
+ raise e
+ return None
+ else:
+ if not quiet:
+ print("Loading cached %s" % cachefile)
+ return dat
+
+
+def save_cached_state(cachefile, obj, args):
+ """
+ Saves the state_dict of the given object in a dict or npz file.
+ """
+ if cachefile is None:
+ return
+ dat = obj.state_dict()
+ for a, v in args.items():
+ if a in dat:
+ assert dat[a] == v
+ dat[a] = v
+ if isinstance(cachefile, dict):
+ cachefile.clear()
+ cachefile.update(dat)
+ else:
+ os.makedirs(os.path.dirname(cachefile), exist_ok=True)
+ numpy.savez(cachefile, **box_numpy_null(dat))
+
+
+class FixedSubsetSampler(Sampler):
+ """Represents a fixed sequence of data set indices.
+ Subsets can be created by specifying a subset of output indexes.
+ """
+
+ def __init__(self, samples):
+ self.samples = samples
+
+ def __iter__(self):
+ return iter(self.samples)
+
+ def __len__(self):
+ return len(self.samples)
+
+ def __getitem__(self, key):
+ return self.samples[key]
+
+ def subset(self, new_subset):
+ return FixedSubsetSampler(self.dereference(new_subset))
+
+ def dereference(self, indices):
+ """
+ Translate output sample indices (small numbers indexing the sample)
+ to input sample indices (larger number indexing the original full set)
+ """
+ return [self.samples[i] for i in indices]
+
+
+class FixedRandomSubsetSampler(FixedSubsetSampler):
+ """Samples a fixed number of samples from the dataset, deterministically.
+ Arguments:
+ data_source,
+ sample_size,
+ seed (optional)
+ """
+
+ def __init__(self, data_source, start=None, end=None, seed=1):
+ rng = random.Random(seed)
+ shuffled = list(range(len(data_source)))
+ rng.shuffle(shuffled)
+ self.data_source = data_source
+ super(FixedRandomSubsetSampler, self).__init__(shuffled[start:end])
+
+ def class_subset(self, class_filter):
+ """
+ Returns only the subset matching the given rule.
+ """
+ if isinstance(class_filter, int):
+
+ def rule(d):
+ return d[1] == class_filter
+
+ else:
+ rule = class_filter
+ return self.subset(
+ [i for i, j in enumerate(self.samples) if rule(self.data_source[j])]
+ )
+
+
+def make_loader(
+ dataset, sample_size=None, batch_size=1, sampler=None, random_sample=None, **kwargs
+):
+ """Utility for creating a dataloader on fixed sample subset."""
+ import typing
+
+ if isinstance(dataset, typing.Callable):
+ # To support deferred dataset loading, support passing a factory
+ # that creates the dataset when called.
+ dataset = dataset()
+ if isinstance(dataset, torch.Tensor):
+ # The dataset can be a simple tensor.
+ dataset = torch.utils.data.TensorDataset(dataset)
+ if sample_size is not None:
+ assert sampler is None, "sampler cannot be specified with sample_size"
+ if sample_size > len(dataset):
+ print(
+ "Warning: sample size %d > dataset size %d"
+ % (sample_size, len(dataset))
+ )
+ sample_size = len(dataset)
+ if random_sample is None:
+ sampler = FixedSubsetSampler(list(range(sample_size)))
+ else:
+ sampler = FixedRandomSubsetSampler(
+ dataset, seed=random_sample, end=sample_size
+ )
+ return torch.utils.data.DataLoader(
+ dataset, sampler=sampler, batch_size=batch_size, **kwargs
+ )
+
+
+# Unit Tests
+def _unit_test():
+ import warnings
+
+ warnings.filterwarnings("error")
+ import argparse
+ import random
+ import shutil
+ import tempfile
+ import time
+
+ parser = argparse.ArgumentParser(description="Test things out")
+ parser.add_argument("--mode", default="cpu", help="cpu or cuda")
+ parser.add_argument("--test_size", type=int, default=1000000)
+ args = parser.parse_args()
+ testdir = tempfile.mkdtemp()
+ batch_size = random.randint(500, 1500)
+
+ # Test NaNboxing.
+ assert numpy.isnan(null_numpy_value)
+ assert is_null_numpy_value(null_numpy_value)
+ assert not is_null_numpy_value(numpy.nan)
+
+ # Test Covariance
+ goal = torch.tensor(numpy.random.RandomState(1).standard_normal(10 * 10)).view(
+ 10, 10
+ )
+ data = (
+ torch.tensor(numpy.random.RandomState(2).standard_normal(args.test_size * 10))
+ .view(args.test_size, 10)
+ .mm(goal)
+ )
+ data += torch.randn(1, 10) * 999
+ dcov = data.t().cov()
+ dcorr = data.t().corrcoef()
+ rcov = Covariance()
+ rcov.add(data) # All one batch
+ assert (rcov.covariance() - dcov).abs().max() < 1e-16
+ cs = CombinedStat(cov=Covariance(), xcov=CrossCovariance())
+ ds = torch.utils.data.TensorDataset(data)
+ for [a] in tally(cs, ds, batch_size=9876):
+ cs.cov.add(a)
+ cs.xcov.add(a[:, :3], a[:, 3:])
+ assert (data.mean(0) - cs.cov.mean()).abs().max() < 1e-12
+ assert (dcov - cs.cov.covariance()).abs().max() < 2e-12
+ assert (dcov[:3, 3:] - cs.xcov.covariance()).abs().max() < 1e-12
+ assert (dcov.diagonal() - torch.cat(cs.xcov.variance())).abs().max() < 1e-12
+ assert (dcorr - cs.cov.correlation()).abs().max() < 2e-12
+
+ # Test CrossCovariance and CrossIoU
+ fn = f"{testdir}/cross_cache.npz"
+ ds = torch.utils.data.TensorDataset(
+ (
+ torch.arange(args.test_size)[:, None] % torch.arange(1, 6)[None, :] == 0
+ ).double(),
+ (
+ torch.arange(args.test_size)[:, None] % torch.arange(5, 8)[None, :] == 0
+ ).double(),
+ )
+ c = CombinedStat(c=CrossCovariance(), iou=CrossIoU())
+ riou = IoU()
+ count = 0
+ for [a, b] in tally(c, ds, cache=fn, batch_size=100):
+ count += 1
+ c.add(a, b)
+ riou.add(torch.cat([a, b], dim=1))
+ assert count == -(-args.test_size // 100)
+ cor = c.c.correlation()
+ iou = c.iou.iou()
+ assert cor.shape == iou.shape == (5, 3)
+ assert iou[4, 0] == 1.0
+ assert abs(iou[0, 2] + (-args.test_size // 7 / float(args.test_size))) < 1e-6
+ assert abs(cor[4, 0] - 1.0) < 1e-2
+ assert abs(cor[0, 2] - 0.0) < 1e-6
+ assert all((riou.iou()[:5, -3:] == iou).view(-1))
+ assert all(riou.iou().diagonal(0) == 1)
+ c = CombinedStat(c=CrossCovariance(), iou=CrossIoU())
+ count = 0
+ for [a, b] in tally(c, ds, cache=fn, batch_size=10):
+ count += 1
+ c.add(a, b)
+ assert count == 0
+ assert all((c.c.correlation() == cor).view(-1))
+ assert all((c.iou.iou() == iou).view(-1))
+
+ # Test Concatantaion, Mean, Bincount and tally.
+ fn = f"{testdir}/series_cache.npz"
+ count = 0
+ ds = torch.utils.data.TensorDataset(torch.arange(args.test_size))
+ c = CombinedStat(s=History(), m=Mean(), b=Bincount())
+ for [b] in tally(c, ds, cache=fn, batch_size=batch_size):
+ count += 1
+ c.add(b)
+ assert count == -(-args.test_size // batch_size)
+ assert len(c.s.history()) == args.test_size
+ assert c.s.history()[-1] == args.test_size - 1
+ assert all(c.s.history() == ds.tensors[0])
+ assert all(c.b.bincount() == torch.ones(args.test_size))
+ assert c.m.mean() == float(args.test_size - 1) / 2.0
+ c2 = CombinedStat(s=History(), m=Mean(), b=Bincount())
+ batches = tally(c2, ds, cache=fn)
+ assert len(c2.s.history()) == args.test_size
+ assert all(c2.s.history() == c.s.history())
+ assert all(c2.b.bincount() == torch.ones(args.test_size))
+ assert c2.m.mean() == c.m.mean()
+ count = 0
+ for b in batches:
+ count += 1
+ assert count == 0 # Shouldn't do anything when it's cached
+
+ # An adverarial case: we keep finding more numbers in the middle
+ # as the stream goes on.
+ amount = args.test_size
+ quantiles = 1000
+ data = numpy.arange(float(amount))
+ data[1::2] = data[-1::-2] + (len(data) - 1)
+ data /= 2
+ depth = 50
+ alldata = data[:, None] + (numpy.arange(depth) * amount)[None, :]
+ actual_sum = torch.FloatTensor(numpy.sum(alldata * alldata, axis=0))
+ amt = amount // depth
+ for r in range(depth):
+ numpy.random.shuffle(alldata[r * amt : r * amt + amt, r])
+ if args.mode == "cuda":
+ alldata = torch.cuda.FloatTensor(alldata)
+ device = torch.device("cuda")
+ else:
+ alldata = torch.FloatTensor(alldata)
+ device = None
+ starttime = time.time()
+ cs = CombinedStat(
+ qc=Quantile(),
+ m=Mean(),
+ v=Variance(),
+ c=Covariance(),
+ s=SecondMoment(),
+ t=TopK(),
+ i=IoU(),
+ )
+ # Feed data in little batches
+ i = 0
+ while i < len(alldata):
+ batch_size = numpy.random.randint(1000)
+ cs.add(alldata[i : i + batch_size])
+ i += batch_size
+ # Test state dict
+ saved = cs.state_dict()
+ # numpy.savez(f'{testdir}/saved.npz', **box_numpy_null(saved))
+ # saved = unbox_numpy_null(numpy.load(f'{testdir}/saved.npz'))
+ cs.save(f"{testdir}/saved.npz")
+ loaded = unbox_numpy_null(numpy.load(f"{testdir}/saved.npz"))
+ assert set(loaded.keys()) == set(saved.keys())
+
+ # Restore using state=saved in constructor.
+ cs2 = CombinedStat(
+ qc=Quantile(),
+ m=Mean(),
+ v=Variance(),
+ c=Covariance(),
+ s=SecondMoment(),
+ t=TopK(),
+ i=IoU(),
+ state=saved,
+ )
+ # saved = unbox_numpy_null(numpy.load(f'{testdir}/saved.npz'))
+ assert not cs2.qc.device.type == "cuda"
+ cs2.to_(device)
+ # alldata = alldata.cpu()
+ cs2.add(alldata)
+ actual_sum *= 2
+ # print(abs(alldata.mean(0) - cs2.m.mean()) / alldata.mean())
+ assert all(abs(alldata.mean(0) - cs2.m.mean()) / alldata.mean() < 1e-5)
+ assert all(abs(alldata.mean(0) - cs2.v.mean()) / alldata.mean() < 1e-5)
+ assert all(abs(alldata.mean(0) - cs2.c.mean()) / alldata.mean() < 1e-5)
+ # print(abs(alldata.var(0) - cs2.v.variance()) / alldata.var(0))
+ assert all(abs(alldata.var(0) - cs2.v.variance()) / alldata.var(0) < 1e-3)
+ assert all(abs(alldata.var(0) - cs2.c.variance()) / alldata.var(0) < 1e-2)
+ # print(abs(alldata.std(0) - cs2.v.stdev()) / alldata.std(0))
+ assert all(abs(alldata.std(0) - cs2.v.stdev()) / alldata.std(0) < 1e-4)
+ # print(abs(alldata.std(0) - cs2.c.stdev()) / alldata.std(0))
+ assert all(abs(alldata.std(0) - cs2.c.stdev()) / alldata.std(0) < 2e-3)
+ moment = (alldata.t() @ alldata) / len(alldata)
+ # print(abs(moment - cs2.s.moment()) / moment.abs())
+ assert all((abs(moment - cs2.s.moment()) / moment.abs()).view(-1) < 1e-2)
+ assert all(alldata.max(dim=0)[0] == cs2.t.topk()[0][:, 0])
+ assert cs2.i.iou()[0, 0] == 1
+ assert all((cs2.i.iou()[1:, 1:] == 1).view(-1))
+ assert all(cs2.i.iou()[1:, 0] < 1)
+ assert all(cs2.i.iou()[1:, 0] == cs2.i.iou()[0, 1:])
+
+ # Restore using cs.load() method.
+ cs = CombinedStat(
+ qc=Quantile(),
+ m=Mean(),
+ v=Variance(),
+ c=Covariance(),
+ s=SecondMoment(),
+ t=TopK(),
+ i=IoU(),
+ )
+ cs.load(f"{testdir}/saved.npz")
+ assert not cs.qc.device.type == "cuda"
+ cs.to_(device)
+ cs.add(alldata)
+ # actual_sum *= 2
+ # print(abs(alldata.mean(0) - cs.m.mean()) / alldata.mean())
+ assert all(abs(alldata.mean(0) - cs.m.mean()) / alldata.mean() < 1e-5)
+ assert all(abs(alldata.mean(0) - cs.v.mean()) / alldata.mean() < 1e-5)
+ assert all(abs(alldata.mean(0) - cs.c.mean()) / alldata.mean() < 1e-5)
+ # print(abs(alldata.var(0) - cs.v.variance()) / alldata.var(0))
+ assert all(abs(alldata.var(0) - cs.v.variance()) / alldata.var(0) < 1e-3)
+ assert all(abs(alldata.var(0) - cs.c.variance()) / alldata.var(0) < 1e-2)
+ # print(abs(alldata.std(0) - cs.v.stdev()) / alldata.std(0))
+ assert all(abs(alldata.std(0) - cs.v.stdev()) / alldata.std(0) < 1e-4)
+ # print(abs(alldata.std(0) - cs.c.stdev()) / alldata.std(0))
+ assert all(abs(alldata.std(0) - cs.c.stdev()) / alldata.std(0) < 2e-3)
+ moment = (alldata.t() @ alldata) / len(alldata)
+ # print(abs(moment - cs.s.moment()) / moment.abs())
+ assert all((abs(moment - cs.s.moment()) / moment.abs()).view(-1) < 1e-2)
+ assert all(alldata.max(dim=0)[0] == cs.t.topk()[0][:, 0])
+ assert cs.i.iou()[0, 0] == 1
+ assert all((cs.i.iou()[1:, 1:] == 1).view(-1))
+ assert all(cs.i.iou()[1:, 0] < 1)
+ assert all(cs.i.iou()[1:, 0] == cs.i.iou()[0, 1:])
+
+ # Randomized quantile test
+ qc = cs.qc
+ ro = qc.readout(1001).cpu()
+ endtime = time.time()
+ gt = (
+ torch.linspace(0, amount, quantiles + 1)[None, :]
+ + (torch.arange(qc.depth, dtype=torch.float) * amount)[:, None]
+ )
+ maxreldev = torch.max(torch.abs(ro - gt) / amount) * quantiles
+ print("Randomized quantile test results:")
+ print("Maximum relative deviation among %d perentiles: %f" % (quantiles, maxreldev))
+ minerr = torch.max(
+ torch.abs(
+ qc.minmax().cpu()[:, 0] - torch.arange(qc.depth, dtype=torch.float) * amount
+ )
+ )
+ maxerr = torch.max(
+ torch.abs(
+ (qc.minmax().cpu()[:, -1] + 1)
+ - (torch.arange(qc.depth, dtype=torch.float) + 1) * amount
+ )
+ )
+ print("Minmax error %f, %f" % (minerr, maxerr))
+ interr = torch.max(
+ torch.abs(qc.integrate(lambda x: x * x).cpu() - actual_sum) / actual_sum
+ )
+ print("Integral error: %f" % interr)
+ medianerr = torch.max(
+ torch.abs(qc.median() - alldata.median(0)[0]) / alldata.median(0)[0]
+ ).cpu()
+ print("Median error: %f" % medianerr)
+ meanerr = torch.max(torch.abs(qc.mean() - alldata.mean(0)) / alldata.mean(0)).cpu()
+ print("Mean error: %f" % meanerr)
+ varerr = torch.max(torch.abs(qc.variance() - alldata.var(0)) / alldata.var(0)).cpu()
+ print("Variance error: %f" % varerr)
+ counterr = (
+ (qc.integrate(lambda x: torch.ones(x.shape[-1]).cpu()) - qc.size())
+ / (0.0 + qc.size())
+ ).item()
+ print("Count error: %f" % counterr)
+ print("Time %f" % (endtime - starttime))
+ # Algorithm is randomized, so some of these will fail with low probability.
+ assert maxreldev < 1.0
+ assert minerr == 0.0
+ assert maxerr == 0.0
+ assert interr < 0.01
+ assert abs(counterr) < 0.001
+ shutil.rmtree(testdir, ignore_errors=True)
+ print("OK")
+
+
+if __name__ == "__main__":
+ _unit_test()
diff --git a/torchkit/show.py b/torchkit/show.py
new file mode 100644
index 0000000..7129d99
--- /dev/null
+++ b/torchkit/show.py
@@ -0,0 +1,138 @@
+# show.py
+#
+# An abbreviated way to output simple HTML layout of text and images
+# into a python notebook.
+#
+# - show a PIL image to show an inline HTML .
+# - show an array of items to vertically stack them, centered in a block.
+# - show an array of arrays to horizontally lay them out as inline blocks.
+# - show an array of tuples to create a table.
+
+import PIL.Image, base64, io, IPython, types, sys
+import html as html_module
+from IPython.display import display
+
+g_buffer = None
+
+def blocks(obj, space=''):
+ return IPython.display.HTML(space.join(blocks_tags(obj)))
+
+def rows(obj, space=''):
+ return IPython.display.HTML(space.join(rows_tags(obj)))
+
+def rows_tags(obj):
+ if isinstance(obj, dict):
+ obj = obj.items()
+ results = []
+ results.append('