mirror of
https://github.com/wassname/scikit-image.git
synced 2026-06-30 05:41:30 +08:00
ENH CRAZY speedup
This commit is contained in:
@@ -9,8 +9,8 @@ from IPython.core.debugger import Tracer
|
||||
tracer = Tracer()
|
||||
|
||||
|
||||
img = img_as_float(lena())[::3, ::3, :].copy("C")
|
||||
segments = quickshift(img, sigma=2)
|
||||
img = img_as_float(lena())[::2, ::2, :].copy("C")
|
||||
segments = quickshift(img)
|
||||
segments = np.unique(segments, return_inverse=True)[1].reshape(img.shape[:2])
|
||||
|
||||
plt.subplot(131, title="original")
|
||||
|
||||
@@ -47,20 +47,26 @@ def quickshift(np.ndarray[dtype=np.float_t, ndim=3, mode="c"] image, sigma=5, ta
|
||||
cdef int channels = image.shape[2]
|
||||
cdef float closest, dist
|
||||
cdef int x, y, xx, yy, x_, y_
|
||||
|
||||
cdef np.float_t* image_p = <np.float_t*> image.data
|
||||
cdef np.float_t* current_pixel_p = image_p
|
||||
cdef np.float_t* current_entry_p
|
||||
|
||||
cdef np.ndarray[dtype=np.float_t, ndim=2] densities = np.zeros((width, height))
|
||||
start = time()
|
||||
# compute densities
|
||||
for x, y in product(xrange(width), xrange(height)):
|
||||
current_pixel = image[x, y, :]
|
||||
for xx, yy in product(xrange(-w / 2, w / 2 + 1), repeat=2):
|
||||
x_, y_ = x + xx, y + yy
|
||||
if 0 <= x_ < width and 0 <= y_ < height:
|
||||
dist = 0
|
||||
current_entry_p = current_pixel_p
|
||||
for c in xrange(channels):
|
||||
dist += (current_pixel[c] - image[x_, y_, c])**2
|
||||
dist += (current_pixel_p[c] - image[x_, y_, c])**2
|
||||
dist += (x - x_)**2 + (y - y_)**2
|
||||
densities[x, y] += float(exp(-dist / sigma))
|
||||
current_pixel_p += channels
|
||||
|
||||
print("densities: %f" % (time() - start))
|
||||
|
||||
# this will break ties that otherwise would give us headache
|
||||
@@ -71,9 +77,9 @@ def quickshift(np.ndarray[dtype=np.float_t, ndim=3, mode="c"] image, sigma=5, ta
|
||||
cdef np.ndarray[dtype=np.float_t, ndim=2] dist_parent = np.zeros((width, height))
|
||||
start = time()
|
||||
# find nearest node with higher density
|
||||
current_pixel_p = image_p
|
||||
for x, y in product(xrange(width), xrange(height)):
|
||||
current_density = densities[x, y]
|
||||
current_pixel = image[x, y, :]
|
||||
closest = np.inf
|
||||
for xx, yy in product(xrange(-w / 2, w / 2 + 1), repeat=2):
|
||||
x_, y_ = x + xx, y + yy
|
||||
@@ -81,12 +87,13 @@ def quickshift(np.ndarray[dtype=np.float_t, ndim=3, mode="c"] image, sigma=5, ta
|
||||
if densities[x_, y_] > current_density:
|
||||
dist = 0
|
||||
for c in xrange(channels):
|
||||
dist += (current_pixel[c] - image[x_, y_, c])**2
|
||||
dist += (current_pixel_p[c] - image[x_, y_, c])**2
|
||||
dist += (x - x_)**2 + (y - y_)**2
|
||||
if dist < closest:
|
||||
closest = dist
|
||||
parent[x, y] = x_ * width + y_
|
||||
dist_parent[x, y] = closest
|
||||
current_pixel_p += channels
|
||||
print("parents: %f" % (time() - start))
|
||||
|
||||
start = time()
|
||||
|
||||
Reference in New Issue
Block a user