mirror of
https://github.com/wassname/ray.git
synced 2026-07-04 21:04:35 +08:00
raise exception if user tries to pass large object by value (#276)
This commit is contained in:
committed by
Philipp Moritz
parent
8465df1146
commit
ced5ce4924
+20
-1
@@ -138,6 +138,7 @@ PyObject* make_pyobjref(PyObject* worker_capsule, ObjRef objref) {
|
||||
// Error handling
|
||||
|
||||
static PyObject *RayError;
|
||||
static PyObject *RaySizeError;
|
||||
|
||||
// Pass arguments from Python to C++
|
||||
|
||||
@@ -574,6 +575,20 @@ static PyObject* serialize_task(PyObject* self, PyObject* args) {
|
||||
RAY_LOG(RAY_REFCOUNT, "In serialize_task, calling increment_reference_count for contained objrefs");
|
||||
worker->increment_reference_count(objrefs);
|
||||
}
|
||||
std::string output;
|
||||
task->SerializeToString(&output);
|
||||
int task_size = output.length();
|
||||
if (task_size > 1024) {
|
||||
// Large objects should not be passed to tasks by value. Instead, they
|
||||
// should be placed in the object store and passed by object
|
||||
// reference.
|
||||
RAY_LOG(RAY_INFO, "Warning: attempting to serialize a task with size " << task_size << ".");
|
||||
PyErr_SetString(RaySizeError, "serialize_task: This task is too large (greater than 1024 bytes). "
|
||||
"Please do not pass large objects by value to remote functions. "
|
||||
"Instead, put large objects in the object store and pass them by "
|
||||
"object reference to the remote function.");
|
||||
return NULL;
|
||||
}
|
||||
return PyCapsule_New(static_cast<void*>(task), "task", &TaskCapsule_Destructor);
|
||||
}
|
||||
|
||||
@@ -935,9 +950,13 @@ PyMODINIT_FUNC initlibraylib(void) {
|
||||
Py_INCREF(&PyObjRefType);
|
||||
PyModule_AddObject(m, "ObjRef", (PyObject *)&PyObjRefType);
|
||||
char ray_error[] = "ray.error";
|
||||
char ray_size_error[] = "ray_size.error";
|
||||
RayError = PyErr_NewException(ray_error, NULL, NULL);
|
||||
RaySizeError = PyErr_NewException(ray_size_error, NULL, NULL);
|
||||
Py_INCREF(RayError);
|
||||
PyModule_AddObject(m, "error", RayError);
|
||||
Py_INCREF(RaySizeError);
|
||||
PyModule_AddObject(m, "ray_error", RayError);
|
||||
PyModule_AddObject(m, "ray_size_error", RaySizeError);
|
||||
import_array();
|
||||
}
|
||||
|
||||
|
||||
+1
-1
@@ -25,7 +25,7 @@ class RemoteArrayTest(unittest.TestCase):
|
||||
self.assertTrue(np.alltrue(val == np.zeros([3, 4, 5])))
|
||||
|
||||
# test qr - pass by value
|
||||
val_a = np.random.normal(size=[10, 13])
|
||||
val_a = np.random.normal(size=[10, 11])
|
||||
ref_q, ref_r = ra.linalg.qr(val_a)
|
||||
val_q = ray.get(ref_q)
|
||||
val_r = ray.get(ref_r)
|
||||
|
||||
Reference in New Issue
Block a user