diff --git a/torchkit/README.md b/torchkit/README.md new file mode 100644 index 0000000..7ac7575 --- /dev/null +++ b/torchkit/README.md @@ -0,0 +1,3 @@ +# David's torchkit + +Some utilities useful for prototyping with pytorch. diff --git a/torchkit/plotwidget.py b/torchkit/plotwidget.py new file mode 100644 index 0000000..62a0701 --- /dev/null +++ b/torchkit/plotwidget.py @@ -0,0 +1,60 @@ +from .labwidget import ImageWidget, Property +import matplotlib, matplotlib.pyplot +import inspect + +class PlotWidget(ImageWidget): + """ + A widget to create interactive matplotlib plots by defining a simple function. + Example of usage: + + ``` + import numpy + def simple_redraw_rule(fig, amp=1.0, freq=1.0): + [ax] = fig.axes + ax.clear() + x = numpy.linspace(0, 5, 100) + ax.plot(x, amp * numpy.sin(freq * x)) + + plot = PlotWidget(simple_redraw_rule) + display(plot) + ``` + + The keyword arguments in the provided function will become properties + of the plot widget; updating those properties will redraw the plot in-place. + For example, in the above, assigning `plot.freq = 3` will redraw the + plot with freq set to 3. + """ + def __init__(self, redraw_rule, **kwargs): + super().__init__() + init_args = dict(kwargs) + + all_names = [] + for i, (name, p) in enumerate(inspect.signature(redraw_rule).parameters.items()): + if i == 0: + assert p.default == inspect._empty, 'First arg of redraw rule should be the figure' + else: + if name in kwargs: + default = init_args.pop(name) + else: + assert p.default != inspect._empty, 'Arguments must have default values' + default = p.default + setattr(self, name, Property(default)) + all_names.append(name) + + old_backend = matplotlib.get_backend() + matplotlib.use('agg') + if 'mosaic' in init_args: + self.fig, _ = matplotlib.pyplot.subplot_mosaic(**init_args) + else: + self.fig, _ = matplotlib.pyplot.subplots(**init_args) + matplotlib.use(old_backend) + + def invoke_redraw(): + args = [self.fig] + for name in all_names: + args.append(getattr(self, name)) + redraw_rule(*args) + self.render(self.fig) + self.on(' '.join(all_names), invoke_redraw) + invoke_redraw() +