Fix LineModelND test cases

This commit is contained in:
Johannes Schönberger
2015-12-04 23:05:59 -05:00
parent 07f18c6067
commit f88694e5a3
2 changed files with 29 additions and 31 deletions
+1 -1
View File
@@ -275,7 +275,7 @@ class LineModelND(BaseModel):
l = (x - X0[axis]) / u[axis]
return X0 + l[..., np.newaxis] * u
def predict_x(self, y, params=None, new_params=None):
def predict_x(self, y, params=None):
"""Predict x-coordinates for 2D lines using the estimated model.
Alias for::
+28 -30
View File
@@ -1,18 +1,18 @@
import numpy as np
from numpy.testing import assert_equal, assert_raises, assert_almost_equal
from skimage.measure import LineModel, LineModelND, CircleModel, EllipseModel, ransac
from skimage.measure import LineModelND, CircleModel, EllipseModel, ransac
from skimage.transform import AffineTransform
from skimage.measure.fit import _dynamic_max_trials
from skimage._shared._warnings import expected_warnings
def test_line_model_invalid_input():
assert_raises(ValueError, LineModel().estimate, np.empty((5, 3)))
assert_raises(ValueError, LineModelND().estimate, np.empty((1, 3)))
def test_line_model_predict():
model = LineModel()
model.params = (10, 1)
model = LineModelND()
model.params = ((0, 0), (1, 1))
x = np.arange(-10, 10)
y = model.predict_y(x)
assert_almost_equal(x, model.predict_x(y))
@@ -20,38 +20,36 @@ def test_line_model_predict():
def test_line_model_estimate():
# generate original data without noise
model0 = LineModel()
model0.params = (10, 1)
model0 = LineModelND()
model0.params = ((0, 0), (1, 1))
x0 = np.arange(-100, 100)
y0 = model0.predict_y(x0)
data0 = np.column_stack([x0, y0])
# add gaussian noise to data
np.random.seed(1234)
data = data0 + np.random.normal(size=data0.shape)
data = np.column_stack([x0, y0])
# estimate parameters of noisy data
model_est = LineModel()
model_est = LineModelND()
model_est.estimate(data)
# test whether estimated parameters almost equal original parameters
assert_almost_equal(model0.params, model_est.params, 1)
x = np.random.rand(100, 2)
assert_almost_equal(model0.predict(x), model_est.predict(x), 1)
def test_line_model_residuals():
model = LineModel()
model.params = (0, 0)
assert_equal(abs(model.residuals(np.array([[0, 0]]))), 0)
assert_equal(abs(model.residuals(np.array([[0, 10]]))), 0)
assert_equal(abs(model.residuals(np.array([[10, 0]]))), 10)
model.params = (5, np.pi / 4)
assert_equal(abs(model.residuals(np.array([[0, 0]]))), 5)
assert_almost_equal(abs(model.residuals(np.array([[np.sqrt(50), 0]]))), 0)
model = LineModelND()
model.params = (np.array([0, 0]), np.array([0, 1]))
assert_equal(model.residuals(np.array([[0, 0]])), 0)
assert_equal(model.residuals(np.array([[0, 10]])), 0)
assert_equal(model.residuals(np.array([[10, 0]])), 10)
model.params = (np.array([-2, 0]), np.array([1, 1]) / np.sqrt(2))
assert_equal(model.residuals(np.array([[0, 0]])), np.sqrt(2))
assert_almost_equal(model.residuals(np.array([[-4, 0]])), np.sqrt(2))
def test_line_model_under_determined():
data = np.empty((1, 2))
assert_raises(ValueError, LineModel().estimate, data)
assert_raises(ValueError, LineModelND().estimate, data)
def test_line_modelND_invalid_input():
@@ -60,7 +58,7 @@ def test_line_modelND_invalid_input():
def test_line_modelND_predict():
model = LineModelND()
model.params = (np.array([0,0]), np.array([0.2,0.98]))
model.params = (np.array([0, 0]), np.array([0.2, 0.98]))
x = np.arange(-10, 10)
y = model.predict_y(x)
assert_almost_equal(x, model.predict_x(y))
@@ -99,10 +97,10 @@ def test_line_modelND_estimate():
def test_line_modelND_residuals():
model = LineModelND()
model.params = (np.array([0,0,0]), np.array([0,0,1]))
assert_equal(abs(model.residuals(np.array([[0, 0,0]]))), 0)
assert_equal(abs(model.residuals(np.array([[0,0,1]]))), 0)
assert_equal(abs(model.residuals(np.array([[10, 0,0]]))), 10)
model.params = (np.array([0, 0, 0]), np.array([0, 0, 1]))
assert_equal(abs(model.residuals(np.array([[0, 0, 0]]))), 0)
assert_equal(abs(model.residuals(np.array([[0, 0, 1]]))), 0)
assert_equal(abs(model.residuals(np.array([[10, 0, 0]]))), 10)
def test_line_modelND_under_determined():
@@ -245,7 +243,7 @@ def test_ransac_is_data_valid():
np.random.seed(1)
is_data_valid = lambda data: data.shape[0] > 2
model, inliers = ransac(np.empty((10, 2)), LineModel, 2, np.inf,
model, inliers = ransac(np.empty((10, 2)), LineModelND, 2, np.inf,
is_data_valid=is_data_valid)
assert_equal(model, None)
assert_equal(inliers, None)
@@ -256,7 +254,7 @@ def test_ransac_is_model_valid():
def is_model_valid(model, data):
return False
model, inliers = ransac(np.empty((10, 2)), LineModel, 2, np.inf,
model, inliers = ransac(np.empty((10, 2)), LineModelND, 2, np.inf,
is_model_valid=is_model_valid)
assert_equal(model, None)
assert_equal(inliers, None)
@@ -304,8 +302,8 @@ def test_ransac_invalid_input():
def test_deprecated_params_attribute():
model = LineModel()
model.params = (10, 1)
model = LineModelND()
model.params = ((0, 0), (1, 1))
x = np.arange(-10, 10)
y = model.predict_y(x)
with expected_warnings(['`_params`']):