Extract the type and shape checks into a common function.

This commit is contained in:
Ralf Gommers
2009-10-19 15:14:21 +02:00
parent 6097565337
commit 844f14bca0
+14 -18
View File
@@ -13,6 +13,18 @@ __docformat__ = "restructuredtext en"
import numpy as np
def _prepare_colorarray(arr, dtype="float32"):
"""Check the shape of the array, and give it the requested type"""
if type(arr) != np.ndarray:
raise TypeError, "the input array must be a numpy.ndarray"
if arr.ndim != 3 or arr.shape[2] != 3:
msg = "the input array must be have a shape == (.,.,3))"
raise ValueError, msg
return arr.astype(dtype)
def rgb2hsv(rgb):
"""RGB to HSV color space conversion.
@@ -49,15 +61,7 @@ def rgb2hsv(rgb):
>>> lena = imread(os.path.join(data_dir, 'lena.png'))
>>> lena_hsv = color.rgb2hsv(lena)
"""
if type(rgb) != np.ndarray:
raise TypeError, "the input array 'rgb' must be a numpy.ndarray"
if rgb.ndim != 3 or rgb.shape[2] != 3:
msg = "the input array 'rgb' must be have a shape == (.,.,3))"
raise ValueError, msg
arr = rgb.astype("float32")
arr = _prepare_colorarray(rgb)
out = np.empty_like(arr)
# -- V channel
@@ -130,15 +134,7 @@ def hsv2rgb(hsv):
>>> lena_hsv = rgb2hsv(lena)
>>> lena_rgb = hsv2rgb(lena_hsv)
"""
if type(hsv) != np.ndarray:
raise TypeError, "the input array 'hsv' must be a numpy.ndarray"
if hsv.ndim != 3 or hsv.shape[2] != 3:
msg = "the input array 'hsv' must be have a shape == (.,.,3))"
raise ValueError, msg
arr = hsv.astype("float32")
arr = _prepare_colorarray(hsv)
hi = np.floor(arr[:,:,0] * 6)
f = arr[:,:,0] * 6 - hi