Merge pull request #1826 from jni/rag-generic

Allow construction of simple RAGs from label images
This commit is contained in:
Emmanuelle Gouillart
2015-12-22 17:37:02 +01:00
2 changed files with 85 additions and 77 deletions
+67 -77
View File
@@ -1,16 +1,6 @@
try:
import networkx as nx
except ImportError:
msg = "Graph functions require networkx, which is not installed"
class nx:
class Graph:
def __init__(self, *args, **kwargs):
raise ImportError(msg)
import warnings
warnings.warn(msg)
import networkx as nx
import numpy as np
from numpy.lib.stride_tricks import as_strided
from scipy import ndimage as ndi
import math
from ... import draw, measure, segmentation, util, color
@@ -51,21 +41,79 @@ def min_weight(graph, src, dst, n):
return min(w1, w2)
def _add_edge_filter(values, graph):
"""Create edge in `graph` between central element of `values` and the rest.
Add an edge between the middle element in `values` and
all other elements of `values` into `graph`. ``values[len(values) // 2]``
is expected to be the central value of the footprint used.
Parameters
----------
values : array
The array to process.
graph : RAG
The graph to add edges in.
Returns
-------
0 : float
Always returns 0. The return value is required so that `generic_filter`
can put it in the output array, but it is ignored by this filter.
"""
values = values.astype(int)
center = values[len(values) // 2]
for value in values:
if value != center and not graph.has_edge(center, value):
graph.add_edge(center, value)
return 0.
class RAG(nx.Graph):
"""
The Region Adjacency Graph (RAG) of an image, subclasses
`networx.Graph <http://networkx.github.io/documentation/latest/reference/classes.graph.html>`_
Parameters
----------
label_image : array of int
An initial segmentation, with each region labeled as a different
integer. Every unique value in ``label_image`` will correspond to
a node in the graph.
connectivity : int in {1, ..., ``label_image.ndim``}, optional
The connectivity between pixels in ``label_image``. For a 2D image,
a connectivity of 1 corresponds to immediate neighbors up, down,
left, and right, while a connectivity of 2 also includes diagonal
neighbors. See `scipy.ndimage.generate_binary_structure`.
data : networkx Graph specification, optional
Initial or additional edges to pass to the NetworkX Graph
constructor. See `networkx.Graph`. Valid edge specifications
include edge list (list of tuples), NumPy arrays, and SciPy
sparse matrices.
**attr : keyword arguments, optional
Additional attributes to add to the graph.
"""
def __init__(self, data=None, **attr):
def __init__(self, label_image=None, connectivity=1, data=None, **attr):
super(RAG, self).__init__(data, **attr)
try:
self.max_id = max(self.nodes_iter())
except ValueError:
# Empty sequence
if self.number_of_nodes() == 0:
self.max_id = 0
else:
self.max_id = max(self.nodes_iter())
if label_image is not None:
fp = ndi.generate_binary_structure(label_image.ndim, connectivity)
ndi.generic_filter(
label_image,
function=_add_edge_filter,
footprint=fp,
mode='nearest',
output=as_strided(np.empty((1,), dtype=np.float_),
shape=label_image.shape,
strides=((0,) * label_image.ndim)),
extra_arguments=(self,))
def merge_nodes(self, src, dst, weight_func=min_weight, in_place=True,
extra_arguments=[], extra_keywords={}):
@@ -172,36 +220,6 @@ class RAG(nx.Graph):
super(RAG, self).add_node(n)
def _add_edge_filter(values, graph):
"""Create edge in `g` between the first element of `values` and the rest.
Add an edge between the first element in `values` and
all other elements of `values` in the graph `g`. `values[0]`
is expected to be the central value of the footprint used.
Parameters
----------
values : array
The array to process.
graph : RAG
The graph to add edges in.
Returns
-------
0 : int
Always returns 0. The return value is required so that `generic_filter`
can put it in the output array.
"""
values = values.astype(int)
current = values[0]
for value in values[1:]:
if value != current:
graph.add_edge(current, value)
return 0
def rag_mean_color(image, labels, connectivity=2, mode='distance',
sigma=255.0):
"""Compute the Region Adjacency Graph using mean colors.
@@ -224,7 +242,7 @@ def rag_mean_color(image, labels, connectivity=2, mode='distance',
Pixels with a squared distance less than `connectivity` from each other
are considered adjacent. It can range from 1 to `labels.ndim`. Its
behavior is the same as `connectivity` parameter in
`scipy.ndimage.filters.generate_binary_structure`.
`scipy.ndimage.generate_binary_structure`.
mode : {'distance', 'similarity'}, optional
The strategy to assign edge weights.
@@ -263,35 +281,7 @@ def rag_mean_color(image, labels, connectivity=2, mode='distance',
http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.11.5274
"""
graph = RAG()
# The footprint is constructed in such a way that the first
# element in the array being passed to _add_edge_filter is
# the central value.
fp = ndi.generate_binary_structure(labels.ndim, connectivity)
for d in range(fp.ndim):
fp = fp.swapaxes(0, d)
fp[0, ...] = 0
fp = fp.swapaxes(0, d)
# For example
# if labels.ndim = 2 and connectivity = 1
# fp = [[0,0,0],
# [0,1,1],
# [0,1,0]]
#
# if labels.ndim = 2 and connectivity = 2
# fp = [[0,0,0],
# [0,1,1],
# [0,1,1]]
ndi.generic_filter(
labels,
function=_add_edge_filter,
footprint=fp,
mode='nearest',
output=np.zeros(labels.shape, dtype=np.uint8),
extra_arguments=(graph,))
graph = RAG(labels, connectivity=connectivity)
for n in graph:
graph.node[n].update({'labels': [n],
+18
View File
@@ -178,3 +178,21 @@ def test_ncut_stable_subgraph():
new_labels, _, _ = segmentation.relabel_sequential(new_labels)
assert new_labels.max() == 0
def test_generic_rag_2d():
labels = np.array([[1, 2], [3, 4]], dtype=np.uint8)
g = graph.RAG(labels)
assert g.has_edge(1, 2) and g.has_edge(2, 4) and not g.has_edge(1, 4)
h = graph.RAG(labels, connectivity=2)
assert h.has_edge(1, 2) and h.has_edge(1, 4) and h.has_edge(2, 3)
def test_generic_rag_3d():
labels = np.arange(8, dtype=np.uint8).reshape((2, 2, 2))
g = graph.RAG(labels)
assert g.has_edge(0, 1) and g.has_edge(1, 3) and not g.has_edge(0, 3)
h = graph.RAG(labels, connectivity=2)
assert h.has_edge(0, 1) and h.has_edge(0, 3) and not h.has_edge(0, 7)
k = graph.RAG(labels, connectivity=3)
assert k.has_edge(0, 1) and k.has_edge(1, 2) and k.has_edge(2, 5)