mirror of
https://github.com/wassname/scikit-image.git
synced 2026-06-29 03:37:34 +08:00
Replace is_degenerate with is_model_valid and is_data_valid in ransac
This commit is contained in:
+24
-10
@@ -502,6 +502,7 @@ class EllipseModel(BaseModel):
|
||||
|
||||
|
||||
def ransac(data, model_class, min_samples, residual_threshold,
|
||||
is_data_valid=None, is_model_valid=None,
|
||||
max_trials=100, stop_sample_num=np.inf, stop_residuals_sum=0):
|
||||
"""Fits a model to data with the RANSAC (random sample consensus) algorithm.
|
||||
|
||||
@@ -510,11 +511,13 @@ def ransac(data, model_class, min_samples, residual_threshold,
|
||||
the following tasks:
|
||||
|
||||
1. Select `min_samples` random samples from the original data and check
|
||||
whether the set of points is valid (`is_degenerate(*data)`).
|
||||
2. Estimate a model to the random subset (`estimate(*data[random_subset]`).
|
||||
whether the set of data is valid (see `is_data_valid`).
|
||||
2. Estimate a model to the random subset
|
||||
(`model_cls.estimate(*data[random_subset]`) and check whether the
|
||||
estimated model is valid (see `is_model_valid`).
|
||||
3. Classify all data as inliers or outliers by calculating the residuals to
|
||||
the estimated model (`residuals(*data)`) - all data samples with
|
||||
residuals smaller than the `residual_threshold` are considered as
|
||||
the estimated model (`model_cls.residuals(*data)`) - all data samples
|
||||
with residuals smaller than the `residual_threshold` are considered as
|
||||
inliers.
|
||||
4. Save estimated model as best model if number of inlier samples is
|
||||
maximal. In case the current estimated model has the same number of
|
||||
@@ -537,16 +540,21 @@ def ransac(data, model_class, min_samples, residual_threshold,
|
||||
`is_degenerate(*data)` must all take each data array as separate
|
||||
arguments.
|
||||
model_class : object
|
||||
Object with the following methods implemented:
|
||||
Object with the following object methods:
|
||||
|
||||
* `estimate(*data)`
|
||||
* `residuals(*data)`
|
||||
* `is_degenerate(*data)`
|
||||
|
||||
min_samples : int
|
||||
The minimum number of data points to fit a model.
|
||||
residual_threshold : float
|
||||
Maximum distance for a data point to be classified as an inlier.
|
||||
is_data_valid : function, optional
|
||||
This function is called with the randomly selected data before the
|
||||
model is fitted to it: `is_data_valid(*random_data)`.
|
||||
is_model_valid : function, optional
|
||||
This function is called with the estimated model and the randomly
|
||||
selected data: `is_model_valid(model, *random_data)`, .
|
||||
max_trials : int, optional
|
||||
Maximum number of iterations for random sample selection.
|
||||
stop_sample_num : int, optional
|
||||
@@ -640,19 +648,25 @@ def ransac(data, model_class, min_samples, residual_threshold,
|
||||
|
||||
for _ in range(max_trials):
|
||||
|
||||
# choose random sample
|
||||
# choose random sample set
|
||||
samples = []
|
||||
random_idxs = np.random.randint(0, N, min_samples)
|
||||
for d in data:
|
||||
samples.append(d[random_idxs])
|
||||
|
||||
# check if random sample is degenerate
|
||||
if model_class.is_degenerate(*samples):
|
||||
# check if random sample set is valid
|
||||
if is_data_valid is not None and not is_data_valid(*samples):
|
||||
continue
|
||||
|
||||
# create new instance of model class for current sample
|
||||
# estimate model for current random sample set
|
||||
sample_model = model_class()
|
||||
sample_model.estimate(*samples)
|
||||
|
||||
# check if estimated model is valid
|
||||
if is_model_valid is not None and not is_model_valid(sample_model,
|
||||
*samples):
|
||||
continue
|
||||
|
||||
sample_model_residuals = np.abs(sample_model.residuals(*data))
|
||||
# consensus set / inliers
|
||||
sample_model_inliers = data_idxs[sample_model_residuals
|
||||
|
||||
Reference in New Issue
Block a user