mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 22:38:16 +08:00
Simplify imports and exports and provide driver isolation for remote functions. (#288)
* Remove import counter and export counter. * Provide isolation between drivers for remote functions. * Add test for driver function isolation. * Hash source code into function ID to reduce likelihood of collisions. * Fix failure test example. * Replace assertTrue with assertIn to improve failure messages in tests. * Fix failure test.
This commit is contained in:
committed by
Philipp Moritz
parent
883f945db4
commit
88a5b4e77b
+21
-23
@@ -27,18 +27,19 @@ def get_actor_method_function_id(attr):
|
||||
Returns:
|
||||
Function ID corresponding to the method.
|
||||
"""
|
||||
function_id = hashlib.sha1()
|
||||
function_id.update(attr.encode("ascii"))
|
||||
return photon.ObjectID(function_id.digest())
|
||||
function_id_hash = hashlib.sha1()
|
||||
function_id_hash.update(attr.encode("ascii"))
|
||||
function_id = function_id_hash.digest()
|
||||
assert len(function_id) == 20
|
||||
return photon.ObjectID(function_id)
|
||||
|
||||
def fetch_and_register_actor(key, worker):
|
||||
"""Import an actor."""
|
||||
driver_id, actor_id_str, actor_name, module, pickled_class, class_export_counter = \
|
||||
worker.redis_client.hmget(key, ["driver_id", "actor_id", "name", "module", "class", "class_export_counter"])
|
||||
driver_id, actor_id_str, actor_name, module, pickled_class = \
|
||||
worker.redis_client.hmget(key, ["driver_id", "actor_id", "name", "module", "class"])
|
||||
actor_id = photon.ObjectID(actor_id_str)
|
||||
actor_name = actor_name.decode("ascii")
|
||||
module = module.decode("ascii")
|
||||
class_export_counter = int(class_export_counter)
|
||||
try:
|
||||
unpickled_class = pickling.loads(pickled_class)
|
||||
except:
|
||||
@@ -49,16 +50,17 @@ def fetch_and_register_actor(key, worker):
|
||||
worker.actors[actor_id_str] = unpickled_class.__new__(unpickled_class)
|
||||
for (k, v) in inspect.getmembers(unpickled_class, predicate=(lambda x: inspect.isfunction(x) or inspect.ismethod(x))):
|
||||
function_id = get_actor_method_function_id(k).id()
|
||||
worker.function_names[function_id] = k
|
||||
worker.functions[function_id] = v
|
||||
worker.functions[driver_id][function_id] = (k, v)
|
||||
# We do not set worker.function_properties[driver_id][function_id] because
|
||||
# we currently do need the actor worker to submit new tasks for the actor.
|
||||
|
||||
def export_actor(actor_id, Class, worker):
|
||||
def export_actor(actor_id, Class, actor_method_names, worker):
|
||||
"""Export an actor to redis.
|
||||
|
||||
Args:
|
||||
actor_id: The ID of the actor.
|
||||
Class: Name of the class to be exported as an actor.
|
||||
worker: The worker class
|
||||
actor_method_names (list): A list of the names of this actor's methods.
|
||||
"""
|
||||
ray.worker.check_main_thread()
|
||||
if worker.mode is None:
|
||||
@@ -66,28 +68,25 @@ def export_actor(actor_id, Class, worker):
|
||||
key = "Actor:{}".format(actor_id.id())
|
||||
pickled_class = pickling.dumps(Class)
|
||||
|
||||
# For now, all actor methods have 1 return value and require 0 CPUs and GPUs.
|
||||
driver_id = worker.task_driver_id.id()
|
||||
for actor_method_name in actor_method_names:
|
||||
function_id = get_actor_method_function_id(actor_method_name).id()
|
||||
worker.function_properties[driver_id][function_id] = (1, 0, 0)
|
||||
|
||||
# Select a local scheduler for the actor.
|
||||
local_schedulers = state.get_local_schedulers()
|
||||
local_scheduler_id = random.choice(local_schedulers)
|
||||
|
||||
worker.redis_client.publish("actor_notifications", actor_id.id() + local_scheduler_id)
|
||||
|
||||
# The export counter is computed differently depending on whether we are
|
||||
# currently in a driver or a worker.
|
||||
if worker.mode in [ray.SCRIPT_MODE, ray.SILENT_MODE]:
|
||||
export_counter = worker.driver_export_counter
|
||||
elif worker.mode == ray.WORKER_MODE:
|
||||
# We don't actually need export counters for actors.
|
||||
export_counter = 0
|
||||
d = {"driver_id": worker.task_driver_id.id(),
|
||||
d = {"driver_id": driver_id,
|
||||
"actor_id": actor_id.id(),
|
||||
"name": Class.__name__,
|
||||
"module": Class.__module__,
|
||||
"class": pickled_class,
|
||||
"class_export_counter": export_counter}
|
||||
"class": pickled_class}
|
||||
worker.redis_client.hmset(key, d)
|
||||
worker.redis_client.rpush("Exports", key)
|
||||
worker.driver_export_counter += 1
|
||||
|
||||
def actor(Class):
|
||||
# The function actor_method_call gets called if somebody tries to call a
|
||||
@@ -105,7 +104,6 @@ def actor(Class):
|
||||
num_cpus = 0
|
||||
num_gpus = 0
|
||||
object_ids = ray.worker.global_worker.submit_task(function_id, "", args,
|
||||
num_cpus, num_gpus,
|
||||
actor_id=actor_id)
|
||||
if len(object_ids) == 1:
|
||||
return object_ids[0]
|
||||
@@ -116,7 +114,7 @@ def actor(Class):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._ray_actor_id = random_actor_id()
|
||||
self._ray_actor_methods = {k: v for (k, v) in inspect.getmembers(Class, predicate=(lambda x: inspect.isfunction(x) or inspect.ismethod(x)))}
|
||||
export_actor(self._ray_actor_id, Class, ray.worker.global_worker)
|
||||
export_actor(self._ray_actor_id, Class, self._ray_actor_methods, ray.worker.global_worker)
|
||||
# Call __init__ as a remote function.
|
||||
if "__init__" in self._ray_actor_methods.keys():
|
||||
actor_method_call(self._ray_actor_id, "__init__", *args, **kwargs)
|
||||
|
||||
+79
-127
@@ -2,22 +2,23 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
import hashlib
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import copy
|
||||
import collections
|
||||
import funcsigs
|
||||
import numpy as np
|
||||
import colorama
|
||||
import atexit
|
||||
import collections
|
||||
import colorama
|
||||
import copy
|
||||
import funcsigs
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
import numpy as np
|
||||
import os
|
||||
import random
|
||||
import redis
|
||||
import threading
|
||||
import string
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
|
||||
# Ray modules
|
||||
import ray.pickling as pickling
|
||||
@@ -373,61 +374,27 @@ class Worker(object):
|
||||
that connect has been called already.
|
||||
cached_functions_to_run (List): A list of functions to run on all of the
|
||||
workers that should be exported as soon as connect is called.
|
||||
driver_export_counter (int): The number of exports that the driver has
|
||||
exported. This is only used on the driver.
|
||||
worker_import_counter (int): The number of exports that the worker has
|
||||
imported so far. This is only used on the workers.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize a Worker object."""
|
||||
self.functions = {}
|
||||
# Use a defaultdict for the number of return values. If this is accessed
|
||||
# with a missing key, the default value of 1 is returned, and that key value
|
||||
# pair is added to the dict.
|
||||
self.num_return_vals = collections.defaultdict(lambda: 1)
|
||||
self.function_names = {}
|
||||
self.function_export_counters = {}
|
||||
# The functions field is a dictionary that maps a driver ID to a dictionary
|
||||
# of functions that have been registered for that driver (this inner
|
||||
# dictionary maps function IDs to a tuple of the function name and the
|
||||
# function itself). This should only be used on workers that execute remote
|
||||
# functions.
|
||||
self.functions = collections.defaultdict(lambda: {})
|
||||
# The function_properties field is a dictionary that maps a driver ID to a
|
||||
# dictionary of functions that have been registered for that driver (this
|
||||
# inner dictionary maps function IDs to a tuple of the number of values
|
||||
# returned by that function, the number of CPUs required by that function,
|
||||
# and the number of GPUs required by that function). This is used when
|
||||
# submitting a function (which can be done both on workers and on drivers).
|
||||
self.function_properties = collections.defaultdict(lambda: {})
|
||||
self.connected = False
|
||||
self.mode = None
|
||||
self.cached_remote_functions = []
|
||||
self.cached_functions_to_run = []
|
||||
# The driver_export_counter and worker_import_counter are used to make sure
|
||||
# that no task executes before everything it needs is present. For example,
|
||||
# if we define a remote function f, a worker cannot execute a task for f
|
||||
# until the worker has imported the function f.
|
||||
# - When a remote function, a reusable variable, or a function to run is
|
||||
# exported, the driver_export_counter is incremented. These exports must
|
||||
# take place from the driver.
|
||||
# - When an actor is created, the driver_export_counter is NOT
|
||||
# incremented. Note that an actor can be created from a driver or from
|
||||
# any worker.
|
||||
# - When a worker imports a remote function, a reusable variable, or a
|
||||
# function to run, its worker_import_counter is incremented.
|
||||
# - Notably, when an actor is imported, its worker_import_counter is NOT
|
||||
# incremented.
|
||||
# - Whenever a remote function is DEFINED on the driver, it records the
|
||||
# value of the driver_export_counter and a worker will not execute that
|
||||
# remote function until it has imported that many exports (excluding
|
||||
# actors).
|
||||
# - When an actor is defined.
|
||||
# a) If the actor is created on a driver, it records the
|
||||
# driver_export_counter.
|
||||
# b) If the actor is created inside a task on a regular worker, it
|
||||
# records the driver_export_counter associated with the function in
|
||||
# task creating the actor.
|
||||
# c) If the actor is created inside a task on an actor worker, it
|
||||
# records
|
||||
# The worker that ultimately runs the actor will not execute any tasks
|
||||
# until it has imported that many imports.
|
||||
#
|
||||
# TODO(rkn): These counters must be tracked separately for each driver.
|
||||
# TODO(rkn): Maybe none of these counters are necessary? When executing a
|
||||
# regular task, workers can just wait until the function ID is present. When
|
||||
# executing an actor task, the actor worker can just wait until the actor
|
||||
# has been defined.
|
||||
self.driver_export_counter = 0
|
||||
self.worker_import_counter = 0
|
||||
self.fetch_and_register = {}
|
||||
self.actors = {}
|
||||
# Use a defaultdict for the actor counts. If this is accessed with a missing
|
||||
@@ -526,7 +493,7 @@ class Worker(object):
|
||||
assert final_results[i][0] == object_ids[i].id()
|
||||
return [result[1][0] for result in final_results]
|
||||
|
||||
def submit_task(self, function_id, func_name, args, num_cpus, num_gpus, actor_id=photon.ObjectID(NIL_ACTOR_ID)):
|
||||
def submit_task(self, function_id, func_name, args, actor_id=photon.ObjectID(NIL_ACTOR_ID)):
|
||||
"""Submit a remote task to the scheduler.
|
||||
|
||||
Tell the scheduler to schedule the execution of the function with name
|
||||
@@ -538,8 +505,6 @@ class Worker(object):
|
||||
args (List[Any]): The arguments to pass into the function. Arguments can
|
||||
be object IDs or they can be values. If they are values, they
|
||||
must be serializable objecs.
|
||||
num_cpus (int): The number of cpu cores this task requires to run.
|
||||
num_gpus (int): The number of gpus this task requires to run.
|
||||
"""
|
||||
with log_span("ray:submit_task", worker=self):
|
||||
check_main_thread()
|
||||
@@ -554,11 +519,14 @@ class Worker(object):
|
||||
else:
|
||||
args_for_photon.append(put(arg))
|
||||
|
||||
# Look up the various function properties.
|
||||
num_return_vals, num_cpus, num_gpus = self.function_properties[self.task_driver_id.id()][function_id.id()]
|
||||
|
||||
# Submit the task to Photon.
|
||||
task = photon.Task(self.task_driver_id,
|
||||
photon.ObjectID(function_id.id()),
|
||||
args_for_photon,
|
||||
self.num_return_vals[function_id.id()],
|
||||
num_return_vals,
|
||||
self.current_task_id,
|
||||
self.task_index,
|
||||
actor_id, self.actor_counters[actor_id],
|
||||
@@ -604,7 +572,6 @@ class Worker(object):
|
||||
"function_id": function_to_run_id,
|
||||
"function": pickling.dumps(function)})
|
||||
self.redis_client.rpush("Exports", key)
|
||||
self.driver_export_counter += 1
|
||||
|
||||
def push_error_to_driver(self, driver_id, error_type, message, data=None):
|
||||
"""Push an error message to the driver to be printed in the background.
|
||||
@@ -963,8 +930,6 @@ def cleanup(worker=global_worker):
|
||||
|
||||
disconnect(worker)
|
||||
worker.set_mode(None)
|
||||
worker.driver_export_counter = 0
|
||||
worker.worker_import_counter = 0
|
||||
if hasattr(worker, "plasma_client"):
|
||||
worker.plasma_client.shutdown()
|
||||
services.cleanup()
|
||||
@@ -1037,14 +1002,13 @@ If this driver is hanging, start a new one with
|
||||
|
||||
def fetch_and_register_remote_function(key, worker=global_worker):
|
||||
"""Import a remote function."""
|
||||
driver_id, function_id_str, function_name, serialized_function, num_return_vals, module, function_export_counter, num_cpus, num_gpus = \
|
||||
driver_id, function_id_str, function_name, serialized_function, num_return_vals, module, num_cpus, num_gpus = \
|
||||
worker.redis_client.hmget(key, ["driver_id",
|
||||
"function_id",
|
||||
"name",
|
||||
"function",
|
||||
"num_return_vals",
|
||||
"module",
|
||||
"function_export_counter",
|
||||
"num_cpus",
|
||||
"num_gpus"])
|
||||
function_id = photon.ObjectID(function_id_str)
|
||||
@@ -1053,19 +1017,14 @@ def fetch_and_register_remote_function(key, worker=global_worker):
|
||||
num_cpus = int(num_cpus)
|
||||
num_gpus = int(num_gpus)
|
||||
module = module.decode("ascii")
|
||||
function_export_counter = int(function_export_counter)
|
||||
|
||||
worker.function_names[function_id.id()] = function_name
|
||||
worker.num_return_vals[function_id.id()] = num_return_vals
|
||||
worker.function_export_counters[function_id.id()] = function_export_counter
|
||||
# This is a placeholder in case the function can't be unpickled. This will be
|
||||
# overwritten if the function is unpickled successfully.
|
||||
# overwritten if the function is successfully registered.
|
||||
def f():
|
||||
raise Exception("This function was not imported properly.")
|
||||
worker.functions[function_id.id()] = remote(num_return_vals=num_return_vals,
|
||||
function_id=function_id,
|
||||
num_cpus=num_cpus,
|
||||
num_gpus=num_gpus)(lambda *xs: f())
|
||||
remote_f_placeholder = remote(function_id=function_id)(lambda *xs: f())
|
||||
worker.functions[driver_id][function_id.id()] = (function_name, remote_f_placeholder)
|
||||
worker.function_properties[driver_id][function_id.id()] = (num_return_vals, num_cpus, num_gpus)
|
||||
|
||||
try:
|
||||
function = pickling.loads(serialized_function)
|
||||
@@ -1081,10 +1040,7 @@ def fetch_and_register_remote_function(key, worker=global_worker):
|
||||
else:
|
||||
# TODO(rkn): Why is the below line necessary?
|
||||
function.__module__ = module
|
||||
worker.functions[function_id.id()] = remote(num_return_vals=num_return_vals,
|
||||
function_id=function_id,
|
||||
num_cpus=num_cpus,
|
||||
num_gpus=num_gpus)(function)
|
||||
worker.functions[driver_id][function_id.id()] = (function_name, remote(function_id=function_id)(function))
|
||||
# Add the function to the function table.
|
||||
worker.redis_client.rpush("FunctionTable:{}".format(function_id.id()), worker.worker_id)
|
||||
|
||||
@@ -1133,10 +1089,7 @@ def import_thread(worker):
|
||||
# in the loop.
|
||||
worker.import_pubsub_client.psubscribe("__keyspace@0__:Exports")
|
||||
worker_info_key = "WorkerInfo:{}".format(worker.worker_id)
|
||||
worker.redis_client.hset(worker_info_key, "export_counter", 0)
|
||||
worker.worker_import_counter = 0
|
||||
# The number of imports is similar to the worker_import_counter except that it
|
||||
# also counts actors.
|
||||
# Keep track of the number of imports that we've imported.
|
||||
num_imported = 0
|
||||
|
||||
# Get the exports that occurred before the call to psubscribe.
|
||||
@@ -1157,10 +1110,6 @@ def import_thread(worker):
|
||||
worker.fetch_and_register["Actor"](key, worker)
|
||||
else:
|
||||
raise Exception("This code should be unreachable.")
|
||||
# Actors do not contribute to the import counter.
|
||||
if not key.startswith(b"Actor"):
|
||||
worker.redis_client.hincrby(worker_info_key, "export_counter", 1)
|
||||
worker.worker_import_counter += 1
|
||||
num_imported += 1
|
||||
|
||||
for msg in worker.import_pubsub_client.listen():
|
||||
@@ -1189,10 +1138,6 @@ def import_thread(worker):
|
||||
worker.fetch_and_register["Actor"](key, worker)
|
||||
else:
|
||||
raise Exception("This code should be unreachable.")
|
||||
# Actors do not contribute to the import counter.
|
||||
if not key.startswith(b"Actor"):
|
||||
worker.redis_client.hincrby(worker_info_key, "export_counter", 1)
|
||||
worker.worker_import_counter += 1
|
||||
num_imported += 1
|
||||
|
||||
def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker, actor_id=NIL_ACTOR_ID):
|
||||
@@ -1522,13 +1467,12 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
|
||||
remaining_ids = [photon.ObjectID(object_id) for object_id in remaining_ids]
|
||||
return ready_ids, remaining_ids
|
||||
|
||||
def wait_for_valid_import_counter(function_id, driver_id, timeout=5, worker=global_worker):
|
||||
"""Wait until this worker has imported enough to execute the function.
|
||||
def wait_for_function(function_id, driver_id, timeout=5, worker=global_worker):
|
||||
"""Wait until the function to be executed is present on this worker.
|
||||
|
||||
This method will simply loop until the import thread has imported enough of
|
||||
the exports to execute the function. If we spend too long in this loop, that
|
||||
may indicate a problem somewhere and we will push an error message to the
|
||||
user.
|
||||
This method will simply loop until the import thread has imported the relevant
|
||||
function. If we spend too long in this loop, that may indicate a problem
|
||||
somewhere and we will push an error message to the user.
|
||||
|
||||
If this worker is an actor, then this will wait until the actor has been
|
||||
defined.
|
||||
@@ -1545,17 +1489,14 @@ def wait_for_valid_import_counter(function_id, driver_id, timeout=5, worker=glob
|
||||
num_warnings_sent = 0
|
||||
while True:
|
||||
with worker.lock:
|
||||
if worker.actor_id == NIL_ACTOR_ID and function_id.id() in worker.functions and (worker.function_export_counters[function_id.id()] <= worker.worker_import_counter):
|
||||
if worker.actor_id == NIL_ACTOR_ID and function_id.id() in worker.functions[driver_id]:
|
||||
break
|
||||
elif worker.actor_id != NIL_ACTOR_ID and worker.actor_id in worker.actors:
|
||||
break
|
||||
if time.time() - start_time > timeout * (num_warnings_sent + 1):
|
||||
if function_id.id() not in worker.functions:
|
||||
warning_message = "This worker was asked to execute a function that it does not have registered. You may have to restart Ray."
|
||||
else:
|
||||
warning_message = "This worker's import counter is too small."
|
||||
warning_message = "This worker was asked to execute a function that it does not have registered. You may have to restart Ray."
|
||||
if not warning_sent:
|
||||
worker.push_error_to_driver(driver_id, "import_counter",
|
||||
worker.push_error_to_driver(driver_id, "wait_for_function",
|
||||
warning_message)
|
||||
warning_sent = True
|
||||
time.sleep(0.001)
|
||||
@@ -1614,18 +1555,18 @@ def main_loop(worker=global_worker):
|
||||
function_id = task.function_id()
|
||||
args = task.arguments()
|
||||
return_object_ids = task.returns()
|
||||
function_name = worker.function_names[function_id.id()]
|
||||
function_name, function_executor = worker.functions[worker.task_driver_id.id()][function_id.id()]
|
||||
|
||||
# Get task arguments from the object store.
|
||||
with log_span("ray:task:get_arguments", worker=worker):
|
||||
arguments = get_arguments_for_execution(worker.functions[function_id.id()], args, worker)
|
||||
arguments = get_arguments_for_execution(function_name, args, worker)
|
||||
|
||||
# Execute the task.
|
||||
with log_span("ray:task:execute", worker=worker):
|
||||
if task.actor_id().id() == NIL_ACTOR_ID:
|
||||
outputs = worker.functions[task.function_id().id()].executor(arguments)
|
||||
outputs = function_executor.executor(arguments)
|
||||
else:
|
||||
outputs = worker.functions[task.function_id().id()](worker.actors[task.actor_id().id()], *arguments)
|
||||
outputs = function_executor(worker.actors[task.actor_id().id()], *arguments)
|
||||
|
||||
# Store the outputs in the local object store.
|
||||
with log_span("ray:task:store_outputs", worker=worker):
|
||||
@@ -1680,11 +1621,11 @@ def main_loop(worker=global_worker):
|
||||
task = worker.photon_client.get_task()
|
||||
|
||||
function_id = task.function_id()
|
||||
# Check that the number of imports we have is at least as great as the
|
||||
# export counter for the task. If not, wait until we have imported enough.
|
||||
# We will push warnings to the user if we spend too long in this loop.
|
||||
with log_span("ray:wait_for_import_counter", worker=worker):
|
||||
wait_for_valid_import_counter(function_id, task.driver_id().id(), worker=worker)
|
||||
# Wait until the function to be executed has actually been registered on
|
||||
# this worker. We will push warnings to the user if we spend too long in
|
||||
# this loop.
|
||||
with log_span("ray:wait_for_function", worker=worker):
|
||||
wait_for_function(function_id, task.driver_id().id(), worker=worker)
|
||||
|
||||
# Execute the task.
|
||||
# TODO(rkn): Consider acquiring this lock with a timeout and pushing a
|
||||
@@ -1695,7 +1636,8 @@ def main_loop(worker=global_worker):
|
||||
with worker.lock:
|
||||
log(event_type="ray:acquire_lock", kind=LOG_SPAN_END, worker=worker)
|
||||
|
||||
contents = {"function_name": worker.function_names[function_id.id()],
|
||||
function_name, _ = worker.functions[task.driver_id().id()][function_id.id()]
|
||||
contents = {"function_name": function_name,
|
||||
"task_id": task.task_id().hex()}
|
||||
with log_span("ray:task", contents=contents, worker=worker):
|
||||
process_task(task)
|
||||
@@ -1703,7 +1645,7 @@ def main_loop(worker=global_worker):
|
||||
# Push all of the log events to the global state store.
|
||||
flush_log()
|
||||
|
||||
def _submit_task(function_id, func_name, args, num_cpus, num_gpus, worker=global_worker):
|
||||
def _submit_task(function_id, func_name, args, worker=global_worker):
|
||||
"""This is a wrapper around worker.submit_task.
|
||||
|
||||
We use this wrapper so that in the remote decorator, we can call _submit_task
|
||||
@@ -1711,7 +1653,7 @@ def _submit_task(function_id, func_name, args, num_cpus, num_gpus, worker=global
|
||||
serialize remote functions, we don't attempt to serialize the worker object,
|
||||
which cannot be serialized.
|
||||
"""
|
||||
return worker.submit_task(function_id, func_name, args, num_cpus, num_gpus)
|
||||
return worker.submit_task(function_id, func_name, args)
|
||||
|
||||
def _mode(worker=global_worker):
|
||||
"""This is a wrapper around worker.mode.
|
||||
@@ -1751,14 +1693,13 @@ def _export_environment_variable(name, environment_variable, worker=global_worke
|
||||
"initializer": pickling.dumps(environment_variable.initializer),
|
||||
"reinitializer": pickling.dumps(environment_variable.reinitializer)})
|
||||
worker.redis_client.rpush("Exports", key)
|
||||
worker.driver_export_counter += 1
|
||||
|
||||
def export_remote_function(function_id, func_name, func, num_return_vals, num_cpus, num_gpus, worker=global_worker):
|
||||
check_main_thread()
|
||||
if _mode(worker) not in [SCRIPT_MODE, SILENT_MODE]:
|
||||
raise Exception("export_remote_function can only be called on a driver.")
|
||||
worker.function_properties[worker.task_driver_id.id()][function_id.id()] = (num_return_vals, num_cpus, num_gpus)
|
||||
key = "RemoteFunction:{}".format(function_id.id())
|
||||
worker.num_return_vals[function_id.id()] = num_return_vals
|
||||
pickled_func = pickling.dumps(func)
|
||||
worker.redis_client.hmset(key, {"driver_id": worker.task_driver_id.id(),
|
||||
"function_id": function_id.id(),
|
||||
@@ -1766,11 +1707,9 @@ def export_remote_function(function_id, func_name, func, num_return_vals, num_cp
|
||||
"module": func.__module__,
|
||||
"function": pickled_func,
|
||||
"num_return_vals": num_return_vals,
|
||||
"function_export_counter": worker.driver_export_counter,
|
||||
"num_cpus": num_cpus,
|
||||
"num_gpus": num_gpus})
|
||||
worker.redis_client.rpush("Exports", key)
|
||||
worker.driver_export_counter += 1
|
||||
|
||||
def remote(*args, **kwargs):
|
||||
"""This decorator is used to create remote functions.
|
||||
@@ -1778,13 +1717,26 @@ def remote(*args, **kwargs):
|
||||
Args:
|
||||
num_return_vals (int): The number of object IDs that a call to this function
|
||||
should return.
|
||||
num_cpus (int): The number of CPUs needed to execute this function. This
|
||||
should only be passed in when defining the remote function on the driver.
|
||||
num_gpus (int): The number of GPUs needed to execute this function. This
|
||||
should only be passed in when defining the remote function on the driver.
|
||||
"""
|
||||
worker = global_worker
|
||||
def make_remote_decorator(num_return_vals, num_cpus, num_gpus, func_id=None):
|
||||
def remote_decorator(func):
|
||||
func_name = "{}.{}".format(func.__module__, func.__name__)
|
||||
if func_id is None:
|
||||
function_id = FunctionID((hashlib.sha256(func_name.encode("ascii")).digest())[:20])
|
||||
# Compute the function ID as a hash of the function name as well as the
|
||||
# source code. We could in principle hash in the values in the closure
|
||||
# of the function, but that is likely to introduce non-determinism in
|
||||
# the computation of the function ID.
|
||||
function_id_hash = hashlib.sha1()
|
||||
function_id_hash.update(func_name.encode("ascii"))
|
||||
function_id_hash.update(inspect.getsource(func).encode("ascii"))
|
||||
function_id = function_id_hash.digest()
|
||||
assert len(function_id) == 20
|
||||
function_id = FunctionID(function_id)
|
||||
else:
|
||||
function_id = func_id
|
||||
|
||||
@@ -1807,7 +1759,7 @@ def remote(*args, **kwargs):
|
||||
_env()._reinitialize()
|
||||
_env()._running_remote_function_locally = False
|
||||
return result
|
||||
objectids = _submit_task(function_id, func_name, args, num_cpus, num_gpus)
|
||||
objectids = _submit_task(function_id, func_name, args)
|
||||
if len(objectids) == 1:
|
||||
return objectids[0]
|
||||
elif len(objectids) > 1:
|
||||
@@ -1909,7 +1861,7 @@ def check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaul
|
||||
if has_vararg_param and any([d != funcsigs._empty for _, d in keyword_defaults]):
|
||||
raise "Function {} has a *args argument as well as a keyword argument, which is currently not supported.".format(name)
|
||||
|
||||
def get_arguments_for_execution(function, serialized_args, worker=global_worker):
|
||||
def get_arguments_for_execution(function_name, serialized_args, worker=global_worker):
|
||||
"""Retrieve the arguments for the remote function.
|
||||
|
||||
This retrieves the values for the arguments to the remote function that were
|
||||
@@ -1917,8 +1869,8 @@ def get_arguments_for_execution(function, serialized_args, worker=global_worker)
|
||||
This is called by the worker that is executing the remote function.
|
||||
|
||||
Args:
|
||||
function (Callable): The remote function whose arguments are being
|
||||
retrieved.
|
||||
function_name (str): The name of the remote function whose arguments are
|
||||
being retrieved.
|
||||
serialized_args (List): The arguments to the function. These are either
|
||||
strings representing serialized objects passed by value or they are
|
||||
ObjectIDs.
|
||||
@@ -1939,7 +1891,7 @@ def get_arguments_for_execution(function, serialized_args, worker=global_worker)
|
||||
if isinstance(argument, RayTaskError):
|
||||
# If the result is a RayTaskError, then the task that created this
|
||||
# object failed, and we should propagate the error message here.
|
||||
raise RayGetArgumentError(function.__name__, i, arg, argument)
|
||||
raise RayGetArgumentError(function_name, i, arg, argument)
|
||||
else:
|
||||
# pass the argument by value
|
||||
argument = arg
|
||||
|
||||
Reference in New Issue
Block a user