Fix bug in which remote function redefinition doesn't happen. (#6175)

This commit is contained in:
Robert Nishihara
2019-11-26 09:19:19 -08:00
committed by Edward Oakes
parent 7f8de61441
commit ffb9c0ecae
6 changed files with 256 additions and 47 deletions
+47 -18
View File
@@ -2,6 +2,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import dis
import hashlib
import importlib
import inspect
@@ -102,7 +103,7 @@ class FunctionDescriptor(object):
"Invalid input for FunctionDescriptor.from_bytes_list")
@classmethod
def from_function(cls, function):
def from_function(cls, function, pickled_function):
"""Create a FunctionDescriptor from a function instance.
This function is used to create the function descriptor from
@@ -113,6 +114,9 @@ class FunctionDescriptor(object):
cls: Current class which is required argument for classmethod.
function: the python function used to create the function
descriptor.
pickled_function: This is factored in to ensure that any
modifications to the function result in a different function
descriptor.
Returns:
The FunctionDescriptor instance created according to the function.
@@ -121,22 +125,10 @@ class FunctionDescriptor(object):
function_name = function.__name__
class_name = ""
function_source_hasher = hashlib.sha1()
try:
# If we are running a script or are in IPython, include the source
# code in the hash.
source = inspect.getsource(function)
if sys.version_info[0] >= 3:
source = source.encode()
function_source_hasher.update(source)
function_source_hash = function_source_hasher.digest()
except (IOError, OSError, TypeError):
# Source code may not be available:
# e.g. Cython or Python interpreter.
function_source_hash = b""
pickled_function_hash = hashlib.sha1(pickled_function).digest()
return cls(module_name, function_name, class_name,
function_source_hash)
pickled_function_hash)
@classmethod
def from_class(cls, target_class):
@@ -315,6 +307,40 @@ class FunctionActorManager(object):
job_id = ray.JobID.nil()
return self._num_task_executions[job_id][function_id]
def compute_collision_identifier(self, function_or_class):
"""The identifier is used to detect excessive duplicate exports.
The identifier is used to determine when the same function or class is
exported many times. This can yield false positives.
Args:
function_or_class: The function or class to compute an identifier
for.
Returns:
The identifier. Note that different functions or classes can give
rise to same identifier. However, the same function should
hopefully always give rise to the same identifier. TODO(rkn):
verify if this is actually the case. Note that if the
identifier is incorrect in any way, then we may give warnings
unnecessarily or fail to give warnings, but the application's
behavior won't change.
"""
if sys.version_info[0] >= 3:
import io
string_file = io.StringIO()
if sys.version_info[1] >= 7:
dis.dis(function_or_class, file=string_file, depth=2)
else:
dis.dis(function_or_class, file=string_file)
collision_identifier = (
function_or_class.__name__ + ":" + string_file.getvalue())
else:
collision_identifier = function_or_class.__name__
# Return a hash of the identifier in case it is too large.
return hashlib.sha1(collision_identifier.encode("ascii")).digest()
def export(self, remote_function):
"""Pickle a remote function and export it to redis.
@@ -339,9 +365,11 @@ class FunctionActorManager(object):
"job_id": self._worker.current_job_id.binary(),
"function_id": remote_function._function_descriptor.
function_id.binary(),
"name": remote_function._function_name,
"function_name": remote_function._function_name,
"module": function.__module__,
"function": pickled_function,
"collision_identifier": self.compute_collision_identifier(
function),
"max_calls": remote_function._max_calls
})
self._worker.redis_client.rpush("Exports", key)
@@ -351,8 +379,8 @@ class FunctionActorManager(object):
(job_id_str, function_id_str, function_name, serialized_function,
num_return_vals, module, resources,
max_calls) = self._worker.redis_client.hmget(key, [
"job_id", "function_id", "name", "function", "num_return_vals",
"module", "resources", "max_calls"
"job_id", "function_id", "function_name", "function",
"num_return_vals", "module", "resources", "max_calls"
])
function_id = ray.FunctionID(function_id_str)
job_id = ray.JobID(job_id_str)
@@ -549,6 +577,7 @@ class FunctionActorManager(object):
"module": Class.__module__,
"class": pickle.dumps(Class),
"job_id": job_id.binary(),
"collision_identifier": self.compute_collision_identifier(Class),
"actor_method_names": json.dumps(list(actor_method_names))
}