mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 20:30:49 +08:00
General attribute-based heterogeneity support with hard and soft constraints (#248)
* attribute-based heterogeneity-awareness in global scheduler and photon * minor post-rebase fix * photon: enforce dynamic capacity constraint on task dispatch * globalsched: cap the number of times we try to schedule a task in round robin * propagating ability to specify resource capacity to ray.init * adding resources to remote function export and fetch/register * globalsched: remove unused functions; update cached photon resource capacity (until next photon heartbeat) * Add some integration tests. * globalsched: cleanup + factor out constraint checking * lots of style * task_spec_required_resource: global refactor * clang format * clang format + comment update in photon * clang format photon comment * valgrind * reduce verbosity for Travis * Add test for scheduler load balancing. * addressing comments * refactoring global scheduler algorithm * Minor cleanups. * Linting. * Fix array_test.py and linting. * valgrind fix for photon tests * Attempt to fix stress tests. * fix hashmap free * fix hashmap free comment * memset photon resource vectors to 0 in case they get used before the first heartbeat * More whitespace changes. * Undo whitespace error I introduced.
This commit is contained in:
committed by
Robert Nishihara
parent
1a7e1c47cb
commit
dfb6107b22
@@ -21,6 +21,7 @@ from plasma.utils import random_object_id, generate_metadata, write_to_data_buff
|
||||
USE_VALGRIND = False
|
||||
PLASMA_STORE_MEMORY = 1000000000
|
||||
ID_SIZE = 20
|
||||
NUM_CLUSTER_NODES = 2
|
||||
|
||||
# These constants must match the scheduling state enum in task.h.
|
||||
TASK_STATUS_WAITING = 1
|
||||
@@ -43,13 +44,16 @@ def random_task_id():
|
||||
def random_function_id():
|
||||
return photon.ObjectID(np.random.bytes(ID_SIZE))
|
||||
|
||||
def random_object_id():
|
||||
return photon.ObjectID(np.random.bytes(ID_SIZE))
|
||||
|
||||
def new_port():
|
||||
return random.randint(10000, 65535)
|
||||
|
||||
class TestGlobalScheduler(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Start a Redis server.
|
||||
# Start one Redis server and N pairs of (plasma, photon)
|
||||
redis_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), "../../core/src/common/thirdparty/redis/src/redis-server")
|
||||
redis_module = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../core/src/common/redis_module/libray_redis_module.so")
|
||||
assert os.path.isfile(redis_path)
|
||||
@@ -61,29 +65,47 @@ class TestGlobalScheduler(unittest.TestCase):
|
||||
time.sleep(0.1)
|
||||
# Create a Redis client.
|
||||
self.redis_client = redis.StrictRedis(host=node_ip_address, port=redis_port)
|
||||
# Start the global scheduler.
|
||||
# Start one global scheduler.
|
||||
self.p1 = global_scheduler.start_global_scheduler(redis_address, use_valgrind=USE_VALGRIND)
|
||||
# Start the Plasma store.
|
||||
plasma_store_name, self.p2 = plasma.start_plasma_store()
|
||||
# Start the Plasma manager.
|
||||
plasma_manager_name, self.p3, plasma_manager_port = plasma.start_plasma_manager(plasma_store_name, redis_address)
|
||||
self.plasma_address = "{}:{}".format(node_ip_address, plasma_manager_port)
|
||||
self.plasma_client = plasma.PlasmaClient(plasma_store_name, plasma_manager_name)
|
||||
# Start the local scheduler.
|
||||
local_scheduler_name, self.p4 = photon.start_local_scheduler(
|
||||
plasma_store_name,
|
||||
plasma_manager_name=plasma_manager_name,
|
||||
plasma_address=self.plasma_address,
|
||||
redis_address=redis_address)
|
||||
# Connect to the scheduler.
|
||||
self.photon_client = photon.PhotonClient(local_scheduler_name)
|
||||
self.plasma_store_pids = []
|
||||
self.plasma_manager_pids = []
|
||||
self.local_scheduler_pids = []
|
||||
self.plasma_clients = []
|
||||
self.photon_clients = []
|
||||
|
||||
for i in range(NUM_CLUSTER_NODES):
|
||||
# Start the Plasma store. Plasma store name is randomly generated.
|
||||
plasma_store_name, p2 = plasma.start_plasma_store()
|
||||
self.plasma_store_pids.append(p2)
|
||||
# Start the Plasma manager.
|
||||
# Assumption: Plasma manager name and port are randomly generated by the plasma module.
|
||||
plasma_manager_name, p3, plasma_manager_port = plasma.start_plasma_manager(plasma_store_name, redis_address)
|
||||
self.plasma_manager_pids.append(p3)
|
||||
plasma_address = "{}:{}".format(node_ip_address, plasma_manager_port)
|
||||
plasma_client = plasma.PlasmaClient(plasma_store_name, plasma_manager_name)
|
||||
self.plasma_clients.append(plasma_client)
|
||||
# Start the local scheduler.
|
||||
local_scheduler_name, p4 = photon.start_local_scheduler(
|
||||
plasma_store_name,
|
||||
plasma_manager_name=plasma_manager_name,
|
||||
plasma_address=plasma_address,
|
||||
redis_address=redis_address,
|
||||
static_resource_list=[None, 0])
|
||||
# Connect to the scheduler.
|
||||
photon_client = photon.PhotonClient(local_scheduler_name)
|
||||
self.photon_clients.append(photon_client)
|
||||
self.local_scheduler_pids.append(p4)
|
||||
|
||||
def tearDown(self):
|
||||
# Check that the processes are still alive.
|
||||
self.assertEqual(self.p1.poll(), None)
|
||||
self.assertEqual(self.p2.poll(), None)
|
||||
self.assertEqual(self.p3.poll(), None)
|
||||
self.assertEqual(self.p4.poll(), None)
|
||||
for p2 in self.plasma_store_pids:
|
||||
self.assertEqual(p2.poll(), None)
|
||||
for p3 in self.plasma_manager_pids:
|
||||
self.assertEqual(p3.poll(), None)
|
||||
for p4 in self.local_scheduler_pids:
|
||||
self.assertEqual(p4.poll(), None)
|
||||
|
||||
self.assertEqual(self.redis_process.poll(), None)
|
||||
|
||||
# Kill the global scheduler.
|
||||
@@ -94,9 +116,10 @@ class TestGlobalScheduler(unittest.TestCase):
|
||||
os._exit(-1)
|
||||
else:
|
||||
self.p1.kill()
|
||||
self.p2.kill()
|
||||
self.p3.kill()
|
||||
self.p4.kill()
|
||||
# Kill local schedulers, plasma managers, and plasma stores.
|
||||
map(subprocess.Popen.kill, self.local_scheduler_pids)
|
||||
map(subprocess.Popen.kill, self.plasma_manager_pids)
|
||||
map(subprocess.Popen.kill, self.plasma_store_pids)
|
||||
# Kill Redis. In the event that we are using valgrind, this needs to happen
|
||||
# after we kill the global scheduler.
|
||||
self.redis_process.kill()
|
||||
@@ -123,15 +146,21 @@ class TestGlobalScheduler(unittest.TestCase):
|
||||
|
||||
return db_client_id
|
||||
|
||||
def test_task_default_resources(self):
|
||||
task1 = photon.Task(random_driver_id(), random_function_id(), [random_object_id()], 0, random_task_id(), 0)
|
||||
self.assertEqual(task1.required_resources(), [1.0, 0.0])
|
||||
task2 = photon.Task(random_driver_id(), random_function_id(), [random_object_id()], 0, random_task_id(), 0, [1.0, 2.0])
|
||||
self.assertEqual(task2.required_resources(), [1.0, 2.0])
|
||||
|
||||
def test_redis_only_single_task(self):
|
||||
"""
|
||||
Tests global scheduler functionality by interacting with Redis and checking
|
||||
task state transitions in Redis only. TODO(atumanov): implement.
|
||||
"""
|
||||
# Check precondition for this test:
|
||||
# There should be three db clients, the global scheduler, the local
|
||||
# scheduler, and the plasma manager.
|
||||
self.assertEqual(len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))), 3)
|
||||
# There should be 2n+1 db clients: the global scheduler + one photon and one plasma per node.
|
||||
self.assertEqual(len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))),
|
||||
2 * NUM_CLUSTER_NODES + 1)
|
||||
db_client_id = self.get_plasma_manager_id()
|
||||
assert(db_client_id != None)
|
||||
assert(db_client_id.startswith(b"CL:"))
|
||||
@@ -140,21 +169,23 @@ class TestGlobalScheduler(unittest.TestCase):
|
||||
def test_integration_single_task(self):
|
||||
# There should be three db clients, the global scheduler, the local
|
||||
# scheduler, and the plasma manager.
|
||||
self.assertEqual(len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))), 3)
|
||||
self.assertEqual(len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))),
|
||||
2 * NUM_CLUSTER_NODES + 1)
|
||||
|
||||
num_return_vals = [0, 1, 2, 3, 5, 10]
|
||||
# There should not be anything else in Redis yet.
|
||||
self.assertEqual(len(self.redis_client.keys("*")), 3)
|
||||
self.assertEqual(len(self.redis_client.keys("*")), 2 * NUM_CLUSTER_NODES + 1)
|
||||
# Insert the object into Redis.
|
||||
data_size = 0xf1f0
|
||||
metadata_size = 0x40
|
||||
object_dep, memory_buffer, metadata = create_object(self.plasma_client, data_size, metadata_size, seal=True)
|
||||
plasma_client = self.plasma_clients[0]
|
||||
object_dep, memory_buffer, metadata = create_object(plasma_client, data_size, metadata_size, seal=True)
|
||||
|
||||
# Sleep before submitting task to photon.
|
||||
time.sleep(0.1)
|
||||
# Submit a task to Redis.
|
||||
task = photon.Task(random_driver_id(), random_function_id(), [photon.ObjectID(object_dep)], num_return_vals[0], random_task_id(), 0)
|
||||
self.photon_client.submit(task)
|
||||
self.photon_clients[0].submit(task)
|
||||
time.sleep(0.1)
|
||||
# There should now be a task in Redis, and it should get assigned to the
|
||||
# local scheduler
|
||||
@@ -184,7 +215,8 @@ class TestGlobalScheduler(unittest.TestCase):
|
||||
def integration_many_tasks_helper(self, timesync=True):
|
||||
# There should be three db clients, the global scheduler, the local
|
||||
# scheduler, and the plasma manager.
|
||||
self.assertEqual(len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))), 3)
|
||||
self.assertEqual(len(self.redis_client.keys("{}*".format(DB_CLIENT_PREFIX))),
|
||||
2 * NUM_CLUSTER_NODES + 1)
|
||||
num_return_vals = [0, 1, 2, 3, 5, 10]
|
||||
|
||||
# Submit a bunch of tasks to Redis.
|
||||
@@ -193,12 +225,13 @@ class TestGlobalScheduler(unittest.TestCase):
|
||||
# Create a new object for each task.
|
||||
data_size = np.random.randint(1 << 20)
|
||||
metadata_size = np.random.randint(1 << 10)
|
||||
object_dep, memory_buffer, metadata = create_object(self.plasma_client, data_size, metadata_size, seal=True)
|
||||
plasma_client = self.plasma_clients[0]
|
||||
object_dep, memory_buffer, metadata = create_object(plasma_client, data_size, metadata_size, seal=True)
|
||||
if timesync:
|
||||
# Give 10ms for object info handler to fire (long enough to yield CPU).
|
||||
time.sleep(0.010)
|
||||
task = photon.Task(random_driver_id(), random_function_id(), [photon.ObjectID(object_dep)], num_return_vals[0], random_task_id(), 0)
|
||||
self.photon_client.submit(task)
|
||||
self.photon_clients[0].submit(task)
|
||||
# Check that there are the correct number of tasks in Redis and that they
|
||||
# all get assigned to the local scheduler.
|
||||
num_retries = 10
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import multiprocessing
|
||||
import os
|
||||
import random
|
||||
import subprocess
|
||||
@@ -14,7 +15,7 @@ def start_local_scheduler(plasma_store_name, plasma_manager_name=None,
|
||||
worker_path=None, plasma_address=None,
|
||||
node_ip_address="127.0.0.1", redis_address=None,
|
||||
use_valgrind=False, use_profiler=False,
|
||||
redirect_output=False):
|
||||
redirect_output=False, static_resource_list=None):
|
||||
"""Start a local scheduler process.
|
||||
|
||||
Args:
|
||||
@@ -37,6 +38,9 @@ def start_local_scheduler(plasma_store_name, plasma_manager_name=None,
|
||||
profiler. If this is True, use_valgrind must be False.
|
||||
redirect_output (bool): True if stdout and stderr should be redirected to
|
||||
/dev/null.
|
||||
static_resource_list (list): A list of integers specifying the local
|
||||
scheduler's resource capacities. The resources should appear in an order
|
||||
matching the order defined in task.h.
|
||||
|
||||
Return:
|
||||
A tuple of the name of the local scheduler socket and the process ID of the
|
||||
@@ -71,6 +75,24 @@ def start_local_scheduler(plasma_store_name, plasma_manager_name=None,
|
||||
command += ["-r", redis_address]
|
||||
if plasma_address is not None:
|
||||
command += ["-a", plasma_address]
|
||||
# We want to be able to support independently setting capacity for each of the
|
||||
# supported resource types. Thus, the list can be None or contain any number
|
||||
# of None values.
|
||||
if static_resource_list is None:
|
||||
static_resource_list = [None, None]
|
||||
if static_resource_list[0] is None:
|
||||
# By default, use the number of hardware execution threads for the number of
|
||||
# cores.
|
||||
static_resource_list[0] = multiprocessing.cpu_count()
|
||||
if static_resource_list[1] is None:
|
||||
# By default, do not configure any GPUs on this node.
|
||||
static_resource_list[1] = 0
|
||||
# Pass the resource capacity string to the photon scheduler in all cases.
|
||||
# Sanity check to make sure all resource capacities in the list are numeric
|
||||
# (int or float).
|
||||
assert(all([x == int or x == float for x in map(type, static_resource_list)]))
|
||||
command += ["-c", ",".join(map(str, static_resource_list))]
|
||||
|
||||
with open(os.devnull, "w") as FNULL:
|
||||
stdout = FNULL if redirect_output else None
|
||||
stderr = FNULL if redirect_output else None
|
||||
|
||||
@@ -100,6 +100,8 @@ class PlasmaClient(object):
|
||||
store_socket_name (str): Name of the socket the plasma store is listening at.
|
||||
manager_socket_name (str): Name of the socket the plasma manager is listening at.
|
||||
"""
|
||||
self.store_socket_name = store_socket_name
|
||||
self.manager_socket_name = manager_socket_name
|
||||
self.alive = True
|
||||
|
||||
if manager_socket_name is not None:
|
||||
|
||||
+36
-8
@@ -264,7 +264,8 @@ def start_global_scheduler(redis_address, cleanup=True, redirect_output=False):
|
||||
|
||||
def start_local_scheduler(redis_address, node_ip_address, plasma_store_name,
|
||||
plasma_manager_name, worker_path, plasma_address=None,
|
||||
cleanup=True, redirect_output=False):
|
||||
cleanup=True, redirect_output=False,
|
||||
static_resource_list=None):
|
||||
"""Start a local scheduler process.
|
||||
|
||||
Args:
|
||||
@@ -281,6 +282,8 @@ def start_local_scheduler(redis_address, node_ip_address, plasma_store_name,
|
||||
that imported services exits.
|
||||
redirect_output (bool): True if stdout and stderr should be redirected to
|
||||
/dev/null.
|
||||
static_resource_list (list): An ordered list of the configured resource
|
||||
capacities for this local scheduler.
|
||||
|
||||
Return:
|
||||
The name of the local scheduler socket.
|
||||
@@ -292,7 +295,8 @@ def start_local_scheduler(redis_address, node_ip_address, plasma_store_name,
|
||||
redis_address=redis_address,
|
||||
plasma_address=plasma_address,
|
||||
use_profiler=RUN_PHOTON_PROFILER,
|
||||
redirect_output=redirect_output)
|
||||
redirect_output=redirect_output,
|
||||
static_resource_list=static_resource_list)
|
||||
if cleanup:
|
||||
all_processes[PROCESS_TYPE_LOCAL_SCHEDULER].append(p)
|
||||
return local_scheduler_name
|
||||
@@ -386,7 +390,9 @@ def start_ray_processes(address_info=None,
|
||||
cleanup=True,
|
||||
redirect_output=False,
|
||||
include_global_scheduler=False,
|
||||
include_redis=False):
|
||||
include_redis=False,
|
||||
num_cpus=None,
|
||||
num_gpus=None):
|
||||
"""Helper method to start Ray processes.
|
||||
|
||||
Args:
|
||||
@@ -411,11 +417,22 @@ def start_ray_processes(address_info=None,
|
||||
start a global scheduler process.
|
||||
include_redis (bool): If include_redis is True, then start a Redis server
|
||||
process.
|
||||
num_cpus: A list of length num_local_schedulers containing the number of
|
||||
CPUs each local scheduler should be configured with.
|
||||
num_gpus: A list of length num_local_schedulers containing the number of
|
||||
GPUs each local scheduler should be configured with.
|
||||
|
||||
Returns:
|
||||
A dictionary of the address information for the processes that were
|
||||
started.
|
||||
"""
|
||||
if not isinstance(num_cpus, list):
|
||||
num_cpus = num_local_schedulers * [num_cpus]
|
||||
if not isinstance(num_gpus, list):
|
||||
num_gpus = num_local_schedulers * [num_gpus]
|
||||
assert len(num_cpus) == num_local_schedulers
|
||||
assert len(num_gpus) == num_local_schedulers
|
||||
|
||||
if address_info is None:
|
||||
address_info = {}
|
||||
address_info["node_ip_address"] = node_ip_address
|
||||
@@ -486,7 +503,8 @@ def start_ray_processes(address_info=None,
|
||||
worker_path,
|
||||
plasma_address=plasma_address,
|
||||
cleanup=cleanup,
|
||||
redirect_output=redirect_output)
|
||||
redirect_output=redirect_output,
|
||||
static_resource_list=[num_cpus[i], num_gpus[i]])
|
||||
local_scheduler_socket_names.append(local_scheduler_name)
|
||||
time.sleep(0.1)
|
||||
|
||||
@@ -517,7 +535,9 @@ def start_ray_node(node_ip_address,
|
||||
num_local_schedulers=1,
|
||||
worker_path=None,
|
||||
cleanup=True,
|
||||
redirect_output=False):
|
||||
redirect_output=False,
|
||||
num_cpus=None,
|
||||
num_gpus=None):
|
||||
"""Start the Ray processes for a single node.
|
||||
|
||||
This assumes that the Ray processes on some master node have already been
|
||||
@@ -550,7 +570,9 @@ def start_ray_node(node_ip_address,
|
||||
num_local_schedulers=num_local_schedulers,
|
||||
worker_path=worker_path,
|
||||
cleanup=cleanup,
|
||||
redirect_output=redirect_output)
|
||||
redirect_output=redirect_output,
|
||||
num_cpus=num_cpus,
|
||||
num_gpus=num_gpus)
|
||||
|
||||
def start_ray_head(address_info=None,
|
||||
node_ip_address="127.0.0.1",
|
||||
@@ -558,7 +580,9 @@ def start_ray_head(address_info=None,
|
||||
num_local_schedulers=1,
|
||||
worker_path=None,
|
||||
cleanup=True,
|
||||
redirect_output=False):
|
||||
redirect_output=False,
|
||||
num_cpus=None,
|
||||
num_gpus=None):
|
||||
"""Start Ray in local mode.
|
||||
|
||||
Args:
|
||||
@@ -579,6 +603,8 @@ def start_ray_head(address_info=None,
|
||||
method exits.
|
||||
redirect_output (bool): True if stdout and stderr should be redirected to
|
||||
/dev/null.
|
||||
num_cpus (int): number of cpus to configure the local scheduler with.
|
||||
num_gpus (int): number of gpus to configure the local scheduler with.
|
||||
|
||||
Returns:
|
||||
A dictionary of the address information for the processes that were
|
||||
@@ -592,4 +618,6 @@ def start_ray_head(address_info=None,
|
||||
cleanup=cleanup,
|
||||
redirect_output=redirect_output,
|
||||
include_global_scheduler=True,
|
||||
include_redis=True)
|
||||
include_redis=True,
|
||||
num_cpus=num_cpus,
|
||||
num_gpus=num_gpus)
|
||||
|
||||
+67
-28
@@ -479,7 +479,7 @@ class Worker(object):
|
||||
assert final_results[i][0] == object_ids[i].id()
|
||||
return [result[1][0] for result in final_results]
|
||||
|
||||
def submit_task(self, function_id, func_name, args):
|
||||
def submit_task(self, function_id, func_name, args, num_cpus, num_gpus):
|
||||
"""Submit a remote task to the scheduler.
|
||||
|
||||
Tell the scheduler to schedule the execution of the function with name
|
||||
@@ -491,6 +491,8 @@ class Worker(object):
|
||||
args (List[Any]): The arguments to pass into the function. Arguments can
|
||||
be object IDs or they can be values. If they are values, they
|
||||
must be serializable objecs.
|
||||
num_cpus (int): The number of cpu cores this task requires to run.
|
||||
num_gpus (int): The number of gpus this task requires to run.
|
||||
"""
|
||||
with log_span("ray:submit_task", worker=self):
|
||||
check_main_thread()
|
||||
@@ -511,7 +513,8 @@ class Worker(object):
|
||||
args_for_photon,
|
||||
self.num_return_vals[function_id.id()],
|
||||
self.current_task_id,
|
||||
self.task_index)
|
||||
self.task_index,
|
||||
[num_cpus, num_gpus])
|
||||
# Increment the worker's task index to track how many tasks have been
|
||||
# submitted by the current task so far.
|
||||
self.task_index += 1
|
||||
@@ -734,7 +737,7 @@ def get_address_info_from_redis(redis_address, node_ip_address, num_retries=5):
|
||||
|
||||
def _init(address_info=None, start_ray_local=False, object_id_seed=None,
|
||||
num_workers=None, num_local_schedulers=None,
|
||||
driver_mode=SCRIPT_MODE):
|
||||
driver_mode=SCRIPT_MODE, num_cpus=None, num_gpus=None):
|
||||
"""Helper method to connect to an existing Ray cluster or start a new one.
|
||||
|
||||
This method handles two cases. Either a Ray cluster already exists and we
|
||||
@@ -761,6 +764,10 @@ def _init(address_info=None, start_ray_local=False, object_id_seed=None,
|
||||
only provided if start_ray_local is True.
|
||||
driver_mode (bool): The mode in which to start the driver. This should be
|
||||
one of ray.SCRIPT_MODE, ray.PYTHON_MODE, and ray.SILENT_MODE.
|
||||
num_cpus: A list containing the number of CPUs the local schedulers should
|
||||
be configured with.
|
||||
num_gpus: A list containing the number of GPUs the local schedulers should
|
||||
be configured with.
|
||||
|
||||
Returns:
|
||||
Address information about the started processes.
|
||||
@@ -807,7 +814,8 @@ def _init(address_info=None, start_ray_local=False, object_id_seed=None,
|
||||
address_info = services.start_ray_head(address_info=address_info,
|
||||
node_ip_address=node_ip_address,
|
||||
num_workers=num_workers,
|
||||
num_local_schedulers=num_local_schedulers)
|
||||
num_local_schedulers=num_local_schedulers,
|
||||
num_cpus=num_cpus, num_gpus=num_gpus)
|
||||
else:
|
||||
if redis_address is None:
|
||||
raise Exception("If start_ray_local=False, then redis_address must be provided.")
|
||||
@@ -815,6 +823,8 @@ def _init(address_info=None, start_ray_local=False, object_id_seed=None,
|
||||
raise Exception("If start_ray_local=False, then num_workers must not be provided.")
|
||||
if num_local_schedulers is not None:
|
||||
raise Exception("If start_ray_local=False, then num_local_schedulers must not be provided.")
|
||||
if num_cpus is not None or num_gpus is not None:
|
||||
raise Exception("If start_ray_local=False, then num_cpus and num_gpus must not be provided.")
|
||||
# Get the node IP address if one is not provided.
|
||||
if node_ip_address is None:
|
||||
node_ip_address = services.get_node_ip_address(redis_address)
|
||||
@@ -839,7 +849,7 @@ def _init(address_info=None, start_ray_local=False, object_id_seed=None,
|
||||
return address_info
|
||||
|
||||
def init(redis_address=None, node_ip_address=None, object_id_seed=None,
|
||||
num_workers=None, driver_mode=SCRIPT_MODE):
|
||||
num_workers=None, driver_mode=SCRIPT_MODE, num_cpus=None, num_gpus=None):
|
||||
"""Either connect to an existing Ray cluster or start one and connect to it.
|
||||
|
||||
This method handles two cases. Either a Ray cluster already exists and we
|
||||
@@ -860,6 +870,8 @@ def init(redis_address=None, node_ip_address=None, object_id_seed=None,
|
||||
redis_address is not provided.
|
||||
driver_mode (bool): The mode in which to start the driver. This should be
|
||||
one of ray.SCRIPT_MODE, ray.PYTHON_MODE, and ray.SILENT_MODE.
|
||||
num_cpus (int): Number of cpus the user wishes all local schedulers to be configured with.
|
||||
num_gpus (int): Number of gpus the user wishes all local schedulers to be configured with.
|
||||
|
||||
Returns:
|
||||
Address information about the started processes.
|
||||
@@ -873,7 +885,8 @@ def init(redis_address=None, node_ip_address=None, object_id_seed=None,
|
||||
"redis_address": redis_address,
|
||||
}
|
||||
return _init(address_info=info, start_ray_local=(redis_address is None),
|
||||
num_workers=num_workers, driver_mode=driver_mode)
|
||||
num_workers=num_workers, driver_mode=driver_mode,
|
||||
num_cpus=num_cpus, num_gpus=num_gpus)
|
||||
|
||||
def cleanup(worker=global_worker):
|
||||
"""Disconnect the driver, and terminate any processes started in init.
|
||||
@@ -964,10 +977,21 @@ If this driver is hanging, start a new one with
|
||||
|
||||
def fetch_and_register_remote_function(key, worker=global_worker):
|
||||
"""Import a remote function."""
|
||||
driver_id, function_id_str, function_name, serialized_function, num_return_vals, module, function_export_counter = worker.redis_client.hmget(key, ["driver_id", "function_id", "name", "function", "num_return_vals", "module", "function_export_counter"])
|
||||
driver_id, function_id_str, function_name, serialized_function, num_return_vals, module, function_export_counter, num_cpus, num_gpus = \
|
||||
worker.redis_client.hmget(key, ["driver_id",
|
||||
"function_id",
|
||||
"name",
|
||||
"function",
|
||||
"num_return_vals",
|
||||
"module",
|
||||
"function_export_counter",
|
||||
"num_cpus",
|
||||
"num_gpus"])
|
||||
function_id = photon.ObjectID(function_id_str)
|
||||
function_name = function_name.decode("ascii")
|
||||
num_return_vals = int(num_return_vals)
|
||||
num_cpus = int(num_cpus)
|
||||
num_gpus = int(num_gpus)
|
||||
module = module.decode("ascii")
|
||||
function_export_counter = int(function_export_counter)
|
||||
|
||||
@@ -978,7 +1002,10 @@ def fetch_and_register_remote_function(key, worker=global_worker):
|
||||
# overwritten if the function is unpickled successfully.
|
||||
def f():
|
||||
raise Exception("This function was not imported properly.")
|
||||
worker.functions[function_id.id()] = remote(num_return_vals=num_return_vals, function_id=function_id)(lambda *xs: f())
|
||||
worker.functions[function_id.id()] = remote(num_return_vals=num_return_vals,
|
||||
function_id=function_id,
|
||||
num_cpus=num_cpus,
|
||||
num_gpus=num_gpus)(lambda *xs: f())
|
||||
|
||||
try:
|
||||
function = pickling.loads(serialized_function)
|
||||
@@ -994,7 +1021,10 @@ def fetch_and_register_remote_function(key, worker=global_worker):
|
||||
else:
|
||||
# TODO(rkn): Why is the below line necessary?
|
||||
function.__module__ = module
|
||||
worker.functions[function_id.id()] = remote(num_return_vals=num_return_vals, function_id=function_id)(function)
|
||||
worker.functions[function_id.id()] = remote(num_return_vals=num_return_vals,
|
||||
function_id=function_id,
|
||||
num_cpus=num_cpus,
|
||||
num_gpus=num_gpus)(function)
|
||||
# Add the function to the function table.
|
||||
worker.redis_client.rpush("FunctionTable:{}".format(function_id.id()), worker.worker_id)
|
||||
|
||||
@@ -1207,8 +1237,8 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker):
|
||||
for name, environment_variable in env._cached_environment_variables:
|
||||
env.__setattr__(name, environment_variable)
|
||||
# Export cached remote functions to the workers.
|
||||
for function_id, func_name, func, num_return_vals in worker.cached_remote_functions:
|
||||
export_remote_function(function_id, func_name, func, num_return_vals, worker)
|
||||
for function_id, func_name, func, num_return_vals, num_cpus, num_gpus in worker.cached_remote_functions:
|
||||
export_remote_function(function_id, func_name, func, num_return_vals, num_cpus, num_gpus, worker)
|
||||
worker.cached_functions_to_run = None
|
||||
worker.cached_remote_functions = None
|
||||
env._cached_environment_variables = None
|
||||
@@ -1576,7 +1606,7 @@ def main_loop(worker=global_worker):
|
||||
# Push all of the log events to the global state store.
|
||||
flush_log()
|
||||
|
||||
def _submit_task(function_id, func_name, args, worker=global_worker):
|
||||
def _submit_task(function_id, func_name, args, num_cpus, num_gpus, worker=global_worker):
|
||||
"""This is a wrapper around worker.submit_task.
|
||||
|
||||
We use this wrapper so that in the remote decorator, we can call _submit_task
|
||||
@@ -1584,7 +1614,7 @@ def _submit_task(function_id, func_name, args, worker=global_worker):
|
||||
serialize remote functions, we don't attempt to serialize the worker object,
|
||||
which cannot be serialized.
|
||||
"""
|
||||
return worker.submit_task(function_id, func_name, args)
|
||||
return worker.submit_task(function_id, func_name, args, num_cpus, num_gpus)
|
||||
|
||||
def _mode(worker=global_worker):
|
||||
"""This is a wrapper around worker.mode.
|
||||
@@ -1626,7 +1656,7 @@ def _export_environment_variable(name, environment_variable, worker=global_worke
|
||||
worker.redis_client.rpush("Exports", key)
|
||||
worker.driver_export_counter += 1
|
||||
|
||||
def export_remote_function(function_id, func_name, func, num_return_vals, worker=global_worker):
|
||||
def export_remote_function(function_id, func_name, func, num_return_vals, num_cpus, num_gpus, worker=global_worker):
|
||||
check_main_thread()
|
||||
if _mode(worker) not in [SCRIPT_MODE, SILENT_MODE]:
|
||||
raise Exception("export_remote_function can only be called on a driver.")
|
||||
@@ -1639,7 +1669,9 @@ def export_remote_function(function_id, func_name, func, num_return_vals, worker
|
||||
"module": func.__module__,
|
||||
"function": pickled_func,
|
||||
"num_return_vals": num_return_vals,
|
||||
"function_export_counter": worker.driver_export_counter})
|
||||
"function_export_counter": worker.driver_export_counter,
|
||||
"num_cpus": num_cpus,
|
||||
"num_gpus": num_gpus})
|
||||
worker.redis_client.rpush("Exports", key)
|
||||
worker.driver_export_counter += 1
|
||||
|
||||
@@ -1651,7 +1683,7 @@ def remote(*args, **kwargs):
|
||||
should return.
|
||||
"""
|
||||
worker = global_worker
|
||||
def make_remote_decorator(num_return_vals, func_id=None):
|
||||
def make_remote_decorator(num_return_vals, num_cpus, num_gpus, func_id=None):
|
||||
def remote_decorator(func):
|
||||
func_name = "{}.{}".format(func.__module__, func.__name__)
|
||||
if func_id is None:
|
||||
@@ -1678,7 +1710,7 @@ def remote(*args, **kwargs):
|
||||
_env()._reinitialize()
|
||||
_env()._running_remote_function_locally = False
|
||||
return result
|
||||
objectids = _submit_task(function_id, func_name, args)
|
||||
objectids = _submit_task(function_id, func_name, args, num_cpus, num_gpus)
|
||||
if len(objectids) == 1:
|
||||
return objectids[0]
|
||||
elif len(objectids) > 1:
|
||||
@@ -1722,37 +1754,44 @@ def remote(*args, **kwargs):
|
||||
if func_name_global_valid: func.__globals__[func.__name__] = func_name_global_value
|
||||
else: del func.__globals__[func.__name__]
|
||||
if worker.mode in [SCRIPT_MODE, SILENT_MODE]:
|
||||
export_remote_function(function_id, func_name, func, num_return_vals)
|
||||
export_remote_function(function_id, func_name, func, num_return_vals, num_cpus, num_gpus)
|
||||
elif worker.mode is None:
|
||||
worker.cached_remote_functions.append((function_id, func_name, func, num_return_vals))
|
||||
worker.cached_remote_functions.append((function_id, func_name, func, num_return_vals, num_cpus, num_gpus))
|
||||
return func_invoker
|
||||
|
||||
return remote_decorator
|
||||
|
||||
num_return_vals = kwargs["num_return_vals"] if "num_return_vals" in kwargs.keys() else 1
|
||||
num_cpus = kwargs["num_cpus"] if "num_cpus" in kwargs.keys() else 1
|
||||
num_gpus = kwargs["num_gpus"] if "num_gpus" in kwargs.keys() else 0
|
||||
|
||||
if _mode() == WORKER_MODE:
|
||||
if "function_id" in kwargs:
|
||||
num_return_vals = kwargs["num_return_vals"]
|
||||
function_id = kwargs["function_id"]
|
||||
return make_remote_decorator(num_return_vals, function_id)
|
||||
return make_remote_decorator(num_return_vals, num_cpus, num_gpus, function_id)
|
||||
|
||||
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
|
||||
# This is the case where the decorator is just @ray.remote.
|
||||
num_return_vals = 1
|
||||
func = args[0]
|
||||
return make_remote_decorator(num_return_vals)(func)
|
||||
return make_remote_decorator(num_return_vals, num_cpus, num_gpus)(args[0])
|
||||
else:
|
||||
# This is the case where the decorator is something like
|
||||
# @ray.remote(num_return_vals=2).
|
||||
assert len(args) == 0 and "num_return_vals" in kwargs, "The @ray.remote decorator must be applied either with no arguments and no parentheses, for example '@ray.remote', or it must be applied with only the argument num_return_vals, like '@ray.remote(num_return_vals=2)'."
|
||||
num_return_vals = kwargs["num_return_vals"]
|
||||
error_string = ("The @ray.remote decorator must be applied either with no "
|
||||
"arguments and no parentheses, for example '@ray.remote', "
|
||||
"or it must be applied using some of the arguments "
|
||||
"'num_return_vals', 'num_cpus', or 'num_gpus', like "
|
||||
"'@ray.remote(num_return_vals=2)'.")
|
||||
assert len(args) == 0 and ("num_return_vals" in kwargs or
|
||||
"num_cpus" in kwargs or
|
||||
"num_gpus" in kwargs), error_string
|
||||
assert not "function_id" in kwargs
|
||||
return make_remote_decorator(num_return_vals)
|
||||
return make_remote_decorator(num_return_vals, num_cpus, num_gpus)
|
||||
|
||||
def check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaults, name):
|
||||
"""Check if we support the signature of this function.
|
||||
|
||||
We currently do not allow remote functions to have **kwargs. We also do not
|
||||
support keyword argumens in conjunction with a *args argument.
|
||||
support keyword arguments in conjunction with a *args argument.
|
||||
|
||||
Args:
|
||||
has_kwards_param (bool): True if the function being checked has a **kwargs
|
||||
|
||||
Reference in New Issue
Block a user