mirror of
https://github.com/wassname/IndicoIo-python.git
synced 2026-06-27 16:10:34 +08:00
gs
This commit is contained in:
@@ -2,6 +2,7 @@ import inspect, json, getpass, os.path, base64, StringIO, re, warnings
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
from indicoio.utils.errors import IndicoError
|
||||
from indicoio import JSON_HEADERS
|
||||
from indicoio import config
|
||||
|
||||
@@ -29,13 +30,13 @@ def api_handler(arg, cloud, api, url_params = {"batch":False, "api_key":None}, *
|
||||
|
||||
response = requests.post(url, data=json_data, headers=JSON_HEADERS)
|
||||
if response.status_code == 503 and cloud != None:
|
||||
raise Exception("Private cloud '%s' does not include api '%s'" % (cloud, api))
|
||||
raise IndicoError("Private cloud '%s' does not include api '%s'" % (cloud, api))
|
||||
|
||||
json_results = response.json()
|
||||
results = json_results.get('results', False)
|
||||
if results is False:
|
||||
error = json_results.get('error')
|
||||
raise ValueError(error)
|
||||
raise IndicoError(error)
|
||||
return results
|
||||
|
||||
|
||||
@@ -97,7 +98,7 @@ def image_preprocess(image, size=(48,48), batch=False):
|
||||
elif B64_PATTERN.match(b64_str) is not None:
|
||||
return b64_str
|
||||
else:
|
||||
raise ValueError("Snose tring provided must be a valid filepath or base64 encoded string")
|
||||
raise IndicoError("Snose tring provided must be a valid filepath or base64 encoded string")
|
||||
|
||||
elif isinstance(image, list): # image passed in is a list and not np.array
|
||||
warnings.warn(
|
||||
@@ -110,7 +111,7 @@ def image_preprocess(image, size=(48,48), batch=False):
|
||||
elif type(image).__name__ == "ndarray": # image is from numpy/scipy
|
||||
out_image = Image.fromarray(image)
|
||||
else:
|
||||
raise ValueError("Image must be a filepath, base64 encoded string, or a numpy array")
|
||||
raise IndicoError("Image must be a filepath, base64 encoded string, or a numpy array")
|
||||
|
||||
# image resizing
|
||||
outImage = outImage.resize(size)
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
class IndicoError(ValueError):
|
||||
pass
|
||||
+11
-13
@@ -1,4 +1,5 @@
|
||||
from indicoio.config import TEXT_APIS, IMAGE_APIS, API_NAMES
|
||||
from indicoio.utils.errors import IndicoError
|
||||
from indicoio.utils import api_handler
|
||||
|
||||
|
||||
@@ -22,7 +23,7 @@ def multi(data, type, apis, available, batch=False, **kwargs):
|
||||
# Client side api name checking - strictly only accept func name api
|
||||
invalid_apis = [api for api in apis if api not in available]
|
||||
if invalid_apis:
|
||||
raise ValueError("%s are not valid %s APIs. Please reference the available APIs below:\n%s"
|
||||
raise IndicoError("%s are not valid %s APIs. Please reference the available APIs below:\n%s"
|
||||
% (", ".join(invalid_apis), type, ", ".join(available))
|
||||
)
|
||||
# Convert client api names to server names before sending request
|
||||
@@ -31,15 +32,9 @@ def multi(data, type, apis, available, batch=False, **kwargs):
|
||||
return handle_response(result)
|
||||
|
||||
def handle_response(result):
|
||||
try:
|
||||
# Parse out the results to a dicionary of api: result
|
||||
return dict((SERVER_CLIENT_MAP[api], parsed_response(res))
|
||||
for api, res in result.iteritems())
|
||||
except KeyError:
|
||||
for api in result:
|
||||
if "error" in result[api]:
|
||||
raise ValueError(result[api]["error"])
|
||||
raise Exception("Sorry, %s API returned an unexpected response:\n%s" % (api, result[api]))
|
||||
# Parse out the results to a dicionary of api: result
|
||||
return dict((SERVER_CLIENT_MAP[api], parsed_response(res))
|
||||
for api, res in result.iteritems())
|
||||
|
||||
|
||||
def predict_text(input_text, apis=TEXT_APIS, cloud=None, batch=False, api_key=None, **kwargs):
|
||||
@@ -110,8 +105,11 @@ def predict_image(image, apis=IMAGE_APIS, cloud=None, batch=False, api_key=None,
|
||||
apis=apis,
|
||||
**kwargs)
|
||||
|
||||
def parsed_response(response):
|
||||
result = response.get('results') or response.get('error', False)
|
||||
def parsed_response(api, response):
|
||||
result = response.get('results', False)
|
||||
if result:
|
||||
return result
|
||||
raise KeyError
|
||||
raise IndicoError(
|
||||
"Sorry, the %s API returned an unexpected response.\n\t%s"
|
||||
% (api, response.get('error', ""))
|
||||
)
|
||||
|
||||
@@ -134,7 +134,6 @@ class BatchAPIRun(unittest.TestCase):
|
||||
|
||||
self.assertTrue(isinstance(response, dict))
|
||||
self.assertTrue(set(response.keys()) == set(config.IMAGE_APIS))
|
||||
self.assertTrue("results" in response["fer"])
|
||||
|
||||
|
||||
def test_batch_multi_api_text(self):
|
||||
@@ -143,7 +142,6 @@ class BatchAPIRun(unittest.TestCase):
|
||||
|
||||
self.assertTrue(isinstance(response, dict))
|
||||
self.assertTrue(set(response.keys()) == set(config.TEXT_APIS))
|
||||
self.assertTrue("results" in response["sentiment"])
|
||||
|
||||
def test_default_multi_api_text(self):
|
||||
test_data = ['clearly an english sentence']
|
||||
|
||||
Reference in New Issue
Block a user