mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 17:37:07 +08:00
isinstance -> subclass to better support functionality in typing module (#112)
This commit is contained in:
committed by
Philipp Moritz
parent
b885c2f9aa
commit
bd6e549d6c
@@ -176,7 +176,7 @@ def qr(a):
|
||||
for r in range(i, a.num_blocks[0]):
|
||||
y_ri = y_val.objrefs[r - i, 0]
|
||||
W_rcs.append(qr_helper2(y_ri, a_work.objrefs[r, c]))
|
||||
W_c = ra.linalg.sum_list(*W_rcs)
|
||||
W_c = ra.sum_list(*W_rcs)
|
||||
for r in range(i, a.num_blocks[0]):
|
||||
y_ri = y_val.objrefs[r - i, 0]
|
||||
A_rc = qr_helper1(a_work.objrefs[r, c], y_ri, t, W_c)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from typing import List
|
||||
from typing import List, Any
|
||||
import numpy as np
|
||||
import ray
|
||||
|
||||
__all__ = ["zeros", "zeros_like", "ones", "eye", "dot", "vstack", "hstack", "subarray", "copy", "tril", "triu", "diag", "transpose", "add", "subtract", "sum", "shape"]
|
||||
__all__ = ["zeros", "zeros_like", "ones", "eye", "dot", "vstack", "hstack", "subarray", "copy", "tril", "triu", "diag", "transpose", "add", "subtract", "sum", "shape", "sum_list"]
|
||||
|
||||
@ray.remote([List[int], str, str], [np.ndarray])
|
||||
def zeros(shape, dtype_name="float", order="C"):
|
||||
@@ -75,3 +75,9 @@ def sum(x, axis=-1):
|
||||
@ray.remote([np.ndarray], [tuple])
|
||||
def shape(a):
|
||||
return np.shape(a)
|
||||
|
||||
# We use Any to allow different numerical types as well as numpy arrays.
|
||||
# TODO(rkn):this isn't in the numpy API, so be careful about exposing this.
|
||||
@ray.remote([Any], [Any])
|
||||
def sum_list(*xs):
|
||||
return np.sum(xs, axis=0)
|
||||
|
||||
@@ -86,7 +86,3 @@ def matrix_rank(M):
|
||||
@ray.remote([np.ndarray], [np.ndarray])
|
||||
def multi_dot(*a):
|
||||
raise NotImplementedError
|
||||
|
||||
@ray.remote([np.ndarray], [np.ndarray])
|
||||
def sum_list(*xs):
|
||||
return np.sum(xs, axis=0)
|
||||
|
||||
@@ -164,7 +164,7 @@ def check_return_values(function, result):
|
||||
if len(result) != len(function.return_types):
|
||||
raise Exception("The @remote decorator for function {} has {} return values with types {}, but {} returned {} values.".format(function.__name__, len(function.return_types), function.return_types, function.__name__, len(result)))
|
||||
for i in range(len(result)):
|
||||
if (not isinstance(result[i], function.return_types[i])) and (not isinstance(result[i], ray.lib.ObjRef)):
|
||||
if (not issubclass(type(result[i]), function.return_types[i])) and (not isinstance(result[i], ray.lib.ObjRef)):
|
||||
raise Exception("The {}th return value for function {} has type {}, but the @remote decorator expected a return value of type {} or an ObjRef.".format(i, function.__name__, type(result[i]), function.return_types[i]))
|
||||
|
||||
# helper method, this should not be called by the user
|
||||
@@ -187,7 +187,7 @@ def check_arguments(function, args):
|
||||
# TODO(rkn): When we have type information in the ObjRef, do type checking here.
|
||||
pass
|
||||
else:
|
||||
if not isinstance(arg, expected_type): # TODO(rkn): This check doesn't really work, e.g., isinstance([1,2,3], typing.List[str]) == True
|
||||
if not issubclass(type(arg), expected_type): # TODO(rkn): This check doesn't really work, e.g., issubclass(type([1, 2, 3]), typing.List[str]) == True
|
||||
raise Exception("Argument {} for function {} has type {} but an argument of type {} was expected.".format(i, function.__name__, type(arg), expected_type))
|
||||
|
||||
# helper method, this should not be called by the user
|
||||
@@ -219,7 +219,7 @@ def get_arguments_for_execution(function, args, worker=global_worker):
|
||||
# pass the argument by value
|
||||
argument = arg
|
||||
|
||||
if not isinstance(argument, expected_type):
|
||||
if not issubclass(type(argument), expected_type):
|
||||
raise Exception("Argument {} for function {} has type {} but an argument of type {} was expected.".format(i, function.__name__, type(argument), expected_type))
|
||||
arguments.append(argument)
|
||||
return arguments
|
||||
|
||||
Reference in New Issue
Block a user