Support numpy.ndarray and torch.tensor properties.

This commit is contained in:
David Bau
2022-03-28 20:46:28 -04:00
parent b7140c4d96
commit c12f12499e
+21 -3
View File
@@ -48,7 +48,7 @@ import io
import json
import html
import re
from inspect import signature
from inspect import signature, getmro
from . import show
@@ -237,7 +237,7 @@ class Widget(Model):
Returns the HTML code for the widget.
'''
self._viewcount += 1
json_data = json.dumps({
json_data = jsondump({
k: v.value for k, v in vars(self).items()
if isinstance(v, Property)})
json_data = re.sub('</', '<\\/', json_data)
@@ -285,7 +285,7 @@ class Widget(Model):
colab_output.eval_js(minify(f"""
(window.send_{id(self)} = window.send_{id(self)} ||
new BroadcastChannel("channel_{id(self)}")
).postMessage({json.dumps(args)});
).postMessage({jsondump(args)});
"""), ignore_result=True)
elif WIDGET_ENV == 'jupyter':
if not self._comms:
@@ -1020,6 +1020,24 @@ class Image(Widget):
# Utils
##########################################################################
def baseclass_named(obj, *class_names):
'''
Detects if obj is a subclass of a class named clsname, without requiring import
of the class.
'''
for x in getmro(type(obj)):
if (x.__module__ + '.' + x.__name__) in class_names:
return True
return False
class PermissiveEncoder(json.JSONEncoder):
def default(self, obj):
if baseclass_named(obj, 'numpy.ndarray', 'torch.Tensor'):
return obj.tolist()
return json.JSONEncoder.default(self, obj)
def jsondump(d):
return json.dumps(d, cls=PermissiveEncoder)
def minify(t):
# TODO: plug in some more real minification.