Allow Ray API to be used from multiple threads (#2422)

This commit is contained in:
Hao Chen
2018-07-21 06:39:01 +08:00
committed by Robert Nishihara
parent 4b6157ed09
commit 05f485e274
5 changed files with 167 additions and 93 deletions
+56
View File
@@ -3,10 +3,12 @@ from __future__ import division
from __future__ import print_function
import binascii
import functools
import hashlib
import numpy as np
import os
import sys
import threading
import time
import uuid
@@ -295,3 +297,57 @@ def check_oversized_pickle(pickled, name, obj_type, worker):
ray_constants.PICKLING_LARGE_OBJECT_PUSH_ERROR,
warning_message,
driver_id=worker.task_driver_id.id())
class _ThreadSafeProxy(object):
"""This class is used to create a thread-safe proxy for a given object.
Every method call will be guarded with a lock.
Attributes:
orig_obj (object): the original object.
lock (threading.Lock): the lock object.
_wrapper_cache (dict): a cache from original object's methods to
the proxy methods.
"""
def __init__(self, orig_obj, lock):
self.orig_obj = orig_obj
self.lock = lock
self._wrapper_cache = {}
def __getattr__(self, attr):
orig_attr = getattr(self.orig_obj, attr)
if not callable(orig_attr):
# If the original attr is a field, just return it.
return orig_attr
else:
# If the orginal attr is a method,
# return a wrapper that guards the original method with a lock.
wrapper = self._wrapper_cache.get(attr)
if wrapper is None:
@functools.wraps(orig_attr)
def _wrapper(*args, **kwargs):
with self.lock:
return orig_attr(*args, **kwargs)
self._wrapper_cache[attr] = _wrapper
wrapper = _wrapper
return wrapper
def thread_safe_client(client, lock=None):
"""Create a thread-safe proxy which locks every method call
for the given client.
Args:
client: the client object to be guarded.
lock: the lock object that will be used to lock client's methods.
If None, a new lock will be used.
Returns:
A thread-safe proxy for the given client.
"""
if lock is None:
lock = threading.Lock()
return _ThreadSafeProxy(client, lock)