diff --git a/skimage/measure/fit.py b/skimage/measure/fit.py index ef81f8f9..8fde1ca2 100644 --- a/skimage/measure/fit.py +++ b/skimage/measure/fit.py @@ -502,6 +502,7 @@ class EllipseModel(BaseModel): def ransac(data, model_class, min_samples, residual_threshold, + is_data_valid=None, is_model_valid=None, max_trials=100, stop_sample_num=np.inf, stop_residuals_sum=0): """Fits a model to data with the RANSAC (random sample consensus) algorithm. @@ -510,11 +511,13 @@ def ransac(data, model_class, min_samples, residual_threshold, the following tasks: 1. Select `min_samples` random samples from the original data and check - whether the set of points is valid (`is_degenerate(*data)`). - 2. Estimate a model to the random subset (`estimate(*data[random_subset]`). + whether the set of data is valid (see `is_data_valid`). + 2. Estimate a model to the random subset + (`model_cls.estimate(*data[random_subset]`) and check whether the + estimated model is valid (see `is_model_valid`). 3. Classify all data as inliers or outliers by calculating the residuals to - the estimated model (`residuals(*data)`) - all data samples with - residuals smaller than the `residual_threshold` are considered as + the estimated model (`model_cls.residuals(*data)`) - all data samples + with residuals smaller than the `residual_threshold` are considered as inliers. 4. Save estimated model as best model if number of inlier samples is maximal. In case the current estimated model has the same number of @@ -537,16 +540,21 @@ def ransac(data, model_class, min_samples, residual_threshold, `is_degenerate(*data)` must all take each data array as separate arguments. model_class : object - Object with the following methods implemented: + Object with the following object methods: * `estimate(*data)` * `residuals(*data)` - * `is_degenerate(*data)` min_samples : int The minimum number of data points to fit a model. residual_threshold : float Maximum distance for a data point to be classified as an inlier. + is_data_valid : function, optional + This function is called with the randomly selected data before the + model is fitted to it: `is_data_valid(*random_data)`. + is_model_valid : function, optional + This function is called with the estimated model and the randomly + selected data: `is_model_valid(model, *random_data)`, . max_trials : int, optional Maximum number of iterations for random sample selection. stop_sample_num : int, optional @@ -640,19 +648,25 @@ def ransac(data, model_class, min_samples, residual_threshold, for _ in range(max_trials): - # choose random sample + # choose random sample set samples = [] random_idxs = np.random.randint(0, N, min_samples) for d in data: samples.append(d[random_idxs]) - # check if random sample is degenerate - if model_class.is_degenerate(*samples): + # check if random sample set is valid + if is_data_valid is not None and not is_data_valid(*samples): continue - # create new instance of model class for current sample + # estimate model for current random sample set sample_model = model_class() sample_model.estimate(*samples) + + # check if estimated model is valid + if is_model_valid is not None and not is_model_valid(sample_model, + *samples): + continue + sample_model_residuals = np.abs(sample_model.residuals(*data)) # consensus set / inliers sample_model_inliers = data_idxs[sample_model_residuals