From 199b4efd50aab2880bcff8dd9f6dc414f05348d1 Mon Sep 17 00:00:00 2001 From: mehrdadn Date: Thu, 7 Jul 2016 03:31:58 +0300 Subject: [PATCH] int and long should be treated similarly (#220) --- lib/python/ray/serialization.py | 3 +++ lib/python/ray/worker.py | 19 +++++++++++++++---- protos/types.proto | 5 +++++ src/raylib.cc | 11 +++++++++++ test/runtest.py | 2 +- 5 files changed, 35 insertions(+), 5 deletions(-) diff --git a/lib/python/ray/serialization.py b/lib/python/ray/serialization.py index 8b99cdacd..5079c406a 100644 --- a/lib/python/ray/serialization.py +++ b/lib/python/ray/serialization.py @@ -11,6 +11,9 @@ import ray class Int(int): pass +class Long(long): + pass + class Float(float): pass diff --git a/lib/python/ray/worker.py b/lib/python/ray/worker.py index 60afa4ad9..5fc653da8 100644 --- a/lib/python/ray/worker.py +++ b/lib/python/ray/worker.py @@ -70,6 +70,8 @@ class Worker(object): result = serialization.deserialize(self.handle, object_capsule) if isinstance(result, int): result = serialization.Int(result) + elif isinstance(result, long): + result = serialization.Long(result) elif isinstance(result, float): result = serialization.Float(result) elif isinstance(result, bool): @@ -396,6 +398,17 @@ def check_return_values(function, result): if (not issubclass(type(result[i]), function.return_types[i])) and (not isinstance(result[i], ray.lib.ObjRef)): raise Exception("The {}th return value for function {} has type {}, but the @remote decorator expected a return value of type {} or an ObjRef.".format(i, function.__name__, type(result[i]), function.return_types[i])) +def typecheck_arg(arg, expected_type, i, function): + if issubclass(type(arg), expected_type): + # Passed the type-checck + # TODO(rkn): This check doesn't really work, e.g., issubclass(type([1, 2, 3]), typing.List[str]) == True + pass + elif isinstance(arg, long) and issubclass(int, expected_type): + # TODO(mehrdadn): Should long really be convertible to int? + pass + else: + raise Exception("Argument {} for function {} has type {} but an argument of type {} was expected.".format(i, function.__name__, type(arg), expected_type)) + # helper method, this should not be called by the user def check_arguments(function, args): # check the number of args @@ -416,8 +429,7 @@ def check_arguments(function, args): # TODO(rkn): When we have type information in the ObjRef, do type checking here. pass else: - if not issubclass(type(arg), expected_type): # TODO(rkn): This check doesn't really work, e.g., issubclass(type([1, 2, 3]), typing.List[str]) == True - raise Exception("Argument {} for function {} has type {} but an argument of type {} was expected.".format(i, function.__name__, type(arg), expected_type)) + typecheck_arg(arg, expected_type, i, function) # helper method, this should not be called by the user def get_arguments_for_execution(function, args, worker=global_worker): @@ -448,8 +460,7 @@ def get_arguments_for_execution(function, args, worker=global_worker): # pass the argument by value argument = arg - if not issubclass(type(argument), expected_type): - raise Exception("Argument {} for function {} has type {} but an argument of type {} was expected.".format(i, function.__name__, type(argument), expected_type)) + typecheck_arg(argument, expected_type, i, function) arguments.append(argument) return arguments diff --git a/protos/types.proto b/protos/types.proto index 2a91abad3..cbd87d82e 100644 --- a/protos/types.proto +++ b/protos/types.proto @@ -4,6 +4,10 @@ message Int { int64 data = 1; } +message Long { + int64 data = 1; +} + message String { string data = 1; } @@ -32,6 +36,7 @@ message PyObj { message Obj { String string_data = 1; Int int_data = 2; + Long long_data = 12; Double double_data = 3; Bool bool_data = 10; Tuple tuple_data = 7; diff --git a/src/raylib.cc b/src/raylib.cc index f808d62f4..1bea09c91 100644 --- a/src/raylib.cc +++ b/src/raylib.cc @@ -225,6 +225,7 @@ void set_dict_item_and_transfer_ownership(PyObject* dict, PyObject* key, PyObjec // serialize will serialize the python object val into the protocol buffer // object obj, returns 0 if successful and something else if not +// NOTE: If some primitive types are added here, they may also need to be handled in serialization.py // FIXME(pcm): This currently only works for contiguous arrays // This method will push all of the object references contained in `obj` to the `objrefs` vector. int serialize(PyObject* worker_capsule, PyObject* val, Obj* obj, std::vector &objrefs) { @@ -240,6 +241,14 @@ int serialize(PyObject* worker_capsule, PyObject* val, Obj* obj, std::vectormutable_int_data(); long d = PyInt_AsLong(val); data->set_data(d); + } else if (PyLong_Check(val)) { + // TODO(mehrdadn): We do not currently support arbitrary long values. + int overflow = 0; + Long* data = obj->mutable_long_data(); + data->set_data(PyLong_AsLongLongAndOverflow(val, &overflow)); + if (overflow) { + PyErr_SetString(RayError, "serialization: long overflow"); + } } else if (PyFloat_Check(val)) { Double* data = obj->mutable_double_data(); double d = PyFloat_AsDouble(val); @@ -359,6 +368,8 @@ int serialize(PyObject* worker_capsule, PyObject* val, Obj* obj, std::vector &objrefs) { if (obj.has_int_data()) { return PyInt_FromLong(obj.int_data().data()); + } else if (obj.has_long_data()) { + return PyLong_FromLongLong(obj.long_data().data()); } else if (obj.has_double_data()) { return PyFloat_FromDouble(obj.double_data().data()); } else if (obj.has_bool_data()) { diff --git a/test/runtest.py b/test/runtest.py index ab46631e1..8745355fb 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -12,7 +12,7 @@ import test_functions import ray.array.remote as ra import ray.array.distributed as da -RAY_TEST_OBJECTS = [[1, "hello", 3.0], 42, "hello world", 42.0, +RAY_TEST_OBJECTS = [[1, "hello", 3.0], 42, 43L, "hello world", 42.0, 1L << 62, (1.0, "hi"), None, (None, None), ("hello", None), True, False, (True, False), {True: "hello", False: "world"},