Replace is_degenerate with is_model_valid and is_data_valid in ransac

This commit is contained in:
Johannes Schönberger
2013-05-02 18:12:23 +02:00
parent 28bda25da8
commit a73076157c
+24 -10
View File
@@ -502,6 +502,7 @@ class EllipseModel(BaseModel):
def ransac(data, model_class, min_samples, residual_threshold,
is_data_valid=None, is_model_valid=None,
max_trials=100, stop_sample_num=np.inf, stop_residuals_sum=0):
"""Fits a model to data with the RANSAC (random sample consensus) algorithm.
@@ -510,11 +511,13 @@ def ransac(data, model_class, min_samples, residual_threshold,
the following tasks:
1. Select `min_samples` random samples from the original data and check
whether the set of points is valid (`is_degenerate(*data)`).
2. Estimate a model to the random subset (`estimate(*data[random_subset]`).
whether the set of data is valid (see `is_data_valid`).
2. Estimate a model to the random subset
(`model_cls.estimate(*data[random_subset]`) and check whether the
estimated model is valid (see `is_model_valid`).
3. Classify all data as inliers or outliers by calculating the residuals to
the estimated model (`residuals(*data)`) - all data samples with
residuals smaller than the `residual_threshold` are considered as
the estimated model (`model_cls.residuals(*data)`) - all data samples
with residuals smaller than the `residual_threshold` are considered as
inliers.
4. Save estimated model as best model if number of inlier samples is
maximal. In case the current estimated model has the same number of
@@ -537,16 +540,21 @@ def ransac(data, model_class, min_samples, residual_threshold,
`is_degenerate(*data)` must all take each data array as separate
arguments.
model_class : object
Object with the following methods implemented:
Object with the following object methods:
* `estimate(*data)`
* `residuals(*data)`
* `is_degenerate(*data)`
min_samples : int
The minimum number of data points to fit a model.
residual_threshold : float
Maximum distance for a data point to be classified as an inlier.
is_data_valid : function, optional
This function is called with the randomly selected data before the
model is fitted to it: `is_data_valid(*random_data)`.
is_model_valid : function, optional
This function is called with the estimated model and the randomly
selected data: `is_model_valid(model, *random_data)`, .
max_trials : int, optional
Maximum number of iterations for random sample selection.
stop_sample_num : int, optional
@@ -640,19 +648,25 @@ def ransac(data, model_class, min_samples, residual_threshold,
for _ in range(max_trials):
# choose random sample
# choose random sample set
samples = []
random_idxs = np.random.randint(0, N, min_samples)
for d in data:
samples.append(d[random_idxs])
# check if random sample is degenerate
if model_class.is_degenerate(*samples):
# check if random sample set is valid
if is_data_valid is not None and not is_data_valid(*samples):
continue
# create new instance of model class for current sample
# estimate model for current random sample set
sample_model = model_class()
sample_model.estimate(*samples)
# check if estimated model is valid
if is_model_valid is not None and not is_model_valid(sample_model,
*samples):
continue
sample_model_residuals = np.abs(sample_model.residuals(*data))
# consensus set / inliers
sample_model_inliers = data_idxs[sample_model_residuals