mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 10:49:16 +08:00
Enable function_descriptor in backend to replace the function_id (#3028)
This commit is contained in:
committed by
Robert Nishihara
parent
3822b20319
commit
fb33fa9097
@@ -3,11 +3,9 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import hashlib
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
import ray.ray_constants as ray_constants
|
||||
from ray.function_manager import FunctionDescriptor
|
||||
import ray.signature
|
||||
|
||||
# Default parameters for remote functions.
|
||||
@@ -18,33 +16,6 @@ DEFAULT_REMOTE_FUNCTION_MAX_CALLS = 0
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def compute_function_id(function):
|
||||
"""Compute an function ID for a function.
|
||||
|
||||
Args:
|
||||
func: The actual function.
|
||||
|
||||
Returns:
|
||||
Raw bytes of the function id
|
||||
"""
|
||||
function_id_hash = hashlib.sha1()
|
||||
# Include the function module and name in the hash.
|
||||
function_id_hash.update(function.__module__.encode("ascii"))
|
||||
function_id_hash.update(function.__name__.encode("ascii"))
|
||||
try:
|
||||
# If we are running a script or are in IPython, include the source code
|
||||
# in the hash.
|
||||
source = inspect.getsource(function).encode("ascii")
|
||||
function_id_hash.update(source)
|
||||
except (IOError, OSError, TypeError):
|
||||
# Source code may not be available: e.g. Cython or Python interpreter.
|
||||
pass
|
||||
# Compute the function ID.
|
||||
function_id = function_id_hash.digest()
|
||||
assert len(function_id) == ray_constants.ID_SIZE
|
||||
return function_id
|
||||
|
||||
|
||||
class RemoteFunction(object):
|
||||
"""A remote function.
|
||||
|
||||
@@ -52,7 +23,7 @@ class RemoteFunction(object):
|
||||
|
||||
Attributes:
|
||||
_function: The original function.
|
||||
_function_id: The ID of the function.
|
||||
_function_descriptor: The function descriptor.
|
||||
_function_name: The module and function name.
|
||||
_num_cpus: The default number of CPUs to use for invocations of this
|
||||
remote function.
|
||||
@@ -70,10 +41,7 @@ class RemoteFunction(object):
|
||||
def __init__(self, function, num_cpus, num_gpus, resources,
|
||||
num_return_vals, max_calls):
|
||||
self._function = function
|
||||
# TODO(rkn): We store the function ID as a string, so that
|
||||
# RemoteFunction objects can be pickled. We should undo this when
|
||||
# we allow ObjectIDs to be pickled.
|
||||
self._function_id = compute_function_id(function)
|
||||
self._function_descriptor = FunctionDescriptor.from_function(function)
|
||||
self._function_name = (
|
||||
self._function.__module__ + '.' + self._function.__name__)
|
||||
self._num_cpus = (DEFAULT_REMOTE_FUNCTION_CPUS
|
||||
@@ -147,7 +115,7 @@ class RemoteFunction(object):
|
||||
result = self._function(*copy.deepcopy(args))
|
||||
return result
|
||||
object_ids = worker.submit_task(
|
||||
ray.ObjectID(self._function_id),
|
||||
self._function_descriptor,
|
||||
args,
|
||||
num_return_vals=num_return_vals,
|
||||
resources=resources)
|
||||
|
||||
Reference in New Issue
Block a user