diff --git a/README.md b/README.md new file mode 100644 index 0000000..027add9 --- /dev/null +++ b/README.md @@ -0,0 +1 @@ +Some utilities useful for prototyping with pytorch. diff --git a/torchkit/labwidget.py b/torchkit/labwidget.py new file mode 100644 index 0000000..116e023 --- /dev/null +++ b/torchkit/labwidget.py @@ -0,0 +1,678 @@ +""" +labwidget by David Bau. + +Base class for a lightweight javascript notebook widget framework +that is portable across Google colab and Jupyter notebooks. +No use of requirejs: the design uses all inline javascript. + +Defines Model, Widget, Trigger, and Property, which set up data binding +using the communication channels available in either google colab +environment or jupyter notebook. + +This module also defines Label, Textbox, Range, Choice, and Div +widgets; the code for these are good examples of usage of Widget, +Trigger, and Property objects. + +User interaction should update the javascript model using +model.set('propname', value); this will propagate to the python +model and notify any registered python listeners. + +TODO: Support jupyterlab also. +""" + +import json, html, re +from inspect import signature + +class Model(object): + ''' + Abstract base class that supports data binding. Within __init__, + a model subclass defines databound events and properties using: + + self.evtname = Trigger() + self.propname = Property(initval) + + Any Trigger or Property member can be watched by registering a + listener with `model.on('propname', callback)`. + + An event can be triggered by `model.evtname.trigger(value)`. + A property can be read with `model.propname`, and can be set by + `model.propname = value`; this also triggers notifications. + In both these cases, any registered listeners will be called + with the given value. + ''' + def on(self, name, cb): + ''' + Registers a listener for named events and properties. + A space-separated list of names can be provided as `name`. + ''' + for n in name.split(): + self.prop(n).on(cb) + return self + + def off(self, name, cb): + ''' + Unregisters a listener for named events and properties. + A space-separated list of names can be provided as `name`. + ''' + for n in name.split(): + self.prop(n).off(cb) + return self + + def prop(self, name): + ''' + Returns the underlying Trigger or Property object for a + property, rather than its held value. + ''' + curvalue = super().__getattribute__(name) + if not isinstance(curvalue, Trigger): + raise AttributeError('%s not a property or trigger but %s' + % (name, str(type(curvalue)))) + return curvalue + + def _initprop_(self, name, value): + ''' + To be overridden in base classes. Handles initialization of + a new Trigger or Property member. + ''' + return + + def __setattr__(self, name, value): + ''' + When a member is an Trigger or Property, then assignment notation + is delegated to the Trigger or Property so that notifications + and reparenting can be handled. That is, `model.name = value` + turns into `prop(name).set(value)`. + ''' + if hasattr(self, name): + curvalue = super().__getattribute__(name) + if isinstance(curvalue, Trigger): + # Delegte "set" to the underlying Property. + curvalue.set(value) + else: + super().__setattr__(name, value) + else: + super().__setattr__(name, value) + if isinstance(value, Trigger): + self._initprop_(name, value) + + def __getattribute__(self, name): + ''' + When a member is a Property, then property getter + notation is delegated to the peoperty object. + ''' + curvalue = super().__getattribute__(name) + if isinstance(curvalue, Property): + return curvalue.value + return curvalue + +class Widget(Model): + ''' + Base class for an HTML widget that uses a Javascript model object + to syncrhonize HTML view state with the backend Python model state. + Each widget subclass overrides widget_js to provide Javascript code + that defines the widget's behavior. This javascript will be wrapped + in an immediately-invoked function and included in the widget's HTML + representation (_repr_html_) when the widget is viewed. + + A widget's javascript is provided with two local variables: + + element - the widget's root HTML element. By default this is + a
but can be overridden in widget_html. + model - the object representing the data model for the widget. + within javascript. + + The model object provides the following javascript API: + + model.get('propname') obtains a current property value. + model.set('propname', 'value') requests a change in value. + model.on('propname', callback) listens for property changes. + model.trigger('evtname', value) triggers an event. + + Note that model.set just requests a change but does not change the + value immediately: model.get will not reflect the change until the + python backend has handled it and notified the javascript of the new + value, which will trigger any callbacks previously registered using + .on('propname', callback). Thus Widget impelements a V-shaped + notification protocol: + + User entry -> | -> User-visible feedback + js model.set -> | -> js.model.on callback + python prop.trigger -> | -> python prop.notify + python prop.handle + ''' + + def __init__(self): + # In the jupyter case, there can be some delay between js injection + # and comm creation, so we need to queue some initial messages. + if WIDGET_ENV == 'jupyter': + self._comms = [] + self._queue = [] + # Each call to _repr_html_ creates a unique view instance. + self._viewcount = 0 + # Python notification is handled by Property objects. + def handle_remote_set(name, value): + self.prop(name).trigger(value) + self._recv_from_js_(handle_remote_set) + + def widget_js(self): + ''' + Override to define the javascript logic for the widget. Should + render the initial view based on the current model state (if not + already rendered using widget_html) and set up listeners to keep + the model and the view synchornized. + ''' + return '' + + def widget_html(self): + ''' + Override to define the initial HTML view of the widget. Should + define an element with id given by view_id(). + ''' + return f'
' + + def view_id(self): + ''' + Returns an HTML element id for the view currently being rendered. + Note that each time _repr_html_ is called, this id will change. + ''' + return f"_{id(self)}_{self._viewcount}" + + def _repr_html_(self): + ''' + Returns the HTML code for the widget. + ''' + self._viewcount += 1 + json_data = json.dumps({ + k: v.value for k, v in vars(self).items() + if isinstance(v, Property)}) + json_data = re.sub(' + (function() {{ + {WIDGET_MODEL_JS} + var model = new Model("{id(self)}", {json_data}); + var element = document.getElementById("{self.view_id()}"); + {self.widget_js()} + }})(); + + """ + + def _initprop_(self, name, value): + if not hasattr(self, '_viewcount'): + raise ValueError('base Model __init__ must be called') + def notify_js(value): + self._send_to_js_(id(self), name, value) + if isinstance(value, Trigger): + value.on(notify_js) + + def _send_to_js_(self, *args): + if self._viewcount > 0: + if WIDGET_ENV == 'colab': + colab_output.eval_js(f""" + (window.send_{id(self)} = window.send_{id(self)} || + new BroadcastChannel("channel_{id(self)}") + ).postMessage({json.dumps(args)}); + """, ignore_result=True) + elif WIDGET_ENV == 'jupyter': + if not self._comms: + self._queue.append(args) + return + for comm in self._comms: + comm.send(args) + + def _recv_from_js_(self, fn): + if WIDGET_ENV == 'colab': + colab_output.register_callback(f"invoke_{id(self)}", fn) + elif WIDGET_ENV == 'jupyter': + def handle_comm(msg): + fn(*(msg['content']['data'])) + # TODO: handle closing also. + def handle_close(close_msg): + comm_id = close_msg['content']['comm_id'] + self._comms = [c for c in self._comms if c.comm_id != comm_id] + def open_comm(comm, open_msg): + self._comms.append(comm) + comm.on_msg(handle_comm) + comm.on_close(handle_close) + comm.send('ok') + if self._queue: + for args in self._queue: + comm.send(args) + self._queue.clear() + if open_msg['content']['data']: + handle_comm(open_msg) + cname = "comm_" + str(id(self)) + COMM_MANAGER.register_target(cname, open_comm) + +class Trigger(object): + """ + Trigger is the base class for Property and other data-bound + field objects. Trigger holds a list of listeners that need to + be notified about the event. + + Multple Trigger objects can be tied (typically a parent Model can + have Triggers that are triggered by children models). To support + this, each Trigger can have a parent. + + Trigger objects provide a notification protocol where view + interactions trigger events at a leaf that are sent up to the + root Trigger to be handled. By default, the root handler accepts + events by notifying all listeners and children in the tree. + """ + def __init__(self): + self._listeners = [] + self.parent = None + def handle(self, value): + ''' + Method to override; called at the root when an event has been + triggered, and on a child when the parent has notified. By + default notifies all listeners. + ''' + self.notify(value) + def trigger(self, value=None): + ''' + Triggers an event to be handled by the root. By default, the root + handler will accept the event so all the listeners will be notified. + ''' + if self.parent is not None: + self.parent.trigger(value) + else: + self.handle(value) + def set(self, value): + ''' + Sets the parent Trigger. Child Triggers trigger events by + triggering parents, and in turn they handle notifications + that come from parents. + ''' + if self.parent is not None: + self.parent.off(self.handle) + self.parent = None + if isinstance(value, Trigger): + ancestor = value.parent + while ancestor is not None: + if ancestor == self: + raise ValueError('bound properties should not form a loop') + ancestor = ancestor.parent + self.parent = value + self.parent.on(self.handle) + elif not isinstance(self, Property): + raise ValueError('only properties can be set to a value') + def notify(self, value=None): + ''' + Notifies listeners and children. If a listener accepts an argument, + the value will be passed as a single argument. + ''' + for cb in self._listeners: + if len(signature(cb).parameters) == 0: + cb() # no-parameter callback. + else: + cb(value) + # TODO: consider adding a two-parameter callback form + # where a detailed event object can be added. + def on(self, cb): + ''' + Registers a listener. Calling multiple times registers + multiple listeners. + ''' + self._listeners.append(cb) + def off(self, cb): + ''' + Unregisters a listener. + ''' + self._listeners = [c for c in self._listeners if c != cb] + +class Property(Trigger): + """ + A Property is just an Trigger that remembers its last value. + """ + def __init__(self, value=None): + ''' + Can be initialized with a starting value. + ''' + super().__init__() + self.set(value) + def handle(self, value): + ''' + The default handling for a Property is to store the value, + then notify listeners. This method can be overridden, + for example to validate values. + ''' + self.value = value + self.notify(value) + def set(self, value): + ''' + When a Property value is set to an ordinary value, it + triggers an event which causes a notification to be + sent to update all linked Properties. A Property set + to another Property becomes a child of the value. + ''' + # Handle setting a parent Property + if isinstance(value, Property): + super().set(value) + self.handle(value.value) + elif isinstance(value, Trigger): + raise ValueError('Cannot set a Property to an Trigger') + else: + self.trigger(value) + + +########################################################################## +## Specific widgets +########################################################################## + +class Button(Widget): + def __init__(self, label='button'): + super().__init__() + self.click = Trigger() + self.label = Property(label) + def widget_js(self): + return ''' + element.addEventListener('click', (e) => { + model.trigger('click'); + }) + model.on('label', (v) => { + element.value = v; + }) + ''' + def widget_html(self): + return f''' + + ''' + +class Label(Widget): + def __init__(self, value=''): + super().__init__() + # databinding is defined using Property objects. + self.value = Property(value) + + def widget_js(self): + # Both "model" and "element" objects are defined within the scope + # where the js is run. "element" looks for the element with id + # self.view_id(); if widget_html is overridden, this id should be used. + return ''' + model.on('value', (value) => { + element.innerText = model.get('value'); + }); + ''' + def widget_html(self): + return f''' + + ''' + +class Textbox(Widget): + def __init__(self, value='', size=20): + super().__init__() + # databinding is defined using Property objects. + self.value = Property(value) + self.size = Property(size) + + def widget_js(self): + # Both "model" and "element" objects are defined within the scope + # where the js is run. "element" looks for the element with id + # self.view_id(); if widget_html is overridden, this id should be used. + return ''' + element.value = model.get('value'); + element.size = model.get('size'); + element.addEventListener('keydown', (e) => { + if (e.code == 'Enter') { + model.set('value', element.value); + } + }); + model.on('value', (value) => { + element.value = model.get('value'); + }); + model.on('size', (value) => { + element.size = model.get('size'); + }); + ''' + def widget_html(self): + return f''' + + ''' + +class Range(Widget): + def __init__(self, value=50, min=0, max=100): + super().__init__() + # databinding is defined using Property objects. + self.value = Property(value) + self.min = Property(min) + self.max = Property(max) + + def widget_js(self): + # Note that the 'input' event would enable during-drag feedback, + # but this is pretty slow on google colab. + return ''' + element.addEventListener('change', (e) => { + model.set('value', element.value); + }); + model.on('value', (value) => { + if (!element.matches(':active')) { + element.value = value; + } + }) + ''' + def widget_html(self): + return f''' + + ''' + +class Choice(Widget): + def __init__(self, choices=None, selection=None, horizontal=False): + super().__init__() + if choices is None: + choices = [] + self.choices = Property(choices) + self.horizontal = Property(horizontal) + self.selection = Property(selection) + def widget_js(self): + # Note that the 'input' event would enable during-drag feedback, + # but this is pretty slow on google colab. + return ''' + function esc(unsafe) { + return unsafe.replace(/&/g, "&").replace(//g, ">").replace(/"/g, """); + } + function render() { + console.log('rendering'); + var lines = model.get('choices').map((c) => { + return '' + }); + element.innerHTML = lines.join(model.get('horizontal')?' ':'
'); + } + model.on('choices horizontal', render); + model.on('selection', (selection) => { + [...element.querySelectorAll('input')].forEach((e) => { + e.checked = (e.value == selection); + }) + }); + element.addEventListener('change', (e) => { + model.set('selection', element.choice.value); + }); + ''' + def widget_html(self): + radios = [ + f"""""" + for value in self.choices ] + sep = " " if self.horizontal else "
" + return f'
{sep.join(radios)}
' + + +class Div(Widget): + def __init__(self, innerHTML='', style=None, data=None): + super().__init__() + style = {} if style is None else style + data = {} if data is None else data + # TODO: convert non-string innerHTML objects to html; unify + # with the show() library. + self.innerHTML = Property(innerHTML) + self.style = Property(style) + self.click = Trigger() + + def print(self, text, replace=False): + newHTML = '
%s
' % html.escape(str(text)); + if replace: + self.innerHTML = newHTML + else: + self.innerHTML += newHTML + + def widget_js(self): + # Note that if we want innerHTML to support script execution, + # we need to do it explicitly, like this. + return ''' + function updater(attr) { + return (val) => { for (k in val) { element[attr][k] = val[k]; } } + } + model.on('style', updater('style')); + updater('style')(); + model.on('data', updater('dataset')); + updater('dataset')(); + model.on('innerHTML', (innerHTML) => { + element.innerHTML = innerHTML; + Array.from(element.querySelectorAll("script")).forEach(old=>{ + const newScript = document.createElement("script"); + Array.from(old.attributes).forEach(attr => + newScript.setAttribute(attr.name, attr.value)); + newScript.appendChild(document.createTextNode(old.innerHTML)); + old.parentNode.replaceChild(newScript, old); + }); + }); + ''' + def widget_html(self): + return f''' +
{self.innerHTML}
+ ''' + +class ClickDiv(Div): + ''' + A Div that triggers click events when anything inside them is clicked. + If a clicked element contains a data-click value, then that value is + sent as the click event value. + ''' + def __init__(self, innerHTML='', style=None, data=None): + super().__init__(innertHTML, style, data) + self.click = Trigger() + + def widget_js(self): + return super().widget_js() + ''' + element.addEventListener('click', (ev) => { + var target = ev.target; + while (target && target != element && !target.dataset.click) { + target = target.parentElement; + } + var value = target.dataset.click; + model.trigger('click', value); + }); + ''' + +########################################################################## +## Implementation Details +########################################################################## + +WIDGET_ENV = None +if WIDGET_ENV is None: + try: + from google.colab import output as colab_output + WIDGET_ENV = 'colab' + except: + pass +if WIDGET_ENV is None: + try: + from ipykernel.comm import Comm as jupyter_comm + COMM_MANAGER = get_ipython().kernel.comm_manager + WIDGET_ENV = 'jupyter' + except: + pass + +SEND_RECV_JS = """ +function recvFromPython(obj_id, fn) { + var recvname = "recv_" + obj_id; + if (window[recvname] === undefined) { + window[recvname] = new BroadcastChannel("channel_" + obj_id); + } + window[recvname].addEventListener("message", (ev) => { + if (ev.data == 'ok') { + window[recvname].ok = true; + return; + } + fn.apply(null, ev.data.slice(1)); + }); +} +function sendToPython(obj_id, ...args) { + google.colab.kernel.invokeFunction('invoke_' + obj_id, args, {}) +} +""" if WIDGET_ENV == 'colab' else """ +function getChan(obj_id) { + var cname = "comm_" + obj_id; + if (!window[cname]) { window[cname] = []; } + var chan = window[cname]; + if (!chan.comm && Jupyter.notebook.kernel) { + chan.comm = Jupyter.notebook.kernel.comm_manager.new_comm(cname, {}); + chan.comm.on_msg((ev) => { + if (chan.retry) { clearInterval(chan.retry); chan.retry = null; } + if (ev.content.data == 'ok') { return; } + var args = ev.content.data.slice(1); + for (fn of chan) { fn.apply(null, args); } + }); + chan.retry = setInterval(() => { chan.comm.open(); }, 2000); + } + return chan; +} +function recvFromPython(obj_id, fn) { + getChan(obj_id).push(fn); +} +function sendToPython(obj_id, ...args) { + var comm = getChan(obj_id).comm; + if (comm) { comm.send(args); } +} +""" + + +WIDGET_MODEL_JS = SEND_RECV_JS + """ +class Model { + constructor(obj_id, init) { + this._id = obj_id; + this._listeners = {}; + this._data = Object.assign({}, init) + recvFromPython(this._id, (name, value) => { + this._data[name] = value; + if (this._listeners.hasOwnProperty(name)) { + this._listeners[name].forEach((fn) => { fn(value); }); + } + }) + } + trigger(name, value) { + sendToPython(this._id, name, value); + } + get(name) { + return this._data[name]; + } + set(name, value) { + this.trigger(name, value); + } + on(name, fn) { + name.split(/\s+/).forEach((n) => { + if (!this._listeners.hasOwnProperty(n)) { + this._listeners[n] = []; + } + this._listeners[n].push(fn); + }); + } + off(name, fn) { + name.split(/\s+/).forEach((n) => { + if (!fn) { + delete this._listeners[n]; + } else if (this._listeners.hasOwnProperty(n)) { + this._listeners[n] = this._listeners[n].filter( + (e) => { return e !== fn; }); + } + }); + } +} +""" diff --git a/torchkit/nethook.py b/torchkit/nethook.py new file mode 100644 index 0000000..12177ca --- /dev/null +++ b/torchkit/nethook.py @@ -0,0 +1,451 @@ +""" +Utilities for instrumenting a torch model. + +Trace will hook one layer at a time. +TraceDict will hook multiple layers at once. +subsequence slices intervals from Sequential modules. +get_module, replace_module, get_parameter resolve dotted names. +set_requires_grad recursively sets requires_grad in module parameters. +""" + +import contextlib +import copy +import inspect +from collections import OrderedDict + +import torch + + +class Trace(contextlib.AbstractContextManager): + """ + To retain the output of the named layer during the computation of + the given network: + + with Trace(net, 'layer.name') as ret: + _ = net(inp) + representation = ret.output + + A layer module can be passed directly without a layer name, and + its output will be retained. By default, a direct reference to + the output object is returned, but options can control this: + + clone=True - retains a copy of the output, which can be + useful if you want to see the output before it might + be modified by the network in-place later. + detach=True - retains a detached reference or copy. (By + default the value would be left attached to the graph.) + retain_grad=True - request gradient to be retained on the + output. After backward(), ret.output.grad is populated. + + retain_input=True - also retains the input. + retain_output=False - can disable retaining the output. + edit_output=fn - calls the function to modify the output + of the layer before passing it the rest of the model. + fn can optionally accept (output, layer) arguments + for the original output and the layer name. + stop=True - throws a StopForward exception after the layer + is run, which allows running just a portion of a model. + """ + + def __init__( + self, + module, + layer=None, + retain_output=True, + retain_input=False, + clone=False, + detach=False, + retain_grad=False, + edit_output=None, + stop=False, + ): + """ + Method to replace a forward method with a closure that + intercepts the call, and tracks the hook so that it can be reverted. + """ + retainer = self + self.layer = layer + if layer is not None: + module = get_module(module, layer) + + def retain_hook(m, inputs, output): + if retain_input: + retainer.input = recursive_copy( + inputs[0] if len(inputs) == 1 else inputs, + clone=clone, + detach=detach, + retain_grad=False, + ) # retain_grad applies to output only. + if edit_output: + output = invoke_with_optional_args( + edit_output, output=output, layer=self.layer + ) + if retain_output: + retainer.output = recursive_copy( + output, clone=clone, detach=detach, retain_grad=retain_grad + ) + # When retain_grad is set, also insert a trivial + # copy operation. That allows in-place operations + # to follow without error. + if retain_grad: + output = recursive_copy(retainer.output, clone=True, detach=False) + if stop: + raise StopForward() + return output + + self.registered_hook = module.register_forward_hook(retain_hook) + self.stop = stop + + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + self.close() + if self.stop and issubclass(type, StopForward): + return True + + def close(self): + self.registered_hook.remove() + + +class TraceDict(OrderedDict, contextlib.AbstractContextManager): + """ + To retain the output of multiple named layers during the computation + of the given network: + + with TraceDict(net, ['layer1.name1', 'layer2.name2']) as ret: + _ = net(inp) + representation = ret['layer1.name1'].output + + If edit_output is provided, it should be a function that takes + two arguments: output, and the layer name; and then it returns the + modified output. + + Other arguments are the same as Trace. If stop is True, then the + execution of the network will be stopped after the last layer + listed (even if it would not have been the last to be executed). + """ + + def __init__( + self, + module, + layers=None, + retain_output=True, + retain_input=False, + clone=False, + detach=False, + retain_grad=False, + edit_output=None, + stop=False, + ): + self.stop = stop + + def flag_last_unseen(it): + try: + it = iter(it) + prev = next(it) + seen = set([prev]) + except StopIteration: + return + for item in it: + if item not in seen: + yield False, prev + seen.add(item) + prev = item + yield True, prev + + for is_last, layer in flag_last_unseen(layers): + self[layer] = Trace( + module=module, + layer=layer, + retain_output=retain_output, + retain_input=retain_input, + clone=clone, + detach=detach, + retain_grad=retain_grad, + edit_output=edit_output, + stop=stop and is_last, + ) + + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + self.close() + if self.stop and issubclass(type, StopForward): + return True + + def close(self): + for layer, trace in reversed(self.items()): + trace.close() + + +class StopForward(Exception): + """ + If the only output needed from running a network is the retained + submodule then Trace(submodule, stop=True) will stop execution + immediately after the retained submodule by raising the StopForward() + exception. When Trace is used as context manager, it catches that + exception and can be used as follows: + + with Trace(net, layername, stop=True) as tr: + net(inp) # Only runs the network up to layername + print(tr.output) + """ + + pass + + +def recursive_copy(x, clone=None, detach=None, retain_grad=None): + """ + Copies a reference to a tensor, or an object that contains tensors, + optionally detaching and cloning the tensor(s). If retain_grad is + true, the original tensors are marked to have grads retained. + """ + if not clone and not detach and not retain_grad: + return x + if isinstance(x, torch.Tensor): + if retain_grad: + if not x.requires_grad: + x.requires_grad = True + x.retain_grad() + elif detach: + x = x.detach() + if clone: + x = x.clone() + return x + # Only dicts, lists, and tuples (and subclasses) can be copied. + if isinstance(x, dict): + return type(x)({k: recursive_copy(v) for k, v in x.items()}) + elif isinstance(x, (list, tuple)): + return type(x)([recursive_copy(v) for v in x]) + else: + assert False, f"Unknown type {type(x)} cannot be broken into tensors." + + +def subsequence( + sequential, + first_layer=None, + last_layer=None, + after_layer=None, + upto_layer=None, + single_layer=None, + share_weights=False, +): + """ + Creates a subsequence of a pytorch Sequential model, copying over + modules together with parameters for the subsequence. Only + modules from first_layer to last_layer (inclusive) are included, + or modules between after_layer and upto_layer (exclusive). + Handles descent into dotted layer names as long as all references + are within nested Sequential models. + + If share_weights is True, then references the original modules + and their parameters without copying them. Otherwise, by default, + makes a separate brand-new copy. + """ + assert (single_layer is None) or ( + first_layer is last_layer is after_layer is upto_layer is None + ) + if single_layer is not None: + first_layer = single_layer + last_layer = single_layer + first, last, after, upto = [ + None if d is None else d.split(".") + for d in [first_layer, last_layer, after_layer, upto_layer] + ] + return hierarchical_subsequence( + sequential, + first=first, + last=last, + after=after, + upto=upto, + share_weights=share_weights, + ) + + +def hierarchical_subsequence( + sequential, first, last, after, upto, share_weights=False, depth=0 +): + """ + Recursive helper for subsequence() to support descent into dotted + layer names. In this helper, first, last, after, and upto are + arrays of names resulting from splitting on dots. Can only + descend into nested Sequentials. + """ + assert (last is None) or (upto is None) + assert (first is None) or (after is None) + if first is last is after is upto is None: + return sequential if share_weights else copy.deepcopy(sequential) + assert isinstance(sequential, torch.nn.Sequential), ( + ".".join((first or last or after or upto)[:depth] or "arg") + " not Sequential" + ) + including_children = (first is None) and (after is None) + included_children = OrderedDict() + # A = current level short name of A. + # AN = full name for recursive descent if not innermost. + (F, FN), (L, LN), (A, AN), (U, UN) = [ + (d[depth], (None if len(d) == depth + 1 else d)) + if d is not None + else (None, None) + for d in [first, last, after, upto] + ] + for name, layer in sequential._modules.items(): + if name == F: + first = None + including_children = True + if name == A and AN is not None: # just like F if not a leaf. + after = None + including_children = True + if name == U and UN is None: + upto = None + including_children = False + if including_children: + # AR = full name for recursive descent if name matches. + FR, LR, AR, UR = [ + n if n is None or n[depth] == name else None for n in [FN, LN, AN, UN] + ] + chosen = hierarchical_subsequence( + layer, + first=FR, + last=LR, + after=AR, + upto=UR, + share_weights=share_weights, + depth=depth + 1, + ) + if chosen is not None: + included_children[name] = chosen + if name == L: + last = None + including_children = False + if name == U and UN is not None: # just like L if not a leaf. + upto = None + including_children = False + if name == A and AN is None: + after = None + including_children = True + for name in [first, last, after, upto]: + if name is not None: + raise ValueError("Layer %s not found" % ".".join(name)) + # Omit empty subsequences except at the outermost level, + # where we should not return None. + if not len(included_children) and depth > 0: + return None + result = torch.nn.Sequential(included_children) + result.training = sequential.training + return result + + +def set_requires_grad(requires_grad, *models): + """ + Sets requires_grad true or false for all parameters within the + models passed. + """ + for model in models: + if isinstance(model, torch.nn.Module): + for param in model.parameters(): + param.requires_grad = requires_grad + elif isinstance(model, (torch.nn.Parameter, torch.Tensor)): + model.requires_grad = requires_grad + else: + assert False, "unknown type %r" % type(model) + + +def get_module(model, name): + """ + Finds the named module within the given model. + """ + for n, m in model.named_modules(): + if n == name: + return m + raise LookupError(name) + + +def get_parameter(model, name): + """ + Finds the named parameter within the given model. + """ + for n, p in model.named_parameters(): + if n == name: + return p + raise LookupError(name) + + +def replace_module(model, name, new_module): + """ + Replaces the named module within the given model. + """ + if "." in name: + parent_name, attr_name = name.rsplit(".", 1) + model = get_module(model, parent_name) + # original_module = getattr(model, attr_name) + setattr(model, attr_name, new_module) + + +def invoke_with_optional_args(fn, *args, **kwargs): + """ + Invokes a function with only the arguments that it + is written to accept, giving priority to arguments + that match by-name, using the following rules. + (1) arguments with matching names are passed by name. + (2) remaining non-name-matched args are passed by order. + (3) extra caller arguments that the function cannot + accept are not passed. + (4) extra required function arguments that the caller + cannot provide cause a TypeError to be raised. + Ordinary python calling conventions are helpful for + supporting a function that might be revised to accept + extra arguments in a newer version, without requiring the + caller to pass those new arguments. This function helps + support function callers that might be revised to supply + extra arguments, without requiring the callee to accept + those new arguments. + """ + argspec = inspect.getfullargspec(fn) + pass_args = [] + used_kw = set() + unmatched_pos = [] + used_pos = 0 + defaulted_pos = len(argspec.args) - ( + 0 if not argspec.defaults else len(argspec.defaults) + ) + # Pass positional args that match name first, then by position. + for i, n in enumerate(argspec.args): + if n in kwargs: + pass_args.append(kwargs[n]) + used_kw.add(n) + elif used_pos < len(args): + pass_args.append(args[used_pos]) + used_pos += 1 + else: + unmatched_pos.append(len(pass_args)) + pass_args.append( + None if i < defaulted_pos else argspec.defaults[i - defaulted_pos] + ) + # Fill unmatched positional args with unmatched keyword args in order. + if len(unmatched_pos): + for k, v in kwargs.items(): + if k in used_kw or k in argspec.kwonlyargs: + continue + pass_args[unmatched_pos[0]] = v + used_kw.add(k) + unmatched_pos = unmatched_pos[1:] + if len(unmatched_pos) == 0: + break + else: + if unmatched_pos[0] < defaulted_pos: + unpassed = ", ".join( + argspec.args[u] for u in unmatched_pos if u < defaulted_pos + ) + raise TypeError(f"{fn.__name__}() cannot be passed {unpassed}.") + # Pass remaining kw args if they can be accepted. + pass_kw = { + k: v + for k, v in kwargs.items() + if k not in used_kw and (k in argspec.kwonlyargs or argspec.varargs is not None) + } + # Pass remaining positional args if they can be accepted. + if argspec.varargs is not None: + pass_args += list(args[used_pos:]) + return fn(*pass_args, **pass_kw) diff --git a/torchkit/paintwidget.py b/torchkit/paintwidget.py new file mode 100644 index 0000000..cef3744 --- /dev/null +++ b/torchkit/paintwidget.py @@ -0,0 +1,147 @@ +from .labwidget import Widget, Property +import html + +class PaintWidget(Widget): + def __init__(self, + width=256, height=256, + image='', mask='', brushsize=10.0, oneshot=False, disabled=False): + super().__init__() + self.mask = Property(mask) + self.image = Property(image) + self.brushsize = Property(brushsize) + self.erase = Property(False) + self.oneshot = Property(oneshot) + self.disabled = Property(disabled) + self.width = Property(width) + self.height = Property(height) + + def widget_js(self): + return f''' + {PAINT_WIDGET_JS} + var pw = new PaintWidget(element, model); + ''' + def widget_html(self): + v = self.view_id() + return f''' + +
+ ''' + +PAINT_WIDGET_JS = """ +class PaintWidget { + constructor(el, model) { + this.el = el; + this.model = model; + this.size_changed(); + this.model.on('mask', this.mask_changed.bind(this)); + this.model.on('image', this.image_changed.bind(this)); + this.model.on('width', this.size_changed.bind(this)); + this.model.on('height', this.size_changed.bind(this)); + } + mouse_stroke(first_event) { + var self = this; + if (self.model.get('disabled')) { return; } + if (self.model.get('oneshot')) { + var canvas = self.mask_canvas; + var ctx = canvas.getContext('2d'); + ctx.clearRect(0, 0, canvas.width, canvas.height); + } + function track_mouse(evt) { + if (evt.type == 'keydown' || self.model.get('disabled')) { + if (self.model.get('disabled') || evt.key === "Escape") { + window.removeEventListener('mousemove', track_mouse); + window.removeEventListener('mouseup', track_mouse); + window.removeEventListener('keydown', track_mouse, true); + self.mask_changed(); + } + return; + } + if (evt.type == 'mouseup' || + (typeof evt.buttons != 'undefined' && evt.buttons == 0)) { + window.removeEventListener('mousemove', track_mouse); + window.removeEventListener('mouseup', track_mouse); + window.removeEventListener('keydown', track_mouse, true); + self.model.set('mask', self.mask_canvas.toDataURL()); + return; + } + var p = self.cursor_position(); + self.fill_circle(p.x, p.y, + self.model.get('brushsize'), + self.model.get('erase')); + } + this.mask_canvas.focus(); + window.addEventListener('mousemove', track_mouse); + window.addEventListener('mouseup', track_mouse); + window.addEventListener('keydown', track_mouse, true); + track_mouse(first_event); + } + mask_changed(val) { + this.draw_data_url(this.mask_canvas, this.model.get('mask')); + } + image_changed() { + this.draw_data_url(this.image_canvas, this.model.get('image')); + } + size_changed() { + this.mask_canvas = document.createElement('canvas'); + this.image_canvas = document.createElement('canvas'); + this.mask_canvas.className = "paintmask"; + this.image_canvas.className = "paintimage"; + for (var attr of ['width', 'height']) { + this.mask_canvas[attr] = this.model.get(attr); + this.image_canvas[attr] = this.model.get(attr); + } + + this.el.innerHTML = ''; + this.el.appendChild(this.image_canvas); + this.el.appendChild(this.mask_canvas); + this.mask_canvas.addEventListener('mousedown', + this.mouse_stroke.bind(this)); + this.mask_changed(); + this.image_changed(); + } + + cursor_position(evt) { + const rect = this.mask_canvas.getBoundingClientRect(); + const x = event.clientX - rect.left; + const y = event.clientY - rect.top; + return {x: x, y: y}; + } + + fill_circle(x, y, r, erase, blur) { + var ctx = this.mask_canvas.getContext('2d'); + ctx.save(); + if (blur) { + ctx.filter = 'blur(' + blur + 'px)'; + } + ctx.globalCompositeOperation = ( + erase ? "destination-out" : 'source-over'); + ctx.fillStyle = '#fff'; + ctx.beginPath(); + ctx.arc(x, y, r, 0, 2 * Math.PI); + ctx.fill(); + ctx.restore() + } + + draw_data_url(canvas, durl) { + var ctx = canvas.getContext('2d'); + var img = new Image; + canvas.pendingImg = img; + function imgdone() { + if (canvas.pendingImg == img) { + ctx.clearRect(0, 0, canvas.width, canvas.height); + ctx.drawImage(img, 0, 0); + canvas.pendingImg = null; + } + } + img.addEventListener('load', imgdone); + img.addEventListener('error', imgdone); + img.src = durl; + } +} +""" diff --git a/torchkit/pbar.py b/torchkit/pbar.py new file mode 100644 index 0000000..79264ee --- /dev/null +++ b/torchkit/pbar.py @@ -0,0 +1,212 @@ +''' +Utilities for showing progress bars, controlling default verbosity, etc. +''' + +# If the tqdm package is not available, then do not show progress bars; +# just connect print_progress to print. +import sys +import types +import builtins +try: + from tqdm import tqdm + try: + from tqdm.notebook import tqdm as tqdm_nb + except: + from tqdm import tqdm_notebook as tqdm_nb +except: + tqdm = None + +default_verbosity = True +next_description = None +python_print = builtins.print + + +def post(**kwargs): + ''' + When within a progress loop, pbar.post(k=str) will display + the given k=str status on the right-hand-side of the progress + status bar. If not within a visible progress bar, does nothing. + ''' + innermost = innermost_tqdm() + if innermost is not None: + innermost.set_postfix(**kwargs) + + +def desc(desc): + ''' + When within a progress loop, pbar.desc(str) changes the + left-hand-side description of the loop toe the given description. + ''' + innermost = innermost_tqdm() + if innermost is not None: + innermost.set_description(str(desc)) + + +def descnext(desc): + ''' + Called before starting a progress loop, pbar.descnext(str) + sets the description text that will be used in the following loop. + ''' + global next_description + if not default_verbosity or tqdm is None: + return + next_description = desc + + +def print(*args): + ''' + When within a progress loop, will print above the progress loop. + ''' + global next_description + next_description = None + if default_verbosity: + msg = ' '.join(str(s) for s in args) + if tqdm is None: + python_print(msg) + else: + tqdm.write(msg) + + +def tqdm_terminal(it, *args, **kwargs): + ''' + Some settings for tqdm that make it run better in resizable terminals. + ''' + return tqdm(it, *args, dynamic_ncols=True, ascii=True, + leave=(innermost_tqdm() is not None), **kwargs) + + +def in_notebook(): + ''' + True if running inside a Jupyter notebook. + ''' + # From https://stackoverflow.com/a/39662359/265298 + try: + shell = get_ipython().__class__.__name__ + if shell == 'ZMQInteractiveShell': + return True # Jupyter notebook or qtconsole + elif shell == 'TerminalInteractiveShell': + return False # Terminal running IPython + else: + return False # Other type (?) + except NameError: + return False # Probably standard Python interpreter + + +def innermost_tqdm(): + ''' + Returns the innermost active tqdm progress loop on the stack. + ''' + if hasattr(tqdm, '_instances') and len(tqdm._instances) > 0: + return max(tqdm._instances, key=lambda x: x.pos) + else: + return None + + +def reporthook(*args, **kwargs): + ''' + For use with urllib.request.urlretrieve. + + with pbar.reporthook() as hook: + urllib.request.urlretrieve(url, filename, reporthook=hook) + ''' + kwargs2 = dict(unit_scale=True, miniters=1) + kwargs2.update(kwargs) + bar = __call__(None, *args, **kwargs2) + + class ReportHook(object): + def __init__(self, t): + self.t = t + + def __call__(self, b=1, bsize=1, tsize=None): + if hasattr(self.t, 'total'): + if tsize is not None: + self.t.total = tsize + if hasattr(self.t, 'update'): + self.t.update(b * bsize - self.t.n) + + def __enter__(self): + return self + + def __exit__(self, *exc): + if hasattr(self.t, '__exit__'): + self.t.__exit__(*exc) + return ReportHook(bar) + + +def __call__(x, *args, **kwargs): + ''' + Invokes a progress function that can wrap iterators to print + progress messages, if verbose is True. + + If verbose is False or tqdm is unavailable, then a quiet + non-printing identity function is used. + + verbose can also be set to a spefific progress function rather + than True, and that function will be used. + ''' + global default_verbosity, next_description + if not default_verbosity or tqdm is None: + return x + if default_verbosity == True: + fn = tqdm_nb if in_notebook() else tqdm_terminal + else: + fn = default_verbosity + if next_description is not None: + kwargs = dict(kwargs) + kwargs['desc'] = next_description + next_description = None + return fn(x, *args, **kwargs) + + +class VerboseContextManager(): + def __init__(self, v, entered=False): + self.v, self.entered, self.saved = v, False, [] + if entered: + self.__enter__() + self.entered = True + + def __enter__(self): + global default_verbosity + if self.entered: + self.entered = False + else: + self.saved.append(default_verbosity) + default_verbosity = self.v + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + global default_verbosity + default_verbosity = self.saved.pop() + + def __call__(self, v=True): + ''' + Calling the context manager makes a new context that is + pre-entered, so it works as both a plain function and as a + factory for a context manager. + ''' + new_v = v if self.v else not v + cm = VerboseContextManager(new_v, entered=True) + default_verbosity = new_v + return cm + + +# Use as either "with pbar.verbose:" or "pbar.verbose(False)", or also +# "with pbar.verbose(False):" +verbose = VerboseContextManager(True) + +# Use as either "with @pbar.quiet" or "pbar.quiet(True)". or also +# "with pbar.quiet(True):" +quiet = VerboseContextManager(False) + + +class CallableModule(types.ModuleType): + def __init__(self): + # or super().__init__(__name__) for Python 3 + types.ModuleType.__init__(self, __name__) + self.__dict__.update(sys.modules[__name__].__dict__) + + def __call__(self, x, *args, **kwargs): + return __call__(x, *args, **kwargs) + + +sys.modules[__name__] = CallableModule() diff --git a/torchkit/pidfile.py b/torchkit/pidfile.py new file mode 100644 index 0000000..3224b4d --- /dev/null +++ b/torchkit/pidfile.py @@ -0,0 +1,129 @@ +''' +Utility for simple distribution of work on multiple processes, by +making sure only one process is working on a job at once. +''' + +import os +import errno +import socket +import atexit +import time +import sys + + +def reserve_dir(*args): + ''' + Convenience function to get exclusive access to an unfinished + experiment directory. Exits the program if the directory is + already done or busy (using exit_of_job_done). Otherwise, + returns a function creates filenames within that directory. + ''' + directory = os.path.join(*args) + exit_if_job_done(directory) + + def dirfn(*fn): + return os.path.join(directory, *fn) + dirfn.dir = directory + + def done(): + mark_job_done(directory) + dirfn.done = done + print('Working in %s' % directory) + return dirfn + + +# Old function name. +exclusive_dirfn = reserve_dir + + +def exit_if_job_done(directory, redo=False, force=False, verbose=True): + if pidfile_taken(os.path.join(directory, 'lockfile.pid'), + force=force, verbose=verbose): + sys.exit(0) + donefile = os.path.join(directory, 'done.txt') + if os.path.isfile(donefile): + with open(donefile) as f: + msg = f.read() + if redo or force: + if verbose: + print('Removing %s %s' % (donefile, msg)) + os.remove(donefile) + else: + if verbose: + print('%s %s' % (donefile, msg)) + sys.exit(0) + + +def mark_job_done(directory): + with open(os.path.join(directory, 'done.txt'), 'w') as f: + f.write('done by %d@%s %s at %s' % + (os.getpid(), socket.gethostname(), + os.getenv('STY', ''), + time.strftime('%c'))) + + +def pidfile_taken(path, verbose=False, force=False): + ''' + Usage. To grab an exclusive lock for the remaining duration of the + current process (and exit if another process already has the lock), + do this: + + if pidfile_taken('job_423/lockfile.pid', verbose=True): + sys.exit(0) + + To do a batch of jobs, just run a script that does them all on + each available machine, sharing a network filesystem. When each + job grabs a lock, then this will automatically distribute the + jobs so that each one is done just once on one machine. + ''' + + # Try to create the file exclusively and write my pid into it. + try: + os.makedirs(os.path.dirname(path), exist_ok=True) + fd = os.open(path, os.O_CREAT | os.O_EXCL | os.O_RDWR) + except OSError as e: + if e.errno == errno.EEXIST: + # If we cannot because there was a race, yield the conflicter. + conflicter = 'race' + try: + with open(path, 'r') as lockfile: + conflicter = lockfile.read().strip() or 'empty' + except: + pass + # Force is for manual one-time use, for deleting stale lockfiles. + if force: + if verbose: + print('Removing %s from %s' % (path, conflicter)) + os.remove(path) + return pidfile_taken(path, verbose=verbose, force=False) + if verbose: + print('%s held by %s' % (path, conflicter)) + return conflicter + else: + # Other problems get an exception. + raise + # Register to delete this file on exit. + lockfile = os.fdopen(fd, 'r+') + atexit.register(delete_pidfile, lockfile, path) + # Write my pid into the open file. + lockfile.write('%d@%s %s\n' % (os.getpid(), socket.gethostname(), + os.getenv('STY', ''))) + lockfile.flush() + os.fsync(lockfile) + # Return 'None' to say there was not a conflict. + return None + + +def delete_pidfile(lockfile, path): + ''' + Runs at exit after pidfile_taken succeeds. + ''' + if lockfile is not None: + try: + lockfile.close() + except: + pass + try: + os.unlink(path) + except: + pass diff --git a/torchkit/runningstats.py b/torchkit/runningstats.py new file mode 100644 index 0000000..5cc575c --- /dev/null +++ b/torchkit/runningstats.py @@ -0,0 +1,1871 @@ +""" +To use a runningstats object, + + 1. Create the the desired stat object, e.g., `m = Mean()` + 2. Feed it batches via the add method, e.g., `m.add(batch)` + 3. Repeat step 2 any number of times. + 4. Read out the statistic of interest, e.g., `m.mean()` + +Built-in runningstats objects include: + + Mean - produces mean(). + Variance - mean() and variance() and stdev(). + Covariance - mean(), covariance(), correlation(), variance(), stdev(). + SecondMoment - moment() is the non-mean-centered covariance, E[x x^T]. + Quantile - quantile(), min(), max(), median(), mean(), variance(), stdev(). + TopK - topk() returns (values, indexes). + Bincount - bincount() histograms nonnegative integer data. + IoU - intersection(), union(), iou() tally binary co-occurrences. + History - history() returns concatenation of data. + CrossCovariance - covariance between two signals, without self-covariance. + CrossIoU - iou between two signals, without self-IoU. + CombinedStat - aggregates any set of stats. + +Add more running stats by subclassing the Stat class. + +These statistics are vectorized along dim>=1, so stat.add() +should supply a two-dimensional input where the zeroth +dimension is the batch/sampling dimension and the first +dimension is the feature dimension. + +The data type and device used matches the data passed to add(); +for example, for higher-precision covariances, convert to double +before calling add(). + +It is common to want to compute and remember a statistic sampled +over a Dataset, computed in batches, possibly caching the computed +statistic in a file. The tally(stat, dataset, cache) handles +this pattern. It takes a statistic, a dataset, and a cache filename +and sets up a data loader that can be run (or not, if cached) to +compute the statistic, adopting the convention that cached stats are +saved to and loaded from numpy npz files. +""" + +import math +import os +import random +import struct + +import numpy +import torch +from torch.utils.data.sampler import Sampler + + +def tally(stat, dataset, cache=None, quiet=False, **kwargs): + """ + To use tally, write code like the following. + + stat = Mean() + ds = MyDataset() + for batch in tally(stat, ds, cache='mymean.npz', batch_size=50): + stat.add(batch) + mean = stat.mean() + + The first argument should be the Stat being computed. After the + loader is exhausted, tally will bring this stat to the cpu and + cache it (if a cache is specified). + + The dataset can be a torch Dataset or a plain Tensor, or it can + be a callable that returns one of those. + + Details on caching via the cache= argument: + + If the given filename cannot be loaded, tally will leave the + statistic object empty and set up a DataLoader object so that + the loop can be run. After the last iteration of the loop, the + completed statistic will be moved to the cpu device and also + saved in the cache file. + + If the cached statistic can be loaded from the given file, tally + will not set up the data loader and instead will return a fully + loaded statistic object (on the cpu device) and an empty list as + the loader. + + The `with cache_load_enabled(False):` context manager can + be used to disable loading from the cache. + + If needed, a DataLoader will be created to wrap the dataset: + + Keyword arguments of tally are passed to the DataLoader, + so batch_size, num_workers, pin_memory, etc. can be specified. + + Subsampling is supported via sample_size= and random_sample=: + + If sample_size=N is specified, rather than loading the whole + dataset, only the first N items are sampled. If additionally + random_sample=S is specified, the pseudorandom seed S will be + used to select a fixed psedorandom sample of size N to sample. + """ + assert isinstance(stat, Stat) + args = {} + for k in ["sample_size"]: + if k in kwargs: + args[k] = kwargs[k] + cached_state = load_cached_state(cache, args, quiet=quiet) + if cached_state is not None: + stat.load_state_dict(cached_state) + + def empty_loader(): + return + yield + + return empty_loader() + loader = make_loader(dataset, **kwargs) + + def wrapped_loader(): + yield from loader + stat.to_(device="cpu") + if cache is not None: + save_cached_state(cache, stat, args) + + return wrapped_loader() + + +class cache_load_enabled: + """ + When used as a context manager, cache_load_enabled(False) will prevent + tally from loading cached statsitics, forcing them to be recomputed. + """ + + def __init__(self, enabled=True): + self.prev = False + self.enabled = enabled + + def __enter__(self): + global global_load_cache_enabled + self.prev = global_load_cache_enabled + global_load_cache_enabled = self.enabled + + def __exit__(self, exc_type, exc_value, traceback): + global global_load_cache_enabled + global_load_cache_enabled = self.prev + + +class Stat: + """ + Abstract base class for a running pytorch statistic. + """ + + def __init__(self, state): + """ + By convention, all Stat subclasses can be initialized by passing + state=; and then they will initialize by calling load_state_dict. + """ + self.load_state_dict(resolve_state_dict(state)) + + def add(self, x, *args, **kwargs): + """ + Observes a batch of samples to be incorporated into the statistic. + Dimension 0 should be the batch dimension, and dimension 1 should + be the feature dimension of the pytorch tensor x. + """ + pass + + def load_state_dict(self, d): + """ + Loads this Stat from a dictionary of numpy arrays as saved + by state_dict. + """ + pass + + def state_dict(self): + """ + Saves this Stat as a dictionary of numpy arrays that can be + stored in an npz or reloaded later using load_state_dict. + """ + return {} + + def save(self, filename): + """ + Saves this stat as an npz file containing the state_dict. + """ + save_cached_state(filename, self, {}) + + def load(self, filename): + """ + Loads this stat from an npz file containing a saved state_dict. + """ + self.load_state_dict(load_cached_state(filename, {}, quiet=True, throw=True)) + + def to_(self, device): + """ + Moves this Stat to the given device. + """ + pass + + def cpu_(self): + """ + Moves this Stat to the cpu device. + """ + self.to_("cpu") + + def cuda_(self): + """ + Moves this Stat to the default cuda device. + """ + self.to_("cuda") + + def _normalize_add_shape(self, x, attr="data_shape"): + """ + Flattens input data to 2d. + """ + if not torch.is_tensor(x): + x = torch.tensor(x) + if len(x.shape) < 1: + x = x.view(-1) + data_shape = getattr(self, attr, None) + if data_shape is None: + data_shape = x.shape[1:] + setattr(self, attr, data_shape) + else: + assert x.shape[1:] == data_shape + return x.view(x.shape[0], int(numpy.prod(data_shape))) + + def _restore_result_shape(self, x, attr="data_shape"): + """ + Restores output data to input data shape. + """ + data_shape = getattr(self, attr, None) + if data_shape is None: + return x + return x.view(data_shape * len(x.shape)) + + +class Mean(Stat): + """ + Running mean. + """ + + def __init__(self, state=None): + if state is not None: + return super().__init__(state) + self.count = 0 + self.batchcount = 0 + self._mean = None + self.data_shape = None + + def add(self, a): + a = self._normalize_add_shape(a) + if len(a) == 0: + return + batch_count = a.shape[0] + batch_mean = a.sum(0) / batch_count + self.batchcount += 1 + # Initial batch. + if self._mean is None: + self.count = batch_count + self._mean = batch_mean + return + # Update a batch using Chan-style update for numerical stability. + self.count += batch_count + new_frac = float(batch_count) / self.count + # Update the mean according to the batch deviation from the old mean. + delta = batch_mean.sub_(self._mean).mul_(new_frac) + self._mean.add_(delta) + + def size(self): + return self.count + + def mean(self): + return self._restore_result_shape(self._mean) + + def to_(self, device): + if self._mean is not None: + self._mean = self._mean.to(device) + + def load_state_dict(self, state): + self.count = state["count"] + self.batchcount = state["batchcount"] + self._mean = torch.from_numpy(state["mean"]) + self.data_shape = ( + None if state["data_shape"] is None else tuple(state["data_shape"]) + ) + + def state_dict(self): + return dict( + constructor=self.__module__ + "." + self.__class__.__name__ + "()", + count=self.count, + data_shape=self.data_shape and tuple(self.data_shape), + batchcount=self.batchcount, + mean=self._mean.cpu().numpy(), + ) + + +class Variance(Stat): + """ + Running computation of mean and variance. Use this when you just need + basic stats without covariance. + """ + + def __init__(self, state=None): + if state is not None: + return super().__init__(state) + self.count = 0 + self.batchcount = 0 + self._mean = None + self.v_cmom2 = None + self.data_shape = None + + def add(self, a): + a = self._normalize_add_shape(a) + if len(a) == 0: + return + batch_count = a.shape[0] + batch_mean = a.sum(0) / batch_count + centered = a - batch_mean + self.batchcount += 1 + # Initial batch. + if self._mean is None: + self.count = batch_count + self._mean = batch_mean + self.v_cmom2 = centered.pow(2).sum(0) + return + # Update a batch using Chan-style update for numerical stability. + oldcount = self.count + self.count += batch_count + new_frac = float(batch_count) / self.count + # Update the mean according to the batch deviation from the old mean. + delta = batch_mean.sub_(self._mean).mul_(new_frac) + self._mean.add_(delta) + # Update the variance using the batch deviation + self.v_cmom2.add_(centered.pow(2).sum(0)) + self.v_cmom2.add_(delta.pow_(2).mul_(new_frac * oldcount)) + + def size(self): + return self.count + + def mean(self): + return self._restore_result_shape(self._mean) + + def variance(self, unbiased=True): + return self._restore_result_shape( + self.v_cmom2 / (self.count - (1 if unbiased else 0)) + ) + + def stdev(self, unbiased=True): + return self.variance(unbiased=unbiased).sqrt() + + def to_(self, device): + if self._mean is not None: + self._mean = self._mean.to(device) + if self.v_cmom2 is not None: + self.v_cmom2 = self.v_cmom2.to(device) + + def load_state_dict(self, state): + self.count = state["count"] + self.batchcount = state["batchcount"] + self._mean = torch.from_numpy(state["mean"]) + self.v_cmom2 = torch.from_numpy(state["cmom2"]) + self.data_shape = ( + None if state["data_shape"] is None else tuple(state["data_shape"]) + ) + + def state_dict(self): + return dict( + constructor=self.__module__ + "." + self.__class__.__name__ + "()", + count=self.count, + data_shape=self.data_shape and tuple(self.data_shape), + batchcount=self.batchcount, + mean=self._mean.cpu().numpy(), + cmom2=self.v_cmom2.cpu().numpy(), + ) + + +class Covariance(Stat): + """ + Running computation. Use this when the entire covariance matrix is needed, + and when the whole covariance matrix fits in the GPU. + + Chan-style numerically stable update of mean and full covariance matrix. + Chan, Golub. LeVeque. 1983. http://www.jstor.org/stable/2683386 + """ + + def __init__(self, state=None): + if state is not None: + return super().__init__(state) + self.count = 0 + self._mean = None + self.cmom2 = None + self.data_shape = None + + def add(self, a): + a = self._normalize_add_shape(a) + if len(a) == 0: + return + batch_count = a.shape[0] + # Initial batch. + if self._mean is None: + self.count = batch_count + self._mean = a.sum(0) / batch_count + centered = a - self._mean + self.cmom2 = centered.t().mm(centered) + return + # Update a batch using Chan-style update for numerical stability. + self.count += batch_count + # Update the mean according to the batch deviation from the old mean. + delta = a - self._mean + self._mean.add_(delta.sum(0) / self.count) + delta2 = a - self._mean + # Update the variance using the batch deviation + self.cmom2.addmm_(mat1=delta.t(), mat2=delta2) + + def to_(self, device): + if self._mean is not None: + self._mean = self._mean.to(device) + if self.cmom2 is not None: + self.cmom2 = self.cmom2.to(device) + + def mean(self): + return self._restore_result_shape(self._mean) + + def covariance(self, unbiased=True): + return self._restore_result_shape( + self.cmom2 / (self.count - (1 if unbiased else 0)) + ) + + def correlation(self, unbiased=True): + cov = self.cmom2 / (self.count - (1 if unbiased else 0)) + rstdev = cov.diag().sqrt().reciprocal() + return self._restore_result_shape(rstdev[:, None] * cov * rstdev[None, :]) + + def variance(self, unbiased=True): + return self._restore_result_shape( + self.cmom2.diag() / (self.count - (1 if unbiased else 0)) + ) + + def stdev(self, unbiased=True): + return self.variance(unbiased=unbiased).sqrt() + + def state_dict(self): + return dict( + constructor=self.__module__ + "." + self.__class__.__name__ + "()", + count=self.count, + data_shape=self.data_shape and tuple(self.data_shape), + mean=self._mean.cpu().numpy(), + cmom2=self.cmom2.cpu().numpy(), + ) + + def load_state_dict(self, state): + self.count = state["count"] + self._mean = torch.from_numpy(state["mean"]) + self.cmom2 = torch.from_numpy(state["cmom2"]) + self.data_shape = ( + None if state["data_shape"] is None else tuple(state["data_shape"]) + ) + + +class SecondMoment(Stat): + """ + Running computation. Use this when the entire non-centered 2nd-moment + 'covariance-like' matrix is needed, and when the whole matrix fits + in the GPU. + """ + + def __init__(self, split_batch=True, state=None): + if state is not None: + return super().__init__(state) + self.count = 0 + self.mom2 = None + self.split_batch = split_batch + + def add(self, a): + a = self._normalize_add_shape(a) + if len(a) == 0: + return + # Initial batch reveals the shape of the data. + if self.count == 0: + self.mom2 = a.new(a.shape[1], a.shape[1]).zero_() + batch_count = a.shape[0] + # Update the covariance using the batch deviation + self.count += batch_count + self.mom2 += a.t().mm(a) + + def to_(self, device): + if self.mom2 is not None: + self.mom2 = self.mom2.to(device) + + def moment(self): + return self.mom2 / self.count + + def state_dict(self): + return dict( + constructor=self.__module__ + "." + self.__class__.__name__ + "()", + count=self.count, + mom2=self.mom2.cpu().numpy(), + ) + + def load_state_dict(self, state): + self.count = int(state["count"]) + self.mom2 = torch.from_numpy(state["mom2"]) + + +class Bincount(Stat): + """ + Running bincount. The counted array should be an integer type with + non-negative integers. + """ + + def __init__(self, state=None): + if state is not None: + return super().__init__(state) + self.count = 0 + self._bincount = None + + def add(self, a, size=None): + a = a.view(-1) + bincount = a.bincount() + if self._bincount is None: + self._bincount = bincount + elif len(self._bincount) < len(bincount): + bincount[: len(self._bincount)] += self._bincount + self._bincount = bincount + else: + self._bincount[: len(bincount)] += bincount + if size is None: + self.count += len(a) + else: + self.count += size + + def to_(self, device): + self._bincount = self._bincount.to(device) + + def size(self): + return self.count + + def bincount(self): + return self._bincount + + def state_dict(self): + return dict( + constructor=self.__module__ + "." + self.__class__.__name__ + "()", + count=self.count, + bincount=self._bincount.cpu().numpy(), + ) + + def load_state_dict(self, dic): + self.count = int(dic["count"]) + self._bincount = torch.from_numpy(dic["bincount"]) + + +class CrossCovariance(Stat): + """ + Covariance. Use this when an off-diagonal block of the covariance + matrix is needed (e.g., when the whole covariance matrix does + not fit in the GPU, this could use a quarter of the memory). + + Chan-style numerically stable update of mean and full covariance matrix. + Chan, Golub. LeVeque. 1983. http://www.jstor.org/stable/2683386 + """ + + def __init__(self, split_batch=True, state=None): + if state is not None: + return super().__init__(state) + self.count = 0 + self._mean = None + self.cmom2 = None + self.v_cmom2 = None + self.split_batch = split_batch + + def add(self, a, b): + if len(a.shape) == 1: + a = a[None, :] + b = b[None, :] + assert a.shape[0] == b.shape[0] + if len(a.shape) > 2: + a, b = [ + d.view(d.shape[0], d.shape[1], -1) + .permute(0, 2, 1) + .reshape(-1, d.shape[1]) + for d in [a, b] + ] + batch_count = a.shape[0] + # Initial batch. + if self._mean is None: + self.count = batch_count + self._mean = [d.sum(0) / batch_count for d in [a, b]] + centered = [d - bm for d, bm in zip([a, b], self._mean)] + self.v_cmom2 = [c.pow(2).sum(0) for c in centered] + self.cmom2 = centered[0].t().mm(centered[1]) + return + # Update a batch using Chan-style update for numerical stability. + self.count += batch_count + # Update the mean according to the batch deviation from the old mean. + delta = [(d - bm) for d, bm in zip([a, b], self._mean)] + for m, d in zip(self._mean, delta): + m.add_(d.sum(0) / self.count) + delta2 = [(d - bm) for d, bm in zip([a, b], self._mean)] + # Update the cross-covariance using the batch deviation + self.cmom2.addmm_(mat1=delta[0].t(), mat2=delta2[1]) + # Update the variance using the batch deviation + for vc2, d, d2 in zip(self.v_cmom2, delta, delta2): + vc2.add_((d * d2).sum(0)) + + def mean(self): + return self._mean + + def variance(self, unbiased=True): + return [vc2 / (self.count - (1 if unbiased else 0)) for vc2 in self.v_cmom2] + + def stdev(self, unbiased=True): + return [v.sqrt() for v in self.variance(unbiased=unbiased)] + + def covariance(self, unbiased=True): + return self.cmom2 / (self.count - (1 if unbiased else 0)) + + def correlation(self): + covariance = self.covariance(unbiased=False) + rstdev = [s.reciprocal() for s in self.stdev(unbiased=False)] + cor = rstdev[0][:, None] * covariance * rstdev[1][None, :] + # Remove NaNs + cor[torch.isnan(cor)] = 0 + return cor + + def to_(self, device): + self._mean = [m.to(device) for m in self._mean] + self.v_cmom2 = [vcs.to(device) for vcs in self.v_cmom2] + self.cmom2 = self.cmom2.to(device) + + def state_dict(self): + return dict( + constructor=self.__module__ + "." + self.__class__.__name__ + "()", + count=self.count, + mean_a=self._mean[0].cpu().numpy(), + mean_b=self._mean[1].cpu().numpy(), + cmom2_a=self.v_cmom2[0].cpu().numpy(), + cmom2_b=self.v_cmom2[1].cpu().numpy(), + cmom2=self.cmom2.cpu().numpy(), + ) + + def load_state_dict(self, state): + self.count = int(state["count"]) + self._mean = [torch.from_numpy(state[f"mean_{k}"]) for k in "ab"] + self.v_cmom2 = [torch.from_numpy(state[f"cmom2_{k}"]) for k in "ab"] + self.cmom2 = torch.from_numpy(state["cmom2"]) + + +def _float_from_bool(a): + """ + Since pytorch only supports matrix multiplication on float, + IoU computations are done using floating point types. + + This function binarizes the input (positive to True and + nonpositive to False), and converts from bool to float. + If the data is already a floating-point type, it leaves + it keeps the same type; otherwise it uses float. + """ + if a.dtype == torch.bool: + return a.float() + if a.dtype.is_floating_point: + return a.sign().clamp_(0) + return (a > 0).float() + + +class IoU(Stat): + """ + Running computation of intersections and unions of all features. + """ + + def __init__(self, state=None): + if state is not None: + return super().__init__(state) + self.count = 0 + self._intersection = None + + def add(self, a): + assert len(a.shape) == 2 + a = _float_from_bool(a) + if self._intersection is None: + self._intersection = torch.mm(a.t(), a) + else: + self._intersection.addmm_(a.t(), a) + self.count += len(a) + + def size(self): + return self.count + + def intersection(self): + return self._intersection + + def union(self): + total = self._intersection.diagonal(0) + return total[:, None] + total[None, :] - self._intersection + + def iou(self): + return self.intersection() / (self.union() + 1e-20) + + def to_(self, _device): + self._intersection = self._intersection.to(_device) + + def state_dict(self): + return dict( + constructor=self.__module__ + "." + self.__class__.__name__ + "()", + count=self.count, + intersection=self._intersection.cpu().numpy(), + ) + + def load_state_dict(self, state): + self.count = int(state["count"]) + self._intersection = torch.tensor(state["intersection"]) + + +class CrossIoU(Stat): + """ + Running computation of intersections and unions of two binary vectors. + """ + + def __init__(self, state=None): + if state is not None: + return super().__init__(state) + self.count = 0 + self._intersection = None + self.total_a = None + self.total_b = None + + def add(self, a, b): + assert len(a.shape) == 2 and len(b.shape) == 2 + assert len(a) == len(b), f"{len(a)} vs {len(b)}" + a = _float_from_bool(a) # CUDA only supports mm on float... + b = _float_from_bool(b) # otherwise we would use integers. + intersection = torch.mm(a.t(), b) + asum = a.sum(0) + bsum = b.sum(0) + if self._intersection is None: + self._intersection = intersection + self.total_a = asum + self.total_b = bsum + else: + self._intersection += intersection + self.total_a += asum + self.total_b += bsum + self.count += len(a) + + def size(self): + return self.count + + def intersection(self): + return self._intersection + + def union(self): + return self.total_a[:, None] + self.total_b[None, :] - self._intersection + + def iou(self): + return self.intersection() / (self.union() + 1e-20) + + def to_(self, _device): + self.total_a = self.total_a.to(_device) + self.total_b = self.total_b.to(_device) + self._intersection = self._intersection.to(_device) + + def state_dict(self): + return dict( + constructor=self.__module__ + "." + self.__class__.__name__ + "()", + count=self.count, + total_a=self.total_a.cpu().numpy(), + total_b=self.total_b.cpu().numpy(), + intersection=self._intersection.cpu().numpy(), + ) + + def load_state_dict(self, state): + self.count = int(state["count"]) + self.total_a = torch.tensor(state["total_a"]) + self.total_b = torch.tensor(state["total_b"]) + self._intersection = torch.tensor(state["intersection"]) + + +class Quantile(Stat): + """ + Streaming randomized quantile computation for torch. + + Add any amount of data repeatedly via add(data). At any time, + quantile estimates be read out using quantile(q). + + Implemented as a sorted sample that retains at least r samples + (by default r = 3072); the number of retained samples will grow to + a finite ceiling as the data is accumulated. Accuracy scales according + to r: the default is to set resolution to be accurate to better than about + 0.1%, while limiting storage to about 50,000 samples. + + Good for computing quantiles of huge data without using much memory. + Works well on arbitrary data with probability near 1. + + Based on the optimal KLL quantile algorithm by Karnin, Lang, and Liberty + from FOCS 2016. http://ieee-focs.org/FOCS-2016-Papers/3933a071.pdf + """ + + def __init__(self, r=3 * 1024, buffersize=None, seed=None, state=None): + if state is not None: + return super().__init__(state) + self.depth = None + self.dtype = None + self.device = None + resolution = r * 2 # sample array is at least half full before discard + self.resolution = resolution + # Default buffersize: 128 samples (and smaller than resolution). + if buffersize is None: + buffersize = min(128, (resolution + 7) // 8) + self.buffersize = buffersize + self.samplerate = 1.0 + self.data = None + self.firstfree = [0] + self.randbits = torch.ByteTensor(resolution) + self.currentbit = len(self.randbits) - 1 + self.extremes = None + self.count = 0 + self.batchcount = 0 + + def size(self): + return self.count + + def _lazy_init(self, incoming): + self.depth = incoming.shape[1] + self.dtype = incoming.dtype + self.device = incoming.device + self.data = [ + torch.zeros( + self.depth, self.resolution, dtype=self.dtype, device=self.device + ) + ] + self.extremes = torch.zeros(self.depth, 2, dtype=self.dtype, device=self.device) + self.extremes[:, 0] = float("inf") + self.extremes[:, -1] = -float("inf") + + def to_(self, device): + """Switches internal storage to specified device.""" + if device != self.device: + old_data = self.data + old_extremes = self.extremes + self.data = [d.to(device) for d in self.data] + self.extremes = self.extremes.to(device) + self.device = self.extremes.device + del old_data + del old_extremes + + def add(self, incoming): + if self.depth is None: + self._lazy_init(incoming) + assert len(incoming.shape) == 2 + assert incoming.shape[1] == self.depth, (incoming.shape[1], self.depth) + self.count += incoming.shape[0] + self.batchcount += 1 + # Convert to a flat torch array. + if self.samplerate >= 1.0: + self._add_every(incoming) + return + # If we are sampling, then subsample a large chunk at a time. + self._scan_extremes(incoming) + chunksize = int(math.ceil(self.buffersize / self.samplerate)) + for index in range(0, len(incoming), chunksize): + batch = incoming[index : index + chunksize] + sample = sample_portion(batch, self.samplerate) + if len(sample): + self._add_every(sample) + + def _add_every(self, incoming): + supplied = len(incoming) + index = 0 + while index < supplied: + ff = self.firstfree[0] + available = self.data[0].shape[1] - ff + if available == 0: + if not self._shift(): + # If we shifted by subsampling, then subsample. + incoming = incoming[index:] + if self.samplerate >= 0.5: + # First time sampling - the data source is very large. + self._scan_extremes(incoming) + incoming = sample_portion(incoming, self.samplerate) + index = 0 + supplied = len(incoming) + ff = self.firstfree[0] + available = self.data[0].shape[1] - ff + copycount = min(available, supplied - index) + self.data[0][:, ff : ff + copycount] = torch.t( + incoming[index : index + copycount, :] + ) + self.firstfree[0] += copycount + index += copycount + + def _shift(self): + index = 0 + # If remaining space at the current layer is less than half prev + # buffer size (rounding up), then we need to shift it up to ensure + # enough space for future shifting. + while self.data[index].shape[1] - self.firstfree[index] < ( + -(-self.data[index - 1].shape[1] // 2) if index else 1 + ): + if index + 1 >= len(self.data): + return self._expand() + data = self.data[index][:, 0 : self.firstfree[index]] + data = data.sort()[0] + if index == 0 and self.samplerate >= 1.0: + self._update_extremes(data[:, 0], data[:, -1]) + offset = self._randbit() + position = self.firstfree[index + 1] + subset = data[:, offset::2] + self.data[index + 1][:, position : position + subset.shape[1]] = subset + self.firstfree[index] = 0 + self.firstfree[index + 1] += subset.shape[1] + index += 1 + return True + + def _scan_extremes(self, incoming): + # When sampling, we need to scan every item still to get extremes + self._update_extremes( + torch.min(incoming, dim=0)[0], torch.max(incoming, dim=0)[0] + ) + + def _update_extremes(self, minr, maxr): + self.extremes[:, 0] = torch.min( + torch.stack([self.extremes[:, 0], minr]), dim=0 + )[0] + self.extremes[:, -1] = torch.max( + torch.stack([self.extremes[:, -1], maxr]), dim=0 + )[0] + + def _randbit(self): + self.currentbit += 1 + if self.currentbit >= len(self.randbits): + self.randbits.random_(to=2) + self.currentbit = 0 + return self.randbits[self.currentbit] + + def state_dict(self): + state = dict( + constructor=self.__module__ + "." + self.__class__.__name__ + "()", + resolution=self.resolution, + depth=self.depth, + buffersize=self.buffersize, + samplerate=self.samplerate, + sizes=numpy.array([d.shape[1] for d in self.data]), + extremes=self.extremes.cpu().detach().numpy(), + size=self.count, + batchcount=self.batchcount, + ) + for i, (d, f) in enumerate(zip(self.data, self.firstfree)): + state[f"data.{i}"] = d.cpu().detach().numpy()[:, :f].T + return state + + def load_state_dict(self, state): + self.resolution = int(state["resolution"]) + self.randbits = torch.ByteTensor(self.resolution) + self.currentbit = len(self.randbits) - 1 + self.depth = int(state["depth"]) + self.buffersize = int(state["buffersize"]) + self.samplerate = float(state["samplerate"]) + firstfree = [] + buffers = [] + for i, s in enumerate(state["sizes"]): + d = state[f"data.{i}"] + firstfree.append(d.shape[0]) + buf = numpy.zeros((d.shape[1], s), dtype=d.dtype) + buf[:, : d.shape[0]] = d.T + buffers.append(torch.from_numpy(buf)) + self.firstfree = firstfree + self.data = buffers + self.extremes = torch.from_numpy((state["extremes"])) + self.count = int(state["size"]) + self.batchcount = int(state.get("batchcount", 0)) + self.dtype = self.extremes.dtype + self.device = self.extremes.device + + def min(self): + return self.minmax()[0] + + def max(self): + return self.minmax()[-1] + + def minmax(self): + if self.firstfree[0]: + self._scan_extremes(self.data[0][:, : self.firstfree[0]].t()) + return self.extremes.clone() + + def median(self): + return self.quantiles(0.5) + + def mean(self): + return self.integrate(lambda x: x) / self.count + + def variance(self, unbiased=True): + mean = self.mean()[:, None] + return self.integrate(lambda x: (x - mean).pow(2)) / ( + self.count - (1 if unbiased else 0) + ) + + def stdev(self, unbiased=True): + return self.variance(unbiased=unbiased).sqrt() + + def _expand(self): + cap = self._next_capacity() + if cap > 0: + # First, make a new layer of the proper capacity. + self.data.insert( + 0, torch.zeros(self.depth, cap, dtype=self.dtype, device=self.device) + ) + self.firstfree.insert(0, 0) + else: + # Unless we're so big we are just subsampling. + assert self.firstfree[0] == 0 + self.samplerate *= 0.5 + for index in range(1, len(self.data)): + # Scan for existing data that needs to be moved down a level. + amount = self.firstfree[index] + if amount == 0: + continue + position = self.firstfree[index - 1] + # Move data down if it would leave enough empty space there + # This is the key invariant: enough empty space to fit half + # of the previous level's buffer size (rounding up) + if self.data[index - 1].shape[1] - (amount + position) >= ( + -(-self.data[index - 2].shape[1] // 2) if (index - 1) else 1 + ): + self.data[index - 1][:, position : position + amount] = self.data[ + index + ][:, :amount] + self.firstfree[index - 1] += amount + self.firstfree[index] = 0 + else: + # Scrunch the data if it would not. + data = self.data[index][:, :amount] + data = data.sort()[0] + if index == 1: + self._update_extremes(data[:, 0], data[:, -1]) + offset = self._randbit() + scrunched = data[:, offset::2] + self.data[index][:, : scrunched.shape[1]] = scrunched + self.firstfree[index] = scrunched.shape[1] + return cap > 0 + + def _next_capacity(self): + cap = int(math.ceil(self.resolution * (0.67 ** len(self.data)))) + if cap < 2: + return 0 + # Round up to the nearest multiple of 8 for better GPU alignment. + cap = -8 * (-cap // 8) + return max(self.buffersize, cap) + + def _weighted_summary(self, sort=True): + if self.firstfree[0]: + self._scan_extremes(self.data[0][:, : self.firstfree[0]].t()) + size = sum(self.firstfree) + weights = torch.FloatTensor(size) # Floating point + summary = torch.zeros(self.depth, size, dtype=self.dtype, device=self.device) + index = 0 + for level, ff in enumerate(self.firstfree): + if ff == 0: + continue + summary[:, index : index + ff] = self.data[level][:, :ff] + weights[index : index + ff] = 2.0**level + index += ff + assert index == summary.shape[1] + if sort: + summary, order = torch.sort(summary, dim=-1) + weights = weights[order.view(-1).cpu()].view(order.shape) + summary = torch.cat( + [self.extremes[:, :1], summary, self.extremes[:, 1:]], dim=-1 + ) + weights = torch.cat( + [ + torch.zeros(weights.shape[0], 1), + weights, + torch.zeros(weights.shape[0], 1), + ], + dim=-1, + ) + return (summary, weights) + + def quantiles(self, quantiles): + if not hasattr(quantiles, "cpu"): + quantiles = torch.tensor(quantiles) + qshape = quantiles.shape + if self.count == 0: + return torch.full((self.depth,) + qshape, torch.nan) + summary, weights = self._weighted_summary() + cumweights = torch.cumsum(weights, dim=-1) - weights / 2 + cumweights /= torch.sum(weights, dim=-1, keepdim=True) + result = torch.zeros( + self.depth, quantiles.numel(), dtype=self.dtype, device=self.device + ) + # numpy is needed for interpolation + nq = quantiles.view(-1).cpu().detach().numpy() + ncw = cumweights.cpu().detach().numpy() + nsm = summary.cpu().detach().numpy() + for d in range(self.depth): + result[d] = torch.tensor( + numpy.interp(nq, ncw[d], nsm[d]), dtype=self.dtype, device=self.device + ) + return result.view((self.depth,) + qshape) + + def integrate(self, fun): + result = [] + for level, ff in enumerate(self.firstfree): + if ff == 0: + continue + result.append( + torch.sum(fun(self.data[level][:, :ff]) * (2.0**level), dim=-1) + ) + if len(result) == 0: + return None + return torch.stack(result).sum(dim=0) / self.samplerate + + def readout(self, count=1001): + return self.quantiles(torch.linspace(0.0, 1.0, count)) + + def normalize(self, data): + """ + Given input data as taken from the training distirbution, + normalizes every channel to reflect quantile values, + uniformly distributed, within [0, 1]. + """ + assert self.count > 0 + assert data.shape[0] == self.depth + summary, weights = self._weighted_summary() + cumweights = torch.cumsum(weights, dim=-1) - weights / 2 + cumweights /= torch.sum(weights, dim=-1, keepdim=True) + result = torch.zeros_like(data).float() + # numpy is needed for interpolation + ndata = data.cpu().numpy().reshape((data.shape[0], -1)) + ncw = cumweights.cpu().numpy() + nsm = summary.cpu().numpy() + for d in range(self.depth): + normed = torch.tensor( + numpy.interp(ndata[d], nsm[d], ncw[d]), + dtype=torch.float, + device=data.device, + ).clamp_(0.0, 1.0) + if len(data.shape) > 1: + normed = normed.view(*(data.shape[1:])) + result[d] = normed + return result + + +def sample_portion(vec, p=0.5): + """ + Subsamples a fraction (given by p) of the given batch. Used by + Quantile when the data gets very very large. + """ + bits = torch.bernoulli( + torch.zeros(vec.shape[0], dtype=torch.uint8, device=vec.device), p + ) + return vec[bits] + + +class TopK: + """ + A class to keep a running tally of the the top k values (and indexes) + of any number of torch feature components. Will work on the GPU if + the data is on the GPU. Tracks largest by default, but tracks smallest + if largest=False is passed. + + This version flattens all arrays to avoid crashes. + """ + + def __init__(self, k=100, largest=True, state=None): + if state is not None: + return super().__init__(state) + self.k = k + self.count = 0 + # This version flattens all data internally to 2-d tensors, + # to avoid crashes with the current pytorch topk implementation. + # The data is puffed back out to arbitrary tensor shapes on ouput. + self.data_shape = None + self.top_data = None + self.top_index = None + self.next = 0 + self.linear_index = 0 + self.perm = None + self.largest = largest + + def add(self, data, index=None): + """ + Adds a batch of data to be considered for the running top k. + The zeroth dimension enumerates the observations. All other + dimensions enumerate different features. + """ + if self.top_data is None: + # Allocation: allocate a buffer of size 5*k, at least 10, for each. + self.data_shape = data.shape[1:] + feature_size = int(numpy.prod(self.data_shape)) + self.top_data = torch.zeros( + feature_size, max(10, self.k * 5), out=data.new() + ) + self.top_index = self.top_data.clone().long() + self.linear_index = ( + 0 + if len(data.shape) == 1 + else torch.arange(feature_size, out=self.top_index.new()).mul_( + self.top_data.shape[-1] + )[:, None] + ) + size = data.shape[0] + sk = min(size, self.k) + if self.top_data.shape[-1] < self.next + sk: + # Compression: if full, keep topk only. + self.top_data[:, : self.k], self.top_index[:, : self.k] = self.topk( + sorted=False, flat=True + ) + self.next = self.k + # Pick: copy the top sk of the next batch into the buffer. + # Currently strided topk is slow. So we clone after transpose. + # TODO: remove the clone() if it becomes faster. + cdata = data.reshape(size, numpy.prod(data.shape[1:])).t().clone() + td, ti = cdata.topk(sk, sorted=False, largest=self.largest) + self.top_data[:, self.next : self.next + sk] = td + if index is not None: + ti = index[ti] + else: + ti = ti + self.count + self.top_index[:, self.next : self.next + sk] = ti + self.next += sk + self.count += size + + def size(self): + return self.count + + def topk(self, sorted=True, flat=False): + """ + Returns top k data items and indexes in each dimension, + with channels in the first dimension and k in the last dimension. + """ + k = min(self.k, self.next) + # bti are top indexes relative to buffer array. + td, bti = self.top_data[:, : self.next].topk( + k, sorted=sorted, largest=self.largest + ) + # we want to report top indexes globally, which is ti. + ti = self.top_index.view(-1)[(bti + self.linear_index).view(-1)].view( + *bti.shape + ) + if flat: + return td, ti + else: + return ( + td.view(*(self.data_shape + (-1,))), + ti.view(*(self.data_shape + (-1,))), + ) + + def to_(self, device): + if self.top_data is not None: + self.top_data = self.top_data.to(device) + if self.top_index is not None: + self.top_index = self.top_index.to(device) + if isinstance(self.linear_index, torch.Tensor): + self.linear_index = self.linear_index.to(device) + + def state_dict(self): + return dict( + constructor=self.__module__ + "." + self.__class__.__name__ + "()", + k=self.k, + count=self.count, + largest=self.largest, + data_shape=self.data_shape and tuple(self.data_shape), + top_data=self.top_data.cpu().detach().numpy(), + top_index=self.top_index.cpu().detach().numpy(), + next=self.next, + linear_index=( + self.linear_index.cpu().numpy() + if isinstance(self.linear_index, torch.Tensor) + else self.linear_index + ), + perm=self.perm, + ) + + def load_state_dict(self, state): + self.k = int(state["k"]) + self.count = int(state["count"]) + self.largest = bool(state.get("largest", True)) + self.data_shape = ( + None if state["data_shape"] is None else tuple(state["data_shape"]) + ) + self.top_data = torch.from_numpy(state["top_data"]) + self.top_index = torch.from_numpy(state["top_index"]) + self.next = int(state["next"]) + self.linear_index = ( + torch.from_numpy(state["linear_index"]) + if len(state["linear_index"].shape) > 0 + else int(state["linear_index"]) + ) + + +class History(Stat): + """ + Accumulates the concatenation of all the added data. + """ + + def __init__(self, data=None, state=None): + if state is not None: + return super().__init__(state) + self._data = data + self._added = [] + + def _cat_added(self): + if len(self._added): + self._data = torch.cat( + ([self._data] if self._data is not None else []) + self._added + ) + self._added = [] + + def add(self, d): + self._added.append(d) + if len(self._added) > 100: + self._cat_added() + + def history(self): + self._cat_added() + return self._data + + def load_state_dict(self, state): + data = state["data"] + self._data = None if data is None else torch.from_numpy(data) + self._added = [] + + def state_dict(self): + self._cat_added() + return dict( + constructor=self.__module__ + "." + self.__class__.__name__ + "()", + data=None if self._data is None else self._data.cpu().numpy(), + ) + + def to_(self, device): + """Switches internal storage to specified device.""" + self._cat_added() + if self._data is not None: + self._data = self._data.to(device) + + +class CombinedStat(Stat): + """ + A Stat that bundles together multiple Stat objects. + Convenient for loading and saving a state_dict made up of a + hierarchy of stats, and for use with the tally() function. + Example: + + cs = CombinedStat(m=Mean(), q=Quantile()) + for [b] in tally(cs, MyDataSet(), cache=fn, batch_size=100): + cs.add(b) + print(cs.m.mean()) + print(cs.q.median()) + """ + + def __init__(self, state=None, **kwargs): + self._objs = kwargs + if state is not None: + return super().__init__(state) + + def __getattr__(self, k): + if k in self._objs: + return self._objs[k] + raise AttributeError() + + def add(self, d, *args, **kwargs): + for obj in self._objs.values(): + obj.add(d, *args, **kwargs) + + def load_state_dict(self, state): + for prefix, obj in self._objs.items(): + obj.load_state_dict(pull_key_prefix(prefix, state)) + + def state_dict(self): + result = {} + for prefix, obj in self._objs.items(): + result.update(push_key_prefix(prefix, obj.state_dict())) + return result + + def to_(self, device): + """Switches internal storage to specified device.""" + for v in self._objs.values(): + v.to_(device) + + +def push_key_prefix(prefix, d): + """ + Returns a dict with the same values as d, but where each key + adds the prefix, followed by a dot. + """ + return {prefix + "." + k: v for k, v in d.items()} + + +def pull_key_prefix(prefix, d): + """ + Returns a filtered dict of all the items of d that start with + the given key prefix, plus a dot, with that prefix removed. + """ + pd = prefix + "." + lpd = len(pd) + return {k[lpd:]: v for k, v in d.items() if k.startswith(pd)} + + +# We wish to be able to save None (null) values in numpy npz files, +# yet do so without setting the unsecure 'allow_pickle' flag. To do +# that, we will encode null as a special kind of IEEE 754 NaN value. +# Inspired by https://github.com/zuiderkwast/nanbox/blob/master/nanbox.h +# we follow the same Nanboxing scheme used in JavaScriptCore +# (search for JSCJSValue.h#L435), which encodes null values in NaN +# as the NaN value with hex pattern 0xfff8000000000002. + +null_numpy_value = numpy.array( + struct.unpack(">d", struct.pack(">Q", 0xFFF8000000000002))[0], dtype=numpy.float64 +) + + +def is_null_numpy_value(v): + """ + True if v is a 64-bit float numpy scalar NaN matching null_numpy_value. + """ + return ( + isinstance(v, numpy.ndarray) + and numpy.ndim(v) == 0 + and v.dtype == numpy.float64 + and numpy.isnan(v) + and 0xFFF8000000000002 == struct.unpack(">Q", struct.pack(">d", v))[0] + ) + + +def box_numpy_null(d): + """ + Replaces None with null_numpy_value, leaving non-None values unchanged. + Recursively descends into a dictionary replacing None values. + """ + try: + return {k: box_numpy_null(v) for k, v in d.items()} + except Exception: + return null_numpy_value if d is None else d + + +def unbox_numpy_null(d): + """ + Reverses box_numpy_null, replacing null_numpy_value with None. + Recursively descends into a dictionary replacing None values. + """ + try: + return {k: unbox_numpy_null(v) for k, v in d.items()} + except Exception: + return None if is_null_numpy_value(d) else d + + +def resolve_state_dict(s): + """ + Resolves a state, which can be a filename or a dict-like object. + """ + if isinstance(s, str): + return unbox_numpy_null(numpy.load(s)) + return s + + +global_load_cache_enabled = True + + +def load_cached_state(cachefile, args, quiet=False, throw=False): + """ + Resolves a state, which can be a filename or a dict-like object. + """ + if not global_load_cache_enabled or cachefile is None: + return None + try: + if isinstance(cachefile, dict): + dat = cachefile + cachefile = "state" # for printed messages + else: + dat = unbox_numpy_null(numpy.load(cachefile)) + for a, v in args.items(): + if a not in dat or dat[a] != v: + if not quiet: + print("%s %s changed from %s to %s" % (cachefile, a, dat[a], v)) + return None + except (FileNotFoundError, ValueError) as e: + if throw: + raise e + return None + else: + if not quiet: + print("Loading cached %s" % cachefile) + return dat + + +def save_cached_state(cachefile, obj, args): + """ + Saves the state_dict of the given object in a dict or npz file. + """ + if cachefile is None: + return + dat = obj.state_dict() + for a, v in args.items(): + if a in dat: + assert dat[a] == v + dat[a] = v + if isinstance(cachefile, dict): + cachefile.clear() + cachefile.update(dat) + else: + os.makedirs(os.path.dirname(cachefile), exist_ok=True) + numpy.savez(cachefile, **box_numpy_null(dat)) + + +class FixedSubsetSampler(Sampler): + """Represents a fixed sequence of data set indices. + Subsets can be created by specifying a subset of output indexes. + """ + + def __init__(self, samples): + self.samples = samples + + def __iter__(self): + return iter(self.samples) + + def __len__(self): + return len(self.samples) + + def __getitem__(self, key): + return self.samples[key] + + def subset(self, new_subset): + return FixedSubsetSampler(self.dereference(new_subset)) + + def dereference(self, indices): + """ + Translate output sample indices (small numbers indexing the sample) + to input sample indices (larger number indexing the original full set) + """ + return [self.samples[i] for i in indices] + + +class FixedRandomSubsetSampler(FixedSubsetSampler): + """Samples a fixed number of samples from the dataset, deterministically. + Arguments: + data_source, + sample_size, + seed (optional) + """ + + def __init__(self, data_source, start=None, end=None, seed=1): + rng = random.Random(seed) + shuffled = list(range(len(data_source))) + rng.shuffle(shuffled) + self.data_source = data_source + super(FixedRandomSubsetSampler, self).__init__(shuffled[start:end]) + + def class_subset(self, class_filter): + """ + Returns only the subset matching the given rule. + """ + if isinstance(class_filter, int): + + def rule(d): + return d[1] == class_filter + + else: + rule = class_filter + return self.subset( + [i for i, j in enumerate(self.samples) if rule(self.data_source[j])] + ) + + +def make_loader( + dataset, sample_size=None, batch_size=1, sampler=None, random_sample=None, **kwargs +): + """Utility for creating a dataloader on fixed sample subset.""" + import typing + + if isinstance(dataset, typing.Callable): + # To support deferred dataset loading, support passing a factory + # that creates the dataset when called. + dataset = dataset() + if isinstance(dataset, torch.Tensor): + # The dataset can be a simple tensor. + dataset = torch.utils.data.TensorDataset(dataset) + if sample_size is not None: + assert sampler is None, "sampler cannot be specified with sample_size" + if sample_size > len(dataset): + print( + "Warning: sample size %d > dataset size %d" + % (sample_size, len(dataset)) + ) + sample_size = len(dataset) + if random_sample is None: + sampler = FixedSubsetSampler(list(range(sample_size))) + else: + sampler = FixedRandomSubsetSampler( + dataset, seed=random_sample, end=sample_size + ) + return torch.utils.data.DataLoader( + dataset, sampler=sampler, batch_size=batch_size, **kwargs + ) + + +# Unit Tests +def _unit_test(): + import warnings + + warnings.filterwarnings("error") + import argparse + import random + import shutil + import tempfile + import time + + parser = argparse.ArgumentParser(description="Test things out") + parser.add_argument("--mode", default="cpu", help="cpu or cuda") + parser.add_argument("--test_size", type=int, default=1000000) + args = parser.parse_args() + testdir = tempfile.mkdtemp() + batch_size = random.randint(500, 1500) + + # Test NaNboxing. + assert numpy.isnan(null_numpy_value) + assert is_null_numpy_value(null_numpy_value) + assert not is_null_numpy_value(numpy.nan) + + # Test Covariance + goal = torch.tensor(numpy.random.RandomState(1).standard_normal(10 * 10)).view( + 10, 10 + ) + data = ( + torch.tensor(numpy.random.RandomState(2).standard_normal(args.test_size * 10)) + .view(args.test_size, 10) + .mm(goal) + ) + data += torch.randn(1, 10) * 999 + dcov = data.t().cov() + dcorr = data.t().corrcoef() + rcov = Covariance() + rcov.add(data) # All one batch + assert (rcov.covariance() - dcov).abs().max() < 1e-16 + cs = CombinedStat(cov=Covariance(), xcov=CrossCovariance()) + ds = torch.utils.data.TensorDataset(data) + for [a] in tally(cs, ds, batch_size=9876): + cs.cov.add(a) + cs.xcov.add(a[:, :3], a[:, 3:]) + assert (data.mean(0) - cs.cov.mean()).abs().max() < 1e-12 + assert (dcov - cs.cov.covariance()).abs().max() < 2e-12 + assert (dcov[:3, 3:] - cs.xcov.covariance()).abs().max() < 1e-12 + assert (dcov.diagonal() - torch.cat(cs.xcov.variance())).abs().max() < 1e-12 + assert (dcorr - cs.cov.correlation()).abs().max() < 2e-12 + + # Test CrossCovariance and CrossIoU + fn = f"{testdir}/cross_cache.npz" + ds = torch.utils.data.TensorDataset( + ( + torch.arange(args.test_size)[:, None] % torch.arange(1, 6)[None, :] == 0 + ).double(), + ( + torch.arange(args.test_size)[:, None] % torch.arange(5, 8)[None, :] == 0 + ).double(), + ) + c = CombinedStat(c=CrossCovariance(), iou=CrossIoU()) + riou = IoU() + count = 0 + for [a, b] in tally(c, ds, cache=fn, batch_size=100): + count += 1 + c.add(a, b) + riou.add(torch.cat([a, b], dim=1)) + assert count == -(-args.test_size // 100) + cor = c.c.correlation() + iou = c.iou.iou() + assert cor.shape == iou.shape == (5, 3) + assert iou[4, 0] == 1.0 + assert abs(iou[0, 2] + (-args.test_size // 7 / float(args.test_size))) < 1e-6 + assert abs(cor[4, 0] - 1.0) < 1e-2 + assert abs(cor[0, 2] - 0.0) < 1e-6 + assert all((riou.iou()[:5, -3:] == iou).view(-1)) + assert all(riou.iou().diagonal(0) == 1) + c = CombinedStat(c=CrossCovariance(), iou=CrossIoU()) + count = 0 + for [a, b] in tally(c, ds, cache=fn, batch_size=10): + count += 1 + c.add(a, b) + assert count == 0 + assert all((c.c.correlation() == cor).view(-1)) + assert all((c.iou.iou() == iou).view(-1)) + + # Test Concatantaion, Mean, Bincount and tally. + fn = f"{testdir}/series_cache.npz" + count = 0 + ds = torch.utils.data.TensorDataset(torch.arange(args.test_size)) + c = CombinedStat(s=History(), m=Mean(), b=Bincount()) + for [b] in tally(c, ds, cache=fn, batch_size=batch_size): + count += 1 + c.add(b) + assert count == -(-args.test_size // batch_size) + assert len(c.s.history()) == args.test_size + assert c.s.history()[-1] == args.test_size - 1 + assert all(c.s.history() == ds.tensors[0]) + assert all(c.b.bincount() == torch.ones(args.test_size)) + assert c.m.mean() == float(args.test_size - 1) / 2.0 + c2 = CombinedStat(s=History(), m=Mean(), b=Bincount()) + batches = tally(c2, ds, cache=fn) + assert len(c2.s.history()) == args.test_size + assert all(c2.s.history() == c.s.history()) + assert all(c2.b.bincount() == torch.ones(args.test_size)) + assert c2.m.mean() == c.m.mean() + count = 0 + for b in batches: + count += 1 + assert count == 0 # Shouldn't do anything when it's cached + + # An adverarial case: we keep finding more numbers in the middle + # as the stream goes on. + amount = args.test_size + quantiles = 1000 + data = numpy.arange(float(amount)) + data[1::2] = data[-1::-2] + (len(data) - 1) + data /= 2 + depth = 50 + alldata = data[:, None] + (numpy.arange(depth) * amount)[None, :] + actual_sum = torch.FloatTensor(numpy.sum(alldata * alldata, axis=0)) + amt = amount // depth + for r in range(depth): + numpy.random.shuffle(alldata[r * amt : r * amt + amt, r]) + if args.mode == "cuda": + alldata = torch.cuda.FloatTensor(alldata) + device = torch.device("cuda") + else: + alldata = torch.FloatTensor(alldata) + device = None + starttime = time.time() + cs = CombinedStat( + qc=Quantile(), + m=Mean(), + v=Variance(), + c=Covariance(), + s=SecondMoment(), + t=TopK(), + i=IoU(), + ) + # Feed data in little batches + i = 0 + while i < len(alldata): + batch_size = numpy.random.randint(1000) + cs.add(alldata[i : i + batch_size]) + i += batch_size + # Test state dict + saved = cs.state_dict() + # numpy.savez(f'{testdir}/saved.npz', **box_numpy_null(saved)) + # saved = unbox_numpy_null(numpy.load(f'{testdir}/saved.npz')) + cs.save(f"{testdir}/saved.npz") + loaded = unbox_numpy_null(numpy.load(f"{testdir}/saved.npz")) + assert set(loaded.keys()) == set(saved.keys()) + + # Restore using state=saved in constructor. + cs2 = CombinedStat( + qc=Quantile(), + m=Mean(), + v=Variance(), + c=Covariance(), + s=SecondMoment(), + t=TopK(), + i=IoU(), + state=saved, + ) + # saved = unbox_numpy_null(numpy.load(f'{testdir}/saved.npz')) + assert not cs2.qc.device.type == "cuda" + cs2.to_(device) + # alldata = alldata.cpu() + cs2.add(alldata) + actual_sum *= 2 + # print(abs(alldata.mean(0) - cs2.m.mean()) / alldata.mean()) + assert all(abs(alldata.mean(0) - cs2.m.mean()) / alldata.mean() < 1e-5) + assert all(abs(alldata.mean(0) - cs2.v.mean()) / alldata.mean() < 1e-5) + assert all(abs(alldata.mean(0) - cs2.c.mean()) / alldata.mean() < 1e-5) + # print(abs(alldata.var(0) - cs2.v.variance()) / alldata.var(0)) + assert all(abs(alldata.var(0) - cs2.v.variance()) / alldata.var(0) < 1e-3) + assert all(abs(alldata.var(0) - cs2.c.variance()) / alldata.var(0) < 1e-2) + # print(abs(alldata.std(0) - cs2.v.stdev()) / alldata.std(0)) + assert all(abs(alldata.std(0) - cs2.v.stdev()) / alldata.std(0) < 1e-4) + # print(abs(alldata.std(0) - cs2.c.stdev()) / alldata.std(0)) + assert all(abs(alldata.std(0) - cs2.c.stdev()) / alldata.std(0) < 2e-3) + moment = (alldata.t() @ alldata) / len(alldata) + # print(abs(moment - cs2.s.moment()) / moment.abs()) + assert all((abs(moment - cs2.s.moment()) / moment.abs()).view(-1) < 1e-2) + assert all(alldata.max(dim=0)[0] == cs2.t.topk()[0][:, 0]) + assert cs2.i.iou()[0, 0] == 1 + assert all((cs2.i.iou()[1:, 1:] == 1).view(-1)) + assert all(cs2.i.iou()[1:, 0] < 1) + assert all(cs2.i.iou()[1:, 0] == cs2.i.iou()[0, 1:]) + + # Restore using cs.load() method. + cs = CombinedStat( + qc=Quantile(), + m=Mean(), + v=Variance(), + c=Covariance(), + s=SecondMoment(), + t=TopK(), + i=IoU(), + ) + cs.load(f"{testdir}/saved.npz") + assert not cs.qc.device.type == "cuda" + cs.to_(device) + cs.add(alldata) + # actual_sum *= 2 + # print(abs(alldata.mean(0) - cs.m.mean()) / alldata.mean()) + assert all(abs(alldata.mean(0) - cs.m.mean()) / alldata.mean() < 1e-5) + assert all(abs(alldata.mean(0) - cs.v.mean()) / alldata.mean() < 1e-5) + assert all(abs(alldata.mean(0) - cs.c.mean()) / alldata.mean() < 1e-5) + # print(abs(alldata.var(0) - cs.v.variance()) / alldata.var(0)) + assert all(abs(alldata.var(0) - cs.v.variance()) / alldata.var(0) < 1e-3) + assert all(abs(alldata.var(0) - cs.c.variance()) / alldata.var(0) < 1e-2) + # print(abs(alldata.std(0) - cs.v.stdev()) / alldata.std(0)) + assert all(abs(alldata.std(0) - cs.v.stdev()) / alldata.std(0) < 1e-4) + # print(abs(alldata.std(0) - cs.c.stdev()) / alldata.std(0)) + assert all(abs(alldata.std(0) - cs.c.stdev()) / alldata.std(0) < 2e-3) + moment = (alldata.t() @ alldata) / len(alldata) + # print(abs(moment - cs.s.moment()) / moment.abs()) + assert all((abs(moment - cs.s.moment()) / moment.abs()).view(-1) < 1e-2) + assert all(alldata.max(dim=0)[0] == cs.t.topk()[0][:, 0]) + assert cs.i.iou()[0, 0] == 1 + assert all((cs.i.iou()[1:, 1:] == 1).view(-1)) + assert all(cs.i.iou()[1:, 0] < 1) + assert all(cs.i.iou()[1:, 0] == cs.i.iou()[0, 1:]) + + # Randomized quantile test + qc = cs.qc + ro = qc.readout(1001).cpu() + endtime = time.time() + gt = ( + torch.linspace(0, amount, quantiles + 1)[None, :] + + (torch.arange(qc.depth, dtype=torch.float) * amount)[:, None] + ) + maxreldev = torch.max(torch.abs(ro - gt) / amount) * quantiles + print("Randomized quantile test results:") + print("Maximum relative deviation among %d perentiles: %f" % (quantiles, maxreldev)) + minerr = torch.max( + torch.abs( + qc.minmax().cpu()[:, 0] - torch.arange(qc.depth, dtype=torch.float) * amount + ) + ) + maxerr = torch.max( + torch.abs( + (qc.minmax().cpu()[:, -1] + 1) + - (torch.arange(qc.depth, dtype=torch.float) + 1) * amount + ) + ) + print("Minmax error %f, %f" % (minerr, maxerr)) + interr = torch.max( + torch.abs(qc.integrate(lambda x: x * x).cpu() - actual_sum) / actual_sum + ) + print("Integral error: %f" % interr) + medianerr = torch.max( + torch.abs(qc.median() - alldata.median(0)[0]) / alldata.median(0)[0] + ).cpu() + print("Median error: %f" % medianerr) + meanerr = torch.max(torch.abs(qc.mean() - alldata.mean(0)) / alldata.mean(0)).cpu() + print("Mean error: %f" % meanerr) + varerr = torch.max(torch.abs(qc.variance() - alldata.var(0)) / alldata.var(0)).cpu() + print("Variance error: %f" % varerr) + counterr = ( + (qc.integrate(lambda x: torch.ones(x.shape[-1]).cpu()) - qc.size()) + / (0.0 + qc.size()) + ).item() + print("Count error: %f" % counterr) + print("Time %f" % (endtime - starttime)) + # Algorithm is randomized, so some of these will fail with low probability. + assert maxreldev < 1.0 + assert minerr == 0.0 + assert maxerr == 0.0 + assert interr < 0.01 + assert abs(counterr) < 0.001 + shutil.rmtree(testdir, ignore_errors=True) + print("OK") + + +if __name__ == "__main__": + _unit_test() diff --git a/torchkit/show.py b/torchkit/show.py new file mode 100644 index 0000000..7129d99 --- /dev/null +++ b/torchkit/show.py @@ -0,0 +1,138 @@ +# show.py +# +# An abbreviated way to output simple HTML layout of text and images +# into a python notebook. +# +# - show a PIL image to show an inline HTML . +# - show an array of items to vertically stack them, centered in a block. +# - show an array of arrays to horizontally lay them out as inline blocks. +# - show an array of tuples to create a table. + +import PIL.Image, base64, io, IPython, types, sys +import html as html_module +from IPython.display import display + +g_buffer = None + +def blocks(obj, space=''): + return IPython.display.HTML(space.join(blocks_tags(obj))) + +def rows(obj, space=''): + return IPython.display.HTML(space.join(rows_tags(obj))) + +def rows_tags(obj): + if isinstance(obj, dict): + obj = obj.items() + results = [] + results.append('') + for row in obj: + results.append('') + for item in row: + results.append('') + results.append('') + results.append('
') + results.extend(blocks_tags(item)) + results.append('
') + return results + +def blocks_tags(obj): + results = [] + if hasattr(obj, '_repr_html_'): + results.append(obj._repr_html_()) + elif isinstance(obj, PIL.Image.Image): + results.append(pil_to_html(obj)) + elif isinstance(obj, (str, int, float)): + results.append('
') + results.append(html_module.escape(str(obj))) + results.append('
') + elif isinstance(obj, dict): + results.extend(blocks_tags([(k, v) for k, v in obj.items()])) + elif hasattr(obj, '__iter__'): + blockstart, blockend, tstart, tend, rstart, rend, cstart, cend = [ + '
', + '
', + '', + '
', + '', + '', + '', + '', + ] + needs_end = False + table_mode = False + for i, line in enumerate(obj): + if i == 0: + needs_end = True + if isinstance(line, tuple): + table_mode = True + results.append(tstart) + else: + results.append(blockstart) + if table_mode: + results.append(rstart) + if not isinstance(line, str) and hasattr(line, '__iter__'): + for cell in line: + results.append(cstart) + results.extend(blocks_tags(cell)) + results.append(cend) + else: + results.append(cstart) + results.extend(blocks_tags(line)) + results.append(cend) + results.append(rend) + else: + results.extend(blocks_tags(line)) + if needs_end: + results.append(table_mode and tend or blockend) + return results + +def pil_to_b64(img, format='png'): + buffered = io.BytesIO() + img.save(buffered, format=format) + return base64.b64encode(buffered.getvalue()).decode('utf-8') + +def pil_to_url(img, format='png'): + return 'data:image/%s;base64,%s' % (format, pil_to_b64(img, format)) + +def pil_to_html(img, margin=1): + mattr = ' style="margin:%dpx"' % margin + return '' % (pil_to_url(img), mattr) + +def a(x, cols=None): + global g_buffer + if g_buffer is None: + g_buffer = [] + g_buffer.append(x) + if cols is not None and len(g_buffer) >= cols: + flush() + +def reset(): + global g_buffer + g_buffer = None + +def flush(*args, **kwargs): + global g_buffer + if g_buffer is not None: + x = g_buffer + g_buffer = None + display(blocks(x, *args, **kwargs)) + +def show(x=None, *args, **kwargs): + flush(*args, **kwargs) + if x is not None: + display(blocks(x, *args, **kwargs)) + +def html(obj, space=''): + return blocks(obj, space)._repr_html_() + +class CallableModule(types.ModuleType): + def __init__(self): + # or super().__init__(__name__) for Python 3 + types.ModuleType.__init__(self, __name__) + self.__dict__.update(sys.modules[__name__].__dict__) + def __call__(self, x=None, *args, **kwargs): + show(x, *args, **kwargs) + +sys.modules[__name__] = CallableModule()