From d803846cf0498eb603eea8e0ba7b5ef61f40ad43 Mon Sep 17 00:00:00 2001 From: Chris Lee Date: Fri, 5 Jun 2015 17:18:46 -0400 Subject: [PATCH] FIX: accept PIL Image as input --- indicoio/utils/__init__.py | 2 ++ tests/test_remote.py | 12 ++++++++++++ 2 files changed, 14 insertions(+) diff --git a/indicoio/utils/__init__.py b/indicoio/utils/__init__.py index b7fb5eb..31abcc5 100644 --- a/indicoio/utils/__init__.py +++ b/indicoio/utils/__init__.py @@ -107,6 +107,8 @@ def image_preprocess(image, size=(48,48), batch=False): DeprecationWarning ) outImage = process_list_image(image) + elif isinstance(image, Image.Image): + outImage = image elif type(image).__name__ == "ndarray": # image is from numpy/scipy out_image = Image.fromarray(image) else: diff --git a/tests/test_remote.py b/tests/test_remote.py index 32f4b3f..749034e 100644 --- a/tests/test_remote.py +++ b/tests/test_remote.py @@ -62,6 +62,12 @@ class BatchAPIRun(unittest.TestCase): self.assertTrue(isinstance(response, list)) self.assertTrue(isinstance(response[0], dict)) + def test_batch_fer_pil_image(self): + test_data = [Image.open(os.path.normpath(os.path.join(DIR, "data/fear.png")))] + response = batch_fer(test_data, api_key=self.api_key) + self.assertTrue(isinstance(response, list)) + self.assertTrue(isinstance(response[0], dict)) + def test_batch_fer_nonexistant_filepath(self): test_data = ["data/unhappy.png"] self.assertRaises(ValueError, batch_fer, test_data, api_key=self.api_key) @@ -184,6 +190,12 @@ class FullAPIRun(unittest.TestCase): self.assertTrue(isinstance(response, dict)) self.assertTrue(response['Happy'] > 0.5) + def test_happy_fer_pil(self): + test_face = Image.open(os.path.normpath(os.path.join(DIR, "data/happy.png"))).convert('L'); + response = fer(test_face) + self.assertTrue(isinstance(response, dict)) + self.assertTrue(response['Happy'] > 0.5) + def test_fear_fer(self): test_face = self.load_image("data/fear.png", as_grey=True) response = fer(test_face)