mirror of
https://github.com/wassname/scikit-image.git
synced 2026-07-02 01:37:54 +08:00
Extract the type and shape checks into a common function.
This commit is contained in:
@@ -13,6 +13,18 @@ __docformat__ = "restructuredtext en"
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _prepare_colorarray(arr, dtype="float32"):
|
||||
"""Check the shape of the array, and give it the requested type"""
|
||||
if type(arr) != np.ndarray:
|
||||
raise TypeError, "the input array must be a numpy.ndarray"
|
||||
|
||||
if arr.ndim != 3 or arr.shape[2] != 3:
|
||||
msg = "the input array must be have a shape == (.,.,3))"
|
||||
raise ValueError, msg
|
||||
|
||||
return arr.astype(dtype)
|
||||
|
||||
def rgb2hsv(rgb):
|
||||
"""RGB to HSV color space conversion.
|
||||
|
||||
@@ -49,15 +61,7 @@ def rgb2hsv(rgb):
|
||||
>>> lena = imread(os.path.join(data_dir, 'lena.png'))
|
||||
>>> lena_hsv = color.rgb2hsv(lena)
|
||||
"""
|
||||
|
||||
if type(rgb) != np.ndarray:
|
||||
raise TypeError, "the input array 'rgb' must be a numpy.ndarray"
|
||||
|
||||
if rgb.ndim != 3 or rgb.shape[2] != 3:
|
||||
msg = "the input array 'rgb' must be have a shape == (.,.,3))"
|
||||
raise ValueError, msg
|
||||
|
||||
arr = rgb.astype("float32")
|
||||
arr = _prepare_colorarray(rgb)
|
||||
out = np.empty_like(arr)
|
||||
|
||||
# -- V channel
|
||||
@@ -130,15 +134,7 @@ def hsv2rgb(hsv):
|
||||
>>> lena_hsv = rgb2hsv(lena)
|
||||
>>> lena_rgb = hsv2rgb(lena_hsv)
|
||||
"""
|
||||
|
||||
if type(hsv) != np.ndarray:
|
||||
raise TypeError, "the input array 'hsv' must be a numpy.ndarray"
|
||||
|
||||
if hsv.ndim != 3 or hsv.shape[2] != 3:
|
||||
msg = "the input array 'hsv' must be have a shape == (.,.,3))"
|
||||
raise ValueError, msg
|
||||
|
||||
arr = hsv.astype("float32")
|
||||
arr = _prepare_colorarray(hsv)
|
||||
|
||||
hi = np.floor(arr[:,:,0] * 6)
|
||||
f = arr[:,:,0] * 6 - hi
|
||||
|
||||
Reference in New Issue
Block a user