raise exception if user tries to pass large object by value (#276)

This commit is contained in:
Robert Nishihara
2016-07-16 17:17:48 -07:00
committed by Philipp Moritz
parent 8465df1146
commit ced5ce4924
2 changed files with 21 additions and 2 deletions
+20 -1
View File
@@ -138,6 +138,7 @@ PyObject* make_pyobjref(PyObject* worker_capsule, ObjRef objref) {
// Error handling
static PyObject *RayError;
static PyObject *RaySizeError;
// Pass arguments from Python to C++
@@ -574,6 +575,20 @@ static PyObject* serialize_task(PyObject* self, PyObject* args) {
RAY_LOG(RAY_REFCOUNT, "In serialize_task, calling increment_reference_count for contained objrefs");
worker->increment_reference_count(objrefs);
}
std::string output;
task->SerializeToString(&output);
int task_size = output.length();
if (task_size > 1024) {
// Large objects should not be passed to tasks by value. Instead, they
// should be placed in the object store and passed by object
// reference.
RAY_LOG(RAY_INFO, "Warning: attempting to serialize a task with size " << task_size << ".");
PyErr_SetString(RaySizeError, "serialize_task: This task is too large (greater than 1024 bytes). "
"Please do not pass large objects by value to remote functions. "
"Instead, put large objects in the object store and pass them by "
"object reference to the remote function.");
return NULL;
}
return PyCapsule_New(static_cast<void*>(task), "task", &TaskCapsule_Destructor);
}
@@ -935,9 +950,13 @@ PyMODINIT_FUNC initlibraylib(void) {
Py_INCREF(&PyObjRefType);
PyModule_AddObject(m, "ObjRef", (PyObject *)&PyObjRefType);
char ray_error[] = "ray.error";
char ray_size_error[] = "ray_size.error";
RayError = PyErr_NewException(ray_error, NULL, NULL);
RaySizeError = PyErr_NewException(ray_size_error, NULL, NULL);
Py_INCREF(RayError);
PyModule_AddObject(m, "error", RayError);
Py_INCREF(RaySizeError);
PyModule_AddObject(m, "ray_error", RayError);
PyModule_AddObject(m, "ray_size_error", RaySizeError);
import_array();
}
+1 -1
View File
@@ -25,7 +25,7 @@ class RemoteArrayTest(unittest.TestCase):
self.assertTrue(np.alltrue(val == np.zeros([3, 4, 5])))
# test qr - pass by value
val_a = np.random.normal(size=[10, 13])
val_a = np.random.normal(size=[10, 11])
ref_q, ref_r = ra.linalg.qr(val_a)
val_q = ray.get(ref_q)
val_r = ray.get(ref_r)