int and long should be treated similarly (#220)

This commit is contained in:
mehrdadn
2016-07-07 03:31:58 +03:00
committed by Philipp Moritz
parent 5412d3c773
commit 199b4efd50
5 changed files with 35 additions and 5 deletions
+3
View File
@@ -11,6 +11,9 @@ import ray
class Int(int):
pass
class Long(long):
pass
class Float(float):
pass
+15 -4
View File
@@ -70,6 +70,8 @@ class Worker(object):
result = serialization.deserialize(self.handle, object_capsule)
if isinstance(result, int):
result = serialization.Int(result)
elif isinstance(result, long):
result = serialization.Long(result)
elif isinstance(result, float):
result = serialization.Float(result)
elif isinstance(result, bool):
@@ -396,6 +398,17 @@ def check_return_values(function, result):
if (not issubclass(type(result[i]), function.return_types[i])) and (not isinstance(result[i], ray.lib.ObjRef)):
raise Exception("The {}th return value for function {} has type {}, but the @remote decorator expected a return value of type {} or an ObjRef.".format(i, function.__name__, type(result[i]), function.return_types[i]))
def typecheck_arg(arg, expected_type, i, function):
if issubclass(type(arg), expected_type):
# Passed the type-checck
# TODO(rkn): This check doesn't really work, e.g., issubclass(type([1, 2, 3]), typing.List[str]) == True
pass
elif isinstance(arg, long) and issubclass(int, expected_type):
# TODO(mehrdadn): Should long really be convertible to int?
pass
else:
raise Exception("Argument {} for function {} has type {} but an argument of type {} was expected.".format(i, function.__name__, type(arg), expected_type))
# helper method, this should not be called by the user
def check_arguments(function, args):
# check the number of args
@@ -416,8 +429,7 @@ def check_arguments(function, args):
# TODO(rkn): When we have type information in the ObjRef, do type checking here.
pass
else:
if not issubclass(type(arg), expected_type): # TODO(rkn): This check doesn't really work, e.g., issubclass(type([1, 2, 3]), typing.List[str]) == True
raise Exception("Argument {} for function {} has type {} but an argument of type {} was expected.".format(i, function.__name__, type(arg), expected_type))
typecheck_arg(arg, expected_type, i, function)
# helper method, this should not be called by the user
def get_arguments_for_execution(function, args, worker=global_worker):
@@ -448,8 +460,7 @@ def get_arguments_for_execution(function, args, worker=global_worker):
# pass the argument by value
argument = arg
if not issubclass(type(argument), expected_type):
raise Exception("Argument {} for function {} has type {} but an argument of type {} was expected.".format(i, function.__name__, type(argument), expected_type))
typecheck_arg(argument, expected_type, i, function)
arguments.append(argument)
return arguments
+5
View File
@@ -4,6 +4,10 @@ message Int {
int64 data = 1;
}
message Long {
int64 data = 1;
}
message String {
string data = 1;
}
@@ -32,6 +36,7 @@ message PyObj {
message Obj {
String string_data = 1;
Int int_data = 2;
Long long_data = 12;
Double double_data = 3;
Bool bool_data = 10;
Tuple tuple_data = 7;
+11
View File
@@ -225,6 +225,7 @@ void set_dict_item_and_transfer_ownership(PyObject* dict, PyObject* key, PyObjec
// serialize will serialize the python object val into the protocol buffer
// object obj, returns 0 if successful and something else if not
// NOTE: If some primitive types are added here, they may also need to be handled in serialization.py
// FIXME(pcm): This currently only works for contiguous arrays
// This method will push all of the object references contained in `obj` to the `objrefs` vector.
int serialize(PyObject* worker_capsule, PyObject* val, Obj* obj, std::vector<ObjRef> &objrefs) {
@@ -240,6 +241,14 @@ int serialize(PyObject* worker_capsule, PyObject* val, Obj* obj, std::vector<Obj
Int* data = obj->mutable_int_data();
long d = PyInt_AsLong(val);
data->set_data(d);
} else if (PyLong_Check(val)) {
// TODO(mehrdadn): We do not currently support arbitrary long values.
int overflow = 0;
Long* data = obj->mutable_long_data();
data->set_data(PyLong_AsLongLongAndOverflow(val, &overflow));
if (overflow) {
PyErr_SetString(RayError, "serialization: long overflow");
}
} else if (PyFloat_Check(val)) {
Double* data = obj->mutable_double_data();
double d = PyFloat_AsDouble(val);
@@ -359,6 +368,8 @@ int serialize(PyObject* worker_capsule, PyObject* val, Obj* obj, std::vector<Obj
static PyObject* deserialize(PyObject* worker_capsule, const Obj& obj, std::vector<ObjRef> &objrefs) {
if (obj.has_int_data()) {
return PyInt_FromLong(obj.int_data().data());
} else if (obj.has_long_data()) {
return PyLong_FromLongLong(obj.long_data().data());
} else if (obj.has_double_data()) {
return PyFloat_FromDouble(obj.double_data().data());
} else if (obj.has_bool_data()) {
+1 -1
View File
@@ -12,7 +12,7 @@ import test_functions
import ray.array.remote as ra
import ray.array.distributed as da
RAY_TEST_OBJECTS = [[1, "hello", 3.0], 42, "hello world", 42.0,
RAY_TEST_OBJECTS = [[1, "hello", 3.0], 42, 43L, "hello world", 42.0, 1L << 62,
(1.0, "hi"), None, (None, None), ("hello", None),
True, False, (True, False),
{True: "hello", False: "world"},