Implement a more general version of the merge command

Also, add tests for this command.

This fixes #71.
This commit is contained in:
Martin Baeuml
2014-04-15 23:21:05 +02:00
parent d6a66eb0e6
commit ad05875f0b
3 changed files with 164 additions and 56 deletions
+45 -51
View File
@@ -5,6 +5,7 @@ from pprint import pprint
from sloth.core.cli import BaseCommand, CommandError
from sloth.annotations.container import *
from optparse import make_option
from operator import itemgetter
import logging
@@ -124,10 +125,9 @@ class AppendFilesCommand(BaseCommand):
class MergeFilesCommand(BaseCommand):
"""
Merge annotations of two label files and create a new one from it.
Currently, only video annotation files are supported.
If both input files have annotations for the same frame number, the result
will contain the union of both annotations.
Output format will be determined by the file suffix of output.
"""
args = '<labelfile 1> <labelfile 2> <output>'
@@ -151,57 +151,51 @@ class MergeFilesCommand(BaseCommand):
an3 = self.merge_annotations(an1, an2)
logger.debug("saving annotations to %s" % output)
containerOut = self.labeltool._container_factory.create(output)
containerOut.save(an3, output)
out_container = self.labeltool._container_factory.create(output)
out_container.save(an3, output)
def merge_annotations(self, an1, an2):
# I could also think of an implementation merging an1 and an2, and flattening the lists of lists
# that are obtained
assert(len(an1) == 1 and len(an2) == 1)
d1 = an1[0]
d2 = an2[0]
if(d1['class'] != 'video'):
raise NotImplemented('mergefiles: Currently, only annotation files from video can be merged.')
if(d2['class'] != d1['class']):
raise CommandError("mergefiles: Both annotation files have to be of the same type (%s vs. %s)." % (d1['class'], d2['class']))
if(d1['filename'] != d2['filename']):
raise CommandError('mergefiles: Both annotation files must annotate the same video file.')
def merge_annotations(self, an1, an2, match_key='filename'):
"""This merges all annotations from an2 into an1."""
for item in an2:
matching_items = [it1 for it1 in an1 if
it1['class'] == item['class'] and
it1[match_key] == item[match_key]]
# If we can't find a match, we just append the item to an1.
if len(matching_items) == 0:
an1.append(item)
continue
# We found at least one match, just take the first.
# But put out a warning if there were multiple possible matches.
if len(matching_items) > 1:
logger.warning('Found %d possible matches for %s',
len(matching_items), item['filename'])
match_item = matching_items[0]
# Update the keys first.
for key, value in item.iteritems():
if key == 'annotations':
continue
if match_item['class'] == 'video' and key == 'frames':
continue
if key in match_item and match_item[key] != value:
logger.warning('found matching key %s, but values differ: %s <-> %s',
key, str(value), str(value))
continue
match_item[key] = value
# Merge frames.
if match_item['class'] == 'video':
match_item['frames'] = self.merge_annotations(match_item['frames'], item['frames'], 'num')
match_item['frames'].sort(key=itemgetter('num'))
# Merge annotations.
if 'annotations' in match_item:
match_item['annotations'].extend(item.get('annotations', []))
assert(d1['frames'] != None)
assert(d2['frames'] != None)
frames1 = d1['frames']
frames2 = d2['frames']
# collect list of nums
frameNums1 = set()
for frame in frames1:
frameNums1.add(frame['num'])
# make frames2 accessible by frame number
frameNums2 = dict()
for frame in frames2:
frameNums2[frame['num']] = frame
for frame in frames1:
num = frame['num']
# look for frame with same timestamp in frames2
if num in frameNums2:
# update annotations
frame['annotations'].extend(frameNums2[num]['annotations'])
# append frames with nums only in frames2 to frames1
numsOnlyIn2 = set(frameNums2.keys()) - frameNums1
for key in numsOnlyIn2:
frames1.append(frameNums2[key])
from operator import itemgetter
frames1.sort(key=itemgetter('num'))
return an1
+114
View File
@@ -0,0 +1,114 @@
from sloth.core.commands import *
def test_merge_command_same_images():
ann1 = [{'class': 'image', 'filename': 'abc.jpg',
'annotations': []},
{'class': 'image', 'filename': 'def.jpg',
'annotations': [
{'class': 'point', 'x': 10, 'y': 100}
]}]
ann2 = [{'class': 'image', 'filename': 'abc.jpg',
'annotations': []},
{'class': 'image', 'filename': 'def.jpg',
'custom': 1,
'annotations': [
{'class': 'point', 'x': 10, 'y': 100}
]}]
mc = MergeFilesCommand()
ann3 = mc.merge_annotations(ann1, ann2)
assert len(ann3) == 2
assert ann3[1].get('custom') == 1
def test_merge_command_different_images():
ann1 = [{'class': 'image', 'filename': 'abc.jpg',
'annotations': []},
{'class': 'image', 'filename': 'def.jpg',
'annotations': [
{'class': 'point', 'x': 10, 'y': 100}
]}]
ann2 = [{'class': 'image', 'filename': 'abc1.jpg',
'annotations': []},
{'class': 'image', 'filename': 'def2.jpg',
'annotations': [
{'class': 'point', 'x': 10, 'y': 100}
]}]
mc = MergeFilesCommand()
ann3 = mc.merge_annotations(ann1, ann2)
assert len(ann3) == 4
def test_merge_command_empty():
ann1 = []
ann2 = [{'class': 'image', 'filename': 'abc1.jpg',
'annotations': []},
{'class': 'image', 'filename': 'def2.jpg',
'annotations': [
{'class': 'point', 'x': 10, 'y': 100}
]}]
mc = MergeFilesCommand()
ann3 = mc.merge_annotations(ann1, ann2)
assert len(ann3) == 2
def test_merge_command_different_same_videos():
ann1 = [{'class': 'image', 'filename': 'abc.jpg',
'annotations': []},
{'class': 'video', 'filename': 'def.avi',
'frames': [
{'class': 'frame', 'num': 10, 'timestamp': 100.0,
'annotations': [
{'class': 'point', 'x': 10, 'y': 100}
]},
{'class': 'frame', 'num': 12, 'timestamp': 102.0,
'annotations': [
{'class': 'point', 'x': 10, 'y': 100}
]}
]}]
ann2 = [{'class': 'video', 'filename': 'def.avi',
'frames': [
{'class': 'frame', 'num': 10, 'timestamp': 100.0,
'annotations': [
{'class': 'point', 'x': 10, 'y': 100}
]},
{'class': 'frame', 'num': 11, 'timestamp': 101.0,
'annotations': [
{'class': 'point', 'x': 10, 'y': 100}
]}]
},
{'class': 'image', 'filename': 'def2.jpg',
'annotations': [
{'class': 'point', 'x': 10, 'y': 100}
]}]
mc = MergeFilesCommand()
ann3 = mc.merge_annotations(ann1, ann2)
assert len(ann3) == 3
item = [it for it in ann3 if it['filename'] == 'def.avi'][0]
assert len(item['frames']) == 3
assert item['frames'][0]['num'] == 10
assert item['frames'][1]['num'] == 11
assert item['frames'][2]['num'] == 12
assert len(item['frames'][0]['annotations']) == 2
def test_merge_command_same_file(tmpdir):
class LabelToolMockup:
container_config = (('*', 'sloth.annotations.container.JsonContainer'),)
_container_factory = AnnotationContainerFactory(container_config)
mc = MergeFilesCommand()
mc.labeltool = LabelToolMockup()
output_fname = str(tmpdir.join('output.json'))
mc.handle('tests/data/example1_labels.json', 'tests/data/example1_labels.json', output_fname)
import json
merged_annotations = json.load(open(output_fname))
assert len(merged_annotations) == 2
assert len(merged_annotations[0]['annotations']) == 4
assert len(merged_annotations[1]['annotations']) == 2
+5 -5
View File
@@ -1,31 +1,31 @@
[
{
"type": "image",
"class": "image",
"annotations": [
{
"height": 60.0,
"width": 46.0,
"y": 105.0,
"x": 346.0,
"type": "rect"
"class": "rect"
},
{
"height": 58.0,
"width": 56.0,
"y": 119.0,
"x": 636.0,
"type": "rect"
"class": "rect"
}
],
"filename": "image1.jpg"
},
{
"type": "image",
"class": "image",
"annotations": [
{
"y": 155.0,
"x": 409.0,
"type": "point"
"class": "point"
}
],
"filename": "image2.jpg"