mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 03:02:56 +08:00
Enable function_descriptor in backend to replace the function_id (#3028)
This commit is contained in:
committed by
Robert Nishihara
parent
3822b20319
commit
fb33fa9097
+278
-46
@@ -5,6 +5,7 @@ from __future__ import print_function
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
@@ -18,6 +19,7 @@ from ray import profiling
|
||||
from ray import ray_constants
|
||||
from ray import cloudpickle as pickle
|
||||
from ray.utils import (
|
||||
binary_to_hex,
|
||||
is_cython,
|
||||
is_function_or_method,
|
||||
is_class_method,
|
||||
@@ -31,6 +33,228 @@ FunctionExecutionInfo = namedtuple("FunctionExecutionInfo",
|
||||
["function", "function_name", "max_calls"])
|
||||
"""FunctionExecutionInfo: A named tuple storing remote function information."""
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FunctionDescriptor(object):
|
||||
"""A class used to describe a python function.
|
||||
|
||||
Attributes:
|
||||
module_name: the module name that the function belongs to.
|
||||
class_name: the class name that the function belongs to if exists.
|
||||
It could be empty is the function is not a class method.
|
||||
function_name: the function name of the function.
|
||||
function_hash: the hash code of the function source code if the
|
||||
function code is available.
|
||||
function_id: the function id calculated from this descriptor.
|
||||
is_for_driver_task: whether this descriptor is for driver task.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
module_name,
|
||||
function_name,
|
||||
class_name="",
|
||||
function_source_hash=b""):
|
||||
self._module_name = module_name
|
||||
self._class_name = class_name
|
||||
self._function_name = function_name
|
||||
self._function_source_hash = function_source_hash
|
||||
self._function_id = self._get_function_id()
|
||||
|
||||
def __repr__(self):
|
||||
return ("FunctionDescriptor:" + self._module_name + "." +
|
||||
self._class_name + "." + self._function_name + "." +
|
||||
binary_to_hex(self._function_source_hash))
|
||||
|
||||
@classmethod
|
||||
def from_bytes_list(cls, function_descriptor_list):
|
||||
"""Create a FunctionDescriptor instance from list of bytes.
|
||||
|
||||
This function is used to create the function descriptor from
|
||||
backend data.
|
||||
|
||||
Args:
|
||||
cls: Current class which is required argument for classmethod.
|
||||
function_descriptor_list: list of bytes to represent the
|
||||
function descriptor.
|
||||
|
||||
Returns:
|
||||
The FunctionDescriptor instance created from the bytes list.
|
||||
"""
|
||||
assert isinstance(function_descriptor_list, list)
|
||||
if len(function_descriptor_list) == 0:
|
||||
# This is a function descriptor of driver task.
|
||||
return FunctionDescriptor.for_driver_task()
|
||||
elif (len(function_descriptor_list) == 3
|
||||
or len(function_descriptor_list) == 4):
|
||||
module_name = function_descriptor_list[0].decode()
|
||||
class_name = function_descriptor_list[1].decode()
|
||||
function_name = function_descriptor_list[2].decode()
|
||||
if len(function_descriptor_list) == 4:
|
||||
return cls(module_name, function_name, class_name,
|
||||
function_descriptor_list[3])
|
||||
else:
|
||||
return cls(module_name, function_name, class_name)
|
||||
else:
|
||||
raise Exception(
|
||||
"Invalid input for FunctionDescriptor.from_bytes_list")
|
||||
|
||||
@classmethod
|
||||
def from_function(cls, function):
|
||||
"""Create a FunctionDescriptor from a function instance.
|
||||
|
||||
This function is used to create the function descriptor from
|
||||
a python function. If a function is a class function, it should
|
||||
not be used by this function.
|
||||
|
||||
Args:
|
||||
cls: Current class which is required argument for classmethod.
|
||||
function: the python function used to create the function
|
||||
descriptor.
|
||||
|
||||
Returns:
|
||||
The FunctionDescriptor instance created according to the function.
|
||||
"""
|
||||
module_name = function.__module__
|
||||
function_name = function.__name__
|
||||
class_name = ""
|
||||
|
||||
function_source_hasher = hashlib.sha1()
|
||||
try:
|
||||
# If we are running a script or are in IPython, include the source
|
||||
# code in the hash.
|
||||
source = inspect.getsource(function).encode("ascii")
|
||||
function_source_hasher.update(source)
|
||||
function_source_hash = function_source_hasher.digest()
|
||||
except (IOError, OSError, TypeError):
|
||||
# Source code may not be available:
|
||||
# e.g. Cython or Python interpreter.
|
||||
function_source_hash = b""
|
||||
|
||||
return cls(module_name, function_name, class_name,
|
||||
function_source_hash)
|
||||
|
||||
@classmethod
|
||||
def from_class(cls, target_class):
|
||||
"""Create a FunctionDescriptor from a class.
|
||||
|
||||
Args:
|
||||
cls: Current class which is required argument for classmethod.
|
||||
target_class: the python class used to create the function
|
||||
descriptor.
|
||||
|
||||
Returns:
|
||||
The FunctionDescriptor instance created according to the class.
|
||||
"""
|
||||
module_name = target_class.__module__
|
||||
class_name = target_class.__name__
|
||||
return cls(module_name, "__init__", class_name)
|
||||
|
||||
@classmethod
|
||||
def for_driver_task(cls):
|
||||
"""Create a FunctionDescriptor instance for a driver task."""
|
||||
return cls("", "", "", b"")
|
||||
|
||||
@property
|
||||
def is_for_driver_task(self):
|
||||
"""See whether this function descriptor is for a driver or not.
|
||||
|
||||
Returns:
|
||||
True if this function descriptor is for driver tasks.
|
||||
"""
|
||||
return all(
|
||||
len(x) == 0
|
||||
for x in [self.module_name, self.class_name, self.function_name])
|
||||
|
||||
@property
|
||||
def module_name(self):
|
||||
"""Get the module name of current function descriptor.
|
||||
|
||||
Returns:
|
||||
The module name of the function descriptor.
|
||||
"""
|
||||
return self._module_name
|
||||
|
||||
@property
|
||||
def class_name(self):
|
||||
"""Get the class name of current function descriptor.
|
||||
|
||||
Returns:
|
||||
The class name of the function descriptor. It could be
|
||||
empty if the function is not a class method.
|
||||
"""
|
||||
return self._class_name
|
||||
|
||||
@property
|
||||
def function_name(self):
|
||||
"""Get the function name of current function descriptor.
|
||||
|
||||
Returns:
|
||||
The function name of the function descriptor.
|
||||
"""
|
||||
return self._function_name
|
||||
|
||||
@property
|
||||
def function_hash(self):
|
||||
"""Get the hash code of the function source code.
|
||||
|
||||
Returns:
|
||||
The bytes with length of ray_constants.ID_SIZE if the source
|
||||
code is available. Otherwise, the bytes length will be 0.
|
||||
"""
|
||||
return self._function_source_hash
|
||||
|
||||
@property
|
||||
def function_id(self):
|
||||
"""Get the function id calculated from this descriptor.
|
||||
|
||||
Returns:
|
||||
The value of ray.ObjectID that represents the function id.
|
||||
"""
|
||||
return ray.ObjectID(self._function_id)
|
||||
|
||||
def _get_function_id(self):
|
||||
"""Calculate the function id of current function descriptor.
|
||||
|
||||
This function id is calculated from all the fields of function
|
||||
descriptor.
|
||||
|
||||
Returns:
|
||||
bytes with length of ray_constants.ID_SIZE.
|
||||
"""
|
||||
if self.is_for_driver_task:
|
||||
return ray_constants.NIL_FUNCTION_ID.id()
|
||||
function_id_hash = hashlib.sha1()
|
||||
# Include the function module and name in the hash.
|
||||
function_id_hash.update(self.module_name.encode("ascii"))
|
||||
function_id_hash.update(self.function_name.encode("ascii"))
|
||||
function_id_hash.update(self.class_name.encode("ascii"))
|
||||
function_id_hash.update(self._function_source_hash)
|
||||
# Compute the function ID.
|
||||
function_id = function_id_hash.digest()
|
||||
assert len(function_id) == ray_constants.ID_SIZE
|
||||
return function_id
|
||||
|
||||
def get_function_descriptor_list(self):
|
||||
"""Return a list of bytes representing the function descriptor.
|
||||
|
||||
This function is used to pass this function descriptor to backend.
|
||||
|
||||
Returns:
|
||||
A list of bytes.
|
||||
"""
|
||||
descriptor_list = []
|
||||
if self.is_for_driver_task:
|
||||
# Driver task returns an empty list.
|
||||
return descriptor_list
|
||||
else:
|
||||
descriptor_list.append(self.module_name.encode("ascii"))
|
||||
descriptor_list.append(self.class_name.encode("ascii"))
|
||||
descriptor_list.append(self.function_name.encode("ascii"))
|
||||
if len(self._function_source_hash) != 0:
|
||||
descriptor_list.append(self._function_source_hash)
|
||||
return descriptor_list
|
||||
|
||||
|
||||
class FunctionActorManager(object):
|
||||
"""A class used to export/load remote functions and actors.
|
||||
@@ -45,6 +269,8 @@ class FunctionActorManager(object):
|
||||
and execution_info.
|
||||
_num_task_executions: The map from driver_id to function
|
||||
execution times.
|
||||
imported_actor_classes: The set of actor classes keys (format:
|
||||
ActorClass:function_id) that are already in GCS.
|
||||
"""
|
||||
|
||||
def __init__(self, worker):
|
||||
@@ -58,11 +284,17 @@ class FunctionActorManager(object):
|
||||
# workers that execute remote functions.
|
||||
self._function_execution_info = defaultdict(lambda: {})
|
||||
self._num_task_executions = defaultdict(lambda: {})
|
||||
# A set of all of the actor class keys that have been imported by the
|
||||
# import thread. It is safe to convert this worker into an actor of
|
||||
# these types.
|
||||
self.imported_actor_classes = set()
|
||||
|
||||
def increase_task_counter(self, driver_id, function_id):
|
||||
def increase_task_counter(self, driver_id, function_descriptor):
|
||||
function_id = function_descriptor.function_id.id()
|
||||
self._num_task_executions[driver_id][function_id] += 1
|
||||
|
||||
def get_task_counter(self, driver_id, function_id):
|
||||
def get_task_counter(self, driver_id, function_descriptor):
|
||||
function_id = function_descriptor.function_id.id()
|
||||
return self._num_task_executions[driver_id][function_id]
|
||||
|
||||
def export_cached(self):
|
||||
@@ -124,13 +356,13 @@ class FunctionActorManager(object):
|
||||
check_oversized_pickle(pickled_function,
|
||||
remote_function._function_name,
|
||||
"remote function", self._worker)
|
||||
|
||||
key = (b"RemoteFunction:" + self._worker.task_driver_id.id() + b":" +
|
||||
remote_function._function_id)
|
||||
remote_function._function_descriptor.function_id.id())
|
||||
self._worker.redis_client.hmset(
|
||||
key, {
|
||||
"driver_id": self._worker.task_driver_id.id(),
|
||||
"function_id": remote_function._function_id,
|
||||
"function_id": remote_function._function_descriptor.
|
||||
function_id.id(),
|
||||
"name": remote_function._function_name,
|
||||
"module": function.__module__,
|
||||
"function": pickled_function,
|
||||
@@ -193,24 +425,28 @@ class FunctionActorManager(object):
|
||||
self._worker.redis_client.rpush(
|
||||
b"FunctionTable:" + function_id.id(), self._worker.worker_id)
|
||||
|
||||
def get_execution_info(self, driver_id, function_id):
|
||||
def get_execution_info(self, driver_id, function_descriptor):
|
||||
"""Get the FunctionExecutionInfo of a remote function.
|
||||
|
||||
Args:
|
||||
driver_id: ID of the driver that the function belongs to.
|
||||
function_id: ID of the function to get.
|
||||
function_descriptor: The FunctionDescriptor of the function to get.
|
||||
|
||||
Returns:
|
||||
A FunctionExecutionInfo object.
|
||||
"""
|
||||
# Wait until the function to be executed has actually been registered
|
||||
# on this worker. We will push warnings to the user if we spend too
|
||||
# long in this loop.
|
||||
with profiling.profile("wait_for_function", worker=self._worker):
|
||||
self._wait_for_function(function_id, driver_id)
|
||||
return self._function_execution_info[driver_id][function_id.id()]
|
||||
function_id = function_descriptor.function_id.id()
|
||||
|
||||
def _wait_for_function(self, function_id, driver_id, timeout=10):
|
||||
# Wait until the function to be executed has actually been
|
||||
# registered on this worker. We will push warnings to the user if
|
||||
# we spend too long in this loop.
|
||||
# The driver function may not be found in sys.path. Try to load
|
||||
# the function from GCS.
|
||||
with profiling.profile("wait_for_function", worker=self._worker):
|
||||
self._wait_for_function(function_descriptor, driver_id)
|
||||
return self._function_execution_info[driver_id][function_id]
|
||||
|
||||
def _wait_for_function(self, function_descriptor, driver_id, timeout=10):
|
||||
"""Wait until the function to be executed is present on this worker.
|
||||
|
||||
This method will simply loop until the import thread has imported the
|
||||
@@ -221,7 +457,8 @@ class FunctionActorManager(object):
|
||||
been defined.
|
||||
|
||||
Args:
|
||||
function_id (str): The ID of the function that we want to execute.
|
||||
function_descriptor : The FunctionDescriptor of the function that
|
||||
we want to execute.
|
||||
driver_id (str): The ID of the driver to push the error message to
|
||||
if this times out.
|
||||
"""
|
||||
@@ -231,7 +468,7 @@ class FunctionActorManager(object):
|
||||
while True:
|
||||
with self._worker.lock:
|
||||
if (self._worker.actor_id == ray.worker.NIL_ACTOR_ID
|
||||
and (function_id.id() in
|
||||
and (function_descriptor.function_id.id() in
|
||||
self._function_execution_info[driver_id])):
|
||||
break
|
||||
elif self._worker.actor_id != ray.worker.NIL_ACTOR_ID and (
|
||||
@@ -251,24 +488,6 @@ class FunctionActorManager(object):
|
||||
warning_sent = True
|
||||
time.sleep(0.001)
|
||||
|
||||
@classmethod
|
||||
def compute_actor_method_function_id(cls, class_name, attr):
|
||||
"""Get the function ID corresponding to an actor method.
|
||||
|
||||
Args:
|
||||
class_name (str): The class name of the actor.
|
||||
attr (str): The attribute name of the method.
|
||||
|
||||
Returns:
|
||||
Function ID corresponding to the method.
|
||||
"""
|
||||
function_id_hash = hashlib.sha1()
|
||||
function_id_hash.update(class_name.encode("ascii"))
|
||||
function_id_hash.update(attr.encode("ascii"))
|
||||
function_id = function_id_hash.digest()
|
||||
assert len(function_id) == ray_constants.ID_SIZE
|
||||
return ray.ObjectID(function_id)
|
||||
|
||||
def _publish_actor_class_to_key(self, key, actor_class_info):
|
||||
"""Push an actor class definition to Redis.
|
||||
|
||||
@@ -287,9 +506,10 @@ class FunctionActorManager(object):
|
||||
self._worker.redis_client.hmset(key, actor_class_info)
|
||||
self._worker.redis_client.rpush("Exports", key)
|
||||
|
||||
def export_actor_class(self, class_id, Class, actor_method_names,
|
||||
def export_actor_class(self, Class, actor_method_names,
|
||||
checkpoint_interval):
|
||||
key = b"ActorClass:" + class_id
|
||||
function_descriptor = FunctionDescriptor.from_class(Class)
|
||||
key = b"ActorClass:" + function_descriptor.function_id.id()
|
||||
actor_class_info = {
|
||||
"class_name": Class.__name__,
|
||||
"module": Class.__module__,
|
||||
@@ -318,6 +538,17 @@ class FunctionActorManager(object):
|
||||
# within tasks. I tried to disable this, but it may be necessary
|
||||
# because of https://github.com/ray-project/ray/issues/1146.
|
||||
|
||||
def load_actor(self, driver_id, function_descriptor):
|
||||
key = b"ActorClass:" + function_descriptor.function_id.id()
|
||||
# Wait for the actor class key to have been imported by the
|
||||
# import thread. TODO(rkn): It shouldn't be possible to end
|
||||
# up in an infinite loop here, but we should push an error to
|
||||
# the driver if too much time is spent here.
|
||||
while key not in self.imported_actor_classes:
|
||||
time.sleep(0.001)
|
||||
with self._worker.lock:
|
||||
self.fetch_and_register_actor(key)
|
||||
|
||||
def fetch_and_register_actor(self, actor_class_key):
|
||||
"""Import an actor.
|
||||
|
||||
@@ -330,11 +561,10 @@ class FunctionActorManager(object):
|
||||
worker: The worker to use.
|
||||
"""
|
||||
actor_id_str = self._worker.actor_id
|
||||
(driver_id, class_id, class_name, module, pickled_class,
|
||||
checkpoint_interval,
|
||||
(driver_id, class_name, module, pickled_class, checkpoint_interval,
|
||||
actor_method_names) = self._worker.redis_client.hmget(
|
||||
actor_class_key, [
|
||||
"driver_id", "class_id", "class_name", "module", "class",
|
||||
"driver_id", "class_name", "module", "class",
|
||||
"checkpoint_interval", "actor_method_names"
|
||||
])
|
||||
|
||||
@@ -368,9 +598,9 @@ class FunctionActorManager(object):
|
||||
|
||||
# Register the actor method executors.
|
||||
for actor_method_name in actor_method_names:
|
||||
function_id = (
|
||||
FunctionActorManager.compute_actor_method_function_id(
|
||||
class_name, actor_method_name).id())
|
||||
function_descriptor = FunctionDescriptor(module, actor_method_name,
|
||||
class_name)
|
||||
function_id = function_descriptor.function_id.id()
|
||||
temporary_executor = self._make_actor_method_executor(
|
||||
actor_method_name,
|
||||
temporary_actor_method,
|
||||
@@ -409,9 +639,9 @@ class FunctionActorManager(object):
|
||||
actor_methods = inspect.getmembers(
|
||||
unpickled_class, predicate=is_function_or_method)
|
||||
for actor_method_name, actor_method in actor_methods:
|
||||
function_id = (
|
||||
FunctionActorManager.compute_actor_method_function_id(
|
||||
class_name, actor_method_name).id())
|
||||
function_descriptor = FunctionDescriptor(
|
||||
module, actor_method_name, class_name)
|
||||
function_id = function_descriptor.function_id.id()
|
||||
executor = self._make_actor_method_executor(
|
||||
actor_method_name, actor_method, actor_imported=True)
|
||||
self._function_execution_info[driver_id][function_id] = (
|
||||
@@ -452,7 +682,9 @@ class FunctionActorManager(object):
|
||||
|
||||
# If this is the first task to execute on the actor, try to resume
|
||||
# from a checkpoint.
|
||||
if actor_imported and self._worker.actor_task_counter == 1:
|
||||
# Current __init__ will be called by default. So the real function
|
||||
# call will start from 2.
|
||||
if actor_imported and self._worker.actor_task_counter == 2:
|
||||
checkpoint_resumed = ray.actor.restore_and_log_checkpoint(
|
||||
self._worker, actor)
|
||||
if checkpoint_resumed:
|
||||
|
||||
Reference in New Issue
Block a user