ENH CRAZY speedup

This commit is contained in:
Andreas Mueller
2012-06-16 23:39:48 +02:00
parent b977d59c1b
commit be4b44bc63
2 changed files with 13 additions and 6 deletions
+2 -2
View File
@@ -9,8 +9,8 @@ from IPython.core.debugger import Tracer
tracer = Tracer()
img = img_as_float(lena())[::3, ::3, :].copy("C")
segments = quickshift(img, sigma=2)
img = img_as_float(lena())[::2, ::2, :].copy("C")
segments = quickshift(img)
segments = np.unique(segments, return_inverse=True)[1].reshape(img.shape[:2])
plt.subplot(131, title="original")
+11 -4
View File
@@ -47,20 +47,26 @@ def quickshift(np.ndarray[dtype=np.float_t, ndim=3, mode="c"] image, sigma=5, ta
cdef int channels = image.shape[2]
cdef float closest, dist
cdef int x, y, xx, yy, x_, y_
cdef np.float_t* image_p = <np.float_t*> image.data
cdef np.float_t* current_pixel_p = image_p
cdef np.float_t* current_entry_p
cdef np.ndarray[dtype=np.float_t, ndim=2] densities = np.zeros((width, height))
start = time()
# compute densities
for x, y in product(xrange(width), xrange(height)):
current_pixel = image[x, y, :]
for xx, yy in product(xrange(-w / 2, w / 2 + 1), repeat=2):
x_, y_ = x + xx, y + yy
if 0 <= x_ < width and 0 <= y_ < height:
dist = 0
current_entry_p = current_pixel_p
for c in xrange(channels):
dist += (current_pixel[c] - image[x_, y_, c])**2
dist += (current_pixel_p[c] - image[x_, y_, c])**2
dist += (x - x_)**2 + (y - y_)**2
densities[x, y] += float(exp(-dist / sigma))
current_pixel_p += channels
print("densities: %f" % (time() - start))
# this will break ties that otherwise would give us headache
@@ -71,9 +77,9 @@ def quickshift(np.ndarray[dtype=np.float_t, ndim=3, mode="c"] image, sigma=5, ta
cdef np.ndarray[dtype=np.float_t, ndim=2] dist_parent = np.zeros((width, height))
start = time()
# find nearest node with higher density
current_pixel_p = image_p
for x, y in product(xrange(width), xrange(height)):
current_density = densities[x, y]
current_pixel = image[x, y, :]
closest = np.inf
for xx, yy in product(xrange(-w / 2, w / 2 + 1), repeat=2):
x_, y_ = x + xx, y + yy
@@ -81,12 +87,13 @@ def quickshift(np.ndarray[dtype=np.float_t, ndim=3, mode="c"] image, sigma=5, ta
if densities[x_, y_] > current_density:
dist = 0
for c in xrange(channels):
dist += (current_pixel[c] - image[x_, y_, c])**2
dist += (current_pixel_p[c] - image[x_, y_, c])**2
dist += (x - x_)**2 + (y - y_)**2
if dist < closest:
closest = dist
parent[x, y] = x_ * width + y_
dist_parent[x, y] = closest
current_pixel_p += channels
print("parents: %f" % (time() - start))
start = time()