diff --git a/skimage/future/graph/rag.py b/skimage/future/graph/rag.py index 2975a027..6a45447b 100644 --- a/skimage/future/graph/rag.py +++ b/skimage/future/graph/rag.py @@ -1,16 +1,6 @@ -try: - import networkx as nx -except ImportError: - msg = "Graph functions require networkx, which is not installed" - - class nx: - class Graph: - def __init__(self, *args, **kwargs): - raise ImportError(msg) - import warnings - warnings.warn(msg) - +import networkx as nx import numpy as np +from numpy.lib.stride_tricks import as_strided from scipy import ndimage as ndi import math from ... import draw, measure, segmentation, util, color @@ -51,21 +41,79 @@ def min_weight(graph, src, dst, n): return min(w1, w2) +def _add_edge_filter(values, graph): + """Create edge in `graph` between central element of `values` and the rest. + + Add an edge between the middle element in `values` and + all other elements of `values` into `graph`. ``values[len(values) // 2]`` + is expected to be the central value of the footprint used. + + Parameters + ---------- + values : array + The array to process. + graph : RAG + The graph to add edges in. + + Returns + ------- + 0 : float + Always returns 0. The return value is required so that `generic_filter` + can put it in the output array, but it is ignored by this filter. + """ + values = values.astype(int) + center = values[len(values) // 2] + for value in values: + if value != center and not graph.has_edge(center, value): + graph.add_edge(center, value) + return 0. + + class RAG(nx.Graph): """ The Region Adjacency Graph (RAG) of an image, subclasses `networx.Graph `_ + + Parameters + ---------- + label_image : array of int + An initial segmentation, with each region labeled as a different + integer. Every unique value in ``label_image`` will correspond to + a node in the graph. + connectivity : int in {1, ..., ``label_image.ndim``}, optional + The connectivity between pixels in ``label_image``. For a 2D image, + a connectivity of 1 corresponds to immediate neighbors up, down, + left, and right, while a connectivity of 2 also includes diagonal + neighbors. See `scipy.ndimage.generate_binary_structure`. + data : networkx Graph specification, optional + Initial or additional edges to pass to the NetworkX Graph + constructor. See `networkx.Graph`. Valid edge specifications + include edge list (list of tuples), NumPy arrays, and SciPy + sparse matrices. + **attr : keyword arguments, optional + Additional attributes to add to the graph. """ - def __init__(self, data=None, **attr): + def __init__(self, label_image=None, connectivity=1, data=None, **attr): super(RAG, self).__init__(data, **attr) - try: - self.max_id = max(self.nodes_iter()) - except ValueError: - # Empty sequence + if self.number_of_nodes() == 0: self.max_id = 0 + else: + self.max_id = max(self.nodes_iter()) + + if label_image is not None: + fp = ndi.generate_binary_structure(label_image.ndim, connectivity) + ndi.generic_filter( + label_image, + function=_add_edge_filter, + footprint=fp, + mode='nearest', + output=as_strided(np.empty((1,), dtype=np.float_), + shape=label_image.shape, + strides=((0,) * label_image.ndim)), + extra_arguments=(self,)) def merge_nodes(self, src, dst, weight_func=min_weight, in_place=True, extra_arguments=[], extra_keywords={}): @@ -172,36 +220,6 @@ class RAG(nx.Graph): super(RAG, self).add_node(n) -def _add_edge_filter(values, graph): - """Create edge in `g` between the first element of `values` and the rest. - - Add an edge between the first element in `values` and - all other elements of `values` in the graph `g`. `values[0]` - is expected to be the central value of the footprint used. - - Parameters - ---------- - values : array - The array to process. - graph : RAG - The graph to add edges in. - - Returns - ------- - 0 : int - Always returns 0. The return value is required so that `generic_filter` - can put it in the output array. - - """ - values = values.astype(int) - current = values[0] - for value in values[1:]: - if value != current: - graph.add_edge(current, value) - - return 0 - - def rag_mean_color(image, labels, connectivity=2, mode='distance', sigma=255.0): """Compute the Region Adjacency Graph using mean colors. @@ -224,7 +242,7 @@ def rag_mean_color(image, labels, connectivity=2, mode='distance', Pixels with a squared distance less than `connectivity` from each other are considered adjacent. It can range from 1 to `labels.ndim`. Its behavior is the same as `connectivity` parameter in - `scipy.ndimage.filters.generate_binary_structure`. + `scipy.ndimage.generate_binary_structure`. mode : {'distance', 'similarity'}, optional The strategy to assign edge weights. @@ -263,35 +281,7 @@ def rag_mean_color(image, labels, connectivity=2, mode='distance', http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.11.5274 """ - graph = RAG() - - # The footprint is constructed in such a way that the first - # element in the array being passed to _add_edge_filter is - # the central value. - fp = ndi.generate_binary_structure(labels.ndim, connectivity) - for d in range(fp.ndim): - fp = fp.swapaxes(0, d) - fp[0, ...] = 0 - fp = fp.swapaxes(0, d) - - # For example - # if labels.ndim = 2 and connectivity = 1 - # fp = [[0,0,0], - # [0,1,1], - # [0,1,0]] - # - # if labels.ndim = 2 and connectivity = 2 - # fp = [[0,0,0], - # [0,1,1], - # [0,1,1]] - - ndi.generic_filter( - labels, - function=_add_edge_filter, - footprint=fp, - mode='nearest', - output=np.zeros(labels.shape, dtype=np.uint8), - extra_arguments=(graph,)) + graph = RAG(labels, connectivity=connectivity) for n in graph: graph.node[n].update({'labels': [n], diff --git a/skimage/future/graph/tests/test_rag.py b/skimage/future/graph/tests/test_rag.py index 35c22218..e5882e20 100644 --- a/skimage/future/graph/tests/test_rag.py +++ b/skimage/future/graph/tests/test_rag.py @@ -178,3 +178,21 @@ def test_ncut_stable_subgraph(): new_labels, _, _ = segmentation.relabel_sequential(new_labels) assert new_labels.max() == 0 + + +def test_generic_rag_2d(): + labels = np.array([[1, 2], [3, 4]], dtype=np.uint8) + g = graph.RAG(labels) + assert g.has_edge(1, 2) and g.has_edge(2, 4) and not g.has_edge(1, 4) + h = graph.RAG(labels, connectivity=2) + assert h.has_edge(1, 2) and h.has_edge(1, 4) and h.has_edge(2, 3) + + +def test_generic_rag_3d(): + labels = np.arange(8, dtype=np.uint8).reshape((2, 2, 2)) + g = graph.RAG(labels) + assert g.has_edge(0, 1) and g.has_edge(1, 3) and not g.has_edge(0, 3) + h = graph.RAG(labels, connectivity=2) + assert h.has_edge(0, 1) and h.has_edge(0, 3) and not h.has_edge(0, 7) + k = graph.RAG(labels, connectivity=3) + assert k.has_edge(0, 1) and k.has_edge(1, 2) and k.has_edge(2, 5)