Initial files.

This commit is contained in:
David Bau
2022-02-15 08:55:19 -05:00
parent 4f1aa481dc
commit 2547c71964
8 changed files with 3627 additions and 0 deletions
+1
View File
@@ -0,0 +1 @@
Some utilities useful for prototyping with pytorch.
+678
View File
@@ -0,0 +1,678 @@
"""
labwidget by David Bau.
Base class for a lightweight javascript notebook widget framework
that is portable across Google colab and Jupyter notebooks.
No use of requirejs: the design uses all inline javascript.
Defines Model, Widget, Trigger, and Property, which set up data binding
using the communication channels available in either google colab
environment or jupyter notebook.
This module also defines Label, Textbox, Range, Choice, and Div
widgets; the code for these are good examples of usage of Widget,
Trigger, and Property objects.
User interaction should update the javascript model using
model.set('propname', value); this will propagate to the python
model and notify any registered python listeners.
TODO: Support jupyterlab also.
"""
import json, html, re
from inspect import signature
class Model(object):
'''
Abstract base class that supports data binding. Within __init__,
a model subclass defines databound events and properties using:
self.evtname = Trigger()
self.propname = Property(initval)
Any Trigger or Property member can be watched by registering a
listener with `model.on('propname', callback)`.
An event can be triggered by `model.evtname.trigger(value)`.
A property can be read with `model.propname`, and can be set by
`model.propname = value`; this also triggers notifications.
In both these cases, any registered listeners will be called
with the given value.
'''
def on(self, name, cb):
'''
Registers a listener for named events and properties.
A space-separated list of names can be provided as `name`.
'''
for n in name.split():
self.prop(n).on(cb)
return self
def off(self, name, cb):
'''
Unregisters a listener for named events and properties.
A space-separated list of names can be provided as `name`.
'''
for n in name.split():
self.prop(n).off(cb)
return self
def prop(self, name):
'''
Returns the underlying Trigger or Property object for a
property, rather than its held value.
'''
curvalue = super().__getattribute__(name)
if not isinstance(curvalue, Trigger):
raise AttributeError('%s not a property or trigger but %s'
% (name, str(type(curvalue))))
return curvalue
def _initprop_(self, name, value):
'''
To be overridden in base classes. Handles initialization of
a new Trigger or Property member.
'''
return
def __setattr__(self, name, value):
'''
When a member is an Trigger or Property, then assignment notation
is delegated to the Trigger or Property so that notifications
and reparenting can be handled. That is, `model.name = value`
turns into `prop(name).set(value)`.
'''
if hasattr(self, name):
curvalue = super().__getattribute__(name)
if isinstance(curvalue, Trigger):
# Delegte "set" to the underlying Property.
curvalue.set(value)
else:
super().__setattr__(name, value)
else:
super().__setattr__(name, value)
if isinstance(value, Trigger):
self._initprop_(name, value)
def __getattribute__(self, name):
'''
When a member is a Property, then property getter
notation is delegated to the peoperty object.
'''
curvalue = super().__getattribute__(name)
if isinstance(curvalue, Property):
return curvalue.value
return curvalue
class Widget(Model):
'''
Base class for an HTML widget that uses a Javascript model object
to syncrhonize HTML view state with the backend Python model state.
Each widget subclass overrides widget_js to provide Javascript code
that defines the widget's behavior. This javascript will be wrapped
in an immediately-invoked function and included in the widget's HTML
representation (_repr_html_) when the widget is viewed.
A widget's javascript is provided with two local variables:
element - the widget's root HTML element. By default this is
a <div> but can be overridden in widget_html.
model - the object representing the data model for the widget.
within javascript.
The model object provides the following javascript API:
model.get('propname') obtains a current property value.
model.set('propname', 'value') requests a change in value.
model.on('propname', callback) listens for property changes.
model.trigger('evtname', value) triggers an event.
Note that model.set just requests a change but does not change the
value immediately: model.get will not reflect the change until the
python backend has handled it and notified the javascript of the new
value, which will trigger any callbacks previously registered using
.on('propname', callback). Thus Widget impelements a V-shaped
notification protocol:
User entry -> | -> User-visible feedback
js model.set -> | -> js.model.on callback
python prop.trigger -> | -> python prop.notify
python prop.handle
'''
def __init__(self):
# In the jupyter case, there can be some delay between js injection
# and comm creation, so we need to queue some initial messages.
if WIDGET_ENV == 'jupyter':
self._comms = []
self._queue = []
# Each call to _repr_html_ creates a unique view instance.
self._viewcount = 0
# Python notification is handled by Property objects.
def handle_remote_set(name, value):
self.prop(name).trigger(value)
self._recv_from_js_(handle_remote_set)
def widget_js(self):
'''
Override to define the javascript logic for the widget. Should
render the initial view based on the current model state (if not
already rendered using widget_html) and set up listeners to keep
the model and the view synchornized.
'''
return ''
def widget_html(self):
'''
Override to define the initial HTML view of the widget. Should
define an element with id given by view_id().
'''
return f'<div id="{self.view_id()}"></div>'
def view_id(self):
'''
Returns an HTML element id for the view currently being rendered.
Note that each time _repr_html_ is called, this id will change.
'''
return f"_{id(self)}_{self._viewcount}"
def _repr_html_(self):
'''
Returns the HTML code for the widget.
'''
self._viewcount += 1
json_data = json.dumps({
k: v.value for k, v in vars(self).items()
if isinstance(v, Property)})
json_data = re.sub('</', '<\\/', json_data)
return f"""
{self.widget_html()}
<script>
(function() {{
{WIDGET_MODEL_JS}
var model = new Model("{id(self)}", {json_data});
var element = document.getElementById("{self.view_id()}");
{self.widget_js()}
}})();
</script>
"""
def _initprop_(self, name, value):
if not hasattr(self, '_viewcount'):
raise ValueError('base Model __init__ must be called')
def notify_js(value):
self._send_to_js_(id(self), name, value)
if isinstance(value, Trigger):
value.on(notify_js)
def _send_to_js_(self, *args):
if self._viewcount > 0:
if WIDGET_ENV == 'colab':
colab_output.eval_js(f"""
(window.send_{id(self)} = window.send_{id(self)} ||
new BroadcastChannel("channel_{id(self)}")
).postMessage({json.dumps(args)});
""", ignore_result=True)
elif WIDGET_ENV == 'jupyter':
if not self._comms:
self._queue.append(args)
return
for comm in self._comms:
comm.send(args)
def _recv_from_js_(self, fn):
if WIDGET_ENV == 'colab':
colab_output.register_callback(f"invoke_{id(self)}", fn)
elif WIDGET_ENV == 'jupyter':
def handle_comm(msg):
fn(*(msg['content']['data']))
# TODO: handle closing also.
def handle_close(close_msg):
comm_id = close_msg['content']['comm_id']
self._comms = [c for c in self._comms if c.comm_id != comm_id]
def open_comm(comm, open_msg):
self._comms.append(comm)
comm.on_msg(handle_comm)
comm.on_close(handle_close)
comm.send('ok')
if self._queue:
for args in self._queue:
comm.send(args)
self._queue.clear()
if open_msg['content']['data']:
handle_comm(open_msg)
cname = "comm_" + str(id(self))
COMM_MANAGER.register_target(cname, open_comm)
class Trigger(object):
"""
Trigger is the base class for Property and other data-bound
field objects. Trigger holds a list of listeners that need to
be notified about the event.
Multple Trigger objects can be tied (typically a parent Model can
have Triggers that are triggered by children models). To support
this, each Trigger can have a parent.
Trigger objects provide a notification protocol where view
interactions trigger events at a leaf that are sent up to the
root Trigger to be handled. By default, the root handler accepts
events by notifying all listeners and children in the tree.
"""
def __init__(self):
self._listeners = []
self.parent = None
def handle(self, value):
'''
Method to override; called at the root when an event has been
triggered, and on a child when the parent has notified. By
default notifies all listeners.
'''
self.notify(value)
def trigger(self, value=None):
'''
Triggers an event to be handled by the root. By default, the root
handler will accept the event so all the listeners will be notified.
'''
if self.parent is not None:
self.parent.trigger(value)
else:
self.handle(value)
def set(self, value):
'''
Sets the parent Trigger. Child Triggers trigger events by
triggering parents, and in turn they handle notifications
that come from parents.
'''
if self.parent is not None:
self.parent.off(self.handle)
self.parent = None
if isinstance(value, Trigger):
ancestor = value.parent
while ancestor is not None:
if ancestor == self:
raise ValueError('bound properties should not form a loop')
ancestor = ancestor.parent
self.parent = value
self.parent.on(self.handle)
elif not isinstance(self, Property):
raise ValueError('only properties can be set to a value')
def notify(self, value=None):
'''
Notifies listeners and children. If a listener accepts an argument,
the value will be passed as a single argument.
'''
for cb in self._listeners:
if len(signature(cb).parameters) == 0:
cb() # no-parameter callback.
else:
cb(value)
# TODO: consider adding a two-parameter callback form
# where a detailed event object can be added.
def on(self, cb):
'''
Registers a listener. Calling multiple times registers
multiple listeners.
'''
self._listeners.append(cb)
def off(self, cb):
'''
Unregisters a listener.
'''
self._listeners = [c for c in self._listeners if c != cb]
class Property(Trigger):
"""
A Property is just an Trigger that remembers its last value.
"""
def __init__(self, value=None):
'''
Can be initialized with a starting value.
'''
super().__init__()
self.set(value)
def handle(self, value):
'''
The default handling for a Property is to store the value,
then notify listeners. This method can be overridden,
for example to validate values.
'''
self.value = value
self.notify(value)
def set(self, value):
'''
When a Property value is set to an ordinary value, it
triggers an event which causes a notification to be
sent to update all linked Properties. A Property set
to another Property becomes a child of the value.
'''
# Handle setting a parent Property
if isinstance(value, Property):
super().set(value)
self.handle(value.value)
elif isinstance(value, Trigger):
raise ValueError('Cannot set a Property to an Trigger')
else:
self.trigger(value)
##########################################################################
## Specific widgets
##########################################################################
class Button(Widget):
def __init__(self, label='button'):
super().__init__()
self.click = Trigger()
self.label = Property(label)
def widget_js(self):
return '''
element.addEventListener('click', (e) => {
model.trigger('click');
})
model.on('label', (v) => {
element.value = v;
})
'''
def widget_html(self):
return f'''
<input id="{self.view_id()}" type="button" style="display:block"
value="{html.escape(str(self.label))}">
'''
class Label(Widget):
def __init__(self, value=''):
super().__init__()
# databinding is defined using Property objects.
self.value = Property(value)
def widget_js(self):
# Both "model" and "element" objects are defined within the scope
# where the js is run. "element" looks for the element with id
# self.view_id(); if widget_html is overridden, this id should be used.
return '''
model.on('value', (value) => {
element.innerText = model.get('value');
});
'''
def widget_html(self):
return f'''
<label id="{self.view_id()}">{html.escape(str(self.value))}</label>
'''
class Textbox(Widget):
def __init__(self, value='', size=20):
super().__init__()
# databinding is defined using Property objects.
self.value = Property(value)
self.size = Property(size)
def widget_js(self):
# Both "model" and "element" objects are defined within the scope
# where the js is run. "element" looks for the element with id
# self.view_id(); if widget_html is overridden, this id should be used.
return '''
element.value = model.get('value');
element.size = model.get('size');
element.addEventListener('keydown', (e) => {
if (e.code == 'Enter') {
model.set('value', element.value);
}
});
model.on('value', (value) => {
element.value = model.get('value');
});
model.on('size', (value) => {
element.size = model.get('size');
});
'''
def widget_html(self):
return f'''
<input id="{self.view_id()}" style="display:block"
value="{html.escape(str(self.value))}" size="{self.size}">
'''
class Range(Widget):
def __init__(self, value=50, min=0, max=100):
super().__init__()
# databinding is defined using Property objects.
self.value = Property(value)
self.min = Property(min)
self.max = Property(max)
def widget_js(self):
# Note that the 'input' event would enable during-drag feedback,
# but this is pretty slow on google colab.
return '''
element.addEventListener('change', (e) => {
model.set('value', element.value);
});
model.on('value', (value) => {
if (!element.matches(':active')) {
element.value = value;
}
})
'''
def widget_html(self):
return f'''
<input id="{self.view_id()}" type="range"
value="{self.value}" min="{self.min}" max="{self.max}">
'''
class Choice(Widget):
def __init__(self, choices=None, selection=None, horizontal=False):
super().__init__()
if choices is None:
choices = []
self.choices = Property(choices)
self.horizontal = Property(horizontal)
self.selection = Property(selection)
def widget_js(self):
# Note that the 'input' event would enable during-drag feedback,
# but this is pretty slow on google colab.
return '''
function esc(unsafe) {
return unsafe.replace(/&/g, "&amp;").replace(/</g, "&lt;")
.replace(/>/g, "&gt;").replace(/"/g, "&quot;");
}
function render() {
console.log('rendering');
var lines = model.get('choices').map((c) => {
return '<label><input type="radio" name="choice" value="' +
esc(c) + '">' + esc(c) + '</label>'
});
element.innerHTML = lines.join(model.get('horizontal')?' ':'<br>');
}
model.on('choices horizontal', render);
model.on('selection', (selection) => {
[...element.querySelectorAll('input')].forEach((e) => {
e.checked = (e.value == selection);
})
});
element.addEventListener('change', (e) => {
model.set('selection', element.choice.value);
});
'''
def widget_html(self):
radios = [
f"""<label><input name="choice" type="radio" {
'checked' if value == self.selection else ''
} value="{html.escape(value)}">{html.escape(value)}</label>"""
for value in self.choices ]
sep = " " if self.horizontal else "<br>"
return f'<form id="{self.view_id()}">{sep.join(radios)}</form>'
class Div(Widget):
def __init__(self, innerHTML='', style=None, data=None):
super().__init__()
style = {} if style is None else style
data = {} if data is None else data
# TODO: convert non-string innerHTML objects to html; unify
# with the show() library.
self.innerHTML = Property(innerHTML)
self.style = Property(style)
self.click = Trigger()
def print(self, text, replace=False):
newHTML = '<pre>%s</pre>' % html.escape(str(text));
if replace:
self.innerHTML = newHTML
else:
self.innerHTML += newHTML
def widget_js(self):
# Note that if we want innerHTML to support script execution,
# we need to do it explicitly, like this.
return '''
function updater(attr) {
return (val) => { for (k in val) { element[attr][k] = val[k]; } }
}
model.on('style', updater('style'));
updater('style')();
model.on('data', updater('dataset'));
updater('dataset')();
model.on('innerHTML', (innerHTML) => {
element.innerHTML = innerHTML;
Array.from(element.querySelectorAll("script")).forEach(old=>{
const newScript = document.createElement("script");
Array.from(old.attributes).forEach(attr =>
newScript.setAttribute(attr.name, attr.value));
newScript.appendChild(document.createTextNode(old.innerHTML));
old.parentNode.replaceChild(newScript, old);
});
});
'''
def widget_html(self):
return f'''
<div id="{self.view_id()}">{self.innerHTML}</div>
'''
class ClickDiv(Div):
'''
A Div that triggers click events when anything inside them is clicked.
If a clicked element contains a data-click value, then that value is
sent as the click event value.
'''
def __init__(self, innerHTML='', style=None, data=None):
super().__init__(innertHTML, style, data)
self.click = Trigger()
def widget_js(self):
return super().widget_js() + '''
element.addEventListener('click', (ev) => {
var target = ev.target;
while (target && target != element && !target.dataset.click) {
target = target.parentElement;
}
var value = target.dataset.click;
model.trigger('click', value);
});
'''
##########################################################################
## Implementation Details
##########################################################################
WIDGET_ENV = None
if WIDGET_ENV is None:
try:
from google.colab import output as colab_output
WIDGET_ENV = 'colab'
except:
pass
if WIDGET_ENV is None:
try:
from ipykernel.comm import Comm as jupyter_comm
COMM_MANAGER = get_ipython().kernel.comm_manager
WIDGET_ENV = 'jupyter'
except:
pass
SEND_RECV_JS = """
function recvFromPython(obj_id, fn) {
var recvname = "recv_" + obj_id;
if (window[recvname] === undefined) {
window[recvname] = new BroadcastChannel("channel_" + obj_id);
}
window[recvname].addEventListener("message", (ev) => {
if (ev.data == 'ok') {
window[recvname].ok = true;
return;
}
fn.apply(null, ev.data.slice(1));
});
}
function sendToPython(obj_id, ...args) {
google.colab.kernel.invokeFunction('invoke_' + obj_id, args, {})
}
""" if WIDGET_ENV == 'colab' else """
function getChan(obj_id) {
var cname = "comm_" + obj_id;
if (!window[cname]) { window[cname] = []; }
var chan = window[cname];
if (!chan.comm && Jupyter.notebook.kernel) {
chan.comm = Jupyter.notebook.kernel.comm_manager.new_comm(cname, {});
chan.comm.on_msg((ev) => {
if (chan.retry) { clearInterval(chan.retry); chan.retry = null; }
if (ev.content.data == 'ok') { return; }
var args = ev.content.data.slice(1);
for (fn of chan) { fn.apply(null, args); }
});
chan.retry = setInterval(() => { chan.comm.open(); }, 2000);
}
return chan;
}
function recvFromPython(obj_id, fn) {
getChan(obj_id).push(fn);
}
function sendToPython(obj_id, ...args) {
var comm = getChan(obj_id).comm;
if (comm) { comm.send(args); }
}
"""
WIDGET_MODEL_JS = SEND_RECV_JS + """
class Model {
constructor(obj_id, init) {
this._id = obj_id;
this._listeners = {};
this._data = Object.assign({}, init)
recvFromPython(this._id, (name, value) => {
this._data[name] = value;
if (this._listeners.hasOwnProperty(name)) {
this._listeners[name].forEach((fn) => { fn(value); });
}
})
}
trigger(name, value) {
sendToPython(this._id, name, value);
}
get(name) {
return this._data[name];
}
set(name, value) {
this.trigger(name, value);
}
on(name, fn) {
name.split(/\s+/).forEach((n) => {
if (!this._listeners.hasOwnProperty(n)) {
this._listeners[n] = [];
}
this._listeners[n].push(fn);
});
}
off(name, fn) {
name.split(/\s+/).forEach((n) => {
if (!fn) {
delete this._listeners[n];
} else if (this._listeners.hasOwnProperty(n)) {
this._listeners[n] = this._listeners[n].filter(
(e) => { return e !== fn; });
}
});
}
}
"""
+451
View File
@@ -0,0 +1,451 @@
"""
Utilities for instrumenting a torch model.
Trace will hook one layer at a time.
TraceDict will hook multiple layers at once.
subsequence slices intervals from Sequential modules.
get_module, replace_module, get_parameter resolve dotted names.
set_requires_grad recursively sets requires_grad in module parameters.
"""
import contextlib
import copy
import inspect
from collections import OrderedDict
import torch
class Trace(contextlib.AbstractContextManager):
"""
To retain the output of the named layer during the computation of
the given network:
with Trace(net, 'layer.name') as ret:
_ = net(inp)
representation = ret.output
A layer module can be passed directly without a layer name, and
its output will be retained. By default, a direct reference to
the output object is returned, but options can control this:
clone=True - retains a copy of the output, which can be
useful if you want to see the output before it might
be modified by the network in-place later.
detach=True - retains a detached reference or copy. (By
default the value would be left attached to the graph.)
retain_grad=True - request gradient to be retained on the
output. After backward(), ret.output.grad is populated.
retain_input=True - also retains the input.
retain_output=False - can disable retaining the output.
edit_output=fn - calls the function to modify the output
of the layer before passing it the rest of the model.
fn can optionally accept (output, layer) arguments
for the original output and the layer name.
stop=True - throws a StopForward exception after the layer
is run, which allows running just a portion of a model.
"""
def __init__(
self,
module,
layer=None,
retain_output=True,
retain_input=False,
clone=False,
detach=False,
retain_grad=False,
edit_output=None,
stop=False,
):
"""
Method to replace a forward method with a closure that
intercepts the call, and tracks the hook so that it can be reverted.
"""
retainer = self
self.layer = layer
if layer is not None:
module = get_module(module, layer)
def retain_hook(m, inputs, output):
if retain_input:
retainer.input = recursive_copy(
inputs[0] if len(inputs) == 1 else inputs,
clone=clone,
detach=detach,
retain_grad=False,
) # retain_grad applies to output only.
if edit_output:
output = invoke_with_optional_args(
edit_output, output=output, layer=self.layer
)
if retain_output:
retainer.output = recursive_copy(
output, clone=clone, detach=detach, retain_grad=retain_grad
)
# When retain_grad is set, also insert a trivial
# copy operation. That allows in-place operations
# to follow without error.
if retain_grad:
output = recursive_copy(retainer.output, clone=True, detach=False)
if stop:
raise StopForward()
return output
self.registered_hook = module.register_forward_hook(retain_hook)
self.stop = stop
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
self.close()
if self.stop and issubclass(type, StopForward):
return True
def close(self):
self.registered_hook.remove()
class TraceDict(OrderedDict, contextlib.AbstractContextManager):
"""
To retain the output of multiple named layers during the computation
of the given network:
with TraceDict(net, ['layer1.name1', 'layer2.name2']) as ret:
_ = net(inp)
representation = ret['layer1.name1'].output
If edit_output is provided, it should be a function that takes
two arguments: output, and the layer name; and then it returns the
modified output.
Other arguments are the same as Trace. If stop is True, then the
execution of the network will be stopped after the last layer
listed (even if it would not have been the last to be executed).
"""
def __init__(
self,
module,
layers=None,
retain_output=True,
retain_input=False,
clone=False,
detach=False,
retain_grad=False,
edit_output=None,
stop=False,
):
self.stop = stop
def flag_last_unseen(it):
try:
it = iter(it)
prev = next(it)
seen = set([prev])
except StopIteration:
return
for item in it:
if item not in seen:
yield False, prev
seen.add(item)
prev = item
yield True, prev
for is_last, layer in flag_last_unseen(layers):
self[layer] = Trace(
module=module,
layer=layer,
retain_output=retain_output,
retain_input=retain_input,
clone=clone,
detach=detach,
retain_grad=retain_grad,
edit_output=edit_output,
stop=stop and is_last,
)
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
self.close()
if self.stop and issubclass(type, StopForward):
return True
def close(self):
for layer, trace in reversed(self.items()):
trace.close()
class StopForward(Exception):
"""
If the only output needed from running a network is the retained
submodule then Trace(submodule, stop=True) will stop execution
immediately after the retained submodule by raising the StopForward()
exception. When Trace is used as context manager, it catches that
exception and can be used as follows:
with Trace(net, layername, stop=True) as tr:
net(inp) # Only runs the network up to layername
print(tr.output)
"""
pass
def recursive_copy(x, clone=None, detach=None, retain_grad=None):
"""
Copies a reference to a tensor, or an object that contains tensors,
optionally detaching and cloning the tensor(s). If retain_grad is
true, the original tensors are marked to have grads retained.
"""
if not clone and not detach and not retain_grad:
return x
if isinstance(x, torch.Tensor):
if retain_grad:
if not x.requires_grad:
x.requires_grad = True
x.retain_grad()
elif detach:
x = x.detach()
if clone:
x = x.clone()
return x
# Only dicts, lists, and tuples (and subclasses) can be copied.
if isinstance(x, dict):
return type(x)({k: recursive_copy(v) for k, v in x.items()})
elif isinstance(x, (list, tuple)):
return type(x)([recursive_copy(v) for v in x])
else:
assert False, f"Unknown type {type(x)} cannot be broken into tensors."
def subsequence(
sequential,
first_layer=None,
last_layer=None,
after_layer=None,
upto_layer=None,
single_layer=None,
share_weights=False,
):
"""
Creates a subsequence of a pytorch Sequential model, copying over
modules together with parameters for the subsequence. Only
modules from first_layer to last_layer (inclusive) are included,
or modules between after_layer and upto_layer (exclusive).
Handles descent into dotted layer names as long as all references
are within nested Sequential models.
If share_weights is True, then references the original modules
and their parameters without copying them. Otherwise, by default,
makes a separate brand-new copy.
"""
assert (single_layer is None) or (
first_layer is last_layer is after_layer is upto_layer is None
)
if single_layer is not None:
first_layer = single_layer
last_layer = single_layer
first, last, after, upto = [
None if d is None else d.split(".")
for d in [first_layer, last_layer, after_layer, upto_layer]
]
return hierarchical_subsequence(
sequential,
first=first,
last=last,
after=after,
upto=upto,
share_weights=share_weights,
)
def hierarchical_subsequence(
sequential, first, last, after, upto, share_weights=False, depth=0
):
"""
Recursive helper for subsequence() to support descent into dotted
layer names. In this helper, first, last, after, and upto are
arrays of names resulting from splitting on dots. Can only
descend into nested Sequentials.
"""
assert (last is None) or (upto is None)
assert (first is None) or (after is None)
if first is last is after is upto is None:
return sequential if share_weights else copy.deepcopy(sequential)
assert isinstance(sequential, torch.nn.Sequential), (
".".join((first or last or after or upto)[:depth] or "arg") + " not Sequential"
)
including_children = (first is None) and (after is None)
included_children = OrderedDict()
# A = current level short name of A.
# AN = full name for recursive descent if not innermost.
(F, FN), (L, LN), (A, AN), (U, UN) = [
(d[depth], (None if len(d) == depth + 1 else d))
if d is not None
else (None, None)
for d in [first, last, after, upto]
]
for name, layer in sequential._modules.items():
if name == F:
first = None
including_children = True
if name == A and AN is not None: # just like F if not a leaf.
after = None
including_children = True
if name == U and UN is None:
upto = None
including_children = False
if including_children:
# AR = full name for recursive descent if name matches.
FR, LR, AR, UR = [
n if n is None or n[depth] == name else None for n in [FN, LN, AN, UN]
]
chosen = hierarchical_subsequence(
layer,
first=FR,
last=LR,
after=AR,
upto=UR,
share_weights=share_weights,
depth=depth + 1,
)
if chosen is not None:
included_children[name] = chosen
if name == L:
last = None
including_children = False
if name == U and UN is not None: # just like L if not a leaf.
upto = None
including_children = False
if name == A and AN is None:
after = None
including_children = True
for name in [first, last, after, upto]:
if name is not None:
raise ValueError("Layer %s not found" % ".".join(name))
# Omit empty subsequences except at the outermost level,
# where we should not return None.
if not len(included_children) and depth > 0:
return None
result = torch.nn.Sequential(included_children)
result.training = sequential.training
return result
def set_requires_grad(requires_grad, *models):
"""
Sets requires_grad true or false for all parameters within the
models passed.
"""
for model in models:
if isinstance(model, torch.nn.Module):
for param in model.parameters():
param.requires_grad = requires_grad
elif isinstance(model, (torch.nn.Parameter, torch.Tensor)):
model.requires_grad = requires_grad
else:
assert False, "unknown type %r" % type(model)
def get_module(model, name):
"""
Finds the named module within the given model.
"""
for n, m in model.named_modules():
if n == name:
return m
raise LookupError(name)
def get_parameter(model, name):
"""
Finds the named parameter within the given model.
"""
for n, p in model.named_parameters():
if n == name:
return p
raise LookupError(name)
def replace_module(model, name, new_module):
"""
Replaces the named module within the given model.
"""
if "." in name:
parent_name, attr_name = name.rsplit(".", 1)
model = get_module(model, parent_name)
# original_module = getattr(model, attr_name)
setattr(model, attr_name, new_module)
def invoke_with_optional_args(fn, *args, **kwargs):
"""
Invokes a function with only the arguments that it
is written to accept, giving priority to arguments
that match by-name, using the following rules.
(1) arguments with matching names are passed by name.
(2) remaining non-name-matched args are passed by order.
(3) extra caller arguments that the function cannot
accept are not passed.
(4) extra required function arguments that the caller
cannot provide cause a TypeError to be raised.
Ordinary python calling conventions are helpful for
supporting a function that might be revised to accept
extra arguments in a newer version, without requiring the
caller to pass those new arguments. This function helps
support function callers that might be revised to supply
extra arguments, without requiring the callee to accept
those new arguments.
"""
argspec = inspect.getfullargspec(fn)
pass_args = []
used_kw = set()
unmatched_pos = []
used_pos = 0
defaulted_pos = len(argspec.args) - (
0 if not argspec.defaults else len(argspec.defaults)
)
# Pass positional args that match name first, then by position.
for i, n in enumerate(argspec.args):
if n in kwargs:
pass_args.append(kwargs[n])
used_kw.add(n)
elif used_pos < len(args):
pass_args.append(args[used_pos])
used_pos += 1
else:
unmatched_pos.append(len(pass_args))
pass_args.append(
None if i < defaulted_pos else argspec.defaults[i - defaulted_pos]
)
# Fill unmatched positional args with unmatched keyword args in order.
if len(unmatched_pos):
for k, v in kwargs.items():
if k in used_kw or k in argspec.kwonlyargs:
continue
pass_args[unmatched_pos[0]] = v
used_kw.add(k)
unmatched_pos = unmatched_pos[1:]
if len(unmatched_pos) == 0:
break
else:
if unmatched_pos[0] < defaulted_pos:
unpassed = ", ".join(
argspec.args[u] for u in unmatched_pos if u < defaulted_pos
)
raise TypeError(f"{fn.__name__}() cannot be passed {unpassed}.")
# Pass remaining kw args if they can be accepted.
pass_kw = {
k: v
for k, v in kwargs.items()
if k not in used_kw and (k in argspec.kwonlyargs or argspec.varargs is not None)
}
# Pass remaining positional args if they can be accepted.
if argspec.varargs is not None:
pass_args += list(args[used_pos:])
return fn(*pass_args, **pass_kw)
+147
View File
@@ -0,0 +1,147 @@
from .labwidget import Widget, Property
import html
class PaintWidget(Widget):
def __init__(self,
width=256, height=256,
image='', mask='', brushsize=10.0, oneshot=False, disabled=False):
super().__init__()
self.mask = Property(mask)
self.image = Property(image)
self.brushsize = Property(brushsize)
self.erase = Property(False)
self.oneshot = Property(oneshot)
self.disabled = Property(disabled)
self.width = Property(width)
self.height = Property(height)
def widget_js(self):
return f'''
{PAINT_WIDGET_JS}
var pw = new PaintWidget(element, model);
'''
def widget_html(self):
v = self.view_id()
return f'''
<style>
#{v} {{ position: relative; display: inline-block; }}
#{v} .paintmask {{
position: absolute; top:0; left: 0; z-index: 1;
opacity: 0; transition: opacity .1s ease-in-out; }}
#{v} .paintmask:hover {{ opacity: 0.7; }}
</style>
<div id="{v}"></div>
'''
PAINT_WIDGET_JS = """
class PaintWidget {
constructor(el, model) {
this.el = el;
this.model = model;
this.size_changed();
this.model.on('mask', this.mask_changed.bind(this));
this.model.on('image', this.image_changed.bind(this));
this.model.on('width', this.size_changed.bind(this));
this.model.on('height', this.size_changed.bind(this));
}
mouse_stroke(first_event) {
var self = this;
if (self.model.get('disabled')) { return; }
if (self.model.get('oneshot')) {
var canvas = self.mask_canvas;
var ctx = canvas.getContext('2d');
ctx.clearRect(0, 0, canvas.width, canvas.height);
}
function track_mouse(evt) {
if (evt.type == 'keydown' || self.model.get('disabled')) {
if (self.model.get('disabled') || evt.key === "Escape") {
window.removeEventListener('mousemove', track_mouse);
window.removeEventListener('mouseup', track_mouse);
window.removeEventListener('keydown', track_mouse, true);
self.mask_changed();
}
return;
}
if (evt.type == 'mouseup' ||
(typeof evt.buttons != 'undefined' && evt.buttons == 0)) {
window.removeEventListener('mousemove', track_mouse);
window.removeEventListener('mouseup', track_mouse);
window.removeEventListener('keydown', track_mouse, true);
self.model.set('mask', self.mask_canvas.toDataURL());
return;
}
var p = self.cursor_position();
self.fill_circle(p.x, p.y,
self.model.get('brushsize'),
self.model.get('erase'));
}
this.mask_canvas.focus();
window.addEventListener('mousemove', track_mouse);
window.addEventListener('mouseup', track_mouse);
window.addEventListener('keydown', track_mouse, true);
track_mouse(first_event);
}
mask_changed(val) {
this.draw_data_url(this.mask_canvas, this.model.get('mask'));
}
image_changed() {
this.draw_data_url(this.image_canvas, this.model.get('image'));
}
size_changed() {
this.mask_canvas = document.createElement('canvas');
this.image_canvas = document.createElement('canvas');
this.mask_canvas.className = "paintmask";
this.image_canvas.className = "paintimage";
for (var attr of ['width', 'height']) {
this.mask_canvas[attr] = this.model.get(attr);
this.image_canvas[attr] = this.model.get(attr);
}
this.el.innerHTML = '';
this.el.appendChild(this.image_canvas);
this.el.appendChild(this.mask_canvas);
this.mask_canvas.addEventListener('mousedown',
this.mouse_stroke.bind(this));
this.mask_changed();
this.image_changed();
}
cursor_position(evt) {
const rect = this.mask_canvas.getBoundingClientRect();
const x = event.clientX - rect.left;
const y = event.clientY - rect.top;
return {x: x, y: y};
}
fill_circle(x, y, r, erase, blur) {
var ctx = this.mask_canvas.getContext('2d');
ctx.save();
if (blur) {
ctx.filter = 'blur(' + blur + 'px)';
}
ctx.globalCompositeOperation = (
erase ? "destination-out" : 'source-over');
ctx.fillStyle = '#fff';
ctx.beginPath();
ctx.arc(x, y, r, 0, 2 * Math.PI);
ctx.fill();
ctx.restore()
}
draw_data_url(canvas, durl) {
var ctx = canvas.getContext('2d');
var img = new Image;
canvas.pendingImg = img;
function imgdone() {
if (canvas.pendingImg == img) {
ctx.clearRect(0, 0, canvas.width, canvas.height);
ctx.drawImage(img, 0, 0);
canvas.pendingImg = null;
}
}
img.addEventListener('load', imgdone);
img.addEventListener('error', imgdone);
img.src = durl;
}
}
"""
+212
View File
@@ -0,0 +1,212 @@
'''
Utilities for showing progress bars, controlling default verbosity, etc.
'''
# If the tqdm package is not available, then do not show progress bars;
# just connect print_progress to print.
import sys
import types
import builtins
try:
from tqdm import tqdm
try:
from tqdm.notebook import tqdm as tqdm_nb
except:
from tqdm import tqdm_notebook as tqdm_nb
except:
tqdm = None
default_verbosity = True
next_description = None
python_print = builtins.print
def post(**kwargs):
'''
When within a progress loop, pbar.post(k=str) will display
the given k=str status on the right-hand-side of the progress
status bar. If not within a visible progress bar, does nothing.
'''
innermost = innermost_tqdm()
if innermost is not None:
innermost.set_postfix(**kwargs)
def desc(desc):
'''
When within a progress loop, pbar.desc(str) changes the
left-hand-side description of the loop toe the given description.
'''
innermost = innermost_tqdm()
if innermost is not None:
innermost.set_description(str(desc))
def descnext(desc):
'''
Called before starting a progress loop, pbar.descnext(str)
sets the description text that will be used in the following loop.
'''
global next_description
if not default_verbosity or tqdm is None:
return
next_description = desc
def print(*args):
'''
When within a progress loop, will print above the progress loop.
'''
global next_description
next_description = None
if default_verbosity:
msg = ' '.join(str(s) for s in args)
if tqdm is None:
python_print(msg)
else:
tqdm.write(msg)
def tqdm_terminal(it, *args, **kwargs):
'''
Some settings for tqdm that make it run better in resizable terminals.
'''
return tqdm(it, *args, dynamic_ncols=True, ascii=True,
leave=(innermost_tqdm() is not None), **kwargs)
def in_notebook():
'''
True if running inside a Jupyter notebook.
'''
# From https://stackoverflow.com/a/39662359/265298
try:
shell = get_ipython().__class__.__name__
if shell == 'ZMQInteractiveShell':
return True # Jupyter notebook or qtconsole
elif shell == 'TerminalInteractiveShell':
return False # Terminal running IPython
else:
return False # Other type (?)
except NameError:
return False # Probably standard Python interpreter
def innermost_tqdm():
'''
Returns the innermost active tqdm progress loop on the stack.
'''
if hasattr(tqdm, '_instances') and len(tqdm._instances) > 0:
return max(tqdm._instances, key=lambda x: x.pos)
else:
return None
def reporthook(*args, **kwargs):
'''
For use with urllib.request.urlretrieve.
with pbar.reporthook() as hook:
urllib.request.urlretrieve(url, filename, reporthook=hook)
'''
kwargs2 = dict(unit_scale=True, miniters=1)
kwargs2.update(kwargs)
bar = __call__(None, *args, **kwargs2)
class ReportHook(object):
def __init__(self, t):
self.t = t
def __call__(self, b=1, bsize=1, tsize=None):
if hasattr(self.t, 'total'):
if tsize is not None:
self.t.total = tsize
if hasattr(self.t, 'update'):
self.t.update(b * bsize - self.t.n)
def __enter__(self):
return self
def __exit__(self, *exc):
if hasattr(self.t, '__exit__'):
self.t.__exit__(*exc)
return ReportHook(bar)
def __call__(x, *args, **kwargs):
'''
Invokes a progress function that can wrap iterators to print
progress messages, if verbose is True.
If verbose is False or tqdm is unavailable, then a quiet
non-printing identity function is used.
verbose can also be set to a spefific progress function rather
than True, and that function will be used.
'''
global default_verbosity, next_description
if not default_verbosity or tqdm is None:
return x
if default_verbosity == True:
fn = tqdm_nb if in_notebook() else tqdm_terminal
else:
fn = default_verbosity
if next_description is not None:
kwargs = dict(kwargs)
kwargs['desc'] = next_description
next_description = None
return fn(x, *args, **kwargs)
class VerboseContextManager():
def __init__(self, v, entered=False):
self.v, self.entered, self.saved = v, False, []
if entered:
self.__enter__()
self.entered = True
def __enter__(self):
global default_verbosity
if self.entered:
self.entered = False
else:
self.saved.append(default_verbosity)
default_verbosity = self.v
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
global default_verbosity
default_verbosity = self.saved.pop()
def __call__(self, v=True):
'''
Calling the context manager makes a new context that is
pre-entered, so it works as both a plain function and as a
factory for a context manager.
'''
new_v = v if self.v else not v
cm = VerboseContextManager(new_v, entered=True)
default_verbosity = new_v
return cm
# Use as either "with pbar.verbose:" or "pbar.verbose(False)", or also
# "with pbar.verbose(False):"
verbose = VerboseContextManager(True)
# Use as either "with @pbar.quiet" or "pbar.quiet(True)". or also
# "with pbar.quiet(True):"
quiet = VerboseContextManager(False)
class CallableModule(types.ModuleType):
def __init__(self):
# or super().__init__(__name__) for Python 3
types.ModuleType.__init__(self, __name__)
self.__dict__.update(sys.modules[__name__].__dict__)
def __call__(self, x, *args, **kwargs):
return __call__(x, *args, **kwargs)
sys.modules[__name__] = CallableModule()
+129
View File
@@ -0,0 +1,129 @@
'''
Utility for simple distribution of work on multiple processes, by
making sure only one process is working on a job at once.
'''
import os
import errno
import socket
import atexit
import time
import sys
def reserve_dir(*args):
'''
Convenience function to get exclusive access to an unfinished
experiment directory. Exits the program if the directory is
already done or busy (using exit_of_job_done). Otherwise,
returns a function creates filenames within that directory.
'''
directory = os.path.join(*args)
exit_if_job_done(directory)
def dirfn(*fn):
return os.path.join(directory, *fn)
dirfn.dir = directory
def done():
mark_job_done(directory)
dirfn.done = done
print('Working in %s' % directory)
return dirfn
# Old function name.
exclusive_dirfn = reserve_dir
def exit_if_job_done(directory, redo=False, force=False, verbose=True):
if pidfile_taken(os.path.join(directory, 'lockfile.pid'),
force=force, verbose=verbose):
sys.exit(0)
donefile = os.path.join(directory, 'done.txt')
if os.path.isfile(donefile):
with open(donefile) as f:
msg = f.read()
if redo or force:
if verbose:
print('Removing %s %s' % (donefile, msg))
os.remove(donefile)
else:
if verbose:
print('%s %s' % (donefile, msg))
sys.exit(0)
def mark_job_done(directory):
with open(os.path.join(directory, 'done.txt'), 'w') as f:
f.write('done by %d@%s %s at %s' %
(os.getpid(), socket.gethostname(),
os.getenv('STY', ''),
time.strftime('%c')))
def pidfile_taken(path, verbose=False, force=False):
'''
Usage. To grab an exclusive lock for the remaining duration of the
current process (and exit if another process already has the lock),
do this:
if pidfile_taken('job_423/lockfile.pid', verbose=True):
sys.exit(0)
To do a batch of jobs, just run a script that does them all on
each available machine, sharing a network filesystem. When each
job grabs a lock, then this will automatically distribute the
jobs so that each one is done just once on one machine.
'''
# Try to create the file exclusively and write my pid into it.
try:
os.makedirs(os.path.dirname(path), exist_ok=True)
fd = os.open(path, os.O_CREAT | os.O_EXCL | os.O_RDWR)
except OSError as e:
if e.errno == errno.EEXIST:
# If we cannot because there was a race, yield the conflicter.
conflicter = 'race'
try:
with open(path, 'r') as lockfile:
conflicter = lockfile.read().strip() or 'empty'
except:
pass
# Force is for manual one-time use, for deleting stale lockfiles.
if force:
if verbose:
print('Removing %s from %s' % (path, conflicter))
os.remove(path)
return pidfile_taken(path, verbose=verbose, force=False)
if verbose:
print('%s held by %s' % (path, conflicter))
return conflicter
else:
# Other problems get an exception.
raise
# Register to delete this file on exit.
lockfile = os.fdopen(fd, 'r+')
atexit.register(delete_pidfile, lockfile, path)
# Write my pid into the open file.
lockfile.write('%d@%s %s\n' % (os.getpid(), socket.gethostname(),
os.getenv('STY', '')))
lockfile.flush()
os.fsync(lockfile)
# Return 'None' to say there was not a conflict.
return None
def delete_pidfile(lockfile, path):
'''
Runs at exit after pidfile_taken succeeds.
'''
if lockfile is not None:
try:
lockfile.close()
except:
pass
try:
os.unlink(path)
except:
pass
File diff suppressed because it is too large Load Diff
+138
View File
@@ -0,0 +1,138 @@
# show.py
#
# An abbreviated way to output simple HTML layout of text and images
# into a python notebook.
#
# - show a PIL image to show an inline HTML <img>.
# - show an array of items to vertically stack them, centered in a block.
# - show an array of arrays to horizontally lay them out as inline blocks.
# - show an array of tuples to create a table.
import PIL.Image, base64, io, IPython, types, sys
import html as html_module
from IPython.display import display
g_buffer = None
def blocks(obj, space=''):
return IPython.display.HTML(space.join(blocks_tags(obj)))
def rows(obj, space=''):
return IPython.display.HTML(space.join(rows_tags(obj)))
def rows_tags(obj):
if isinstance(obj, dict):
obj = obj.items()
results = []
results.append('<table style="display:inline-table">')
for row in obj:
results.append('<tr style="padding:0">')
for item in row:
results.append('<td style="text-align:left; vertical-align:top;' +
'padding:1px">')
results.extend(blocks_tags(item))
results.append('</td>')
results.append('</tr>')
results.append('</table>')
return results
def blocks_tags(obj):
results = []
if hasattr(obj, '_repr_html_'):
results.append(obj._repr_html_())
elif isinstance(obj, PIL.Image.Image):
results.append(pil_to_html(obj))
elif isinstance(obj, (str, int, float)):
results.append('<div>')
results.append(html_module.escape(str(obj)))
results.append('</div>')
elif isinstance(obj, dict):
results.extend(blocks_tags([(k, v) for k, v in obj.items()]))
elif hasattr(obj, '__iter__'):
blockstart, blockend, tstart, tend, rstart, rend, cstart, cend = [
'<div style="display:inline-block;text-align:center;line-height:1;' +
'vertical-align:top;padding:1px">',
'</div>',
'<table style="display:inline-table">',
'</table>',
'<tr style="padding:0">',
'</tr>',
'<td style="text-align:left; vertical-align:top; padding:1px">',
'</td>',
]
needs_end = False
table_mode = False
for i, line in enumerate(obj):
if i == 0:
needs_end = True
if isinstance(line, tuple):
table_mode = True
results.append(tstart)
else:
results.append(blockstart)
if table_mode:
results.append(rstart)
if not isinstance(line, str) and hasattr(line, '__iter__'):
for cell in line:
results.append(cstart)
results.extend(blocks_tags(cell))
results.append(cend)
else:
results.append(cstart)
results.extend(blocks_tags(line))
results.append(cend)
results.append(rend)
else:
results.extend(blocks_tags(line))
if needs_end:
results.append(table_mode and tend or blockend)
return results
def pil_to_b64(img, format='png'):
buffered = io.BytesIO()
img.save(buffered, format=format)
return base64.b64encode(buffered.getvalue()).decode('utf-8')
def pil_to_url(img, format='png'):
return 'data:image/%s;base64,%s' % (format, pil_to_b64(img, format))
def pil_to_html(img, margin=1):
mattr = ' style="margin:%dpx"' % margin
return '<img src="%s"%s>' % (pil_to_url(img), mattr)
def a(x, cols=None):
global g_buffer
if g_buffer is None:
g_buffer = []
g_buffer.append(x)
if cols is not None and len(g_buffer) >= cols:
flush()
def reset():
global g_buffer
g_buffer = None
def flush(*args, **kwargs):
global g_buffer
if g_buffer is not None:
x = g_buffer
g_buffer = None
display(blocks(x, *args, **kwargs))
def show(x=None, *args, **kwargs):
flush(*args, **kwargs)
if x is not None:
display(blocks(x, *args, **kwargs))
def html(obj, space=''):
return blocks(obj, space)._repr_html_()
class CallableModule(types.ModuleType):
def __init__(self):
# or super().__init__(__name__) for Python 3
types.ModuleType.__init__(self, __name__)
self.__dict__.update(sys.modules[__name__].__dict__)
def __call__(self, x=None, *args, **kwargs):
show(x, *args, **kwargs)
sys.modules[__name__] = CallableModule()