mirror of
https://github.com/wassname/baukit.git
synced 2026-06-27 19:46:31 +08:00
Support numpy.ndarray and torch.tensor properties.
This commit is contained in:
+21
-3
@@ -48,7 +48,7 @@ import io
|
||||
import json
|
||||
import html
|
||||
import re
|
||||
from inspect import signature
|
||||
from inspect import signature, getmro
|
||||
from . import show
|
||||
|
||||
|
||||
@@ -237,7 +237,7 @@ class Widget(Model):
|
||||
Returns the HTML code for the widget.
|
||||
'''
|
||||
self._viewcount += 1
|
||||
json_data = json.dumps({
|
||||
json_data = jsondump({
|
||||
k: v.value for k, v in vars(self).items()
|
||||
if isinstance(v, Property)})
|
||||
json_data = re.sub('</', '<\\/', json_data)
|
||||
@@ -285,7 +285,7 @@ class Widget(Model):
|
||||
colab_output.eval_js(minify(f"""
|
||||
(window.send_{id(self)} = window.send_{id(self)} ||
|
||||
new BroadcastChannel("channel_{id(self)}")
|
||||
).postMessage({json.dumps(args)});
|
||||
).postMessage({jsondump(args)});
|
||||
"""), ignore_result=True)
|
||||
elif WIDGET_ENV == 'jupyter':
|
||||
if not self._comms:
|
||||
@@ -1020,6 +1020,24 @@ class Image(Widget):
|
||||
# Utils
|
||||
##########################################################################
|
||||
|
||||
def baseclass_named(obj, *class_names):
|
||||
'''
|
||||
Detects if obj is a subclass of a class named clsname, without requiring import
|
||||
of the class.
|
||||
'''
|
||||
for x in getmro(type(obj)):
|
||||
if (x.__module__ + '.' + x.__name__) in class_names:
|
||||
return True
|
||||
return False
|
||||
|
||||
class PermissiveEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if baseclass_named(obj, 'numpy.ndarray', 'torch.Tensor'):
|
||||
return obj.tolist()
|
||||
return json.JSONEncoder.default(self, obj)
|
||||
|
||||
def jsondump(d):
|
||||
return json.dumps(d, cls=PermissiveEncoder)
|
||||
|
||||
def minify(t):
|
||||
# TODO: plug in some more real minification.
|
||||
|
||||
Reference in New Issue
Block a user