mirror of
https://github.com/wassname/scikit-image.git
synced 2026-06-27 19:32:26 +08:00
Fix RANSAC for invalid model estimation and confidence corner case
Previously, estimators did not return whether the model estimation was successful. RANSAC now tests whether the estimation was successful and skips invalid models. When the confidence/stop_probability of RANSAC was set to 1, the iteration was falsely terminated early instead of running for the maximum number of iterations.
This commit is contained in:
+49
-15
@@ -4,9 +4,6 @@ import numpy as np
|
||||
from scipy import optimize
|
||||
|
||||
|
||||
_EPSILON = np.spacing(1)
|
||||
|
||||
|
||||
def _check_data_dim(data, dim):
|
||||
if data.ndim != 2 or data.shape[1] != dim:
|
||||
raise ValueError('Input data must have shape (N, %d).' % dim)
|
||||
@@ -57,6 +54,11 @@ class LineModel(BaseModel):
|
||||
data : (N, 2) array
|
||||
N points with ``(x, y)`` coordinates, respectively.
|
||||
|
||||
Returns
|
||||
-------
|
||||
success : bool
|
||||
True, if model estimation succeeds.
|
||||
|
||||
"""
|
||||
|
||||
_check_data_dim(data, dim=2)
|
||||
@@ -81,6 +83,8 @@ class LineModel(BaseModel):
|
||||
|
||||
self.params = (dist, theta)
|
||||
|
||||
return True
|
||||
|
||||
def residuals(self, data):
|
||||
"""Determine residuals of data to model.
|
||||
|
||||
@@ -182,6 +186,11 @@ class CircleModel(BaseModel):
|
||||
data : (N, 2) array
|
||||
N points with ``(x, y)`` coordinates, respectively.
|
||||
|
||||
Returns
|
||||
-------
|
||||
success : bool
|
||||
True, if model estimation succeeds.
|
||||
|
||||
"""
|
||||
|
||||
_check_data_dim(data, dim=2)
|
||||
@@ -217,6 +226,8 @@ class CircleModel(BaseModel):
|
||||
|
||||
self.params = params
|
||||
|
||||
return True
|
||||
|
||||
def residuals(self, data):
|
||||
"""Determine residuals of data to model.
|
||||
|
||||
@@ -313,6 +324,11 @@ class EllipseModel(BaseModel):
|
||||
data : (N, 2) array
|
||||
N points with ``(x, y)`` coordinates, respectively.
|
||||
|
||||
Returns
|
||||
-------
|
||||
success : bool
|
||||
True, if model estimation succeeds.
|
||||
|
||||
"""
|
||||
|
||||
_check_data_dim(data, dim=2)
|
||||
@@ -373,6 +389,8 @@ class EllipseModel(BaseModel):
|
||||
|
||||
self.params = params[:5]
|
||||
|
||||
return True
|
||||
|
||||
def residuals(self, data):
|
||||
"""Determine residuals of data to model.
|
||||
|
||||
@@ -471,7 +489,6 @@ class EllipseModel(BaseModel):
|
||||
def _dynamic_max_trials(n_inliers, n_samples, min_samples, probability):
|
||||
"""Determine number trials such that at least one outlier-free subset is
|
||||
sampled for the given inlier/outlier ratio.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n_inliers : int
|
||||
@@ -482,21 +499,31 @@ def _dynamic_max_trials(n_inliers, n_samples, min_samples, probability):
|
||||
Minimum number of samples chosen randomly from original data.
|
||||
probability : float
|
||||
Probability (confidence) that one outlier-free sample is generated.
|
||||
|
||||
Returns
|
||||
-------
|
||||
trials : int
|
||||
Number of trials.
|
||||
|
||||
"""
|
||||
inlier_ratio = n_inliers / float(n_samples)
|
||||
nom = max(_EPSILON, 1 - probability)
|
||||
denom = max(_EPSILON, 1 - inlier_ratio ** min_samples)
|
||||
if nom == 1:
|
||||
return 0
|
||||
if denom == 1:
|
||||
if n_inliers == 0:
|
||||
return float('inf')
|
||||
return abs(float(np.ceil(np.log(nom) / np.log(denom))))
|
||||
|
||||
nom = 1 - probability
|
||||
if nom == 0:
|
||||
return float('inf')
|
||||
|
||||
inlier_ratio = n_inliers / float(n_samples)
|
||||
denom = 1 - inlier_ratio ** min_samples
|
||||
if denom == 0:
|
||||
return 1
|
||||
elif denom == 1:
|
||||
return float('inf')
|
||||
|
||||
nom = np.log(nom)
|
||||
denom = np.log(denom)
|
||||
if denom == 0:
|
||||
return 0
|
||||
|
||||
return int(np.ceil(nom / denom))
|
||||
|
||||
|
||||
def ransac(data, model_class, min_samples, residual_threshold,
|
||||
@@ -542,9 +569,11 @@ def ransac(data, model_class, min_samples, residual_threshold,
|
||||
model_class : object
|
||||
Object with the following object methods:
|
||||
|
||||
* ``estimate(*data)``
|
||||
* ``success = estimate(*data)``
|
||||
* ``residuals(*data)``
|
||||
|
||||
where `success` indicates whether the model estimation succeeded
|
||||
(`True` or `None` for success, `False` for failure).
|
||||
min_samples : int
|
||||
The minimum number of data points to fit a model to.
|
||||
residual_threshold : float
|
||||
@@ -688,7 +717,12 @@ def ransac(data, model_class, min_samples, residual_threshold,
|
||||
|
||||
# estimate model for current random sample set
|
||||
sample_model = model_class()
|
||||
sample_model.estimate(*samples)
|
||||
|
||||
success = sample_model.estimate(*samples)
|
||||
|
||||
if success is not None: # backwards compatibility
|
||||
if not success:
|
||||
continue
|
||||
|
||||
# check if estimated model is valid
|
||||
if is_model_valid is not None and not is_model_valid(sample_model,
|
||||
|
||||
@@ -233,9 +233,9 @@ def test_ransac_dynamic_max_trials():
|
||||
# e = 50%, min_samples = 8
|
||||
assert_equal(_dynamic_max_trials(50, 100, 8, 0.99), 1177)
|
||||
|
||||
# e = 0%, min_samples = 10
|
||||
assert_equal(_dynamic_max_trials(1, 100, 10, 0), 0)
|
||||
assert_equal(_dynamic_max_trials(1, 100, 10, 1), float('inf'))
|
||||
# e = 0%, min_samples = 5
|
||||
assert_equal(_dynamic_max_trials(1, 100, 5, 0), 0)
|
||||
assert_equal(_dynamic_max_trials(1, 100, 5, 1), float('inf'))
|
||||
|
||||
|
||||
def test_ransac_invalid_input():
|
||||
|
||||
@@ -263,6 +263,11 @@ class ProjectiveTransform(GeometricTransform):
|
||||
dst : (N, 2) array
|
||||
Destination coordinates.
|
||||
|
||||
Returns
|
||||
-------
|
||||
success : bool
|
||||
True, if model estimation succeeds.
|
||||
|
||||
"""
|
||||
|
||||
try:
|
||||
@@ -270,7 +275,7 @@ class ProjectiveTransform(GeometricTransform):
|
||||
dst_matrix, dst = _center_and_normalize_points(dst)
|
||||
except ZeroDivisionError:
|
||||
self.params = np.nan * np.empty((3, 3))
|
||||
return
|
||||
return False
|
||||
|
||||
xs = src[:, 0]
|
||||
ys = src[:, 1]
|
||||
@@ -309,6 +314,8 @@ class ProjectiveTransform(GeometricTransform):
|
||||
|
||||
self.params = H
|
||||
|
||||
return True
|
||||
|
||||
def __add__(self, other):
|
||||
"""Combine this transformation with another.
|
||||
|
||||
@@ -459,6 +466,11 @@ class PiecewiseAffineTransform(GeometricTransform):
|
||||
dst : (N, 2) array
|
||||
Destination coordinates.
|
||||
|
||||
Returns
|
||||
-------
|
||||
success : bool
|
||||
True, if model estimation succeeds.
|
||||
|
||||
"""
|
||||
|
||||
# forward piecewise affine
|
||||
@@ -481,6 +493,8 @@ class PiecewiseAffineTransform(GeometricTransform):
|
||||
affine.estimate(dst[tri, :], src[tri, :])
|
||||
self.inverse_affines.append(affine)
|
||||
|
||||
return True
|
||||
|
||||
def __call__(self, coords):
|
||||
"""Apply forward transformation.
|
||||
|
||||
@@ -658,6 +672,11 @@ class SimilarityTransform(ProjectiveTransform):
|
||||
dst : (N, 2) array
|
||||
Destination coordinates.
|
||||
|
||||
Returns
|
||||
-------
|
||||
success : bool
|
||||
True, if model estimation succeeds.
|
||||
|
||||
"""
|
||||
|
||||
try:
|
||||
@@ -665,7 +684,7 @@ class SimilarityTransform(ProjectiveTransform):
|
||||
dst_matrix, dst = _center_and_normalize_points(dst)
|
||||
except ZeroDivisionError:
|
||||
self.params = np.nan * np.empty((3, 3))
|
||||
return
|
||||
return False
|
||||
|
||||
xs = src[:, 0]
|
||||
ys = src[:, 1]
|
||||
@@ -699,6 +718,7 @@ class SimilarityTransform(ProjectiveTransform):
|
||||
|
||||
self.params = S
|
||||
|
||||
return True
|
||||
|
||||
@property
|
||||
def scale(self):
|
||||
@@ -798,6 +818,11 @@ class PolynomialTransform(GeometricTransform):
|
||||
order : int, optional
|
||||
Polynomial order (number of coefficients is order + 1).
|
||||
|
||||
Returns
|
||||
-------
|
||||
success : bool
|
||||
True, if model estimation succeeds.
|
||||
|
||||
"""
|
||||
xs = src[:, 0]
|
||||
ys = src[:, 1]
|
||||
@@ -828,6 +853,8 @@ class PolynomialTransform(GeometricTransform):
|
||||
|
||||
self.params = params.reshape((2, u // 2))
|
||||
|
||||
return True
|
||||
|
||||
def __call__(self, coords):
|
||||
"""Apply forward transformation.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user