Files
sloth/sloth/annotations/model.py
T
Martin Baeuml 2ca1f84156 pep8 fixes
2014-03-03 10:24:17 +01:00

814 lines
27 KiB
Python

"""
The annotationmodel module contains the classes for the AnnotationModel.
"""
import os.path
import time
import logging
import copy
from collections import MutableMapping
from PyQt4.QtGui import QTreeView, QItemSelection, QItemSelectionModel, QSortFilterProxyModel, QBrush
from PyQt4.QtCore import QModelIndex, QAbstractItemModel, Qt, pyqtSignal, QVariant
LOG = logging.getLogger(__name__)
ItemRole, DataRole, ImageRole = [Qt.UserRole + i + 1 for i in range(3)]
class ModelItem:
def __init__(self):
self._loaded = True
self._model = None
self._parent = None
self._row = -1
if not hasattr(self, "_children"):
self._children = []
def _load(self, index):
pass
def _ensureLoaded(self, index):
if not self._loaded:
if not isinstance(self._children[index], ModelItem):
# Need to set the before actually loading to avoid
# endless recursion...
self._load(index)
return True
return False
def _ensureAllLoaded(self):
if not self._loaded:
for i in range(len(self._children)):
self._ensureLoaded(i)
return True
return False
def hasChildren(self):
return len(self._children) > 0
def childHasChildren(self, row):
return self.childAt(row).hasChildren()
def row(self):
return self._row
def rowCount(self):
return len(self._children)
def childRowCount(self, pos):
return self.childAt(pos).rowCount()
def children(self):
self._ensureAllLoaded()
return self._children
def model(self):
return self._model
def parent(self):
return self._parent
def data(self, role=Qt.DisplayRole, column=0):
if role == Qt.DisplayRole:
return ""
if role == Qt.BackgroundRole:
c = self.getColor()
if c is not None:
return QBrush(c)
if role == ItemRole:
return self
else:
return None
def childData(self, role=Qt.DisplayRole, row=0, column=0):
return self.childAt(row).data(role, column)
def flags(self, column):
return Qt.ItemIsEnabled | Qt.ItemIsSelectable
def childFlags(self, row, column):
return self.childAt(row).flags(column)
def setData(self, value, role=Qt.DisplayRole, column=0):
return False
def childAt(self, pos):
self._ensureLoaded(pos)
return self._children[pos]
def getPreviousSibling(self, step=1):
return self.getSibling(self._row-step)
def getNextSibling(self, step=1):
return self.getSibling(self._row+step)
def getSibling(self, row):
if self._parent is not None:
try:
return self._parent.childAt(row)
except:
pass
return None
def _attachToModel(self, model):
#assert self.model() is None
#assert self.parent() is not None
#assert self.parent().model() is not None
self._model = model
for item in self._children:
if isinstance(item, ModelItem):
item._attachToModel(model)
def index(self, column=0):
if self._parent is None:
return QModelIndex()
if column >= self._model.columnCount():
return QModelIndex()
return self._model.createIndex(self._row, column, self._parent)
def addChildSorted(self, item, signalModel=True):
self.insertChild(-1, item, signalModel=signalModel)
def appendChild(self, item, signalModel=True):
self.insertChild(-1, item, signalModel=signalModel)
def replaceChild(self, pos, item):
item._parent = self
item._row = pos
self._children[pos] = item
if self._model is not None:
self._children[pos]._attachToModel(self._model)
def insertChild(self, pos, item, signalModel=True):
if pos >= 0:
next_row = pos
else:
next_row = len(self._children)
if self._model is not None and signalModel:
self._model.beginInsertRows(self.index(), next_row, next_row)
item._parent = self
item._row = next_row
self._children.insert(next_row, item)
if pos >= 0:
for i, c in enumerate(self._children):
c._row = i
if self._model is not None:
item._attachToModel(self._model)
if signalModel:
self._model.endInsertRows()
def appendChildren(self, items, signalModel=True):
#for item in items:
#assert isinstance(item, ModelItem)
#assert item.model() is None
#assert item.parent() is None
next_row = len(self._children)
if self._model is not None and signalModel:
self._model.beginInsertRows(self.index(), next_row, next_row + len(items) - 1)
for i, item in enumerate(items):
item._parent = self
item._row = next_row + i
self._children.append(item)
if self._model is not None:
for item in items:
item._attachToModel(self._model)
if signalModel:
self._model.endInsertRows()
def delete(self):
if self._parent is None:
raise RuntimeError("Trying to delete orphan")
else:
self._parent.deleteChild(self)
def deleteChild(self, arg):
# Grandchildren are considered deleted automatically
if isinstance(arg, ModelItem):
return self.deleteChild(self._children.index(arg))
else:
if arg < 0 or arg >= len(self._children):
raise IndexError("child index out of range")
self._ensureLoaded(arg)
if self._model is not None:
self._model.beginRemoveRows(self.index(), arg, arg)
del self._children[arg]
# Update cached row numbers
for i, c in enumerate(self._children):
c._row = i
if self._model is not None:
self._model.endRemoveRows()
def deleteAllChildren(self):
self._ensureAllLoaded()
if self._model is not None:
self._model.beginRemoveRows(self.index(), 0, len(self._children) - 1)
self._children = []
if self._model is not None:
self._model.endRemoveRows()
def getColor(self):
return None
class RootModelItem(ModelItem):
def __init__(self, model, files):
ModelItem.__init__(self)
self._model = model
self._toload = []
for f in files:
self._toload.append(f)
self._children.append(f)
self._loaded = False
def _load(self, index):
self._toload.remove(self._children[index])
fi = FileModelItem.create(self._children[index])
self.replaceChild(index, fi)
if len(self._toload) == 0:
self._loaded = True
def childHasChildren(self, pos):
if isinstance(self._children[pos], ModelItem):
return self._children[pos].hasChildren()
else:
# Hack to speed things up...
return True
def childFlags(self, row, column):
if isinstance(self._children[row], ModelItem):
return self._children[row].flags(column)
else:
return Qt.ItemIsEnabled | Qt.ItemIsSelectable
def appendChild(self, item):
if isinstance(item, FileModelItem):
ModelItem.appendChild(self, item)
else:
raise TypeError("Only FileModelItems can be attached to RootModelItem")
def appendFileItem(self, fileinfo):
item = FileModelItem.create(fileinfo)
LOG.debug("created FileModelItem of type %s" % item.__class__.__name__)
self.appendChild(item)
return item
def appendFileItems(self, fileinfos):
start1 = time.time()
items = [FileModelItem.create(fi) for fi in fileinfos]
diff1 = time.time() - start1
start2 = time.time()
self.appendChildren(items)
diff2 = time.time() - start2
LOG.debug("Creation of ModelItems: %.2fs, addition to model: %.2fs" % (diff1, diff2))
def numFiles(self):
return len(self.children())
def numAnnotations(self):
count = 0
for ann in self._model.iterator(AnnotationModelItem):
count += 1
return count
def getAnnotations(self):
return [child.getAnnotations() for child in self.children()
if hasattr(child, 'getAnnotations')]
class KeyValueModelItem(ModelItem, MutableMapping):
def __init__(self, hidden=[], properties=None):
ModelItem.__init__(self)
self._dict = {}
self._items = {}
self._hidden = hidden + [None, 'class', 'unlabeled', 'unconfirmed']
# dummy key/value so that pyqt does not convert the dict
# into a QVariantMap while communicating with the Views
self._dict[None] = None
if properties is not None:
self._dict.update(properties)
items_to_add = []
for key in self._dict.keys():
if key not in self._hidden:
item = KeyValueRowModelItem(key)
self._items[key] = item
items_to_add.append(item)
items_to_add.sort(key=lambda x: x.key())
self.appendChildren(items_to_add, False)
def addChildSorted(self, item, signalModel=True):
if isinstance(item, KeyValueRowModelItem):
next_row = 0
for child in self._children:
if not isinstance(child, KeyValueRowModelItem) or child.key() > item.key():
break
next_row += 1
self.insertChild(next_row, item, signalModel)
else:
self.appendChild(item, signalModel)
# Methods for MutableMapping
def __len__(self):
return len(self._dict)
def __iter__(self):
return iter(self._dict.keys())
def __getitem__(self, key):
return self._dict[key]
def _emitDataChanged(self, key=None):
if self.model() is not None:
if key is not None and key in self._items:
index_tl = self._items[key].index()
index_br = self._items[key].index(1)
else:
index_tl = self.index()
index_br = self.index(1)
self.model().dataChanged.emit(index_tl, index_br)
def __setitem__(self, key, value, signalModel=True):
if key not in self._dict:
self._dict[key] = value
if key not in self._hidden:
self._items[key] = KeyValueRowModelItem(key)
self.addChildSorted(self._items[key], signalModel=signalModel)
if signalModel:
self._emitDataChanged(key)
elif self._dict[key] != value:
self._dict[key] = value
# TODO: Emit for hidden key/values?
if signalModel:
self._emitDataChanged(key)
def __delitem__(self, key):
del self._dict[key]
if key in self._items:
self.deleteChild(self._items[key])
def update(self, kvs):
for key, value in kvs.iteritems():
self.__setitem__(key, value, False)
self._emitDataChanged()
def has_key(self, key):
return key in self._dict
def clear(self):
if len(self._dict) > 0:
MutableMapping.clear(self)
def getAnnotations(self):
res = copy.deepcopy(self._dict)
if None in res: del res[None]
return res
def isUnlabeled(self):
return 'unlabeled' in self._dict and self._dict['unlabeled']
def setUnlabeled(self, val):
if val:
self._dict['unlabeled'] = val
else:
if 'unlabeled' in self._dict:
del self['unlabeled']
self._emitDataChanged('unlabeled')
def isUnconfirmed(self):
return 'unconfirmed' in self._dict and self._dict['unconfirmed']
def setUnconfirmed(self, val):
if val:
self._dict['unconfirmed'] = val
else:
if 'unconfirmed' in self._dict:
del self['unconfirmed']
self._emitDataChanged('unconfirmed')
class FileModelItem(KeyValueModelItem):
def __init__(self, fileinfo, hidden=['filename']):
KeyValueModelItem.__init__(self, hidden=hidden, properties=fileinfo)
def data(self, role=Qt.DisplayRole, column=0):
if role == Qt.DisplayRole:
if column == 0:
return os.path.basename(self['filename'])
elif column == 1 and self.isUnlabeled():
return '[unlabeled]'
return ModelItem.data(self, role, column)
def getColor(self):
if self.isUnlabeled():
return Qt.red
return None
@staticmethod
def create(fileinfo):
if fileinfo['class'] == 'image':
return ImageFileModelItem(fileinfo)
elif fileinfo['class'] == 'video':
return VideoFileModelItem(fileinfo)
class ImageModelItem(ModelItem):
def __init__(self, annotations):
ModelItem.__init__(self)
items_to_add = [AnnotationModelItem(ann) for ann in annotations]
self.appendChildren(items_to_add, False)
def addAnnotation(self, ann, signalModel=True):
self.addChildSorted(AnnotationModelItem(ann), signalModel=signalModel)
def annotations(self):
for child in self._children:
if isinstance(child, AnnotationModelItem):
yield child
def confirmAll(self):
for ann in self.annotations():
ann.setUnconfirmed(False)
class ImageFileModelItem(FileModelItem, ImageModelItem):
def __init__(self, fileinfo):
self._annotation_data = fileinfo.get("annotations", [])
if "annotations" in fileinfo:
del fileinfo["annotations"]
FileModelItem.__init__(self, fileinfo)
ImageModelItem.__init__(self, [])
self._toload = []
for ann in self._annotation_data:
self._children.append(ann)
self._toload.append(ann)
self._loaded = False
def _load(self, index):
self._toload.remove(self._children[index])
ann = AnnotationModelItem(self._children[index])
self.replaceChild(index, ann)
if len(self._toload) == 0:
self._loaded = True
def data(self, role=Qt.DisplayRole, column=0):
if role == DataRole:
return self._dict
return FileModelItem.data(self, role, column)
def getAnnotations(self):
self._ensureAllLoaded()
fi = KeyValueModelItem.getAnnotations(self)
fi['annotations'] = [child.getAnnotations() for child in self.children()
if hasattr(child, 'getAnnotations')]
return fi
class VideoFileModelItem(FileModelItem):
def __init__(self, fileinfo):
frameinfos = fileinfo.get("frames", [])
if "frames" in fileinfo:
del fileinfo["frames"]
FileModelItem.__init__(self, fileinfo)
items_to_add = [FrameModelItem(frameinfo) for frameinfo in frameinfos]
self.appendChildren(items_to_add, False)
def getAnnotations(self):
self._ensureAllLoaded()
fi = KeyValueModelItem.getAnnotations(self)
fi['frames'] = [child.getAnnotations() for child in self.children()]
return fi
class FrameModelItem(ImageModelItem, KeyValueModelItem):
def __init__(self, frameinfo):
annotations = frameinfo.get("annotations", [])
if "annotations" in frameinfo:
del frameinfo["annotations"]
KeyValueModelItem.__init__(self, properties=frameinfo)
ImageModelItem.__init__(self, annotations)
def framenum(self):
return int(self.get('num', -1))
def timestamp(self):
return float(self.get('timestamp', -1))
def data(self, role=Qt.DisplayRole, column=0):
if role == Qt.DisplayRole:
if column == 0:
return "%d / %.3f" % (self.framenum(), self.timestamp())
elif column == 1 and self.isUnlabeled():
return '[unlabeled]'
return ImageModelItem.data(self, role, column)
def getColor(self):
if self.isUnlabeled():
return Qt.red
return None
def getAnnotations(self):
fi = KeyValueModelItem.getAnnotations(self)
fi['annotations'] = [child.getAnnotations() for child in self.children()
if hasattr(child, 'getAnnotations')]
return fi
class AnnotationModelItem(KeyValueModelItem):
def __init__(self, annotation):
KeyValueModelItem.__init__(self, properties=annotation)
# Delegated from QAbstractItemModel
def data(self, role=Qt.DisplayRole, column=0):
if role == Qt.DisplayRole:
if column == 0:
try:
return self['class']
except KeyError:
LOG.error('Could not find key class in annotation item. Please check your label file.')
return '<error - no class set>'
elif column == 1 and self.isUnconfirmed():
return '[unconfirmed]'
else:
return ""
elif role == DataRole:
return self._annotation
return ModelItem.data(self, role, column)
def getColor(self):
if self.isUnconfirmed():
return Qt.red
return None
class KeyValueRowModelItem(ModelItem):
def __init__(self, key, read_only=True):
ModelItem.__init__(self)
self._key = key
self._read_only = read_only
def key(self):
return self._key
def data(self, role=Qt.DisplayRole, column=0):
if role == Qt.DisplayRole:
if column == 0:
return self._key
elif column == 1:
return self.parent()[self._key]
else:
return None
else:
return ModelItem.data(self, role, column)
def flags(self, column):
if self._read_only:
return Qt.NoItemFlags
else:
if column == 1:
return Qt.ItemIsEnabled | Qt.ItemIsEditable
else:
return Qt.ItemIsEnabled
def setData(self, value, role=Qt.DisplayRole, column=0):
if column == 1:
if isinstance(value, QVariant):
value = value.toPyObject()
self._parent[self._key] = value
return True
return False
class AnnotationModel(QAbstractItemModel):
# signals
dirtyChanged = pyqtSignal(bool, name='dirtyChanged')
def __init__(self, annotations, parent=None):
QAbstractItemModel.__init__(self, parent)
start = time.time()
self._annotations = annotations
self._dirty = False
self._root = RootModelItem(self, annotations)
diff = time.time() - start
LOG.info("Created AnnotationModel in %.2fs" % (diff, ))
self.dataChanged.connect(self.onDataChanged)
self.rowsInserted.connect(self.onDataChanged)
self.rowsRemoved.connect(self.onDataChanged)
# QAbstractItemModel overloads
def hasChildren(self, index=QModelIndex()):
if index.column() > 0:
return 0
if not index.isValid():
return self._root.hasChildren()
parent = self.parentFromIndex(index)
return parent.childHasChildren(index.row())
def columnCount(self, index=QModelIndex()):
return 2
def rowCount(self, index=QModelIndex()):
# Only items with column==1 can have children
if index.column() > 0:
return 0
if not index.isValid():
return self._root.rowCount()
parent = self.parentFromIndex(index)
return parent.childRowCount(index.row())
def parent(self, index):
if index is None:
return QModelIndex()
return self.parentFromIndex(index).index()
def index(self, row, column, parent_idx=QModelIndex()):
# Handle invalid rows/columns
if row < 0 or column < 0:
return QModelIndex()
# Only items with column == 0 can have children
if parent_idx.isValid() and parent_idx.column() > 0:
return QModelIndex()
# Get parent. This returns the root if parent_idx is invalid.
parent = self.itemFromIndex(parent_idx)
if row < 0 or row >= parent.rowCount():
return QModelIndex()
if column < 0 or column >= self.columnCount():
return QModelIndex()
return self.createIndex(row, column, parent)
def data(self, index, role=Qt.DisplayRole):
if not index.isValid():
return None
parent = self.parentFromIndex(index)
return parent.childData(role, index.row(), index.column())
def setData(self, index, value, role=Qt.EditRole):
if not index.isValid():
return False
item = self.itemFromIndex(index)
return item.setData(value, role, index.column())
def flags(self, index):
if not index.isValid():
return self._root.flags(index.column())
parent = self.parentFromIndex(index)
return parent.childFlags(index.row(), index.column())
def headerData(self, section, orientation, role):
if orientation == Qt.Horizontal and role == Qt.DisplayRole:
if section == 0: return "File/Type/Key"
elif section == 1: return "Value"
return None
# Own methods
def root(self):
return self._root
def dirty(self):
return self._dirty
def setDirty(self, dirty=True):
if dirty != self._dirty:
LOG.debug("Setting model state to dirty")
self._dirty = dirty
self.dirtyChanged.emit(self._dirty)
def onDataChanged(self, *args):
self.setDirty()
def itemFromIndex(self, index):
index = QModelIndex(index) # explicitly convert from QPersistentModelIndex
if index.isValid():
return index.internalPointer().childAt(index.row())
return self._root
def parentFromIndex(self, index):
index = QModelIndex(index) # explicitly convert from QPersistentModelIndex
if index.isValid():
return index.internalPointer()
return self._root
def iterator(self, _class=None, predicate=None, start=None, maxlevels=10000):
# Visit all nodes
level = 0
item = start if start is not None else self.root()
while item is not None:
# Return item
if _class is None or isinstance(item, _class):
if predicate is None or predicate(item):
yield item
# Get next item
if item.rowCount() > 0 and level < maxlevels:
level += 1
item = item.childAt(0)
else:
next_sibling = item.getNextSibling()
if next_sibling is not None:
item = next_sibling
else:
level -= 1
ancestor = item.parent()
item = None
while ancestor is not None:
ancestor_sibling = ancestor.getNextSibling()
if ancestor_sibling is not None:
item = ancestor_sibling
break
level -= 1
ancestor = ancestor.parent()
#######################################################################################
# proxy model
#######################################################################################
class AnnotationSortFilterProxyModel(QSortFilterProxyModel):
"""Adds sorting and filtering support to the AnnotationModel without basically
any implementation effort. Special functions such as ``insertPoint()`` just
call the source models respective functions."""
def __init__(self, parent=None):
super(AnnotationSortFilterProxyModel, self).__init__(parent)
def fileIndex(self, index):
fi = self.sourceModel().fileIndex(self.mapToSource(index))
return self.mapFromSource(fi)
def itemFromIndex(self, index):
return self.sourceModel().itemFromIndex(self.mapToSource(index))
def baseDir(self):
return self.sourceModel().baseDir()
def insertPoint(self, pos, parent, **kwargs):
return self.sourceModel().insertPoint(pos, self.mapToSource(parent), **kwargs)
def insertRect(self, rect, parent, **kwargs):
return self.sourceModel().insertRect(rect, self.mapToSource(parent), **kwargs)
def insertMask(self, fname, parent, **kwargs):
return self.sourceModel().insertMask(fname, self.mapToSource(parent), **kwargs)
def insertFile(self, filename):
return self.sourceModel().insertFile(filename)
#######################################################################################
# view
#######################################################################################
class AnnotationTreeView(QTreeView):
selectedItemsChanged = pyqtSignal(object)
def __init__(self, parent=None):
super(AnnotationTreeView, self).__init__(parent)
self.setUniformRowHeights(True)
self.setSelectionMode(QTreeView.ExtendedSelection)
self.setSelectionBehavior(QTreeView.SelectRows)
self.setAllColumnsShowFocus(True)
self.setAlternatingRowColors(True)
#self.setEditTriggers(QAbstractItemView.SelectedClicked)
self.setSortingEnabled(True)
self.setAnimated(True)
self.expanded.connect(self.onExpanded)
def resizeColumns(self):
for column in range(self.model().columnCount(QModelIndex())):
self.resizeColumnToContents(column)
def onExpanded(self):
self.resizeColumns()
def setModel(self, model):
QTreeView.setModel(self, model)
self.resizeColumns()
def rowsInserted(self, index, start, end):
QTreeView.rowsInserted(self, index, start, end)
self.resizeColumns()
def setSelectedItems(self, items):
#block = self.blockSignals(True)
sel = QItemSelection()
for item in items:
sel.merge(QItemSelection(item.index(), item.index(1)), QItemSelectionModel.SelectCurrent)
if set(sel) != set(self.selectionModel().selection()):
self.selectionModel().clear()
self.selectionModel().select(sel, QItemSelectionModel.Select)
#self.blockSignals(block)
def selectionChanged(self, selected, deselected):
items = [ self.model().itemFromIndex(index) for index in self.selectionModel().selectedIndexes()]
self.selectedItemsChanged.emit(items)
QTreeView.selectionChanged(self, selected, deselected)