General attribute-based heterogeneity support with hard and soft constraints (#248)

* attribute-based heterogeneity-awareness in global scheduler and photon

* minor post-rebase fix

* photon: enforce dynamic capacity constraint on task dispatch

* globalsched: cap the number of times we try to schedule a task in round robin

* propagating ability to specify resource capacity to ray.init

* adding resources to remote function export and fetch/register

* globalsched: remove unused functions; update cached photon resource capacity (until next photon heartbeat)

* Add some integration tests.

* globalsched: cleanup + factor out constraint checking

* lots of style

* task_spec_required_resource: global refactor

* clang format

* clang format + comment update in photon

* clang format photon comment

* valgrind

* reduce verbosity for Travis

* Add test for scheduler load balancing.

* addressing comments

* refactoring global scheduler algorithm

* Minor cleanups.

* Linting.

* Fix array_test.py and linting.

* valgrind fix for photon tests

* Attempt to fix stress tests.

* fix hashmap free

* fix hashmap free comment

* memset photon resource vectors to 0 in case they get used before the first heartbeat

* More whitespace changes.

* Undo whitespace error I introduced.
This commit is contained in:
Alexey Tumanov
2017-02-09 01:34:14 -08:00
committed by Robert Nishihara
parent 1a7e1c47cb
commit dfb6107b22
22 changed files with 1037 additions and 226 deletions
+65 -32
View File
@@ -21,6 +21,7 @@ from plasma.utils import random_object_id, generate_metadata, write_to_data_buff
USE_VALGRIND = False
PLASMA_STORE_MEMORY = 1000000000
ID_SIZE = 20
NUM_CLUSTER_NODES = 2
# These constants must match the scheduling state enum in task.h.
TASK_STATUS_WAITING = 1
@@ -43,13 +44,16 @@ def random_task_id():
def random_function_id():
return photon.ObjectID(np.random.bytes(ID_SIZE))
def random_object_id():
return photon.ObjectID(np.random.bytes(ID_SIZE))
def new_port():
return random.randint(10000, 65535)
class TestGlobalScheduler(unittest.TestCase):
def setUp(self):
# Start a Redis server.
# Start one Redis server and N pairs of (plasma, photon)
redis_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), "../../core/src/common/thirdparty/redis/src/redis-server")
redis_module = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../core/src/common/redis_module/libray_redis_module.so")
assert os.path.isfile(redis_path)
@@ -61,29 +65,47 @@ class TestGlobalScheduler(unittest.TestCase):
time.sleep(0.1)
# Create a Redis client.
self.redis_client = redis.StrictRedis(host=node_ip_address, port=redis_port)
# Start the global scheduler.
# Start one global scheduler.
self.p1 = global_scheduler.start_global_scheduler(redis_address, use_valgrind=USE_VALGRIND)
# Start the Plasma store.
plasma_store_name, self.p2 = plasma.start_plasma_store()
# Start the Plasma manager.
plasma_manager_name, self.p3, plasma_manager_port = plasma.start_plasma_manager(plasma_store_name, redis_address)
self.plasma_address = "{}:{}".format(node_ip_address, plasma_manager_port)
self.plasma_client = plasma.PlasmaClient(plasma_store_name, plasma_manager_name)
# Start the local scheduler.
local_scheduler_name, self.p4 = photon.start_local_scheduler(
plasma_store_name,
plasma_manager_name=plasma_manager_name,
plasma_address=self.plasma_address,
redis_address=redis_address)
# Connect to the scheduler.
self.photon_client = photon.PhotonClient(local_scheduler_name)
self.plasma_store_pids = []
self.plasma_manager_pids = []
self.local_scheduler_pids = []
self.plasma_clients = []
self.photon_clients = []
for i in range(NUM_CLUSTER_NODES):
# Start the Plasma store. Plasma store name is randomly generated.
plasma_store_name, p2 = plasma.start_plasma_store()
self.plasma_store_pids.append(p2)
# Start the Plasma manager.
# Assumption: Plasma manager name and port are randomly generated by the plasma module.
plasma_manager_name, p3, plasma_manager_port = plasma.start_plasma_manager(plasma_store_name, redis_address)
self.plasma_manager_pids.append(p3)
plasma_address = "{}:{}".format(node_ip_address, plasma_manager_port)
plasma_client = plasma.PlasmaClient(plasma_store_name, plasma_manager_name)
self.plasma_clients.append(plasma_client)
# Start the local scheduler.
local_scheduler_name, p4 = photon.start_local_scheduler(
plasma_store_name,
plasma_manager_name=plasma_manager_name,
plasma_address=plasma_address,
redis_address=redis_address,
static_resource_list=[None, 0])
# Connect to the scheduler.
photon_client = photon.PhotonClient(local_scheduler_name)
self.photon_clients.append(photon_client)
self.local_scheduler_pids.append(p4)
def tearDown(self):
# Check that the processes are still alive.
self.assertEqual(self.p1.poll(), None)
self.assertEqual(self.p2.poll(), None)
self.assertEqual(self.p3.poll(), None)
self.assertEqual(self.p4.poll(), None)
for p2 in self.plasma_store_pids:
self.assertEqual(p2.poll(), None)
for p3 in self.plasma_manager_pids:
self.assertEqual(p3.poll(), None)
for p4 in self.local_scheduler_pids:
self.assertEqual(p4.poll(), None)
self.assertEqual(self.redis_process.poll(), None)
# Kill the global scheduler.
@@ -94,9 +116,10 @@ class TestGlobalScheduler(unittest.TestCase):
os._exit(-1)
else:
self.p1.kill()
self.p2.kill()
self.p3.kill()
self.p4.kill()
# Kill local schedulers, plasma managers, and plasma stores.
map(subprocess.Popen.kill, self.local_scheduler_pids)
map(subprocess.Popen.kill, self.plasma_manager_pids)
map(subprocess.Popen.kill, self.plasma_store_pids)
# Kill Redis. In the event that we are using valgrind, this needs to happen
# after we kill the global scheduler.
self.redis_process.kill()
@@ -123,15 +146,21 @@ class TestGlobalScheduler(unittest.TestCase):
return db_client_id
def test_task_default_resources(self):
task1 = photon.Task(random_driver_id(), random_function_id(), [random_object_id()], 0, random_task_id(), 0)
self.assertEqual(task1.required_resources(), [1.0, 0.0])
task2 = photon.Task(random_driver_id(), random_function_id(), [random_object_id()], 0, random_task_id(), 0, [1.0, 2.0])
self.assertEqual(task2.required_resources(), [1.0, 2.0])
def test_redis_only_single_task(self):
"""
Tests global scheduler functionality by interacting with Redis and checking
task state transitions in Redis only. TODO(atumanov): implement.
"""
# Check precondition for this test:
# There should be three db clients, the global scheduler, the local
# scheduler, and the plasma manager.
self.assertEqual(len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))), 3)
# There should be 2n+1 db clients: the global scheduler + one photon and one plasma per node.
self.assertEqual(len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))),
2 * NUM_CLUSTER_NODES + 1)
db_client_id = self.get_plasma_manager_id()
assert(db_client_id != None)
assert(db_client_id.startswith(b"CL:"))
@@ -140,21 +169,23 @@ class TestGlobalScheduler(unittest.TestCase):
def test_integration_single_task(self):
# There should be three db clients, the global scheduler, the local
# scheduler, and the plasma manager.
self.assertEqual(len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))), 3)
self.assertEqual(len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))),
2 * NUM_CLUSTER_NODES + 1)
num_return_vals = [0, 1, 2, 3, 5, 10]
# There should not be anything else in Redis yet.
self.assertEqual(len(self.redis_client.keys("*")), 3)
self.assertEqual(len(self.redis_client.keys("*")), 2 * NUM_CLUSTER_NODES + 1)
# Insert the object into Redis.
data_size = 0xf1f0
metadata_size = 0x40
object_dep, memory_buffer, metadata = create_object(self.plasma_client, data_size, metadata_size, seal=True)
plasma_client = self.plasma_clients[0]
object_dep, memory_buffer, metadata = create_object(plasma_client, data_size, metadata_size, seal=True)
# Sleep before submitting task to photon.
time.sleep(0.1)
# Submit a task to Redis.
task = photon.Task(random_driver_id(), random_function_id(), [photon.ObjectID(object_dep)], num_return_vals[0], random_task_id(), 0)
self.photon_client.submit(task)
self.photon_clients[0].submit(task)
time.sleep(0.1)
# There should now be a task in Redis, and it should get assigned to the
# local scheduler
@@ -184,7 +215,8 @@ class TestGlobalScheduler(unittest.TestCase):
def integration_many_tasks_helper(self, timesync=True):
# There should be three db clients, the global scheduler, the local
# scheduler, and the plasma manager.
self.assertEqual(len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))), 3)
self.assertEqual(len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))),
2 * NUM_CLUSTER_NODES + 1)
num_return_vals = [0, 1, 2, 3, 5, 10]
# Submit a bunch of tasks to Redis.
@@ -193,12 +225,13 @@ class TestGlobalScheduler(unittest.TestCase):
# Create a new object for each task.
data_size = np.random.randint(1 << 20)
metadata_size = np.random.randint(1 << 10)
object_dep, memory_buffer, metadata = create_object(self.plasma_client, data_size, metadata_size, seal=True)
plasma_client = self.plasma_clients[0]
object_dep, memory_buffer, metadata = create_object(plasma_client, data_size, metadata_size, seal=True)
if timesync:
# Give 10ms for object info handler to fire (long enough to yield CPU).
time.sleep(0.010)
task = photon.Task(random_driver_id(), random_function_id(), [photon.ObjectID(object_dep)], num_return_vals[0], random_task_id(), 0)
self.photon_client.submit(task)
self.photon_clients[0].submit(task)
# Check that there are the correct number of tasks in Redis and that they
# all get assigned to the local scheduler.
num_retries = 10
+23 -1
View File
@@ -2,6 +2,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import multiprocessing
import os
import random
import subprocess
@@ -14,7 +15,7 @@ def start_local_scheduler(plasma_store_name, plasma_manager_name=None,
worker_path=None, plasma_address=None,
node_ip_address="127.0.0.1", redis_address=None,
use_valgrind=False, use_profiler=False,
redirect_output=False):
redirect_output=False, static_resource_list=None):
"""Start a local scheduler process.
Args:
@@ -37,6 +38,9 @@ def start_local_scheduler(plasma_store_name, plasma_manager_name=None,
profiler. If this is True, use_valgrind must be False.
redirect_output (bool): True if stdout and stderr should be redirected to
/dev/null.
static_resource_list (list): A list of integers specifying the local
scheduler's resource capacities. The resources should appear in an order
matching the order defined in task.h.
Return:
A tuple of the name of the local scheduler socket and the process ID of the
@@ -71,6 +75,24 @@ def start_local_scheduler(plasma_store_name, plasma_manager_name=None,
command += ["-r", redis_address]
if plasma_address is not None:
command += ["-a", plasma_address]
# We want to be able to support independently setting capacity for each of the
# supported resource types. Thus, the list can be None or contain any number
# of None values.
if static_resource_list is None:
static_resource_list = [None, None]
if static_resource_list[0] is None:
# By default, use the number of hardware execution threads for the number of
# cores.
static_resource_list[0] = multiprocessing.cpu_count()
if static_resource_list[1] is None:
# By default, do not configure any GPUs on this node.
static_resource_list[1] = 0
# Pass the resource capacity string to the photon scheduler in all cases.
# Sanity check to make sure all resource capacities in the list are numeric
# (int or float).
assert(all([x == int or x == float for x in map(type, static_resource_list)]))
command += ["-c", ",".join(map(str, static_resource_list))]
with open(os.devnull, "w") as FNULL:
stdout = FNULL if redirect_output else None
stderr = FNULL if redirect_output else None
+2
View File
@@ -100,6 +100,8 @@ class PlasmaClient(object):
store_socket_name (str): Name of the socket the plasma store is listening at.
manager_socket_name (str): Name of the socket the plasma manager is listening at.
"""
self.store_socket_name = store_socket_name
self.manager_socket_name = manager_socket_name
self.alive = True
if manager_socket_name is not None:
+36 -8
View File
@@ -264,7 +264,8 @@ def start_global_scheduler(redis_address, cleanup=True, redirect_output=False):
def start_local_scheduler(redis_address, node_ip_address, plasma_store_name,
plasma_manager_name, worker_path, plasma_address=None,
cleanup=True, redirect_output=False):
cleanup=True, redirect_output=False,
static_resource_list=None):
"""Start a local scheduler process.
Args:
@@ -281,6 +282,8 @@ def start_local_scheduler(redis_address, node_ip_address, plasma_store_name,
that imported services exits.
redirect_output (bool): True if stdout and stderr should be redirected to
/dev/null.
static_resource_list (list): An ordered list of the configured resource
capacities for this local scheduler.
Return:
The name of the local scheduler socket.
@@ -292,7 +295,8 @@ def start_local_scheduler(redis_address, node_ip_address, plasma_store_name,
redis_address=redis_address,
plasma_address=plasma_address,
use_profiler=RUN_PHOTON_PROFILER,
redirect_output=redirect_output)
redirect_output=redirect_output,
static_resource_list=static_resource_list)
if cleanup:
all_processes[PROCESS_TYPE_LOCAL_SCHEDULER].append(p)
return local_scheduler_name
@@ -386,7 +390,9 @@ def start_ray_processes(address_info=None,
cleanup=True,
redirect_output=False,
include_global_scheduler=False,
include_redis=False):
include_redis=False,
num_cpus=None,
num_gpus=None):
"""Helper method to start Ray processes.
Args:
@@ -411,11 +417,22 @@ def start_ray_processes(address_info=None,
start a global scheduler process.
include_redis (bool): If include_redis is True, then start a Redis server
process.
num_cpus: A list of length num_local_schedulers containing the number of
CPUs each local scheduler should be configured with.
num_gpus: A list of length num_local_schedulers containing the number of
GPUs each local scheduler should be configured with.
Returns:
A dictionary of the address information for the processes that were
started.
"""
if not isinstance(num_cpus, list):
num_cpus = num_local_schedulers * [num_cpus]
if not isinstance(num_gpus, list):
num_gpus = num_local_schedulers * [num_gpus]
assert len(num_cpus) == num_local_schedulers
assert len(num_gpus) == num_local_schedulers
if address_info is None:
address_info = {}
address_info["node_ip_address"] = node_ip_address
@@ -486,7 +503,8 @@ def start_ray_processes(address_info=None,
worker_path,
plasma_address=plasma_address,
cleanup=cleanup,
redirect_output=redirect_output)
redirect_output=redirect_output,
static_resource_list=[num_cpus[i], num_gpus[i]])
local_scheduler_socket_names.append(local_scheduler_name)
time.sleep(0.1)
@@ -517,7 +535,9 @@ def start_ray_node(node_ip_address,
num_local_schedulers=1,
worker_path=None,
cleanup=True,
redirect_output=False):
redirect_output=False,
num_cpus=None,
num_gpus=None):
"""Start the Ray processes for a single node.
This assumes that the Ray processes on some master node have already been
@@ -550,7 +570,9 @@ def start_ray_node(node_ip_address,
num_local_schedulers=num_local_schedulers,
worker_path=worker_path,
cleanup=cleanup,
redirect_output=redirect_output)
redirect_output=redirect_output,
num_cpus=num_cpus,
num_gpus=num_gpus)
def start_ray_head(address_info=None,
node_ip_address="127.0.0.1",
@@ -558,7 +580,9 @@ def start_ray_head(address_info=None,
num_local_schedulers=1,
worker_path=None,
cleanup=True,
redirect_output=False):
redirect_output=False,
num_cpus=None,
num_gpus=None):
"""Start Ray in local mode.
Args:
@@ -579,6 +603,8 @@ def start_ray_head(address_info=None,
method exits.
redirect_output (bool): True if stdout and stderr should be redirected to
/dev/null.
num_cpus (int): number of cpus to configure the local scheduler with.
num_gpus (int): number of gpus to configure the local scheduler with.
Returns:
A dictionary of the address information for the processes that were
@@ -592,4 +618,6 @@ def start_ray_head(address_info=None,
cleanup=cleanup,
redirect_output=redirect_output,
include_global_scheduler=True,
include_redis=True)
include_redis=True,
num_cpus=num_cpus,
num_gpus=num_gpus)
+67 -28
View File
@@ -479,7 +479,7 @@ class Worker(object):
assert final_results[i][0] == object_ids[i].id()
return [result[1][0] for result in final_results]
def submit_task(self, function_id, func_name, args):
def submit_task(self, function_id, func_name, args, num_cpus, num_gpus):
"""Submit a remote task to the scheduler.
Tell the scheduler to schedule the execution of the function with name
@@ -491,6 +491,8 @@ class Worker(object):
args (List[Any]): The arguments to pass into the function. Arguments can
be object IDs or they can be values. If they are values, they
must be serializable objecs.
num_cpus (int): The number of cpu cores this task requires to run.
num_gpus (int): The number of gpus this task requires to run.
"""
with log_span("ray:submit_task", worker=self):
check_main_thread()
@@ -511,7 +513,8 @@ class Worker(object):
args_for_photon,
self.num_return_vals[function_id.id()],
self.current_task_id,
self.task_index)
self.task_index,
[num_cpus, num_gpus])
# Increment the worker's task index to track how many tasks have been
# submitted by the current task so far.
self.task_index += 1
@@ -734,7 +737,7 @@ def get_address_info_from_redis(redis_address, node_ip_address, num_retries=5):
def _init(address_info=None, start_ray_local=False, object_id_seed=None,
num_workers=None, num_local_schedulers=None,
driver_mode=SCRIPT_MODE):
driver_mode=SCRIPT_MODE, num_cpus=None, num_gpus=None):
"""Helper method to connect to an existing Ray cluster or start a new one.
This method handles two cases. Either a Ray cluster already exists and we
@@ -761,6 +764,10 @@ def _init(address_info=None, start_ray_local=False, object_id_seed=None,
only provided if start_ray_local is True.
driver_mode (bool): The mode in which to start the driver. This should be
one of ray.SCRIPT_MODE, ray.PYTHON_MODE, and ray.SILENT_MODE.
num_cpus: A list containing the number of CPUs the local schedulers should
be configured with.
num_gpus: A list containing the number of GPUs the local schedulers should
be configured with.
Returns:
Address information about the started processes.
@@ -807,7 +814,8 @@ def _init(address_info=None, start_ray_local=False, object_id_seed=None,
address_info = services.start_ray_head(address_info=address_info,
node_ip_address=node_ip_address,
num_workers=num_workers,
num_local_schedulers=num_local_schedulers)
num_local_schedulers=num_local_schedulers,
num_cpus=num_cpus, num_gpus=num_gpus)
else:
if redis_address is None:
raise Exception("If start_ray_local=False, then redis_address must be provided.")
@@ -815,6 +823,8 @@ def _init(address_info=None, start_ray_local=False, object_id_seed=None,
raise Exception("If start_ray_local=False, then num_workers must not be provided.")
if num_local_schedulers is not None:
raise Exception("If start_ray_local=False, then num_local_schedulers must not be provided.")
if num_cpus is not None or num_gpus is not None:
raise Exception("If start_ray_local=False, then num_cpus and num_gpus must not be provided.")
# Get the node IP address if one is not provided.
if node_ip_address is None:
node_ip_address = services.get_node_ip_address(redis_address)
@@ -839,7 +849,7 @@ def _init(address_info=None, start_ray_local=False, object_id_seed=None,
return address_info
def init(redis_address=None, node_ip_address=None, object_id_seed=None,
num_workers=None, driver_mode=SCRIPT_MODE):
num_workers=None, driver_mode=SCRIPT_MODE, num_cpus=None, num_gpus=None):
"""Either connect to an existing Ray cluster or start one and connect to it.
This method handles two cases. Either a Ray cluster already exists and we
@@ -860,6 +870,8 @@ def init(redis_address=None, node_ip_address=None, object_id_seed=None,
redis_address is not provided.
driver_mode (bool): The mode in which to start the driver. This should be
one of ray.SCRIPT_MODE, ray.PYTHON_MODE, and ray.SILENT_MODE.
num_cpus (int): Number of cpus the user wishes all local schedulers to be configured with.
num_gpus (int): Number of gpus the user wishes all local schedulers to be configured with.
Returns:
Address information about the started processes.
@@ -873,7 +885,8 @@ def init(redis_address=None, node_ip_address=None, object_id_seed=None,
"redis_address": redis_address,
}
return _init(address_info=info, start_ray_local=(redis_address is None),
num_workers=num_workers, driver_mode=driver_mode)
num_workers=num_workers, driver_mode=driver_mode,
num_cpus=num_cpus, num_gpus=num_gpus)
def cleanup(worker=global_worker):
"""Disconnect the driver, and terminate any processes started in init.
@@ -964,10 +977,21 @@ If this driver is hanging, start a new one with
def fetch_and_register_remote_function(key, worker=global_worker):
"""Import a remote function."""
driver_id, function_id_str, function_name, serialized_function, num_return_vals, module, function_export_counter = worker.redis_client.hmget(key, ["driver_id", "function_id", "name", "function", "num_return_vals", "module", "function_export_counter"])
driver_id, function_id_str, function_name, serialized_function, num_return_vals, module, function_export_counter, num_cpus, num_gpus = \
worker.redis_client.hmget(key, ["driver_id",
"function_id",
"name",
"function",
"num_return_vals",
"module",
"function_export_counter",
"num_cpus",
"num_gpus"])
function_id = photon.ObjectID(function_id_str)
function_name = function_name.decode("ascii")
num_return_vals = int(num_return_vals)
num_cpus = int(num_cpus)
num_gpus = int(num_gpus)
module = module.decode("ascii")
function_export_counter = int(function_export_counter)
@@ -978,7 +1002,10 @@ def fetch_and_register_remote_function(key, worker=global_worker):
# overwritten if the function is unpickled successfully.
def f():
raise Exception("This function was not imported properly.")
worker.functions[function_id.id()] = remote(num_return_vals=num_return_vals, function_id=function_id)(lambda *xs: f())
worker.functions[function_id.id()] = remote(num_return_vals=num_return_vals,
function_id=function_id,
num_cpus=num_cpus,
num_gpus=num_gpus)(lambda *xs: f())
try:
function = pickling.loads(serialized_function)
@@ -994,7 +1021,10 @@ def fetch_and_register_remote_function(key, worker=global_worker):
else:
# TODO(rkn): Why is the below line necessary?
function.__module__ = module
worker.functions[function_id.id()] = remote(num_return_vals=num_return_vals, function_id=function_id)(function)
worker.functions[function_id.id()] = remote(num_return_vals=num_return_vals,
function_id=function_id,
num_cpus=num_cpus,
num_gpus=num_gpus)(function)
# Add the function to the function table.
worker.redis_client.rpush("FunctionTable:{}".format(function_id.id()), worker.worker_id)
@@ -1207,8 +1237,8 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker):
for name, environment_variable in env._cached_environment_variables:
env.__setattr__(name, environment_variable)
# Export cached remote functions to the workers.
for function_id, func_name, func, num_return_vals in worker.cached_remote_functions:
export_remote_function(function_id, func_name, func, num_return_vals, worker)
for function_id, func_name, func, num_return_vals, num_cpus, num_gpus in worker.cached_remote_functions:
export_remote_function(function_id, func_name, func, num_return_vals, num_cpus, num_gpus, worker)
worker.cached_functions_to_run = None
worker.cached_remote_functions = None
env._cached_environment_variables = None
@@ -1576,7 +1606,7 @@ def main_loop(worker=global_worker):
# Push all of the log events to the global state store.
flush_log()
def _submit_task(function_id, func_name, args, worker=global_worker):
def _submit_task(function_id, func_name, args, num_cpus, num_gpus, worker=global_worker):
"""This is a wrapper around worker.submit_task.
We use this wrapper so that in the remote decorator, we can call _submit_task
@@ -1584,7 +1614,7 @@ def _submit_task(function_id, func_name, args, worker=global_worker):
serialize remote functions, we don't attempt to serialize the worker object,
which cannot be serialized.
"""
return worker.submit_task(function_id, func_name, args)
return worker.submit_task(function_id, func_name, args, num_cpus, num_gpus)
def _mode(worker=global_worker):
"""This is a wrapper around worker.mode.
@@ -1626,7 +1656,7 @@ def _export_environment_variable(name, environment_variable, worker=global_worke
worker.redis_client.rpush("Exports", key)
worker.driver_export_counter += 1
def export_remote_function(function_id, func_name, func, num_return_vals, worker=global_worker):
def export_remote_function(function_id, func_name, func, num_return_vals, num_cpus, num_gpus, worker=global_worker):
check_main_thread()
if _mode(worker) not in [SCRIPT_MODE, SILENT_MODE]:
raise Exception("export_remote_function can only be called on a driver.")
@@ -1639,7 +1669,9 @@ def export_remote_function(function_id, func_name, func, num_return_vals, worker
"module": func.__module__,
"function": pickled_func,
"num_return_vals": num_return_vals,
"function_export_counter": worker.driver_export_counter})
"function_export_counter": worker.driver_export_counter,
"num_cpus": num_cpus,
"num_gpus": num_gpus})
worker.redis_client.rpush("Exports", key)
worker.driver_export_counter += 1
@@ -1651,7 +1683,7 @@ def remote(*args, **kwargs):
should return.
"""
worker = global_worker
def make_remote_decorator(num_return_vals, func_id=None):
def make_remote_decorator(num_return_vals, num_cpus, num_gpus, func_id=None):
def remote_decorator(func):
func_name = "{}.{}".format(func.__module__, func.__name__)
if func_id is None:
@@ -1678,7 +1710,7 @@ def remote(*args, **kwargs):
_env()._reinitialize()
_env()._running_remote_function_locally = False
return result
objectids = _submit_task(function_id, func_name, args)
objectids = _submit_task(function_id, func_name, args, num_cpus, num_gpus)
if len(objectids) == 1:
return objectids[0]
elif len(objectids) > 1:
@@ -1722,37 +1754,44 @@ def remote(*args, **kwargs):
if func_name_global_valid: func.__globals__[func.__name__] = func_name_global_value
else: del func.__globals__[func.__name__]
if worker.mode in [SCRIPT_MODE, SILENT_MODE]:
export_remote_function(function_id, func_name, func, num_return_vals)
export_remote_function(function_id, func_name, func, num_return_vals, num_cpus, num_gpus)
elif worker.mode is None:
worker.cached_remote_functions.append((function_id, func_name, func, num_return_vals))
worker.cached_remote_functions.append((function_id, func_name, func, num_return_vals, num_cpus, num_gpus))
return func_invoker
return remote_decorator
num_return_vals = kwargs["num_return_vals"] if "num_return_vals" in kwargs.keys() else 1
num_cpus = kwargs["num_cpus"] if "num_cpus" in kwargs.keys() else 1
num_gpus = kwargs["num_gpus"] if "num_gpus" in kwargs.keys() else 0
if _mode() == WORKER_MODE:
if "function_id" in kwargs:
num_return_vals = kwargs["num_return_vals"]
function_id = kwargs["function_id"]
return make_remote_decorator(num_return_vals, function_id)
return make_remote_decorator(num_return_vals, num_cpus, num_gpus, function_id)
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
# This is the case where the decorator is just @ray.remote.
num_return_vals = 1
func = args[0]
return make_remote_decorator(num_return_vals)(func)
return make_remote_decorator(num_return_vals, num_cpus, num_gpus)(args[0])
else:
# This is the case where the decorator is something like
# @ray.remote(num_return_vals=2).
assert len(args) == 0 and "num_return_vals" in kwargs, "The @ray.remote decorator must be applied either with no arguments and no parentheses, for example '@ray.remote', or it must be applied with only the argument num_return_vals, like '@ray.remote(num_return_vals=2)'."
num_return_vals = kwargs["num_return_vals"]
error_string = ("The @ray.remote decorator must be applied either with no "
"arguments and no parentheses, for example '@ray.remote', "
"or it must be applied using some of the arguments "
"'num_return_vals', 'num_cpus', or 'num_gpus', like "
"'@ray.remote(num_return_vals=2)'.")
assert len(args) == 0 and ("num_return_vals" in kwargs or
"num_cpus" in kwargs or
"num_gpus" in kwargs), error_string
assert not "function_id" in kwargs
return make_remote_decorator(num_return_vals)
return make_remote_decorator(num_return_vals, num_cpus, num_gpus)
def check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaults, name):
"""Check if we support the signature of this function.
We currently do not allow remote functions to have **kwargs. We also do not
support keyword argumens in conjunction with a *args argument.
support keyword arguments in conjunction with a *args argument.
Args:
has_kwards_param (bool): True if the function being checked has a **kwargs