mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 16:14:06 +08:00
int and long should be treated similarly (#220)
This commit is contained in:
@@ -11,6 +11,9 @@ import ray
|
||||
class Int(int):
|
||||
pass
|
||||
|
||||
class Long(long):
|
||||
pass
|
||||
|
||||
class Float(float):
|
||||
pass
|
||||
|
||||
|
||||
@@ -70,6 +70,8 @@ class Worker(object):
|
||||
result = serialization.deserialize(self.handle, object_capsule)
|
||||
if isinstance(result, int):
|
||||
result = serialization.Int(result)
|
||||
elif isinstance(result, long):
|
||||
result = serialization.Long(result)
|
||||
elif isinstance(result, float):
|
||||
result = serialization.Float(result)
|
||||
elif isinstance(result, bool):
|
||||
@@ -396,6 +398,17 @@ def check_return_values(function, result):
|
||||
if (not issubclass(type(result[i]), function.return_types[i])) and (not isinstance(result[i], ray.lib.ObjRef)):
|
||||
raise Exception("The {}th return value for function {} has type {}, but the @remote decorator expected a return value of type {} or an ObjRef.".format(i, function.__name__, type(result[i]), function.return_types[i]))
|
||||
|
||||
def typecheck_arg(arg, expected_type, i, function):
|
||||
if issubclass(type(arg), expected_type):
|
||||
# Passed the type-checck
|
||||
# TODO(rkn): This check doesn't really work, e.g., issubclass(type([1, 2, 3]), typing.List[str]) == True
|
||||
pass
|
||||
elif isinstance(arg, long) and issubclass(int, expected_type):
|
||||
# TODO(mehrdadn): Should long really be convertible to int?
|
||||
pass
|
||||
else:
|
||||
raise Exception("Argument {} for function {} has type {} but an argument of type {} was expected.".format(i, function.__name__, type(arg), expected_type))
|
||||
|
||||
# helper method, this should not be called by the user
|
||||
def check_arguments(function, args):
|
||||
# check the number of args
|
||||
@@ -416,8 +429,7 @@ def check_arguments(function, args):
|
||||
# TODO(rkn): When we have type information in the ObjRef, do type checking here.
|
||||
pass
|
||||
else:
|
||||
if not issubclass(type(arg), expected_type): # TODO(rkn): This check doesn't really work, e.g., issubclass(type([1, 2, 3]), typing.List[str]) == True
|
||||
raise Exception("Argument {} for function {} has type {} but an argument of type {} was expected.".format(i, function.__name__, type(arg), expected_type))
|
||||
typecheck_arg(arg, expected_type, i, function)
|
||||
|
||||
# helper method, this should not be called by the user
|
||||
def get_arguments_for_execution(function, args, worker=global_worker):
|
||||
@@ -448,8 +460,7 @@ def get_arguments_for_execution(function, args, worker=global_worker):
|
||||
# pass the argument by value
|
||||
argument = arg
|
||||
|
||||
if not issubclass(type(argument), expected_type):
|
||||
raise Exception("Argument {} for function {} has type {} but an argument of type {} was expected.".format(i, function.__name__, type(argument), expected_type))
|
||||
typecheck_arg(argument, expected_type, i, function)
|
||||
arguments.append(argument)
|
||||
return arguments
|
||||
|
||||
|
||||
@@ -4,6 +4,10 @@ message Int {
|
||||
int64 data = 1;
|
||||
}
|
||||
|
||||
message Long {
|
||||
int64 data = 1;
|
||||
}
|
||||
|
||||
message String {
|
||||
string data = 1;
|
||||
}
|
||||
@@ -32,6 +36,7 @@ message PyObj {
|
||||
message Obj {
|
||||
String string_data = 1;
|
||||
Int int_data = 2;
|
||||
Long long_data = 12;
|
||||
Double double_data = 3;
|
||||
Bool bool_data = 10;
|
||||
Tuple tuple_data = 7;
|
||||
|
||||
@@ -225,6 +225,7 @@ void set_dict_item_and_transfer_ownership(PyObject* dict, PyObject* key, PyObjec
|
||||
|
||||
// serialize will serialize the python object val into the protocol buffer
|
||||
// object obj, returns 0 if successful and something else if not
|
||||
// NOTE: If some primitive types are added here, they may also need to be handled in serialization.py
|
||||
// FIXME(pcm): This currently only works for contiguous arrays
|
||||
// This method will push all of the object references contained in `obj` to the `objrefs` vector.
|
||||
int serialize(PyObject* worker_capsule, PyObject* val, Obj* obj, std::vector<ObjRef> &objrefs) {
|
||||
@@ -240,6 +241,14 @@ int serialize(PyObject* worker_capsule, PyObject* val, Obj* obj, std::vector<Obj
|
||||
Int* data = obj->mutable_int_data();
|
||||
long d = PyInt_AsLong(val);
|
||||
data->set_data(d);
|
||||
} else if (PyLong_Check(val)) {
|
||||
// TODO(mehrdadn): We do not currently support arbitrary long values.
|
||||
int overflow = 0;
|
||||
Long* data = obj->mutable_long_data();
|
||||
data->set_data(PyLong_AsLongLongAndOverflow(val, &overflow));
|
||||
if (overflow) {
|
||||
PyErr_SetString(RayError, "serialization: long overflow");
|
||||
}
|
||||
} else if (PyFloat_Check(val)) {
|
||||
Double* data = obj->mutable_double_data();
|
||||
double d = PyFloat_AsDouble(val);
|
||||
@@ -359,6 +368,8 @@ int serialize(PyObject* worker_capsule, PyObject* val, Obj* obj, std::vector<Obj
|
||||
static PyObject* deserialize(PyObject* worker_capsule, const Obj& obj, std::vector<ObjRef> &objrefs) {
|
||||
if (obj.has_int_data()) {
|
||||
return PyInt_FromLong(obj.int_data().data());
|
||||
} else if (obj.has_long_data()) {
|
||||
return PyLong_FromLongLong(obj.long_data().data());
|
||||
} else if (obj.has_double_data()) {
|
||||
return PyFloat_FromDouble(obj.double_data().data());
|
||||
} else if (obj.has_bool_data()) {
|
||||
|
||||
+1
-1
@@ -12,7 +12,7 @@ import test_functions
|
||||
import ray.array.remote as ra
|
||||
import ray.array.distributed as da
|
||||
|
||||
RAY_TEST_OBJECTS = [[1, "hello", 3.0], 42, "hello world", 42.0,
|
||||
RAY_TEST_OBJECTS = [[1, "hello", 3.0], 42, 43L, "hello world", 42.0, 1L << 62,
|
||||
(1.0, "hi"), None, (None, None), ("hello", None),
|
||||
True, False, (True, False),
|
||||
{True: "hello", False: "world"},
|
||||
|
||||
Reference in New Issue
Block a user