Enable function_descriptor in backend to replace the function_id (#3028)

This commit is contained in:
Yuhong Guo
2018-12-19 07:53:59 +08:00
committed by Robert Nishihara
parent 3822b20319
commit fb33fa9097
20 changed files with 557 additions and 282 deletions
+278 -46
View File
@@ -5,6 +5,7 @@ from __future__ import print_function
import hashlib
import inspect
import json
import logging
import sys
import time
import traceback
@@ -18,6 +19,7 @@ from ray import profiling
from ray import ray_constants
from ray import cloudpickle as pickle
from ray.utils import (
binary_to_hex,
is_cython,
is_function_or_method,
is_class_method,
@@ -31,6 +33,228 @@ FunctionExecutionInfo = namedtuple("FunctionExecutionInfo",
["function", "function_name", "max_calls"])
"""FunctionExecutionInfo: A named tuple storing remote function information."""
logger = logging.getLogger(__name__)
class FunctionDescriptor(object):
"""A class used to describe a python function.
Attributes:
module_name: the module name that the function belongs to.
class_name: the class name that the function belongs to if exists.
It could be empty is the function is not a class method.
function_name: the function name of the function.
function_hash: the hash code of the function source code if the
function code is available.
function_id: the function id calculated from this descriptor.
is_for_driver_task: whether this descriptor is for driver task.
"""
def __init__(self,
module_name,
function_name,
class_name="",
function_source_hash=b""):
self._module_name = module_name
self._class_name = class_name
self._function_name = function_name
self._function_source_hash = function_source_hash
self._function_id = self._get_function_id()
def __repr__(self):
return ("FunctionDescriptor:" + self._module_name + "." +
self._class_name + "." + self._function_name + "." +
binary_to_hex(self._function_source_hash))
@classmethod
def from_bytes_list(cls, function_descriptor_list):
"""Create a FunctionDescriptor instance from list of bytes.
This function is used to create the function descriptor from
backend data.
Args:
cls: Current class which is required argument for classmethod.
function_descriptor_list: list of bytes to represent the
function descriptor.
Returns:
The FunctionDescriptor instance created from the bytes list.
"""
assert isinstance(function_descriptor_list, list)
if len(function_descriptor_list) == 0:
# This is a function descriptor of driver task.
return FunctionDescriptor.for_driver_task()
elif (len(function_descriptor_list) == 3
or len(function_descriptor_list) == 4):
module_name = function_descriptor_list[0].decode()
class_name = function_descriptor_list[1].decode()
function_name = function_descriptor_list[2].decode()
if len(function_descriptor_list) == 4:
return cls(module_name, function_name, class_name,
function_descriptor_list[3])
else:
return cls(module_name, function_name, class_name)
else:
raise Exception(
"Invalid input for FunctionDescriptor.from_bytes_list")
@classmethod
def from_function(cls, function):
"""Create a FunctionDescriptor from a function instance.
This function is used to create the function descriptor from
a python function. If a function is a class function, it should
not be used by this function.
Args:
cls: Current class which is required argument for classmethod.
function: the python function used to create the function
descriptor.
Returns:
The FunctionDescriptor instance created according to the function.
"""
module_name = function.__module__
function_name = function.__name__
class_name = ""
function_source_hasher = hashlib.sha1()
try:
# If we are running a script or are in IPython, include the source
# code in the hash.
source = inspect.getsource(function).encode("ascii")
function_source_hasher.update(source)
function_source_hash = function_source_hasher.digest()
except (IOError, OSError, TypeError):
# Source code may not be available:
# e.g. Cython or Python interpreter.
function_source_hash = b""
return cls(module_name, function_name, class_name,
function_source_hash)
@classmethod
def from_class(cls, target_class):
"""Create a FunctionDescriptor from a class.
Args:
cls: Current class which is required argument for classmethod.
target_class: the python class used to create the function
descriptor.
Returns:
The FunctionDescriptor instance created according to the class.
"""
module_name = target_class.__module__
class_name = target_class.__name__
return cls(module_name, "__init__", class_name)
@classmethod
def for_driver_task(cls):
"""Create a FunctionDescriptor instance for a driver task."""
return cls("", "", "", b"")
@property
def is_for_driver_task(self):
"""See whether this function descriptor is for a driver or not.
Returns:
True if this function descriptor is for driver tasks.
"""
return all(
len(x) == 0
for x in [self.module_name, self.class_name, self.function_name])
@property
def module_name(self):
"""Get the module name of current function descriptor.
Returns:
The module name of the function descriptor.
"""
return self._module_name
@property
def class_name(self):
"""Get the class name of current function descriptor.
Returns:
The class name of the function descriptor. It could be
empty if the function is not a class method.
"""
return self._class_name
@property
def function_name(self):
"""Get the function name of current function descriptor.
Returns:
The function name of the function descriptor.
"""
return self._function_name
@property
def function_hash(self):
"""Get the hash code of the function source code.
Returns:
The bytes with length of ray_constants.ID_SIZE if the source
code is available. Otherwise, the bytes length will be 0.
"""
return self._function_source_hash
@property
def function_id(self):
"""Get the function id calculated from this descriptor.
Returns:
The value of ray.ObjectID that represents the function id.
"""
return ray.ObjectID(self._function_id)
def _get_function_id(self):
"""Calculate the function id of current function descriptor.
This function id is calculated from all the fields of function
descriptor.
Returns:
bytes with length of ray_constants.ID_SIZE.
"""
if self.is_for_driver_task:
return ray_constants.NIL_FUNCTION_ID.id()
function_id_hash = hashlib.sha1()
# Include the function module and name in the hash.
function_id_hash.update(self.module_name.encode("ascii"))
function_id_hash.update(self.function_name.encode("ascii"))
function_id_hash.update(self.class_name.encode("ascii"))
function_id_hash.update(self._function_source_hash)
# Compute the function ID.
function_id = function_id_hash.digest()
assert len(function_id) == ray_constants.ID_SIZE
return function_id
def get_function_descriptor_list(self):
"""Return a list of bytes representing the function descriptor.
This function is used to pass this function descriptor to backend.
Returns:
A list of bytes.
"""
descriptor_list = []
if self.is_for_driver_task:
# Driver task returns an empty list.
return descriptor_list
else:
descriptor_list.append(self.module_name.encode("ascii"))
descriptor_list.append(self.class_name.encode("ascii"))
descriptor_list.append(self.function_name.encode("ascii"))
if len(self._function_source_hash) != 0:
descriptor_list.append(self._function_source_hash)
return descriptor_list
class FunctionActorManager(object):
"""A class used to export/load remote functions and actors.
@@ -45,6 +269,8 @@ class FunctionActorManager(object):
and execution_info.
_num_task_executions: The map from driver_id to function
execution times.
imported_actor_classes: The set of actor classes keys (format:
ActorClass:function_id) that are already in GCS.
"""
def __init__(self, worker):
@@ -58,11 +284,17 @@ class FunctionActorManager(object):
# workers that execute remote functions.
self._function_execution_info = defaultdict(lambda: {})
self._num_task_executions = defaultdict(lambda: {})
# A set of all of the actor class keys that have been imported by the
# import thread. It is safe to convert this worker into an actor of
# these types.
self.imported_actor_classes = set()
def increase_task_counter(self, driver_id, function_id):
def increase_task_counter(self, driver_id, function_descriptor):
function_id = function_descriptor.function_id.id()
self._num_task_executions[driver_id][function_id] += 1
def get_task_counter(self, driver_id, function_id):
def get_task_counter(self, driver_id, function_descriptor):
function_id = function_descriptor.function_id.id()
return self._num_task_executions[driver_id][function_id]
def export_cached(self):
@@ -124,13 +356,13 @@ class FunctionActorManager(object):
check_oversized_pickle(pickled_function,
remote_function._function_name,
"remote function", self._worker)
key = (b"RemoteFunction:" + self._worker.task_driver_id.id() + b":" +
remote_function._function_id)
remote_function._function_descriptor.function_id.id())
self._worker.redis_client.hmset(
key, {
"driver_id": self._worker.task_driver_id.id(),
"function_id": remote_function._function_id,
"function_id": remote_function._function_descriptor.
function_id.id(),
"name": remote_function._function_name,
"module": function.__module__,
"function": pickled_function,
@@ -193,24 +425,28 @@ class FunctionActorManager(object):
self._worker.redis_client.rpush(
b"FunctionTable:" + function_id.id(), self._worker.worker_id)
def get_execution_info(self, driver_id, function_id):
def get_execution_info(self, driver_id, function_descriptor):
"""Get the FunctionExecutionInfo of a remote function.
Args:
driver_id: ID of the driver that the function belongs to.
function_id: ID of the function to get.
function_descriptor: The FunctionDescriptor of the function to get.
Returns:
A FunctionExecutionInfo object.
"""
# Wait until the function to be executed has actually been registered
# on this worker. We will push warnings to the user if we spend too
# long in this loop.
with profiling.profile("wait_for_function", worker=self._worker):
self._wait_for_function(function_id, driver_id)
return self._function_execution_info[driver_id][function_id.id()]
function_id = function_descriptor.function_id.id()
def _wait_for_function(self, function_id, driver_id, timeout=10):
# Wait until the function to be executed has actually been
# registered on this worker. We will push warnings to the user if
# we spend too long in this loop.
# The driver function may not be found in sys.path. Try to load
# the function from GCS.
with profiling.profile("wait_for_function", worker=self._worker):
self._wait_for_function(function_descriptor, driver_id)
return self._function_execution_info[driver_id][function_id]
def _wait_for_function(self, function_descriptor, driver_id, timeout=10):
"""Wait until the function to be executed is present on this worker.
This method will simply loop until the import thread has imported the
@@ -221,7 +457,8 @@ class FunctionActorManager(object):
been defined.
Args:
function_id (str): The ID of the function that we want to execute.
function_descriptor : The FunctionDescriptor of the function that
we want to execute.
driver_id (str): The ID of the driver to push the error message to
if this times out.
"""
@@ -231,7 +468,7 @@ class FunctionActorManager(object):
while True:
with self._worker.lock:
if (self._worker.actor_id == ray.worker.NIL_ACTOR_ID
and (function_id.id() in
and (function_descriptor.function_id.id() in
self._function_execution_info[driver_id])):
break
elif self._worker.actor_id != ray.worker.NIL_ACTOR_ID and (
@@ -251,24 +488,6 @@ class FunctionActorManager(object):
warning_sent = True
time.sleep(0.001)
@classmethod
def compute_actor_method_function_id(cls, class_name, attr):
"""Get the function ID corresponding to an actor method.
Args:
class_name (str): The class name of the actor.
attr (str): The attribute name of the method.
Returns:
Function ID corresponding to the method.
"""
function_id_hash = hashlib.sha1()
function_id_hash.update(class_name.encode("ascii"))
function_id_hash.update(attr.encode("ascii"))
function_id = function_id_hash.digest()
assert len(function_id) == ray_constants.ID_SIZE
return ray.ObjectID(function_id)
def _publish_actor_class_to_key(self, key, actor_class_info):
"""Push an actor class definition to Redis.
@@ -287,9 +506,10 @@ class FunctionActorManager(object):
self._worker.redis_client.hmset(key, actor_class_info)
self._worker.redis_client.rpush("Exports", key)
def export_actor_class(self, class_id, Class, actor_method_names,
def export_actor_class(self, Class, actor_method_names,
checkpoint_interval):
key = b"ActorClass:" + class_id
function_descriptor = FunctionDescriptor.from_class(Class)
key = b"ActorClass:" + function_descriptor.function_id.id()
actor_class_info = {
"class_name": Class.__name__,
"module": Class.__module__,
@@ -318,6 +538,17 @@ class FunctionActorManager(object):
# within tasks. I tried to disable this, but it may be necessary
# because of https://github.com/ray-project/ray/issues/1146.
def load_actor(self, driver_id, function_descriptor):
key = b"ActorClass:" + function_descriptor.function_id.id()
# Wait for the actor class key to have been imported by the
# import thread. TODO(rkn): It shouldn't be possible to end
# up in an infinite loop here, but we should push an error to
# the driver if too much time is spent here.
while key not in self.imported_actor_classes:
time.sleep(0.001)
with self._worker.lock:
self.fetch_and_register_actor(key)
def fetch_and_register_actor(self, actor_class_key):
"""Import an actor.
@@ -330,11 +561,10 @@ class FunctionActorManager(object):
worker: The worker to use.
"""
actor_id_str = self._worker.actor_id
(driver_id, class_id, class_name, module, pickled_class,
checkpoint_interval,
(driver_id, class_name, module, pickled_class, checkpoint_interval,
actor_method_names) = self._worker.redis_client.hmget(
actor_class_key, [
"driver_id", "class_id", "class_name", "module", "class",
"driver_id", "class_name", "module", "class",
"checkpoint_interval", "actor_method_names"
])
@@ -368,9 +598,9 @@ class FunctionActorManager(object):
# Register the actor method executors.
for actor_method_name in actor_method_names:
function_id = (
FunctionActorManager.compute_actor_method_function_id(
class_name, actor_method_name).id())
function_descriptor = FunctionDescriptor(module, actor_method_name,
class_name)
function_id = function_descriptor.function_id.id()
temporary_executor = self._make_actor_method_executor(
actor_method_name,
temporary_actor_method,
@@ -409,9 +639,9 @@ class FunctionActorManager(object):
actor_methods = inspect.getmembers(
unpickled_class, predicate=is_function_or_method)
for actor_method_name, actor_method in actor_methods:
function_id = (
FunctionActorManager.compute_actor_method_function_id(
class_name, actor_method_name).id())
function_descriptor = FunctionDescriptor(
module, actor_method_name, class_name)
function_id = function_descriptor.function_id.id()
executor = self._make_actor_method_executor(
actor_method_name, actor_method, actor_imported=True)
self._function_execution_info[driver_id][function_id] = (
@@ -452,7 +682,9 @@ class FunctionActorManager(object):
# If this is the first task to execute on the actor, try to resume
# from a checkpoint.
if actor_imported and self._worker.actor_task_counter == 1:
# Current __init__ will be called by default. So the real function
# call will start from 2.
if actor_imported and self._worker.actor_task_counter == 2:
checkpoint_resumed = ray.actor.restore_and_log_checkpoint(
self._worker, actor)
if checkpoint_resumed: