mirror of
https://github.com/wassname/baukit.git
synced 2026-06-27 17:14:53 +08:00
Initial files.
This commit is contained in:
@@ -0,0 +1,678 @@
|
||||
"""
|
||||
labwidget by David Bau.
|
||||
|
||||
Base class for a lightweight javascript notebook widget framework
|
||||
that is portable across Google colab and Jupyter notebooks.
|
||||
No use of requirejs: the design uses all inline javascript.
|
||||
|
||||
Defines Model, Widget, Trigger, and Property, which set up data binding
|
||||
using the communication channels available in either google colab
|
||||
environment or jupyter notebook.
|
||||
|
||||
This module also defines Label, Textbox, Range, Choice, and Div
|
||||
widgets; the code for these are good examples of usage of Widget,
|
||||
Trigger, and Property objects.
|
||||
|
||||
User interaction should update the javascript model using
|
||||
model.set('propname', value); this will propagate to the python
|
||||
model and notify any registered python listeners.
|
||||
|
||||
TODO: Support jupyterlab also.
|
||||
"""
|
||||
|
||||
import json, html, re
|
||||
from inspect import signature
|
||||
|
||||
class Model(object):
|
||||
'''
|
||||
Abstract base class that supports data binding. Within __init__,
|
||||
a model subclass defines databound events and properties using:
|
||||
|
||||
self.evtname = Trigger()
|
||||
self.propname = Property(initval)
|
||||
|
||||
Any Trigger or Property member can be watched by registering a
|
||||
listener with `model.on('propname', callback)`.
|
||||
|
||||
An event can be triggered by `model.evtname.trigger(value)`.
|
||||
A property can be read with `model.propname`, and can be set by
|
||||
`model.propname = value`; this also triggers notifications.
|
||||
In both these cases, any registered listeners will be called
|
||||
with the given value.
|
||||
'''
|
||||
def on(self, name, cb):
|
||||
'''
|
||||
Registers a listener for named events and properties.
|
||||
A space-separated list of names can be provided as `name`.
|
||||
'''
|
||||
for n in name.split():
|
||||
self.prop(n).on(cb)
|
||||
return self
|
||||
|
||||
def off(self, name, cb):
|
||||
'''
|
||||
Unregisters a listener for named events and properties.
|
||||
A space-separated list of names can be provided as `name`.
|
||||
'''
|
||||
for n in name.split():
|
||||
self.prop(n).off(cb)
|
||||
return self
|
||||
|
||||
def prop(self, name):
|
||||
'''
|
||||
Returns the underlying Trigger or Property object for a
|
||||
property, rather than its held value.
|
||||
'''
|
||||
curvalue = super().__getattribute__(name)
|
||||
if not isinstance(curvalue, Trigger):
|
||||
raise AttributeError('%s not a property or trigger but %s'
|
||||
% (name, str(type(curvalue))))
|
||||
return curvalue
|
||||
|
||||
def _initprop_(self, name, value):
|
||||
'''
|
||||
To be overridden in base classes. Handles initialization of
|
||||
a new Trigger or Property member.
|
||||
'''
|
||||
return
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
'''
|
||||
When a member is an Trigger or Property, then assignment notation
|
||||
is delegated to the Trigger or Property so that notifications
|
||||
and reparenting can be handled. That is, `model.name = value`
|
||||
turns into `prop(name).set(value)`.
|
||||
'''
|
||||
if hasattr(self, name):
|
||||
curvalue = super().__getattribute__(name)
|
||||
if isinstance(curvalue, Trigger):
|
||||
# Delegte "set" to the underlying Property.
|
||||
curvalue.set(value)
|
||||
else:
|
||||
super().__setattr__(name, value)
|
||||
else:
|
||||
super().__setattr__(name, value)
|
||||
if isinstance(value, Trigger):
|
||||
self._initprop_(name, value)
|
||||
|
||||
def __getattribute__(self, name):
|
||||
'''
|
||||
When a member is a Property, then property getter
|
||||
notation is delegated to the peoperty object.
|
||||
'''
|
||||
curvalue = super().__getattribute__(name)
|
||||
if isinstance(curvalue, Property):
|
||||
return curvalue.value
|
||||
return curvalue
|
||||
|
||||
class Widget(Model):
|
||||
'''
|
||||
Base class for an HTML widget that uses a Javascript model object
|
||||
to syncrhonize HTML view state with the backend Python model state.
|
||||
Each widget subclass overrides widget_js to provide Javascript code
|
||||
that defines the widget's behavior. This javascript will be wrapped
|
||||
in an immediately-invoked function and included in the widget's HTML
|
||||
representation (_repr_html_) when the widget is viewed.
|
||||
|
||||
A widget's javascript is provided with two local variables:
|
||||
|
||||
element - the widget's root HTML element. By default this is
|
||||
a <div> but can be overridden in widget_html.
|
||||
model - the object representing the data model for the widget.
|
||||
within javascript.
|
||||
|
||||
The model object provides the following javascript API:
|
||||
|
||||
model.get('propname') obtains a current property value.
|
||||
model.set('propname', 'value') requests a change in value.
|
||||
model.on('propname', callback) listens for property changes.
|
||||
model.trigger('evtname', value) triggers an event.
|
||||
|
||||
Note that model.set just requests a change but does not change the
|
||||
value immediately: model.get will not reflect the change until the
|
||||
python backend has handled it and notified the javascript of the new
|
||||
value, which will trigger any callbacks previously registered using
|
||||
.on('propname', callback). Thus Widget impelements a V-shaped
|
||||
notification protocol:
|
||||
|
||||
User entry -> | -> User-visible feedback
|
||||
js model.set -> | -> js.model.on callback
|
||||
python prop.trigger -> | -> python prop.notify
|
||||
python prop.handle
|
||||
'''
|
||||
|
||||
def __init__(self):
|
||||
# In the jupyter case, there can be some delay between js injection
|
||||
# and comm creation, so we need to queue some initial messages.
|
||||
if WIDGET_ENV == 'jupyter':
|
||||
self._comms = []
|
||||
self._queue = []
|
||||
# Each call to _repr_html_ creates a unique view instance.
|
||||
self._viewcount = 0
|
||||
# Python notification is handled by Property objects.
|
||||
def handle_remote_set(name, value):
|
||||
self.prop(name).trigger(value)
|
||||
self._recv_from_js_(handle_remote_set)
|
||||
|
||||
def widget_js(self):
|
||||
'''
|
||||
Override to define the javascript logic for the widget. Should
|
||||
render the initial view based on the current model state (if not
|
||||
already rendered using widget_html) and set up listeners to keep
|
||||
the model and the view synchornized.
|
||||
'''
|
||||
return ''
|
||||
|
||||
def widget_html(self):
|
||||
'''
|
||||
Override to define the initial HTML view of the widget. Should
|
||||
define an element with id given by view_id().
|
||||
'''
|
||||
return f'<div id="{self.view_id()}"></div>'
|
||||
|
||||
def view_id(self):
|
||||
'''
|
||||
Returns an HTML element id for the view currently being rendered.
|
||||
Note that each time _repr_html_ is called, this id will change.
|
||||
'''
|
||||
return f"_{id(self)}_{self._viewcount}"
|
||||
|
||||
def _repr_html_(self):
|
||||
'''
|
||||
Returns the HTML code for the widget.
|
||||
'''
|
||||
self._viewcount += 1
|
||||
json_data = json.dumps({
|
||||
k: v.value for k, v in vars(self).items()
|
||||
if isinstance(v, Property)})
|
||||
json_data = re.sub('</', '<\\/', json_data)
|
||||
return f"""
|
||||
{self.widget_html()}
|
||||
<script>
|
||||
(function() {{
|
||||
{WIDGET_MODEL_JS}
|
||||
var model = new Model("{id(self)}", {json_data});
|
||||
var element = document.getElementById("{self.view_id()}");
|
||||
{self.widget_js()}
|
||||
}})();
|
||||
</script>
|
||||
"""
|
||||
|
||||
def _initprop_(self, name, value):
|
||||
if not hasattr(self, '_viewcount'):
|
||||
raise ValueError('base Model __init__ must be called')
|
||||
def notify_js(value):
|
||||
self._send_to_js_(id(self), name, value)
|
||||
if isinstance(value, Trigger):
|
||||
value.on(notify_js)
|
||||
|
||||
def _send_to_js_(self, *args):
|
||||
if self._viewcount > 0:
|
||||
if WIDGET_ENV == 'colab':
|
||||
colab_output.eval_js(f"""
|
||||
(window.send_{id(self)} = window.send_{id(self)} ||
|
||||
new BroadcastChannel("channel_{id(self)}")
|
||||
).postMessage({json.dumps(args)});
|
||||
""", ignore_result=True)
|
||||
elif WIDGET_ENV == 'jupyter':
|
||||
if not self._comms:
|
||||
self._queue.append(args)
|
||||
return
|
||||
for comm in self._comms:
|
||||
comm.send(args)
|
||||
|
||||
def _recv_from_js_(self, fn):
|
||||
if WIDGET_ENV == 'colab':
|
||||
colab_output.register_callback(f"invoke_{id(self)}", fn)
|
||||
elif WIDGET_ENV == 'jupyter':
|
||||
def handle_comm(msg):
|
||||
fn(*(msg['content']['data']))
|
||||
# TODO: handle closing also.
|
||||
def handle_close(close_msg):
|
||||
comm_id = close_msg['content']['comm_id']
|
||||
self._comms = [c for c in self._comms if c.comm_id != comm_id]
|
||||
def open_comm(comm, open_msg):
|
||||
self._comms.append(comm)
|
||||
comm.on_msg(handle_comm)
|
||||
comm.on_close(handle_close)
|
||||
comm.send('ok')
|
||||
if self._queue:
|
||||
for args in self._queue:
|
||||
comm.send(args)
|
||||
self._queue.clear()
|
||||
if open_msg['content']['data']:
|
||||
handle_comm(open_msg)
|
||||
cname = "comm_" + str(id(self))
|
||||
COMM_MANAGER.register_target(cname, open_comm)
|
||||
|
||||
class Trigger(object):
|
||||
"""
|
||||
Trigger is the base class for Property and other data-bound
|
||||
field objects. Trigger holds a list of listeners that need to
|
||||
be notified about the event.
|
||||
|
||||
Multple Trigger objects can be tied (typically a parent Model can
|
||||
have Triggers that are triggered by children models). To support
|
||||
this, each Trigger can have a parent.
|
||||
|
||||
Trigger objects provide a notification protocol where view
|
||||
interactions trigger events at a leaf that are sent up to the
|
||||
root Trigger to be handled. By default, the root handler accepts
|
||||
events by notifying all listeners and children in the tree.
|
||||
"""
|
||||
def __init__(self):
|
||||
self._listeners = []
|
||||
self.parent = None
|
||||
def handle(self, value):
|
||||
'''
|
||||
Method to override; called at the root when an event has been
|
||||
triggered, and on a child when the parent has notified. By
|
||||
default notifies all listeners.
|
||||
'''
|
||||
self.notify(value)
|
||||
def trigger(self, value=None):
|
||||
'''
|
||||
Triggers an event to be handled by the root. By default, the root
|
||||
handler will accept the event so all the listeners will be notified.
|
||||
'''
|
||||
if self.parent is not None:
|
||||
self.parent.trigger(value)
|
||||
else:
|
||||
self.handle(value)
|
||||
def set(self, value):
|
||||
'''
|
||||
Sets the parent Trigger. Child Triggers trigger events by
|
||||
triggering parents, and in turn they handle notifications
|
||||
that come from parents.
|
||||
'''
|
||||
if self.parent is not None:
|
||||
self.parent.off(self.handle)
|
||||
self.parent = None
|
||||
if isinstance(value, Trigger):
|
||||
ancestor = value.parent
|
||||
while ancestor is not None:
|
||||
if ancestor == self:
|
||||
raise ValueError('bound properties should not form a loop')
|
||||
ancestor = ancestor.parent
|
||||
self.parent = value
|
||||
self.parent.on(self.handle)
|
||||
elif not isinstance(self, Property):
|
||||
raise ValueError('only properties can be set to a value')
|
||||
def notify(self, value=None):
|
||||
'''
|
||||
Notifies listeners and children. If a listener accepts an argument,
|
||||
the value will be passed as a single argument.
|
||||
'''
|
||||
for cb in self._listeners:
|
||||
if len(signature(cb).parameters) == 0:
|
||||
cb() # no-parameter callback.
|
||||
else:
|
||||
cb(value)
|
||||
# TODO: consider adding a two-parameter callback form
|
||||
# where a detailed event object can be added.
|
||||
def on(self, cb):
|
||||
'''
|
||||
Registers a listener. Calling multiple times registers
|
||||
multiple listeners.
|
||||
'''
|
||||
self._listeners.append(cb)
|
||||
def off(self, cb):
|
||||
'''
|
||||
Unregisters a listener.
|
||||
'''
|
||||
self._listeners = [c for c in self._listeners if c != cb]
|
||||
|
||||
class Property(Trigger):
|
||||
"""
|
||||
A Property is just an Trigger that remembers its last value.
|
||||
"""
|
||||
def __init__(self, value=None):
|
||||
'''
|
||||
Can be initialized with a starting value.
|
||||
'''
|
||||
super().__init__()
|
||||
self.set(value)
|
||||
def handle(self, value):
|
||||
'''
|
||||
The default handling for a Property is to store the value,
|
||||
then notify listeners. This method can be overridden,
|
||||
for example to validate values.
|
||||
'''
|
||||
self.value = value
|
||||
self.notify(value)
|
||||
def set(self, value):
|
||||
'''
|
||||
When a Property value is set to an ordinary value, it
|
||||
triggers an event which causes a notification to be
|
||||
sent to update all linked Properties. A Property set
|
||||
to another Property becomes a child of the value.
|
||||
'''
|
||||
# Handle setting a parent Property
|
||||
if isinstance(value, Property):
|
||||
super().set(value)
|
||||
self.handle(value.value)
|
||||
elif isinstance(value, Trigger):
|
||||
raise ValueError('Cannot set a Property to an Trigger')
|
||||
else:
|
||||
self.trigger(value)
|
||||
|
||||
|
||||
##########################################################################
|
||||
## Specific widgets
|
||||
##########################################################################
|
||||
|
||||
class Button(Widget):
|
||||
def __init__(self, label='button'):
|
||||
super().__init__()
|
||||
self.click = Trigger()
|
||||
self.label = Property(label)
|
||||
def widget_js(self):
|
||||
return '''
|
||||
element.addEventListener('click', (e) => {
|
||||
model.trigger('click');
|
||||
})
|
||||
model.on('label', (v) => {
|
||||
element.value = v;
|
||||
})
|
||||
'''
|
||||
def widget_html(self):
|
||||
return f'''
|
||||
<input id="{self.view_id()}" type="button" style="display:block"
|
||||
value="{html.escape(str(self.label))}">
|
||||
'''
|
||||
|
||||
class Label(Widget):
|
||||
def __init__(self, value=''):
|
||||
super().__init__()
|
||||
# databinding is defined using Property objects.
|
||||
self.value = Property(value)
|
||||
|
||||
def widget_js(self):
|
||||
# Both "model" and "element" objects are defined within the scope
|
||||
# where the js is run. "element" looks for the element with id
|
||||
# self.view_id(); if widget_html is overridden, this id should be used.
|
||||
return '''
|
||||
model.on('value', (value) => {
|
||||
element.innerText = model.get('value');
|
||||
});
|
||||
'''
|
||||
def widget_html(self):
|
||||
return f'''
|
||||
<label id="{self.view_id()}">{html.escape(str(self.value))}</label>
|
||||
'''
|
||||
|
||||
class Textbox(Widget):
|
||||
def __init__(self, value='', size=20):
|
||||
super().__init__()
|
||||
# databinding is defined using Property objects.
|
||||
self.value = Property(value)
|
||||
self.size = Property(size)
|
||||
|
||||
def widget_js(self):
|
||||
# Both "model" and "element" objects are defined within the scope
|
||||
# where the js is run. "element" looks for the element with id
|
||||
# self.view_id(); if widget_html is overridden, this id should be used.
|
||||
return '''
|
||||
element.value = model.get('value');
|
||||
element.size = model.get('size');
|
||||
element.addEventListener('keydown', (e) => {
|
||||
if (e.code == 'Enter') {
|
||||
model.set('value', element.value);
|
||||
}
|
||||
});
|
||||
model.on('value', (value) => {
|
||||
element.value = model.get('value');
|
||||
});
|
||||
model.on('size', (value) => {
|
||||
element.size = model.get('size');
|
||||
});
|
||||
'''
|
||||
def widget_html(self):
|
||||
return f'''
|
||||
<input id="{self.view_id()}" style="display:block"
|
||||
value="{html.escape(str(self.value))}" size="{self.size}">
|
||||
'''
|
||||
|
||||
class Range(Widget):
|
||||
def __init__(self, value=50, min=0, max=100):
|
||||
super().__init__()
|
||||
# databinding is defined using Property objects.
|
||||
self.value = Property(value)
|
||||
self.min = Property(min)
|
||||
self.max = Property(max)
|
||||
|
||||
def widget_js(self):
|
||||
# Note that the 'input' event would enable during-drag feedback,
|
||||
# but this is pretty slow on google colab.
|
||||
return '''
|
||||
element.addEventListener('change', (e) => {
|
||||
model.set('value', element.value);
|
||||
});
|
||||
model.on('value', (value) => {
|
||||
if (!element.matches(':active')) {
|
||||
element.value = value;
|
||||
}
|
||||
})
|
||||
'''
|
||||
def widget_html(self):
|
||||
return f'''
|
||||
<input id="{self.view_id()}" type="range"
|
||||
value="{self.value}" min="{self.min}" max="{self.max}">
|
||||
'''
|
||||
|
||||
class Choice(Widget):
|
||||
def __init__(self, choices=None, selection=None, horizontal=False):
|
||||
super().__init__()
|
||||
if choices is None:
|
||||
choices = []
|
||||
self.choices = Property(choices)
|
||||
self.horizontal = Property(horizontal)
|
||||
self.selection = Property(selection)
|
||||
def widget_js(self):
|
||||
# Note that the 'input' event would enable during-drag feedback,
|
||||
# but this is pretty slow on google colab.
|
||||
return '''
|
||||
function esc(unsafe) {
|
||||
return unsafe.replace(/&/g, "&").replace(/</g, "<")
|
||||
.replace(/>/g, ">").replace(/"/g, """);
|
||||
}
|
||||
function render() {
|
||||
console.log('rendering');
|
||||
var lines = model.get('choices').map((c) => {
|
||||
return '<label><input type="radio" name="choice" value="' +
|
||||
esc(c) + '">' + esc(c) + '</label>'
|
||||
});
|
||||
element.innerHTML = lines.join(model.get('horizontal')?' ':'<br>');
|
||||
}
|
||||
model.on('choices horizontal', render);
|
||||
model.on('selection', (selection) => {
|
||||
[...element.querySelectorAll('input')].forEach((e) => {
|
||||
e.checked = (e.value == selection);
|
||||
})
|
||||
});
|
||||
element.addEventListener('change', (e) => {
|
||||
model.set('selection', element.choice.value);
|
||||
});
|
||||
'''
|
||||
def widget_html(self):
|
||||
radios = [
|
||||
f"""<label><input name="choice" type="radio" {
|
||||
'checked' if value == self.selection else ''
|
||||
} value="{html.escape(value)}">{html.escape(value)}</label>"""
|
||||
for value in self.choices ]
|
||||
sep = " " if self.horizontal else "<br>"
|
||||
return f'<form id="{self.view_id()}">{sep.join(radios)}</form>'
|
||||
|
||||
|
||||
class Div(Widget):
|
||||
def __init__(self, innerHTML='', style=None, data=None):
|
||||
super().__init__()
|
||||
style = {} if style is None else style
|
||||
data = {} if data is None else data
|
||||
# TODO: convert non-string innerHTML objects to html; unify
|
||||
# with the show() library.
|
||||
self.innerHTML = Property(innerHTML)
|
||||
self.style = Property(style)
|
||||
self.click = Trigger()
|
||||
|
||||
def print(self, text, replace=False):
|
||||
newHTML = '<pre>%s</pre>' % html.escape(str(text));
|
||||
if replace:
|
||||
self.innerHTML = newHTML
|
||||
else:
|
||||
self.innerHTML += newHTML
|
||||
|
||||
def widget_js(self):
|
||||
# Note that if we want innerHTML to support script execution,
|
||||
# we need to do it explicitly, like this.
|
||||
return '''
|
||||
function updater(attr) {
|
||||
return (val) => { for (k in val) { element[attr][k] = val[k]; } }
|
||||
}
|
||||
model.on('style', updater('style'));
|
||||
updater('style')();
|
||||
model.on('data', updater('dataset'));
|
||||
updater('dataset')();
|
||||
model.on('innerHTML', (innerHTML) => {
|
||||
element.innerHTML = innerHTML;
|
||||
Array.from(element.querySelectorAll("script")).forEach(old=>{
|
||||
const newScript = document.createElement("script");
|
||||
Array.from(old.attributes).forEach(attr =>
|
||||
newScript.setAttribute(attr.name, attr.value));
|
||||
newScript.appendChild(document.createTextNode(old.innerHTML));
|
||||
old.parentNode.replaceChild(newScript, old);
|
||||
});
|
||||
});
|
||||
'''
|
||||
def widget_html(self):
|
||||
return f'''
|
||||
<div id="{self.view_id()}">{self.innerHTML}</div>
|
||||
'''
|
||||
|
||||
class ClickDiv(Div):
|
||||
'''
|
||||
A Div that triggers click events when anything inside them is clicked.
|
||||
If a clicked element contains a data-click value, then that value is
|
||||
sent as the click event value.
|
||||
'''
|
||||
def __init__(self, innerHTML='', style=None, data=None):
|
||||
super().__init__(innertHTML, style, data)
|
||||
self.click = Trigger()
|
||||
|
||||
def widget_js(self):
|
||||
return super().widget_js() + '''
|
||||
element.addEventListener('click', (ev) => {
|
||||
var target = ev.target;
|
||||
while (target && target != element && !target.dataset.click) {
|
||||
target = target.parentElement;
|
||||
}
|
||||
var value = target.dataset.click;
|
||||
model.trigger('click', value);
|
||||
});
|
||||
'''
|
||||
|
||||
##########################################################################
|
||||
## Implementation Details
|
||||
##########################################################################
|
||||
|
||||
WIDGET_ENV = None
|
||||
if WIDGET_ENV is None:
|
||||
try:
|
||||
from google.colab import output as colab_output
|
||||
WIDGET_ENV = 'colab'
|
||||
except:
|
||||
pass
|
||||
if WIDGET_ENV is None:
|
||||
try:
|
||||
from ipykernel.comm import Comm as jupyter_comm
|
||||
COMM_MANAGER = get_ipython().kernel.comm_manager
|
||||
WIDGET_ENV = 'jupyter'
|
||||
except:
|
||||
pass
|
||||
|
||||
SEND_RECV_JS = """
|
||||
function recvFromPython(obj_id, fn) {
|
||||
var recvname = "recv_" + obj_id;
|
||||
if (window[recvname] === undefined) {
|
||||
window[recvname] = new BroadcastChannel("channel_" + obj_id);
|
||||
}
|
||||
window[recvname].addEventListener("message", (ev) => {
|
||||
if (ev.data == 'ok') {
|
||||
window[recvname].ok = true;
|
||||
return;
|
||||
}
|
||||
fn.apply(null, ev.data.slice(1));
|
||||
});
|
||||
}
|
||||
function sendToPython(obj_id, ...args) {
|
||||
google.colab.kernel.invokeFunction('invoke_' + obj_id, args, {})
|
||||
}
|
||||
""" if WIDGET_ENV == 'colab' else """
|
||||
function getChan(obj_id) {
|
||||
var cname = "comm_" + obj_id;
|
||||
if (!window[cname]) { window[cname] = []; }
|
||||
var chan = window[cname];
|
||||
if (!chan.comm && Jupyter.notebook.kernel) {
|
||||
chan.comm = Jupyter.notebook.kernel.comm_manager.new_comm(cname, {});
|
||||
chan.comm.on_msg((ev) => {
|
||||
if (chan.retry) { clearInterval(chan.retry); chan.retry = null; }
|
||||
if (ev.content.data == 'ok') { return; }
|
||||
var args = ev.content.data.slice(1);
|
||||
for (fn of chan) { fn.apply(null, args); }
|
||||
});
|
||||
chan.retry = setInterval(() => { chan.comm.open(); }, 2000);
|
||||
}
|
||||
return chan;
|
||||
}
|
||||
function recvFromPython(obj_id, fn) {
|
||||
getChan(obj_id).push(fn);
|
||||
}
|
||||
function sendToPython(obj_id, ...args) {
|
||||
var comm = getChan(obj_id).comm;
|
||||
if (comm) { comm.send(args); }
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
WIDGET_MODEL_JS = SEND_RECV_JS + """
|
||||
class Model {
|
||||
constructor(obj_id, init) {
|
||||
this._id = obj_id;
|
||||
this._listeners = {};
|
||||
this._data = Object.assign({}, init)
|
||||
recvFromPython(this._id, (name, value) => {
|
||||
this._data[name] = value;
|
||||
if (this._listeners.hasOwnProperty(name)) {
|
||||
this._listeners[name].forEach((fn) => { fn(value); });
|
||||
}
|
||||
})
|
||||
}
|
||||
trigger(name, value) {
|
||||
sendToPython(this._id, name, value);
|
||||
}
|
||||
get(name) {
|
||||
return this._data[name];
|
||||
}
|
||||
set(name, value) {
|
||||
this.trigger(name, value);
|
||||
}
|
||||
on(name, fn) {
|
||||
name.split(/\s+/).forEach((n) => {
|
||||
if (!this._listeners.hasOwnProperty(n)) {
|
||||
this._listeners[n] = [];
|
||||
}
|
||||
this._listeners[n].push(fn);
|
||||
});
|
||||
}
|
||||
off(name, fn) {
|
||||
name.split(/\s+/).forEach((n) => {
|
||||
if (!fn) {
|
||||
delete this._listeners[n];
|
||||
} else if (this._listeners.hasOwnProperty(n)) {
|
||||
this._listeners[n] = this._listeners[n].filter(
|
||||
(e) => { return e !== fn; });
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
"""
|
||||
@@ -0,0 +1,451 @@
|
||||
"""
|
||||
Utilities for instrumenting a torch model.
|
||||
|
||||
Trace will hook one layer at a time.
|
||||
TraceDict will hook multiple layers at once.
|
||||
subsequence slices intervals from Sequential modules.
|
||||
get_module, replace_module, get_parameter resolve dotted names.
|
||||
set_requires_grad recursively sets requires_grad in module parameters.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
import inspect
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class Trace(contextlib.AbstractContextManager):
|
||||
"""
|
||||
To retain the output of the named layer during the computation of
|
||||
the given network:
|
||||
|
||||
with Trace(net, 'layer.name') as ret:
|
||||
_ = net(inp)
|
||||
representation = ret.output
|
||||
|
||||
A layer module can be passed directly without a layer name, and
|
||||
its output will be retained. By default, a direct reference to
|
||||
the output object is returned, but options can control this:
|
||||
|
||||
clone=True - retains a copy of the output, which can be
|
||||
useful if you want to see the output before it might
|
||||
be modified by the network in-place later.
|
||||
detach=True - retains a detached reference or copy. (By
|
||||
default the value would be left attached to the graph.)
|
||||
retain_grad=True - request gradient to be retained on the
|
||||
output. After backward(), ret.output.grad is populated.
|
||||
|
||||
retain_input=True - also retains the input.
|
||||
retain_output=False - can disable retaining the output.
|
||||
edit_output=fn - calls the function to modify the output
|
||||
of the layer before passing it the rest of the model.
|
||||
fn can optionally accept (output, layer) arguments
|
||||
for the original output and the layer name.
|
||||
stop=True - throws a StopForward exception after the layer
|
||||
is run, which allows running just a portion of a model.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module,
|
||||
layer=None,
|
||||
retain_output=True,
|
||||
retain_input=False,
|
||||
clone=False,
|
||||
detach=False,
|
||||
retain_grad=False,
|
||||
edit_output=None,
|
||||
stop=False,
|
||||
):
|
||||
"""
|
||||
Method to replace a forward method with a closure that
|
||||
intercepts the call, and tracks the hook so that it can be reverted.
|
||||
"""
|
||||
retainer = self
|
||||
self.layer = layer
|
||||
if layer is not None:
|
||||
module = get_module(module, layer)
|
||||
|
||||
def retain_hook(m, inputs, output):
|
||||
if retain_input:
|
||||
retainer.input = recursive_copy(
|
||||
inputs[0] if len(inputs) == 1 else inputs,
|
||||
clone=clone,
|
||||
detach=detach,
|
||||
retain_grad=False,
|
||||
) # retain_grad applies to output only.
|
||||
if edit_output:
|
||||
output = invoke_with_optional_args(
|
||||
edit_output, output=output, layer=self.layer
|
||||
)
|
||||
if retain_output:
|
||||
retainer.output = recursive_copy(
|
||||
output, clone=clone, detach=detach, retain_grad=retain_grad
|
||||
)
|
||||
# When retain_grad is set, also insert a trivial
|
||||
# copy operation. That allows in-place operations
|
||||
# to follow without error.
|
||||
if retain_grad:
|
||||
output = recursive_copy(retainer.output, clone=True, detach=False)
|
||||
if stop:
|
||||
raise StopForward()
|
||||
return output
|
||||
|
||||
self.registered_hook = module.register_forward_hook(retain_hook)
|
||||
self.stop = stop
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
self.close()
|
||||
if self.stop and issubclass(type, StopForward):
|
||||
return True
|
||||
|
||||
def close(self):
|
||||
self.registered_hook.remove()
|
||||
|
||||
|
||||
class TraceDict(OrderedDict, contextlib.AbstractContextManager):
|
||||
"""
|
||||
To retain the output of multiple named layers during the computation
|
||||
of the given network:
|
||||
|
||||
with TraceDict(net, ['layer1.name1', 'layer2.name2']) as ret:
|
||||
_ = net(inp)
|
||||
representation = ret['layer1.name1'].output
|
||||
|
||||
If edit_output is provided, it should be a function that takes
|
||||
two arguments: output, and the layer name; and then it returns the
|
||||
modified output.
|
||||
|
||||
Other arguments are the same as Trace. If stop is True, then the
|
||||
execution of the network will be stopped after the last layer
|
||||
listed (even if it would not have been the last to be executed).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module,
|
||||
layers=None,
|
||||
retain_output=True,
|
||||
retain_input=False,
|
||||
clone=False,
|
||||
detach=False,
|
||||
retain_grad=False,
|
||||
edit_output=None,
|
||||
stop=False,
|
||||
):
|
||||
self.stop = stop
|
||||
|
||||
def flag_last_unseen(it):
|
||||
try:
|
||||
it = iter(it)
|
||||
prev = next(it)
|
||||
seen = set([prev])
|
||||
except StopIteration:
|
||||
return
|
||||
for item in it:
|
||||
if item not in seen:
|
||||
yield False, prev
|
||||
seen.add(item)
|
||||
prev = item
|
||||
yield True, prev
|
||||
|
||||
for is_last, layer in flag_last_unseen(layers):
|
||||
self[layer] = Trace(
|
||||
module=module,
|
||||
layer=layer,
|
||||
retain_output=retain_output,
|
||||
retain_input=retain_input,
|
||||
clone=clone,
|
||||
detach=detach,
|
||||
retain_grad=retain_grad,
|
||||
edit_output=edit_output,
|
||||
stop=stop and is_last,
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
self.close()
|
||||
if self.stop and issubclass(type, StopForward):
|
||||
return True
|
||||
|
||||
def close(self):
|
||||
for layer, trace in reversed(self.items()):
|
||||
trace.close()
|
||||
|
||||
|
||||
class StopForward(Exception):
|
||||
"""
|
||||
If the only output needed from running a network is the retained
|
||||
submodule then Trace(submodule, stop=True) will stop execution
|
||||
immediately after the retained submodule by raising the StopForward()
|
||||
exception. When Trace is used as context manager, it catches that
|
||||
exception and can be used as follows:
|
||||
|
||||
with Trace(net, layername, stop=True) as tr:
|
||||
net(inp) # Only runs the network up to layername
|
||||
print(tr.output)
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def recursive_copy(x, clone=None, detach=None, retain_grad=None):
|
||||
"""
|
||||
Copies a reference to a tensor, or an object that contains tensors,
|
||||
optionally detaching and cloning the tensor(s). If retain_grad is
|
||||
true, the original tensors are marked to have grads retained.
|
||||
"""
|
||||
if not clone and not detach and not retain_grad:
|
||||
return x
|
||||
if isinstance(x, torch.Tensor):
|
||||
if retain_grad:
|
||||
if not x.requires_grad:
|
||||
x.requires_grad = True
|
||||
x.retain_grad()
|
||||
elif detach:
|
||||
x = x.detach()
|
||||
if clone:
|
||||
x = x.clone()
|
||||
return x
|
||||
# Only dicts, lists, and tuples (and subclasses) can be copied.
|
||||
if isinstance(x, dict):
|
||||
return type(x)({k: recursive_copy(v) for k, v in x.items()})
|
||||
elif isinstance(x, (list, tuple)):
|
||||
return type(x)([recursive_copy(v) for v in x])
|
||||
else:
|
||||
assert False, f"Unknown type {type(x)} cannot be broken into tensors."
|
||||
|
||||
|
||||
def subsequence(
|
||||
sequential,
|
||||
first_layer=None,
|
||||
last_layer=None,
|
||||
after_layer=None,
|
||||
upto_layer=None,
|
||||
single_layer=None,
|
||||
share_weights=False,
|
||||
):
|
||||
"""
|
||||
Creates a subsequence of a pytorch Sequential model, copying over
|
||||
modules together with parameters for the subsequence. Only
|
||||
modules from first_layer to last_layer (inclusive) are included,
|
||||
or modules between after_layer and upto_layer (exclusive).
|
||||
Handles descent into dotted layer names as long as all references
|
||||
are within nested Sequential models.
|
||||
|
||||
If share_weights is True, then references the original modules
|
||||
and their parameters without copying them. Otherwise, by default,
|
||||
makes a separate brand-new copy.
|
||||
"""
|
||||
assert (single_layer is None) or (
|
||||
first_layer is last_layer is after_layer is upto_layer is None
|
||||
)
|
||||
if single_layer is not None:
|
||||
first_layer = single_layer
|
||||
last_layer = single_layer
|
||||
first, last, after, upto = [
|
||||
None if d is None else d.split(".")
|
||||
for d in [first_layer, last_layer, after_layer, upto_layer]
|
||||
]
|
||||
return hierarchical_subsequence(
|
||||
sequential,
|
||||
first=first,
|
||||
last=last,
|
||||
after=after,
|
||||
upto=upto,
|
||||
share_weights=share_weights,
|
||||
)
|
||||
|
||||
|
||||
def hierarchical_subsequence(
|
||||
sequential, first, last, after, upto, share_weights=False, depth=0
|
||||
):
|
||||
"""
|
||||
Recursive helper for subsequence() to support descent into dotted
|
||||
layer names. In this helper, first, last, after, and upto are
|
||||
arrays of names resulting from splitting on dots. Can only
|
||||
descend into nested Sequentials.
|
||||
"""
|
||||
assert (last is None) or (upto is None)
|
||||
assert (first is None) or (after is None)
|
||||
if first is last is after is upto is None:
|
||||
return sequential if share_weights else copy.deepcopy(sequential)
|
||||
assert isinstance(sequential, torch.nn.Sequential), (
|
||||
".".join((first or last or after or upto)[:depth] or "arg") + " not Sequential"
|
||||
)
|
||||
including_children = (first is None) and (after is None)
|
||||
included_children = OrderedDict()
|
||||
# A = current level short name of A.
|
||||
# AN = full name for recursive descent if not innermost.
|
||||
(F, FN), (L, LN), (A, AN), (U, UN) = [
|
||||
(d[depth], (None if len(d) == depth + 1 else d))
|
||||
if d is not None
|
||||
else (None, None)
|
||||
for d in [first, last, after, upto]
|
||||
]
|
||||
for name, layer in sequential._modules.items():
|
||||
if name == F:
|
||||
first = None
|
||||
including_children = True
|
||||
if name == A and AN is not None: # just like F if not a leaf.
|
||||
after = None
|
||||
including_children = True
|
||||
if name == U and UN is None:
|
||||
upto = None
|
||||
including_children = False
|
||||
if including_children:
|
||||
# AR = full name for recursive descent if name matches.
|
||||
FR, LR, AR, UR = [
|
||||
n if n is None or n[depth] == name else None for n in [FN, LN, AN, UN]
|
||||
]
|
||||
chosen = hierarchical_subsequence(
|
||||
layer,
|
||||
first=FR,
|
||||
last=LR,
|
||||
after=AR,
|
||||
upto=UR,
|
||||
share_weights=share_weights,
|
||||
depth=depth + 1,
|
||||
)
|
||||
if chosen is not None:
|
||||
included_children[name] = chosen
|
||||
if name == L:
|
||||
last = None
|
||||
including_children = False
|
||||
if name == U and UN is not None: # just like L if not a leaf.
|
||||
upto = None
|
||||
including_children = False
|
||||
if name == A and AN is None:
|
||||
after = None
|
||||
including_children = True
|
||||
for name in [first, last, after, upto]:
|
||||
if name is not None:
|
||||
raise ValueError("Layer %s not found" % ".".join(name))
|
||||
# Omit empty subsequences except at the outermost level,
|
||||
# where we should not return None.
|
||||
if not len(included_children) and depth > 0:
|
||||
return None
|
||||
result = torch.nn.Sequential(included_children)
|
||||
result.training = sequential.training
|
||||
return result
|
||||
|
||||
|
||||
def set_requires_grad(requires_grad, *models):
|
||||
"""
|
||||
Sets requires_grad true or false for all parameters within the
|
||||
models passed.
|
||||
"""
|
||||
for model in models:
|
||||
if isinstance(model, torch.nn.Module):
|
||||
for param in model.parameters():
|
||||
param.requires_grad = requires_grad
|
||||
elif isinstance(model, (torch.nn.Parameter, torch.Tensor)):
|
||||
model.requires_grad = requires_grad
|
||||
else:
|
||||
assert False, "unknown type %r" % type(model)
|
||||
|
||||
|
||||
def get_module(model, name):
|
||||
"""
|
||||
Finds the named module within the given model.
|
||||
"""
|
||||
for n, m in model.named_modules():
|
||||
if n == name:
|
||||
return m
|
||||
raise LookupError(name)
|
||||
|
||||
|
||||
def get_parameter(model, name):
|
||||
"""
|
||||
Finds the named parameter within the given model.
|
||||
"""
|
||||
for n, p in model.named_parameters():
|
||||
if n == name:
|
||||
return p
|
||||
raise LookupError(name)
|
||||
|
||||
|
||||
def replace_module(model, name, new_module):
|
||||
"""
|
||||
Replaces the named module within the given model.
|
||||
"""
|
||||
if "." in name:
|
||||
parent_name, attr_name = name.rsplit(".", 1)
|
||||
model = get_module(model, parent_name)
|
||||
# original_module = getattr(model, attr_name)
|
||||
setattr(model, attr_name, new_module)
|
||||
|
||||
|
||||
def invoke_with_optional_args(fn, *args, **kwargs):
|
||||
"""
|
||||
Invokes a function with only the arguments that it
|
||||
is written to accept, giving priority to arguments
|
||||
that match by-name, using the following rules.
|
||||
(1) arguments with matching names are passed by name.
|
||||
(2) remaining non-name-matched args are passed by order.
|
||||
(3) extra caller arguments that the function cannot
|
||||
accept are not passed.
|
||||
(4) extra required function arguments that the caller
|
||||
cannot provide cause a TypeError to be raised.
|
||||
Ordinary python calling conventions are helpful for
|
||||
supporting a function that might be revised to accept
|
||||
extra arguments in a newer version, without requiring the
|
||||
caller to pass those new arguments. This function helps
|
||||
support function callers that might be revised to supply
|
||||
extra arguments, without requiring the callee to accept
|
||||
those new arguments.
|
||||
"""
|
||||
argspec = inspect.getfullargspec(fn)
|
||||
pass_args = []
|
||||
used_kw = set()
|
||||
unmatched_pos = []
|
||||
used_pos = 0
|
||||
defaulted_pos = len(argspec.args) - (
|
||||
0 if not argspec.defaults else len(argspec.defaults)
|
||||
)
|
||||
# Pass positional args that match name first, then by position.
|
||||
for i, n in enumerate(argspec.args):
|
||||
if n in kwargs:
|
||||
pass_args.append(kwargs[n])
|
||||
used_kw.add(n)
|
||||
elif used_pos < len(args):
|
||||
pass_args.append(args[used_pos])
|
||||
used_pos += 1
|
||||
else:
|
||||
unmatched_pos.append(len(pass_args))
|
||||
pass_args.append(
|
||||
None if i < defaulted_pos else argspec.defaults[i - defaulted_pos]
|
||||
)
|
||||
# Fill unmatched positional args with unmatched keyword args in order.
|
||||
if len(unmatched_pos):
|
||||
for k, v in kwargs.items():
|
||||
if k in used_kw or k in argspec.kwonlyargs:
|
||||
continue
|
||||
pass_args[unmatched_pos[0]] = v
|
||||
used_kw.add(k)
|
||||
unmatched_pos = unmatched_pos[1:]
|
||||
if len(unmatched_pos) == 0:
|
||||
break
|
||||
else:
|
||||
if unmatched_pos[0] < defaulted_pos:
|
||||
unpassed = ", ".join(
|
||||
argspec.args[u] for u in unmatched_pos if u < defaulted_pos
|
||||
)
|
||||
raise TypeError(f"{fn.__name__}() cannot be passed {unpassed}.")
|
||||
# Pass remaining kw args if they can be accepted.
|
||||
pass_kw = {
|
||||
k: v
|
||||
for k, v in kwargs.items()
|
||||
if k not in used_kw and (k in argspec.kwonlyargs or argspec.varargs is not None)
|
||||
}
|
||||
# Pass remaining positional args if they can be accepted.
|
||||
if argspec.varargs is not None:
|
||||
pass_args += list(args[used_pos:])
|
||||
return fn(*pass_args, **pass_kw)
|
||||
@@ -0,0 +1,147 @@
|
||||
from .labwidget import Widget, Property
|
||||
import html
|
||||
|
||||
class PaintWidget(Widget):
|
||||
def __init__(self,
|
||||
width=256, height=256,
|
||||
image='', mask='', brushsize=10.0, oneshot=False, disabled=False):
|
||||
super().__init__()
|
||||
self.mask = Property(mask)
|
||||
self.image = Property(image)
|
||||
self.brushsize = Property(brushsize)
|
||||
self.erase = Property(False)
|
||||
self.oneshot = Property(oneshot)
|
||||
self.disabled = Property(disabled)
|
||||
self.width = Property(width)
|
||||
self.height = Property(height)
|
||||
|
||||
def widget_js(self):
|
||||
return f'''
|
||||
{PAINT_WIDGET_JS}
|
||||
var pw = new PaintWidget(element, model);
|
||||
'''
|
||||
def widget_html(self):
|
||||
v = self.view_id()
|
||||
return f'''
|
||||
<style>
|
||||
#{v} {{ position: relative; display: inline-block; }}
|
||||
#{v} .paintmask {{
|
||||
position: absolute; top:0; left: 0; z-index: 1;
|
||||
opacity: 0; transition: opacity .1s ease-in-out; }}
|
||||
#{v} .paintmask:hover {{ opacity: 0.7; }}
|
||||
</style>
|
||||
<div id="{v}"></div>
|
||||
'''
|
||||
|
||||
PAINT_WIDGET_JS = """
|
||||
class PaintWidget {
|
||||
constructor(el, model) {
|
||||
this.el = el;
|
||||
this.model = model;
|
||||
this.size_changed();
|
||||
this.model.on('mask', this.mask_changed.bind(this));
|
||||
this.model.on('image', this.image_changed.bind(this));
|
||||
this.model.on('width', this.size_changed.bind(this));
|
||||
this.model.on('height', this.size_changed.bind(this));
|
||||
}
|
||||
mouse_stroke(first_event) {
|
||||
var self = this;
|
||||
if (self.model.get('disabled')) { return; }
|
||||
if (self.model.get('oneshot')) {
|
||||
var canvas = self.mask_canvas;
|
||||
var ctx = canvas.getContext('2d');
|
||||
ctx.clearRect(0, 0, canvas.width, canvas.height);
|
||||
}
|
||||
function track_mouse(evt) {
|
||||
if (evt.type == 'keydown' || self.model.get('disabled')) {
|
||||
if (self.model.get('disabled') || evt.key === "Escape") {
|
||||
window.removeEventListener('mousemove', track_mouse);
|
||||
window.removeEventListener('mouseup', track_mouse);
|
||||
window.removeEventListener('keydown', track_mouse, true);
|
||||
self.mask_changed();
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (evt.type == 'mouseup' ||
|
||||
(typeof evt.buttons != 'undefined' && evt.buttons == 0)) {
|
||||
window.removeEventListener('mousemove', track_mouse);
|
||||
window.removeEventListener('mouseup', track_mouse);
|
||||
window.removeEventListener('keydown', track_mouse, true);
|
||||
self.model.set('mask', self.mask_canvas.toDataURL());
|
||||
return;
|
||||
}
|
||||
var p = self.cursor_position();
|
||||
self.fill_circle(p.x, p.y,
|
||||
self.model.get('brushsize'),
|
||||
self.model.get('erase'));
|
||||
}
|
||||
this.mask_canvas.focus();
|
||||
window.addEventListener('mousemove', track_mouse);
|
||||
window.addEventListener('mouseup', track_mouse);
|
||||
window.addEventListener('keydown', track_mouse, true);
|
||||
track_mouse(first_event);
|
||||
}
|
||||
mask_changed(val) {
|
||||
this.draw_data_url(this.mask_canvas, this.model.get('mask'));
|
||||
}
|
||||
image_changed() {
|
||||
this.draw_data_url(this.image_canvas, this.model.get('image'));
|
||||
}
|
||||
size_changed() {
|
||||
this.mask_canvas = document.createElement('canvas');
|
||||
this.image_canvas = document.createElement('canvas');
|
||||
this.mask_canvas.className = "paintmask";
|
||||
this.image_canvas.className = "paintimage";
|
||||
for (var attr of ['width', 'height']) {
|
||||
this.mask_canvas[attr] = this.model.get(attr);
|
||||
this.image_canvas[attr] = this.model.get(attr);
|
||||
}
|
||||
|
||||
this.el.innerHTML = '';
|
||||
this.el.appendChild(this.image_canvas);
|
||||
this.el.appendChild(this.mask_canvas);
|
||||
this.mask_canvas.addEventListener('mousedown',
|
||||
this.mouse_stroke.bind(this));
|
||||
this.mask_changed();
|
||||
this.image_changed();
|
||||
}
|
||||
|
||||
cursor_position(evt) {
|
||||
const rect = this.mask_canvas.getBoundingClientRect();
|
||||
const x = event.clientX - rect.left;
|
||||
const y = event.clientY - rect.top;
|
||||
return {x: x, y: y};
|
||||
}
|
||||
|
||||
fill_circle(x, y, r, erase, blur) {
|
||||
var ctx = this.mask_canvas.getContext('2d');
|
||||
ctx.save();
|
||||
if (blur) {
|
||||
ctx.filter = 'blur(' + blur + 'px)';
|
||||
}
|
||||
ctx.globalCompositeOperation = (
|
||||
erase ? "destination-out" : 'source-over');
|
||||
ctx.fillStyle = '#fff';
|
||||
ctx.beginPath();
|
||||
ctx.arc(x, y, r, 0, 2 * Math.PI);
|
||||
ctx.fill();
|
||||
ctx.restore()
|
||||
}
|
||||
|
||||
draw_data_url(canvas, durl) {
|
||||
var ctx = canvas.getContext('2d');
|
||||
var img = new Image;
|
||||
canvas.pendingImg = img;
|
||||
function imgdone() {
|
||||
if (canvas.pendingImg == img) {
|
||||
ctx.clearRect(0, 0, canvas.width, canvas.height);
|
||||
ctx.drawImage(img, 0, 0);
|
||||
canvas.pendingImg = null;
|
||||
}
|
||||
}
|
||||
img.addEventListener('load', imgdone);
|
||||
img.addEventListener('error', imgdone);
|
||||
img.src = durl;
|
||||
}
|
||||
}
|
||||
"""
|
||||
@@ -0,0 +1,212 @@
|
||||
'''
|
||||
Utilities for showing progress bars, controlling default verbosity, etc.
|
||||
'''
|
||||
|
||||
# If the tqdm package is not available, then do not show progress bars;
|
||||
# just connect print_progress to print.
|
||||
import sys
|
||||
import types
|
||||
import builtins
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
try:
|
||||
from tqdm.notebook import tqdm as tqdm_nb
|
||||
except:
|
||||
from tqdm import tqdm_notebook as tqdm_nb
|
||||
except:
|
||||
tqdm = None
|
||||
|
||||
default_verbosity = True
|
||||
next_description = None
|
||||
python_print = builtins.print
|
||||
|
||||
|
||||
def post(**kwargs):
|
||||
'''
|
||||
When within a progress loop, pbar.post(k=str) will display
|
||||
the given k=str status on the right-hand-side of the progress
|
||||
status bar. If not within a visible progress bar, does nothing.
|
||||
'''
|
||||
innermost = innermost_tqdm()
|
||||
if innermost is not None:
|
||||
innermost.set_postfix(**kwargs)
|
||||
|
||||
|
||||
def desc(desc):
|
||||
'''
|
||||
When within a progress loop, pbar.desc(str) changes the
|
||||
left-hand-side description of the loop toe the given description.
|
||||
'''
|
||||
innermost = innermost_tqdm()
|
||||
if innermost is not None:
|
||||
innermost.set_description(str(desc))
|
||||
|
||||
|
||||
def descnext(desc):
|
||||
'''
|
||||
Called before starting a progress loop, pbar.descnext(str)
|
||||
sets the description text that will be used in the following loop.
|
||||
'''
|
||||
global next_description
|
||||
if not default_verbosity or tqdm is None:
|
||||
return
|
||||
next_description = desc
|
||||
|
||||
|
||||
def print(*args):
|
||||
'''
|
||||
When within a progress loop, will print above the progress loop.
|
||||
'''
|
||||
global next_description
|
||||
next_description = None
|
||||
if default_verbosity:
|
||||
msg = ' '.join(str(s) for s in args)
|
||||
if tqdm is None:
|
||||
python_print(msg)
|
||||
else:
|
||||
tqdm.write(msg)
|
||||
|
||||
|
||||
def tqdm_terminal(it, *args, **kwargs):
|
||||
'''
|
||||
Some settings for tqdm that make it run better in resizable terminals.
|
||||
'''
|
||||
return tqdm(it, *args, dynamic_ncols=True, ascii=True,
|
||||
leave=(innermost_tqdm() is not None), **kwargs)
|
||||
|
||||
|
||||
def in_notebook():
|
||||
'''
|
||||
True if running inside a Jupyter notebook.
|
||||
'''
|
||||
# From https://stackoverflow.com/a/39662359/265298
|
||||
try:
|
||||
shell = get_ipython().__class__.__name__
|
||||
if shell == 'ZMQInteractiveShell':
|
||||
return True # Jupyter notebook or qtconsole
|
||||
elif shell == 'TerminalInteractiveShell':
|
||||
return False # Terminal running IPython
|
||||
else:
|
||||
return False # Other type (?)
|
||||
except NameError:
|
||||
return False # Probably standard Python interpreter
|
||||
|
||||
|
||||
def innermost_tqdm():
|
||||
'''
|
||||
Returns the innermost active tqdm progress loop on the stack.
|
||||
'''
|
||||
if hasattr(tqdm, '_instances') and len(tqdm._instances) > 0:
|
||||
return max(tqdm._instances, key=lambda x: x.pos)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def reporthook(*args, **kwargs):
|
||||
'''
|
||||
For use with urllib.request.urlretrieve.
|
||||
|
||||
with pbar.reporthook() as hook:
|
||||
urllib.request.urlretrieve(url, filename, reporthook=hook)
|
||||
'''
|
||||
kwargs2 = dict(unit_scale=True, miniters=1)
|
||||
kwargs2.update(kwargs)
|
||||
bar = __call__(None, *args, **kwargs2)
|
||||
|
||||
class ReportHook(object):
|
||||
def __init__(self, t):
|
||||
self.t = t
|
||||
|
||||
def __call__(self, b=1, bsize=1, tsize=None):
|
||||
if hasattr(self.t, 'total'):
|
||||
if tsize is not None:
|
||||
self.t.total = tsize
|
||||
if hasattr(self.t, 'update'):
|
||||
self.t.update(b * bsize - self.t.n)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc):
|
||||
if hasattr(self.t, '__exit__'):
|
||||
self.t.__exit__(*exc)
|
||||
return ReportHook(bar)
|
||||
|
||||
|
||||
def __call__(x, *args, **kwargs):
|
||||
'''
|
||||
Invokes a progress function that can wrap iterators to print
|
||||
progress messages, if verbose is True.
|
||||
|
||||
If verbose is False or tqdm is unavailable, then a quiet
|
||||
non-printing identity function is used.
|
||||
|
||||
verbose can also be set to a spefific progress function rather
|
||||
than True, and that function will be used.
|
||||
'''
|
||||
global default_verbosity, next_description
|
||||
if not default_verbosity or tqdm is None:
|
||||
return x
|
||||
if default_verbosity == True:
|
||||
fn = tqdm_nb if in_notebook() else tqdm_terminal
|
||||
else:
|
||||
fn = default_verbosity
|
||||
if next_description is not None:
|
||||
kwargs = dict(kwargs)
|
||||
kwargs['desc'] = next_description
|
||||
next_description = None
|
||||
return fn(x, *args, **kwargs)
|
||||
|
||||
|
||||
class VerboseContextManager():
|
||||
def __init__(self, v, entered=False):
|
||||
self.v, self.entered, self.saved = v, False, []
|
||||
if entered:
|
||||
self.__enter__()
|
||||
self.entered = True
|
||||
|
||||
def __enter__(self):
|
||||
global default_verbosity
|
||||
if self.entered:
|
||||
self.entered = False
|
||||
else:
|
||||
self.saved.append(default_verbosity)
|
||||
default_verbosity = self.v
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||
global default_verbosity
|
||||
default_verbosity = self.saved.pop()
|
||||
|
||||
def __call__(self, v=True):
|
||||
'''
|
||||
Calling the context manager makes a new context that is
|
||||
pre-entered, so it works as both a plain function and as a
|
||||
factory for a context manager.
|
||||
'''
|
||||
new_v = v if self.v else not v
|
||||
cm = VerboseContextManager(new_v, entered=True)
|
||||
default_verbosity = new_v
|
||||
return cm
|
||||
|
||||
|
||||
# Use as either "with pbar.verbose:" or "pbar.verbose(False)", or also
|
||||
# "with pbar.verbose(False):"
|
||||
verbose = VerboseContextManager(True)
|
||||
|
||||
# Use as either "with @pbar.quiet" or "pbar.quiet(True)". or also
|
||||
# "with pbar.quiet(True):"
|
||||
quiet = VerboseContextManager(False)
|
||||
|
||||
|
||||
class CallableModule(types.ModuleType):
|
||||
def __init__(self):
|
||||
# or super().__init__(__name__) for Python 3
|
||||
types.ModuleType.__init__(self, __name__)
|
||||
self.__dict__.update(sys.modules[__name__].__dict__)
|
||||
|
||||
def __call__(self, x, *args, **kwargs):
|
||||
return __call__(x, *args, **kwargs)
|
||||
|
||||
|
||||
sys.modules[__name__] = CallableModule()
|
||||
@@ -0,0 +1,129 @@
|
||||
'''
|
||||
Utility for simple distribution of work on multiple processes, by
|
||||
making sure only one process is working on a job at once.
|
||||
'''
|
||||
|
||||
import os
|
||||
import errno
|
||||
import socket
|
||||
import atexit
|
||||
import time
|
||||
import sys
|
||||
|
||||
|
||||
def reserve_dir(*args):
|
||||
'''
|
||||
Convenience function to get exclusive access to an unfinished
|
||||
experiment directory. Exits the program if the directory is
|
||||
already done or busy (using exit_of_job_done). Otherwise,
|
||||
returns a function creates filenames within that directory.
|
||||
'''
|
||||
directory = os.path.join(*args)
|
||||
exit_if_job_done(directory)
|
||||
|
||||
def dirfn(*fn):
|
||||
return os.path.join(directory, *fn)
|
||||
dirfn.dir = directory
|
||||
|
||||
def done():
|
||||
mark_job_done(directory)
|
||||
dirfn.done = done
|
||||
print('Working in %s' % directory)
|
||||
return dirfn
|
||||
|
||||
|
||||
# Old function name.
|
||||
exclusive_dirfn = reserve_dir
|
||||
|
||||
|
||||
def exit_if_job_done(directory, redo=False, force=False, verbose=True):
|
||||
if pidfile_taken(os.path.join(directory, 'lockfile.pid'),
|
||||
force=force, verbose=verbose):
|
||||
sys.exit(0)
|
||||
donefile = os.path.join(directory, 'done.txt')
|
||||
if os.path.isfile(donefile):
|
||||
with open(donefile) as f:
|
||||
msg = f.read()
|
||||
if redo or force:
|
||||
if verbose:
|
||||
print('Removing %s %s' % (donefile, msg))
|
||||
os.remove(donefile)
|
||||
else:
|
||||
if verbose:
|
||||
print('%s %s' % (donefile, msg))
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def mark_job_done(directory):
|
||||
with open(os.path.join(directory, 'done.txt'), 'w') as f:
|
||||
f.write('done by %d@%s %s at %s' %
|
||||
(os.getpid(), socket.gethostname(),
|
||||
os.getenv('STY', ''),
|
||||
time.strftime('%c')))
|
||||
|
||||
|
||||
def pidfile_taken(path, verbose=False, force=False):
|
||||
'''
|
||||
Usage. To grab an exclusive lock for the remaining duration of the
|
||||
current process (and exit if another process already has the lock),
|
||||
do this:
|
||||
|
||||
if pidfile_taken('job_423/lockfile.pid', verbose=True):
|
||||
sys.exit(0)
|
||||
|
||||
To do a batch of jobs, just run a script that does them all on
|
||||
each available machine, sharing a network filesystem. When each
|
||||
job grabs a lock, then this will automatically distribute the
|
||||
jobs so that each one is done just once on one machine.
|
||||
'''
|
||||
|
||||
# Try to create the file exclusively and write my pid into it.
|
||||
try:
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
fd = os.open(path, os.O_CREAT | os.O_EXCL | os.O_RDWR)
|
||||
except OSError as e:
|
||||
if e.errno == errno.EEXIST:
|
||||
# If we cannot because there was a race, yield the conflicter.
|
||||
conflicter = 'race'
|
||||
try:
|
||||
with open(path, 'r') as lockfile:
|
||||
conflicter = lockfile.read().strip() or 'empty'
|
||||
except:
|
||||
pass
|
||||
# Force is for manual one-time use, for deleting stale lockfiles.
|
||||
if force:
|
||||
if verbose:
|
||||
print('Removing %s from %s' % (path, conflicter))
|
||||
os.remove(path)
|
||||
return pidfile_taken(path, verbose=verbose, force=False)
|
||||
if verbose:
|
||||
print('%s held by %s' % (path, conflicter))
|
||||
return conflicter
|
||||
else:
|
||||
# Other problems get an exception.
|
||||
raise
|
||||
# Register to delete this file on exit.
|
||||
lockfile = os.fdopen(fd, 'r+')
|
||||
atexit.register(delete_pidfile, lockfile, path)
|
||||
# Write my pid into the open file.
|
||||
lockfile.write('%d@%s %s\n' % (os.getpid(), socket.gethostname(),
|
||||
os.getenv('STY', '')))
|
||||
lockfile.flush()
|
||||
os.fsync(lockfile)
|
||||
# Return 'None' to say there was not a conflict.
|
||||
return None
|
||||
|
||||
|
||||
def delete_pidfile(lockfile, path):
|
||||
'''
|
||||
Runs at exit after pidfile_taken succeeds.
|
||||
'''
|
||||
if lockfile is not None:
|
||||
try:
|
||||
lockfile.close()
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
os.unlink(path)
|
||||
except:
|
||||
pass
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,138 @@
|
||||
# show.py
|
||||
#
|
||||
# An abbreviated way to output simple HTML layout of text and images
|
||||
# into a python notebook.
|
||||
#
|
||||
# - show a PIL image to show an inline HTML <img>.
|
||||
# - show an array of items to vertically stack them, centered in a block.
|
||||
# - show an array of arrays to horizontally lay them out as inline blocks.
|
||||
# - show an array of tuples to create a table.
|
||||
|
||||
import PIL.Image, base64, io, IPython, types, sys
|
||||
import html as html_module
|
||||
from IPython.display import display
|
||||
|
||||
g_buffer = None
|
||||
|
||||
def blocks(obj, space=''):
|
||||
return IPython.display.HTML(space.join(blocks_tags(obj)))
|
||||
|
||||
def rows(obj, space=''):
|
||||
return IPython.display.HTML(space.join(rows_tags(obj)))
|
||||
|
||||
def rows_tags(obj):
|
||||
if isinstance(obj, dict):
|
||||
obj = obj.items()
|
||||
results = []
|
||||
results.append('<table style="display:inline-table">')
|
||||
for row in obj:
|
||||
results.append('<tr style="padding:0">')
|
||||
for item in row:
|
||||
results.append('<td style="text-align:left; vertical-align:top;' +
|
||||
'padding:1px">')
|
||||
results.extend(blocks_tags(item))
|
||||
results.append('</td>')
|
||||
results.append('</tr>')
|
||||
results.append('</table>')
|
||||
return results
|
||||
|
||||
def blocks_tags(obj):
|
||||
results = []
|
||||
if hasattr(obj, '_repr_html_'):
|
||||
results.append(obj._repr_html_())
|
||||
elif isinstance(obj, PIL.Image.Image):
|
||||
results.append(pil_to_html(obj))
|
||||
elif isinstance(obj, (str, int, float)):
|
||||
results.append('<div>')
|
||||
results.append(html_module.escape(str(obj)))
|
||||
results.append('</div>')
|
||||
elif isinstance(obj, dict):
|
||||
results.extend(blocks_tags([(k, v) for k, v in obj.items()]))
|
||||
elif hasattr(obj, '__iter__'):
|
||||
blockstart, blockend, tstart, tend, rstart, rend, cstart, cend = [
|
||||
'<div style="display:inline-block;text-align:center;line-height:1;' +
|
||||
'vertical-align:top;padding:1px">',
|
||||
'</div>',
|
||||
'<table style="display:inline-table">',
|
||||
'</table>',
|
||||
'<tr style="padding:0">',
|
||||
'</tr>',
|
||||
'<td style="text-align:left; vertical-align:top; padding:1px">',
|
||||
'</td>',
|
||||
]
|
||||
needs_end = False
|
||||
table_mode = False
|
||||
for i, line in enumerate(obj):
|
||||
if i == 0:
|
||||
needs_end = True
|
||||
if isinstance(line, tuple):
|
||||
table_mode = True
|
||||
results.append(tstart)
|
||||
else:
|
||||
results.append(blockstart)
|
||||
if table_mode:
|
||||
results.append(rstart)
|
||||
if not isinstance(line, str) and hasattr(line, '__iter__'):
|
||||
for cell in line:
|
||||
results.append(cstart)
|
||||
results.extend(blocks_tags(cell))
|
||||
results.append(cend)
|
||||
else:
|
||||
results.append(cstart)
|
||||
results.extend(blocks_tags(line))
|
||||
results.append(cend)
|
||||
results.append(rend)
|
||||
else:
|
||||
results.extend(blocks_tags(line))
|
||||
if needs_end:
|
||||
results.append(table_mode and tend or blockend)
|
||||
return results
|
||||
|
||||
def pil_to_b64(img, format='png'):
|
||||
buffered = io.BytesIO()
|
||||
img.save(buffered, format=format)
|
||||
return base64.b64encode(buffered.getvalue()).decode('utf-8')
|
||||
|
||||
def pil_to_url(img, format='png'):
|
||||
return 'data:image/%s;base64,%s' % (format, pil_to_b64(img, format))
|
||||
|
||||
def pil_to_html(img, margin=1):
|
||||
mattr = ' style="margin:%dpx"' % margin
|
||||
return '<img src="%s"%s>' % (pil_to_url(img), mattr)
|
||||
|
||||
def a(x, cols=None):
|
||||
global g_buffer
|
||||
if g_buffer is None:
|
||||
g_buffer = []
|
||||
g_buffer.append(x)
|
||||
if cols is not None and len(g_buffer) >= cols:
|
||||
flush()
|
||||
|
||||
def reset():
|
||||
global g_buffer
|
||||
g_buffer = None
|
||||
|
||||
def flush(*args, **kwargs):
|
||||
global g_buffer
|
||||
if g_buffer is not None:
|
||||
x = g_buffer
|
||||
g_buffer = None
|
||||
display(blocks(x, *args, **kwargs))
|
||||
|
||||
def show(x=None, *args, **kwargs):
|
||||
flush(*args, **kwargs)
|
||||
if x is not None:
|
||||
display(blocks(x, *args, **kwargs))
|
||||
|
||||
def html(obj, space=''):
|
||||
return blocks(obj, space)._repr_html_()
|
||||
|
||||
class CallableModule(types.ModuleType):
|
||||
def __init__(self):
|
||||
# or super().__init__(__name__) for Python 3
|
||||
types.ModuleType.__init__(self, __name__)
|
||||
self.__dict__.update(sys.modules[__name__].__dict__)
|
||||
def __call__(self, x=None, *args, **kwargs):
|
||||
show(x, *args, **kwargs)
|
||||
|
||||
sys.modules[__name__] = CallableModule()
|
||||
Reference in New Issue
Block a user