diff --git a/src/raylib.cc b/src/raylib.cc index 1bea09c91..d0f9fc4cd 100644 --- a/src/raylib.cc +++ b/src/raylib.cc @@ -138,6 +138,7 @@ PyObject* make_pyobjref(PyObject* worker_capsule, ObjRef objref) { // Error handling static PyObject *RayError; +static PyObject *RaySizeError; // Pass arguments from Python to C++ @@ -574,6 +575,20 @@ static PyObject* serialize_task(PyObject* self, PyObject* args) { RAY_LOG(RAY_REFCOUNT, "In serialize_task, calling increment_reference_count for contained objrefs"); worker->increment_reference_count(objrefs); } + std::string output; + task->SerializeToString(&output); + int task_size = output.length(); + if (task_size > 1024) { + // Large objects should not be passed to tasks by value. Instead, they + // should be placed in the object store and passed by object + // reference. + RAY_LOG(RAY_INFO, "Warning: attempting to serialize a task with size " << task_size << "."); + PyErr_SetString(RaySizeError, "serialize_task: This task is too large (greater than 1024 bytes). " + "Please do not pass large objects by value to remote functions. " + "Instead, put large objects in the object store and pass them by " + "object reference to the remote function."); + return NULL; + } return PyCapsule_New(static_cast(task), "task", &TaskCapsule_Destructor); } @@ -935,9 +950,13 @@ PyMODINIT_FUNC initlibraylib(void) { Py_INCREF(&PyObjRefType); PyModule_AddObject(m, "ObjRef", (PyObject *)&PyObjRefType); char ray_error[] = "ray.error"; + char ray_size_error[] = "ray_size.error"; RayError = PyErr_NewException(ray_error, NULL, NULL); + RaySizeError = PyErr_NewException(ray_size_error, NULL, NULL); Py_INCREF(RayError); - PyModule_AddObject(m, "error", RayError); + Py_INCREF(RaySizeError); + PyModule_AddObject(m, "ray_error", RayError); + PyModule_AddObject(m, "ray_size_error", RaySizeError); import_array(); } diff --git a/test/array_test.py b/test/array_test.py index 2d2b922f1..de47d298b 100644 --- a/test/array_test.py +++ b/test/array_test.py @@ -25,7 +25,7 @@ class RemoteArrayTest(unittest.TestCase): self.assertTrue(np.alltrue(val == np.zeros([3, 4, 5]))) # test qr - pass by value - val_a = np.random.normal(size=[10, 13]) + val_a = np.random.normal(size=[10, 11]) ref_q, ref_r = ra.linalg.qr(val_a) val_q = ray.get(ref_q) val_r = ray.get(ref_r)