diff --git a/indicoio/utils/__init__.py b/indicoio/utils/__init__.py index fefa922..657a2e6 100644 --- a/indicoio/utils/__init__.py +++ b/indicoio/utils/__init__.py @@ -2,6 +2,7 @@ import inspect, json, getpass, os.path, base64, StringIO, re, warnings import requests from PIL import Image +from indicoio.utils.errors import IndicoError from indicoio import JSON_HEADERS from indicoio import config @@ -29,13 +30,13 @@ def api_handler(arg, cloud, api, url_params = {"batch":False, "api_key":None}, * response = requests.post(url, data=json_data, headers=JSON_HEADERS) if response.status_code == 503 and cloud != None: - raise Exception("Private cloud '%s' does not include api '%s'" % (cloud, api)) + raise IndicoError("Private cloud '%s' does not include api '%s'" % (cloud, api)) json_results = response.json() results = json_results.get('results', False) if results is False: error = json_results.get('error') - raise ValueError(error) + raise IndicoError(error) return results @@ -97,7 +98,7 @@ def image_preprocess(image, size=(48,48), batch=False): elif B64_PATTERN.match(b64_str) is not None: return b64_str else: - raise ValueError("Snose tring provided must be a valid filepath or base64 encoded string") + raise IndicoError("Snose tring provided must be a valid filepath or base64 encoded string") elif isinstance(image, list): # image passed in is a list and not np.array warnings.warn( @@ -110,7 +111,7 @@ def image_preprocess(image, size=(48,48), batch=False): elif type(image).__name__ == "ndarray": # image is from numpy/scipy out_image = Image.fromarray(image) else: - raise ValueError("Image must be a filepath, base64 encoded string, or a numpy array") + raise IndicoError("Image must be a filepath, base64 encoded string, or a numpy array") # image resizing outImage = outImage.resize(size) diff --git a/indicoio/utils/errors.py b/indicoio/utils/errors.py new file mode 100644 index 0000000..e7d0d5d --- /dev/null +++ b/indicoio/utils/errors.py @@ -0,0 +1,2 @@ +class IndicoError(ValueError): + pass diff --git a/indicoio/utils/multi.py b/indicoio/utils/multi.py index e8f033f..b60c624 100644 --- a/indicoio/utils/multi.py +++ b/indicoio/utils/multi.py @@ -1,4 +1,5 @@ from indicoio.config import TEXT_APIS, IMAGE_APIS, API_NAMES +from indicoio.utils.errors import IndicoError from indicoio.utils import api_handler @@ -22,7 +23,7 @@ def multi(data, type, apis, available, batch=False, **kwargs): # Client side api name checking - strictly only accept func name api invalid_apis = [api for api in apis if api not in available] if invalid_apis: - raise ValueError("%s are not valid %s APIs. Please reference the available APIs below:\n%s" + raise IndicoError("%s are not valid %s APIs. Please reference the available APIs below:\n%s" % (", ".join(invalid_apis), type, ", ".join(available)) ) # Convert client api names to server names before sending request @@ -31,15 +32,9 @@ def multi(data, type, apis, available, batch=False, **kwargs): return handle_response(result) def handle_response(result): - try: - # Parse out the results to a dicionary of api: result - return dict((SERVER_CLIENT_MAP[api], parsed_response(res)) - for api, res in result.iteritems()) - except KeyError: - for api in result: - if "error" in result[api]: - raise ValueError(result[api]["error"]) - raise Exception("Sorry, %s API returned an unexpected response:\n%s" % (api, result[api])) + # Parse out the results to a dicionary of api: result + return dict((SERVER_CLIENT_MAP[api], parsed_response(res)) + for api, res in result.iteritems()) def predict_text(input_text, apis=TEXT_APIS, cloud=None, batch=False, api_key=None, **kwargs): @@ -110,8 +105,11 @@ def predict_image(image, apis=IMAGE_APIS, cloud=None, batch=False, api_key=None, apis=apis, **kwargs) -def parsed_response(response): - result = response.get('results') or response.get('error', False) +def parsed_response(api, response): + result = response.get('results', False) if result: return result - raise KeyError + raise IndicoError( + "Sorry, the %s API returned an unexpected response.\n\t%s" + % (api, response.get('error', "")) + ) diff --git a/tests/test_remote.py b/tests/test_remote.py index 2d8c9b7..7ef7a58 100644 --- a/tests/test_remote.py +++ b/tests/test_remote.py @@ -134,7 +134,6 @@ class BatchAPIRun(unittest.TestCase): self.assertTrue(isinstance(response, dict)) self.assertTrue(set(response.keys()) == set(config.IMAGE_APIS)) - self.assertTrue("results" in response["fer"]) def test_batch_multi_api_text(self): @@ -143,7 +142,6 @@ class BatchAPIRun(unittest.TestCase): self.assertTrue(isinstance(response, dict)) self.assertTrue(set(response.keys()) == set(config.TEXT_APIS)) - self.assertTrue("results" in response["sentiment"]) def test_default_multi_api_text(self): test_data = ['clearly an english sentence']