diff --git a/indicoio/utils/multi.py b/indicoio/utils/multi.py index ae7b704..78ca691 100644 --- a/indicoio/utils/multi.py +++ b/indicoio/utils/multi.py @@ -1,6 +1,6 @@ from indicoio.config import TEXT_APIS, IMAGE_APIS, API_NAMES from indicoio.utils.errors import IndicoError -from indicoio.utils import api_handler +from indicoio.utils import api_handler, image_preprocess CLIENT_SERVER_MAP = dict((api, api.strip().replace("_", "").lower()) for api in API_NAMES) @@ -96,7 +96,7 @@ def predict_image(image, apis=IMAGE_APIS, cloud=None, batch=False, api_key=None, return multi( api="apis", - data=image, + data=image_preprocess(image, batch=batch), type="image", available=IMAGE_APIS, cloud=cloud, diff --git a/tests/test_remote.py b/tests/test_remote.py index 32736db..dcb0f81 100644 --- a/tests/test_remote.py +++ b/tests/test_remote.py @@ -115,27 +115,13 @@ class BatchAPIRun(unittest.TestCase): self.assertTrue(isinstance(response, list)) self.assertTrue(response[0]['English'] > 0.25) - def test_multi_api_image(self): - test_data = generate_array((48,48)) - response = predict_image(test_data, apis=config.IMAGE_APIS, api_key=self.api_key) - - self.assertTrue(isinstance(response, dict)) - self.assertTrue(set(response.keys()) == set(config.IMAGE_APIS)) - - def test_multi_api_text(self): - test_data = 'clearly an english sentence' - response = predict_text(test_data, apis=config.TEXT_APIS, api_key=self.api_key) - - self.assertTrue(isinstance(response, dict)) - self.assertTrue(set(response.keys()) == set(config.TEXT_APIS)) - def test_batch_multi_api_image(self): - test_data = [generate_array((48,48))] + test_data = [generate_array((48,48)), generate_int_array((48,48))] response = batch_predict_image(test_data, apis=config.IMAGE_APIS, api_key=self.api_key) self.assertTrue(isinstance(response, dict)) self.assertTrue(set(response.keys()) == set(config.IMAGE_APIS)) - + self.assertTrue(isinstance(response["fer"], list)) def test_batch_multi_api_text(self): test_data = ['clearly an english sentence'] @@ -179,6 +165,9 @@ class BatchAPIRun(unittest.TestCase): class FullAPIRun(unittest.TestCase): + def setUp(self): + self.api_key = config.api_key + def load_image(self, relpath, as_grey=False): im = Image.open(os.path.normpath(os.path.join(DIR, relpath))).convert('L'); pixels = list(im.getdata()) @@ -239,6 +228,14 @@ class FullAPIRun(unittest.TestCase): self.assertTrue(isinstance(response, dict)) self.assertEqual(fer_set, set(response.keys())) + def test_good_int_array_fer(self): + fer_set = set(['Angry', 'Sad', 'Neutral', 'Surprise', 'Fear', 'Happy']) + test_face = generate_int_array((48,48)) + response = fer(test_face) + + self.assertTrue(isinstance(response, dict)) + self.assertEqual(fer_set, set(response.keys())) + def test_happy_fer(self): test_face = self.load_image("data/happy.png", as_grey=True) response = fer(test_face) @@ -273,6 +270,15 @@ class FullAPIRun(unittest.TestCase): self.assertEqual(len(response), 48) self.check_range(response) + def test_good_int_array_facial_features(self): + fer_set = set(['Angry', 'Sad', 'Neutral', 'Surprise', 'Fear', 'Happy']) + test_face = generate_int_array((48,48)) + response = facial_features(test_face) + + self.assertTrue(isinstance(response, list)) + self.assertEqual(len(response), 48) + self.check_range(response) + # TODO: uncomment this test once the remote server is updated to # deal with image_urls # def test_image_url(self): @@ -299,6 +305,21 @@ class FullAPIRun(unittest.TestCase): self.assertEqual(len(response), 2048) self.check_range(response) + def test_multi_api_image(self): + test_data = generate_array((48,48)) + response = predict_image(test_data, apis=config.IMAGE_APIS, api_key=self.api_key) + + self.assertTrue(isinstance(response, dict)) + self.assertTrue(set(response.keys()) == set(config.IMAGE_APIS)) + + def test_multi_api_text(self): + test_data = 'clearly an english sentence' + response = predict_text(test_data, apis=config.TEXT_APIS, api_key=self.api_key) + + self.assertTrue(isinstance(response, dict)) + self.assertTrue(set(response.keys()) == set(config.TEXT_APIS)) + + def test_language(self): language_set = set([ 'English', @@ -390,6 +411,9 @@ def flatten(container): def generate_array(size): return [[random.random() for _ in xrange(size[0])] for _ in xrange(size[1])] +def generate_int_array(size): + return [[random.randint(0, 50) for _ in xrange(size[0])] for _ in xrange(size[1])] + if __name__ == "__main__": unittest.main()