Add option of load_code_from_local which is required in cross-language ray call. (#3675)

This commit is contained in:
Yuhong Guo
2019-02-21 12:37:17 +08:00
committed by Hao Chen
parent e3066d1fa5
commit 1f864a02bc
9 changed files with 270 additions and 105 deletions
+171 -92
View File
@@ -3,9 +3,11 @@ from __future__ import division
from __future__ import print_function
import hashlib
import importlib
import inspect
import json
import logging
import six
import sys
import time
import traceback
@@ -87,9 +89,9 @@ class FunctionDescriptor(object):
return FunctionDescriptor.for_driver_task()
elif (len(function_descriptor_list) == 3
or len(function_descriptor_list) == 4):
module_name = function_descriptor_list[0].decode()
class_name = function_descriptor_list[1].decode()
function_name = function_descriptor_list[2].decode()
module_name = six.ensure_str(function_descriptor_list[0])
class_name = six.ensure_str(function_descriptor_list[1])
function_name = six.ensure_str(function_descriptor_list[2])
if len(function_descriptor_list) == 4:
return cls(module_name, function_name, class_name,
function_descriptor_list[3])
@@ -256,6 +258,14 @@ class FunctionDescriptor(object):
descriptor_list.append(self._function_source_hash)
return descriptor_list
def is_actor_method(self):
"""Wether this function descriptor is an actor method.
Returns:
True if it's an actor method, False if it's a normal function.
"""
return len(self._class_name) > 0
class FunctionActorManager(object):
"""A class used to export/load remote functions and actors.
@@ -289,13 +299,18 @@ class FunctionActorManager(object):
# import thread. It is safe to convert this worker into an actor of
# these types.
self.imported_actor_classes = set()
self._loaded_actor_classes = {}
def increase_task_counter(self, driver_id, function_descriptor):
function_id = function_descriptor.function_id
if self._worker.load_code_from_local:
driver_id = ray.DriverID.nil()
self._num_task_executions[driver_id][function_id] += 1
def get_task_counter(self, driver_id, function_descriptor):
function_id = function_descriptor.function_id
if self._worker.load_code_from_local:
driver_id = ray.DriverID.nil()
return self._num_task_executions[driver_id][function_id]
def export_cached(self):
@@ -336,6 +351,8 @@ class FunctionActorManager(object):
Args:
remote_function: the RemoteFunction object.
"""
if self._worker.load_code_from_local:
return
# Work around limitations of Python pickling.
function = remote_function._function
function_name_global_valid = function.__name__ in function.__globals__
@@ -436,16 +453,24 @@ class FunctionActorManager(object):
Returns:
A FunctionExecutionInfo object.
"""
function_id = function_descriptor.function_id
# Wait until the function to be executed has actually been
# registered on this worker. We will push warnings to the user if
# we spend too long in this loop.
# The driver function may not be found in sys.path. Try to load
# the function from GCS.
with profiling.profile("wait_for_function"):
self._wait_for_function(function_descriptor, driver_id)
if self._worker.load_code_from_local:
# Load function from local code.
# Currently, we don't support isolating code by drivers,
# thus always set driver ID to NIL here.
driver_id = ray.DriverID.nil()
if not function_descriptor.is_actor_method():
self._load_function_from_local(driver_id, function_descriptor)
else:
# Load function from GCS.
# Wait until the function to be executed has actually been
# registered on this worker. We will push warnings to the user if
# we spend too long in this loop.
# The driver function may not be found in sys.path. Try to load
# the function from GCS.
with profiling.profile("wait_for_function"):
self._wait_for_function(function_descriptor, driver_id)
try:
function_id = function_descriptor.function_id
info = self._function_execution_info[driver_id][function_id]
except KeyError as e:
message = ("Error occurs in get_execution_info: "
@@ -454,6 +479,33 @@ class FunctionActorManager(object):
raise KeyError(message)
return info
def _load_function_from_local(self, driver_id, function_descriptor):
assert not function_descriptor.is_actor_method()
function_id = function_descriptor.function_id
if (driver_id in self._function_execution_info
and function_id in self._function_execution_info[function_id]):
return
module_name, function_name = (
function_descriptor.module_name,
function_descriptor.function_name,
)
try:
module = importlib.import_module(module_name)
function = getattr(module, function_name)._function
self._function_execution_info[driver_id][function_id] = (
FunctionExecutionInfo(
function=function,
function_name=function_name,
max_calls=0,
))
self._num_task_executions[driver_id][function_id] = 0
except Exception:
logger.exception(
"Failed to load function %s.".format(function_name))
raise Exception(
"Function {} failed to be loaded from local code.".format(
function_descriptor))
def _wait_for_function(self, function_descriptor, driver_id, timeout=10):
"""Wait until the function to be executed is present on this worker.
@@ -513,6 +565,8 @@ class FunctionActorManager(object):
self._worker.redis_client.rpush("Exports", key)
def export_actor_class(self, Class, actor_method_names):
if self._worker.load_code_from_local:
return
function_descriptor = FunctionDescriptor.from_class(Class)
# `task_driver_id` shouldn't be NIL, unless:
# 1) This worker isn't an actor;
@@ -553,7 +607,87 @@ class FunctionActorManager(object):
# within tasks. I tried to disable this, but it may be necessary
# because of https://github.com/ray-project/ray/issues/1146.
def load_actor(self, driver_id, function_descriptor):
def load_actor_class(self, driver_id, function_descriptor):
"""Load the actor class.
Args:
driver_id: Driver ID of the actor.
function_descriptor: Function descriptor of the actor constructor.
Returns:
The actor class.
"""
function_id = function_descriptor.function_id
# Check if the actor class already exists in the cache.
actor_class = self._loaded_actor_classes.get(function_id, None)
if actor_class is None:
# Load actor class.
if self._worker.load_code_from_local:
driver_id = ray.DriverID.nil()
# Load actor class from local code.
actor_class = self._load_actor_from_local(
driver_id, function_descriptor)
else:
# Load actor class from GCS.
actor_class = self._load_actor_class_from_gcs(
driver_id, function_descriptor)
# Save the loaded actor class in cache.
self._loaded_actor_classes[function_id] = actor_class
# Generate execution info for the methods of this actor class.
module_name = function_descriptor.module_name
actor_class_name = function_descriptor.class_name
actor_methods = inspect.getmembers(
actor_class, predicate=is_function_or_method)
for actor_method_name, actor_method in actor_methods:
method_descriptor = FunctionDescriptor(
module_name, actor_method_name, actor_class_name)
method_id = method_descriptor.function_id
executor = self._make_actor_method_executor(
actor_method_name,
actor_method,
actor_imported=True,
)
self._function_execution_info[driver_id][method_id] = (
FunctionExecutionInfo(
function=executor,
function_name=actor_method_name,
max_calls=0,
))
self._num_task_executions[driver_id][method_id] = 0
self._num_task_executions[driver_id][function_id] = 0
return actor_class
def _load_actor_from_local(self, driver_id, function_descriptor):
"""Load actor class from local code."""
module_name, class_name = (function_descriptor.module_name,
function_descriptor.class_name)
try:
module = importlib.import_module(module_name)
return getattr(module, class_name)._modified_class
except Exception:
logger.exception(
"Failed to load actor_class %s.".format(class_name))
raise Exception(
"Actor {} failed to be imported from local code.".format(
class_name))
def _create_fake_actor_class(self, actor_class_name, actor_method_names):
class TemporaryActor(object):
pass
def temporary_actor_method(*xs):
raise Exception(
"The actor with name {} failed to be imported, "
"and so cannot execute this method.".format(actor_class_name))
for method in actor_method_names:
setattr(TemporaryActor, method, temporary_actor_method)
return TemporaryActor
def _load_actor_class_from_gcs(self, driver_id, function_descriptor):
"""Load actor class from GCS."""
key = (b"ActorClass:" + driver_id.binary() + b":" +
function_descriptor.function_id.binary())
# Wait for the actor class key to have been imported by the
@@ -562,74 +696,32 @@ class FunctionActorManager(object):
# the driver if too much time is spent here.
while key not in self.imported_actor_classes:
time.sleep(0.001)
with self._worker.lock:
self.fetch_and_register_actor(key)
def fetch_and_register_actor(self, actor_class_key):
"""Import an actor.
This will be called by the worker's import thread when the worker
receives the actor_class export, assuming that the worker is an actor
for that class.
Args:
actor_class_key: The key in Redis to use to fetch the actor.
"""
actor_id = self._worker.actor_id
# Fetch raw data from GCS.
(driver_id_str, class_name, module, pickled_class,
actor_method_names) = self._worker.redis_client.hmget(
actor_class_key, [
key, [
"driver_id", "class_name", "module", "class",
"actor_method_names"
])
class_name = decode(class_name)
module = decode(module)
class_name = six.ensure_str(class_name)
module_name = six.ensure_str(module)
driver_id = ray.DriverID(driver_id_str)
actor_method_names = json.loads(decode(actor_method_names))
# In Python 2, json loads strings as unicode, so convert them back to
# strings.
if sys.version_info < (3, 0):
actor_method_names = [
method_name.encode("ascii")
for method_name in actor_method_names
]
# Create a temporary actor with some temporary methods so that if
# the actor fails to be unpickled, the temporary actor can be used
# (just to produce error messages and to prevent the driver from
# hanging).
class TemporaryActor(object):
pass
self._worker.actors[actor_id] = TemporaryActor()
def temporary_actor_method(*xs):
raise Exception(
"The actor with name {} failed to be imported, "
"and so cannot execute this method".format(class_name))
# Register the actor method executors.
for actor_method_name in actor_method_names:
function_descriptor = FunctionDescriptor(module, actor_method_name,
class_name)
function_id = function_descriptor.function_id
temporary_executor = self._make_actor_method_executor(
actor_method_name,
temporary_actor_method,
actor_imported=False)
self._function_execution_info[driver_id][function_id] = (
FunctionExecutionInfo(
function=temporary_executor,
function_name=actor_method_name,
max_calls=0))
self._num_task_executions[driver_id][function_id] = 0
actor_method_names = json.loads(six.ensure_str(actor_method_names))
actor_class = None
try:
unpickled_class = pickle.loads(pickled_class)
self._worker.actor_class = unpickled_class
with self._worker.lock:
actor_class = pickle.loads(pickled_class)
except Exception:
logger.exception(
"Failed to load actor class %s.".format(class_name))
# The actor class failed to be unpickled, create a fake actor
# class instead (just to produce error messages and to prevent
# the driver from hanging).
actor_class = self._create_fake_actor_class(
class_name, actor_method_names)
# If an exception was thrown when the actor was imported, we record
# the traceback and notify the scheduler of the failure.
traceback_str = ray.utils.format_error_message(
@@ -638,33 +730,20 @@ class FunctionActorManager(object):
push_error_to_driver(
self._worker, ray_constants.REGISTER_ACTOR_PUSH_ERROR,
"Failed to unpickle actor class '{}' for actor ID {}. "
"Traceback:\n{}".format(class_name, actor_id.hex(),
"Traceback:\n{}".format(class_name,
self._worker.actor_id.hex(),
traceback_str), driver_id)
# TODO(rkn): In the future, it might make sense to have the worker
# exit here. However, currently that would lead to hanging if
# someone calls ray.get on a method invoked on the actor.
else:
# TODO(pcm): Why is the below line necessary?
unpickled_class.__module__ = module
self._worker.actors[actor_id] = unpickled_class.__new__(
unpickled_class)
actor_methods = inspect.getmembers(
unpickled_class, predicate=is_function_or_method)
for actor_method_name, actor_method in actor_methods:
function_descriptor = FunctionDescriptor(
module, actor_method_name, class_name)
function_id = function_descriptor.function_id
executor = self._make_actor_method_executor(
actor_method_name, actor_method, actor_imported=True)
self._function_execution_info[driver_id][function_id] = (
FunctionExecutionInfo(
function=executor,
function_name=actor_method_name,
max_calls=0))
# We do not set function_properties[driver_id][function_id]
# because we currently do need the actor worker to submit new
# tasks for the actor.
# The below line is necessary. Because in the driver process,
# if the function is defined in the file where the python script
# was started from, its module is `__main__`.
# However in the worker process, the `__main__` module is a
# different module, which is `default_worker.py`
actor_class.__module__ = module_name
return actor_class
def _make_actor_method_executor(self, method_name, method, actor_imported):
"""Make an executor that wraps a user-defined actor method.