From bd6e549d6c830c47b583f07629ed51f25eff64ba Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Sun, 19 Jun 2016 18:04:55 -0700 Subject: [PATCH] isinstance -> subclass to better support functionality in typing module (#112) --- lib/python/ray/arrays/distributed/linalg.py | 2 +- lib/python/ray/arrays/remote/core.py | 10 ++++++++-- lib/python/ray/arrays/remote/linalg.py | 4 ---- lib/python/ray/worker.py | 6 +++--- 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/lib/python/ray/arrays/distributed/linalg.py b/lib/python/ray/arrays/distributed/linalg.py index 257a68da7..a3d509416 100644 --- a/lib/python/ray/arrays/distributed/linalg.py +++ b/lib/python/ray/arrays/distributed/linalg.py @@ -176,7 +176,7 @@ def qr(a): for r in range(i, a.num_blocks[0]): y_ri = y_val.objrefs[r - i, 0] W_rcs.append(qr_helper2(y_ri, a_work.objrefs[r, c])) - W_c = ra.linalg.sum_list(*W_rcs) + W_c = ra.sum_list(*W_rcs) for r in range(i, a.num_blocks[0]): y_ri = y_val.objrefs[r - i, 0] A_rc = qr_helper1(a_work.objrefs[r, c], y_ri, t, W_c) diff --git a/lib/python/ray/arrays/remote/core.py b/lib/python/ray/arrays/remote/core.py index 65f529c2e..a629a7985 100644 --- a/lib/python/ray/arrays/remote/core.py +++ b/lib/python/ray/arrays/remote/core.py @@ -1,8 +1,8 @@ -from typing import List +from typing import List, Any import numpy as np import ray -__all__ = ["zeros", "zeros_like", "ones", "eye", "dot", "vstack", "hstack", "subarray", "copy", "tril", "triu", "diag", "transpose", "add", "subtract", "sum", "shape"] +__all__ = ["zeros", "zeros_like", "ones", "eye", "dot", "vstack", "hstack", "subarray", "copy", "tril", "triu", "diag", "transpose", "add", "subtract", "sum", "shape", "sum_list"] @ray.remote([List[int], str, str], [np.ndarray]) def zeros(shape, dtype_name="float", order="C"): @@ -75,3 +75,9 @@ def sum(x, axis=-1): @ray.remote([np.ndarray], [tuple]) def shape(a): return np.shape(a) + +# We use Any to allow different numerical types as well as numpy arrays. +# TODO(rkn):this isn't in the numpy API, so be careful about exposing this. +@ray.remote([Any], [Any]) +def sum_list(*xs): + return np.sum(xs, axis=0) diff --git a/lib/python/ray/arrays/remote/linalg.py b/lib/python/ray/arrays/remote/linalg.py index 332394b63..f79e94684 100644 --- a/lib/python/ray/arrays/remote/linalg.py +++ b/lib/python/ray/arrays/remote/linalg.py @@ -86,7 +86,3 @@ def matrix_rank(M): @ray.remote([np.ndarray], [np.ndarray]) def multi_dot(*a): raise NotImplementedError - -@ray.remote([np.ndarray], [np.ndarray]) -def sum_list(*xs): - return np.sum(xs, axis=0) diff --git a/lib/python/ray/worker.py b/lib/python/ray/worker.py index ca094c1da..b49430810 100644 --- a/lib/python/ray/worker.py +++ b/lib/python/ray/worker.py @@ -164,7 +164,7 @@ def check_return_values(function, result): if len(result) != len(function.return_types): raise Exception("The @remote decorator for function {} has {} return values with types {}, but {} returned {} values.".format(function.__name__, len(function.return_types), function.return_types, function.__name__, len(result))) for i in range(len(result)): - if (not isinstance(result[i], function.return_types[i])) and (not isinstance(result[i], ray.lib.ObjRef)): + if (not issubclass(type(result[i]), function.return_types[i])) and (not isinstance(result[i], ray.lib.ObjRef)): raise Exception("The {}th return value for function {} has type {}, but the @remote decorator expected a return value of type {} or an ObjRef.".format(i, function.__name__, type(result[i]), function.return_types[i])) # helper method, this should not be called by the user @@ -187,7 +187,7 @@ def check_arguments(function, args): # TODO(rkn): When we have type information in the ObjRef, do type checking here. pass else: - if not isinstance(arg, expected_type): # TODO(rkn): This check doesn't really work, e.g., isinstance([1,2,3], typing.List[str]) == True + if not issubclass(type(arg), expected_type): # TODO(rkn): This check doesn't really work, e.g., issubclass(type([1, 2, 3]), typing.List[str]) == True raise Exception("Argument {} for function {} has type {} but an argument of type {} was expected.".format(i, function.__name__, type(arg), expected_type)) # helper method, this should not be called by the user @@ -219,7 +219,7 @@ def get_arguments_for_execution(function, args, worker=global_worker): # pass the argument by value argument = arg - if not isinstance(argument, expected_type): + if not issubclass(type(argument), expected_type): raise Exception("Argument {} for function {} has type {} but an argument of type {} was expected.".format(i, function.__name__, type(argument), expected_type)) arguments.append(argument) return arguments