Make sure no Python modules mutually import each other. (#334)

This commit is contained in:
Robert Nishihara
2016-08-01 17:55:38 -07:00
committed by Philipp Moritz
parent 96a70e1316
commit c27e6c076c
8 changed files with 213 additions and 199 deletions
+1 -10
View File
@@ -1,12 +1,3 @@
# These three constants are used to define the mode that a worker is running in.
# Right now, this is only used for determining how to print information about
# task failures.
SCRIPT_MODE = 0
WORKER_MODE = 1
SHELL_MODE = 2
PYTHON_MODE = 3
SILENT_MODE = 4 # This is only used during testing.
# Ray version string
__version__ = "0.1"
@@ -19,9 +10,9 @@ if hasattr(ctypes, "windll"):
(lambda kernel32: (lambda job: (lambda n: kernel32.SetInformationJobObject(job, 9, "\0" * 17 + chr(0x8 | 0x4 | 0x20) + "\0" * (n - 18), n))(0x90 if ctypes.sizeof(ctypes.c_void_p) > ctypes.sizeof(ctypes.c_int) else 0x70) and kernel32.AssignProcessToJobObject(job, ctypes.c_void_p(kernel32.GetCurrentProcess())))(ctypes.c_void_p(kernel32.CreateJobObjectW(None, None))) if kernel32 is not None else None)(ctypes.windll.kernel32)
import config
import libraylib as lib
import serialization
from worker import scheduler_info, visualize_computation_graph, task_info, register_module, init, connect, disconnect, get, put, remote, kill_workers, restart_workers_local
from worker import Reusable, reusables
from worker import SCRIPT_MODE, WORKER_MODE, SHELL_MODE, PYTHON_MODE, SILENT_MODE
from libraylib import ObjectID
import internal
+6 -6
View File
@@ -1,7 +1,7 @@
import importlib
import numpy as np
import ray
import libraylib as raylib
# The following definitions are required because Python doesn't allow custom
# attributes for primitive types. We need custom attributes for (a) implementing
@@ -54,18 +54,18 @@ def is_arrow_serializable(value):
def serialize(worker_capsule, obj):
primitive_obj = to_primitive(obj)
obj_capsule, contained_objectids = ray.lib.serialize_object(worker_capsule, primitive_obj) # contained_objectids is a list of the objectids contained in obj
obj_capsule, contained_objectids = raylib.serialize_object(worker_capsule, primitive_obj) # contained_objectids is a list of the objectids contained in obj
return obj_capsule, contained_objectids
def deserialize(worker_capsule, capsule):
primitive_obj = ray.lib.deserialize_object(worker_capsule, capsule)
primitive_obj = raylib.deserialize_object(worker_capsule, capsule)
return from_primitive(primitive_obj)
def serialize_task(worker_capsule, func_name, args):
primitive_args = [(arg if isinstance(arg, ray.ObjectID) else to_primitive(arg)) for arg in args]
return ray.lib.serialize_task(worker_capsule, func_name, primitive_args)
primitive_args = [(arg if isinstance(arg, raylib.ObjectID) else to_primitive(arg)) for arg in args]
return raylib.serialize_task(worker_capsule, func_name, primitive_args)
def deserialize_task(worker_capsule, task):
func_name, primitive_args, return_objectids = task
args = [(arg if isinstance(arg, ray.ObjectID) else from_primitive(arg)) for arg in primitive_args]
args = [(arg if isinstance(arg, raylib.ObjectID) else from_primitive(arg)) for arg in primitive_args]
return func_name, args, return_objectids
+33 -54
View File
@@ -1,11 +1,10 @@
import os
import sys
import time
import atexit
import subprocess32 as subprocess
import ray
import worker
# Ray modules
import config
_services_env = os.environ.copy()
_services_env["PATH"] = os.pathsep.join([os.path.dirname(os.path.abspath(__file__)), _services_env["PATH"]])
@@ -14,9 +13,6 @@ _services_env["PATH"] = os.pathsep.join([os.path.dirname(os.path.abspath(__file_
# that have been started by this services module if Ray is being used in local
# mode.
all_processes = []
# drivers is a list of the worker objects corresponding to drivers if
# start_ray_local is run with return_drivers=True.
drivers = []
IP_ADDRESS = "127.0.0.1"
TIMEOUT_SECONDS = 5
@@ -36,6 +32,12 @@ def new_worker_port():
worker_port_counter += 1
return 40000 + worker_port_counter
driver_port_counter = 0
def new_driver_port():
global driver_port_counter
driver_port_counter += 1
return 30000 + driver_port_counter
objstore_port_counter = 0
def new_objstore_port():
global objstore_port_counter
@@ -47,22 +49,9 @@ def cleanup():
This method is used to shutdown processes that were started with
services.start_ray_local(). It kills all scheduler, object store, and worker
processes that were started by this services module. It disconnects driver
processes but does not kill them. This will automatically run at the end when
a Python process that imports services exits. It is ok to run this twice in a
row. Note that we manually call services.cleanup() in the tests because we
need to start and stop many clusters in the tests, but in the tests, services
is only imported and only exits once.
processes that were started by this services module. Driver processes are
started and disconnected by worker.py.
"""
global drivers
for driver in drivers:
ray.disconnect(driver)
driver.set_mode(None)
if len(drivers) == 0:
ray.disconnect()
ray.worker.global_worker.set_mode(None)
drivers = []
global all_processes
for p, address in all_processes:
if p.poll() is not None: # process has already terminated
@@ -83,8 +72,6 @@ def cleanup():
print "Termination attempt failed, giving up."
all_processes = []
atexit.register(cleanup)
def start_scheduler(scheduler_address, local):
"""This method starts a scheduler process.
@@ -94,7 +81,7 @@ def start_scheduler(scheduler_address, local):
process will be killed by serices.cleanup() when the Python process that
imported services exits.
"""
p = subprocess.Popen(["scheduler", scheduler_address, "--log-file-name", ray.config.get_log_file_path("scheduler.log")], env=_services_env)
p = subprocess.Popen(["scheduler", scheduler_address, "--log-file-name", config.get_log_file_path("scheduler.log")], env=_services_env)
if local:
all_processes.append((p, scheduler_address))
@@ -109,7 +96,7 @@ def start_objstore(scheduler_address, objstore_address, local):
process will be killed by serices.cleanup() when the Python process that
imported services exits.
"""
p = subprocess.Popen(["objstore", scheduler_address, objstore_address, "--log-file-name", ray.config.get_log_file_path("-".join(["objstore", objstore_address]) + ".log")], env=_services_env)
p = subprocess.Popen(["objstore", scheduler_address, objstore_address, "--log-file-name", config.get_log_file_path("-".join(["objstore", objstore_address]) + ".log")], env=_services_env)
if local:
all_processes.append((p, objstore_address))
@@ -189,54 +176,46 @@ def start_workers(scheduler_address, objstore_address, num_workers, worker_path)
for _ in range(num_workers):
start_worker(worker_path, scheduler_address, objstore_address, address(node_ip_address, new_worker_port()), local=False)
def start_ray_local(num_objstores=1, num_workers_per_objstore=0, worker_path=None, driver_mode=ray.SCRIPT_MODE, return_drivers=False):
def start_ray_local(num_objstores=1, num_workers=0, worker_path=None):
"""Start Ray in local mode.
This method starts Ray in local mode (as opposed to cluster mode, which is
handled by cluster.py).
Args:
num_objstores (int): The number of object stores to start.
num_workers_per_objstore (int): The number of workers to start per object
store.
num_objstores (int): The number of object stores to start. Aside from
testing, this should be one.
num_workers (int): The number of workers to start.
worker_path (str): The path of the source code that will be run by the
worker
driver_mode: The mode for the driver, this only affects the printing of
error messages. This should be ray.SCRIPT_MODE if the driver is being run
in a script. It should be ray.SHELL_MODE if it is being used interactively
in the shell. It should be ray.PYTHON_MODE to run things in a manner
equivalent to serial Python code. It should be ray.WORKER_MODE to surpress
the printing of error messages.
return_drivers (bool): This should only be True in special cases for tests.
worker.
Returns:
The address of the scheduler, the addresses of all of the object stores, and
the one new driver address for each object store.
"""
global drivers
if worker_path is None:
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../scripts/default_worker.py")
if num_workers_per_objstore > 0 and num_objstores < 1:
if num_workers > 0 and num_objstores < 1:
raise Exception("Attempting to start a cluster with {} workers per object store, but `num_objstores` is {}.".format(num_objstores))
scheduler_address = address(IP_ADDRESS, new_scheduler_port())
start_scheduler(scheduler_address, local=True)
time.sleep(0.1)
objstore_addresses = []
# create objstores
for _ in range(num_objstores):
for i in range(num_objstores):
objstore_address = address(IP_ADDRESS, new_objstore_port())
objstore_addresses.append(objstore_address)
start_objstore(scheduler_address, objstore_address, local=True)
time.sleep(0.2)
for _ in range(num_workers_per_objstore):
if i < num_objstores - 1:
num_workers_to_start = num_workers / num_objstores
else:
# In case num_workers is not divisible by num_objstores, start the correct
# remaining number of workers.
num_workers_to_start = num_workers - (num_objstores - 1) * (num_workers / num_objstores)
for _ in range(num_workers_to_start):
start_worker(worker_path, scheduler_address, objstore_address, address(IP_ADDRESS, new_worker_port()), local=True)
time.sleep(0.3)
# create drivers
if return_drivers:
driver_workers = []
for i in range(num_objstores):
driver_worker = worker.Worker()
ray.connect(scheduler_address, objstore_address, address(IP_ADDRESS, new_worker_port()), is_driver=True, worker=driver_worker)
driver_workers.append(driver_worker)
drivers.append(driver_worker)
time.sleep(0.5)
return driver_workers
else:
ray.connect(scheduler_address, objstore_addresses[0], address(IP_ADDRESS, new_worker_port()), is_driver=True, mode=driver_mode)
time.sleep(0.5)
driver_addresses = [address(IP_ADDRESS, new_driver_port()) for _ in range(num_objstores)]
return scheduler_address, objstore_addresses, driver_addresses
+134 -98
View File
@@ -8,14 +8,26 @@ import typing
import funcsigs
import numpy as np
import colorama
import atexit
import ray
# Ray modules
import config
import pickling
import serialization
import ray.internal.graph_pb2
import ray.graph
import internal.graph_pb2
import graph
import services
import libnumbuf
import libraylib as raylib
# These three constants are used to define the mode that a worker is running in.
# Right now, this is only used for determining how to print information about
# task failures.
SCRIPT_MODE = 0
WORKER_MODE = 1
SHELL_MODE = 2
PYTHON_MODE = 3
SILENT_MODE = 4 # This is only used during testing.
class RayFailedObject(object):
"""An object used internally to represent a task that threw an exception.
@@ -102,7 +114,7 @@ class RayDealloc(object):
def __del__(self):
"""Deallocate the relevant segment to avoid a memory leak."""
ray.lib.unmap_object(self.handle, self.segmentid)
raylib.unmap_object(self.handle, self.segmentid)
class Reusable(object):
"""An Python object that can be shared between tasks.
@@ -223,7 +235,7 @@ class RayReusables(object):
raise Exception("To set a reusable variable, you must pass in a Reusable object")
self._names.add(name)
self._reusables[name] = reusable
if _mode() in [ray.SHELL_MODE, ray.SCRIPT_MODE, ray.SILENT_MODE]:
if _mode() in [SHELL_MODE, SCRIPT_MODE, SILENT_MODE]:
_export_reusable_variable(name, reusable)
elif _mode() is None:
self._cached_reusables.append((name, reusable))
@@ -249,8 +261,8 @@ class Worker(object):
function to the remote function itself. This is the set of remote
functions that can be executed by this worker.
handle (worker capsule): A Python object wrapping a C++ Worker object.
mode: The mode of the worker. One of ray.SCRIPT_MODE, ray.SHELL_MODE,
ray.PYTHON_MODE, ray.SILENT_MODE, and ray.WORKER_MODE.
mode: The mode of the worker. One of SCRIPT_MODE, SHELL_MODE, PYTHON_MODE,
SILENT_MODE, and WORKER_MODE.
cached_remote_functions (List[str]): A list of serialized remote functions
that were defined before the worker called connect. When the worker
eventually does call connect, if it is a driver, it will export these
@@ -268,28 +280,27 @@ class Worker(object):
def set_mode(self, mode):
"""Set the mode of the worker.
The mode ray.SCRIPT_MODE should be used if this Worker is a driver that is
being run as a Python script. It will print information about task failures.
The mode SCRIPT_MODE should be used if this Worker is a driver that is being
run as a Python script. It will print information about task failures.
The mode ray.SHELL_MODE should be used if this Worker is a driver that is
being run interactively in a Python shell. It will print information about
task failures and successes.
The mode SHELL_MODE should be used if this Worker is a driver that is being
run interactively in a Python shell. It will print information about task
failures and successes.
The mode ray.WORKER_MODE should be used if this Worker is not a driver. It
will not print information about tasks.
The mode WORKER_MODE should be used if this Worker is not a driver. It will
not print information about tasks.
The mode ray.PYTHON_MODE should be used if this Worker is a driver and if
you want to run the driver in a manner equivalent to serial Python for
debugging purposes. It will not send remote function calls to the scheduler
and will insead execute them in a blocking fashion.
The mode PYTHON_MODE should be used if this Worker is a driver and if you
want to run the driver in a manner equivalent to serial Python for debugging
purposes. It will not send remote function calls to the scheduler and will
insead execute them in a blocking fashion.
The mode ray.SILENT_MODE should be used only during testing. It does not
print any information about errors because some of the tests intentionally
fail.
The mode SILENT_MODE should be used only during testing. It does not print
any information about errors because some of the tests intentionally fail.
args:
mode: One of ray.SCRIPT_MODE, ray.WORKER_MODE, ray.SHELL_MODE,
ray.PYTHON_MODE, and ray.SILENT_MODE.
mode: One of SCRIPT_MODE, WORKER_MODE, SHELL_MODE, PYTHON_MODE, and
SILENT_MODE.
"""
self.mode = mode
colorama.init()
@@ -301,7 +312,7 @@ class Worker(object):
local object store.
Args:
objectid (ray.ObjectID): The object ID of the value to be put.
objectid (raylib.ObjectID): The object ID of the value to be put.
value (serializable object): The value to put in the object store.
"""
try:
@@ -313,7 +324,7 @@ class Worker(object):
# the len(schema) is for storing the metadata and the 4096 is for storing
# the metadata in the batch (see INITIAL_METADATA_SIZE in arrow)
size = size + 8 + len(schema) + 4096
buff, segmentid = ray.lib.allocate_buffer(self.handle, objectid, size)
buff, segmentid = raylib.allocate_buffer(self.handle, objectid, size)
# write the metadata length
np.frombuffer(buff, dtype="int64", count=1)[0] = len(schema)
# metadata buffer
@@ -322,12 +333,12 @@ class Worker(object):
metadata[:] = schema
data = np.frombuffer(buff, dtype="byte")[8 + len(schema):]
metadata_offset = libnumbuf.write_to_buffer(serialized, memoryview(data))
ray.lib.finish_buffer(self.handle, objectid, segmentid, metadata_offset)
raylib.finish_buffer(self.handle, objectid, segmentid, metadata_offset)
except:
# At the moment, custom object and objects that contain object IDs take this path
# TODO(pcm): Make sure that these are the only objects getting serialized to protobuf
object_capsule, contained_objectids = serialization.serialize(self.handle, value) # contained_objectids is a list of the objectids contained in object_capsule
ray.lib.put_object(self.handle, objectid, object_capsule, contained_objectids)
raylib.put_object(self.handle, objectid, object_capsule, contained_objectids)
def get_object(self, objectid):
"""Get the value in the local object store associated with objectid.
@@ -336,11 +347,11 @@ class Worker(object):
until the value for objectid has been written to the local object store.
Args:
objectid (ray.ObjectID): The object ID of the value to retrieve.
objectid (raylib.ObjectID): The object ID of the value to retrieve.
"""
if ray.lib.is_arrow(self.handle, objectid):
if raylib.is_arrow(self.handle, objectid):
## this is the new codepath
buff, segmentid, metadata_offset = ray.lib.get_buffer(self.handle, objectid)
buff, segmentid, metadata_offset = raylib.get_buffer(self.handle, objectid)
metadata_size = np.frombuffer(buff, dtype="int64", count=1)[0]
metadata = np.frombuffer(buff, dtype="byte", offset=8, count=metadata_size)
data = np.frombuffer(buff, dtype="byte")[8 + metadata_size:]
@@ -350,9 +361,9 @@ class Worker(object):
assert len(deserialized) == 1
result = deserialized[0]
## this is the old codepath
# result, segmentid = ray.lib.get_arrow(self.handle, objectid)
# result, segmentid = raylib.get_arrow(self.handle, objectid)
else:
object_capsule, segmentid = ray.lib.get_object(self.handle, objectid)
object_capsule, segmentid = raylib.get_object(self.handle, objectid)
result = serialization.deserialize(self.handle, object_capsule)
if isinstance(result, int):
@@ -362,7 +373,7 @@ class Worker(object):
elif isinstance(result, float):
result = serialization.Float(result)
elif isinstance(result, bool):
ray.lib.unmap_object(self.handle, segmentid) # need to unmap here because result is passed back "by value" and we have no reference to unmap later
raylib.unmap_object(self.handle, segmentid) # need to unmap here because result is passed back "by value" and we have no reference to unmap later
return result # can't subclass bool, and don't need to because there is a global True/False
elif isinstance(result, list):
result = serialization.List(result)
@@ -378,7 +389,7 @@ class Worker(object):
return result
# TODO(pcm): close the associated memory segment; if we don't, this leaks memory (but very little, so it is ok for now)
elif result is None:
ray.lib.unmap_object(self.handle, segmentid) # need to unmap here because result is passed back "by value" and we have no reference to unmap later
raylib.unmap_object(self.handle, segmentid) # need to unmap here because result is passed back "by value" and we have no reference to unmap later
return None # can't subclass None and don't need to because there is a global None
result.ray_objectid = objectid # TODO(pcm): This could be done only for the "get" case in the future if we want to increase performance
result.ray_deallocator = RayDealloc(self.handle, segmentid)
@@ -386,7 +397,7 @@ class Worker(object):
def alias_objectids(self, alias_objectid, target_objectid):
"""Make two object IDs refer to the same object."""
ray.lib.alias_objectids(self.handle, alias_objectid, target_objectid)
raylib.alias_objectids(self.handle, alias_objectid, target_objectid)
def register_function(self, function):
"""Register a function with the scheduler.
@@ -399,7 +410,7 @@ class Worker(object):
function (Callable): The remote function that this worker can execute.
"""
_logger().info("Registering function {}.".format(function.func_name))
ray.lib.register_function(self.handle, function.func_name, len(function.return_types))
raylib.register_function(self.handle, function.func_name, len(function.return_types))
self.functions[function.func_name] = function
def submit_task(self, func_name, args):
@@ -416,9 +427,9 @@ class Worker(object):
must be serializable objecs.
"""
task_capsule = serialization.serialize_task(self.handle, func_name, args)
objectids = ray.lib.submit_task(self.handle, task_capsule)
if self.mode in [ray.SHELL_MODE, ray.SCRIPT_MODE]:
print_task_info(ray.lib.task_info(self.handle), self.mode)
objectids = raylib.submit_task(self.handle, task_capsule)
if self.mode in [SHELL_MODE, SCRIPT_MODE]:
print_task_info(raylib.task_info(self.handle), self.mode)
return objectids
global_worker = Worker()
@@ -479,7 +490,7 @@ def print_task_info(task_data, mode):
for task_status in task_data["failed_tasks"]:
print_failed_task(task_status)
print "Error: {} task{} failed.".format(num_tasks_failed, "s" if num_tasks_failed > 1 else "")
if mode == ray.SHELL_MODE:
if mode == SHELL_MODE:
info_strings = []
if num_tasks_succeeded > 0:
info_strings.append("{}{} task{} succeeded{}".format(colorama.Fore.BLUE, num_tasks_succeeded, "s" if num_tasks_succeeded > 1 else "", colorama.Fore.RESET))
@@ -493,7 +504,7 @@ def print_task_info(task_data, mode):
def scheduler_info(worker=global_worker):
"""Return information about the state of the scheduler."""
check_connected(worker)
return ray.lib.scheduler_info(worker.handle)
return raylib.scheduler_info(worker.handle)
def visualize_computation_graph(file_path=None, view=False, worker=global_worker):
"""Write the computation graph to a pdf file.
@@ -516,17 +527,17 @@ def visualize_computation_graph(file_path=None, view=False, worker=global_worker
"""
check_connected(worker)
if file_path is None:
file_path = ray.config.get_log_file_path("computation-graph.pdf")
file_path = config.get_log_file_path("computation-graph.pdf")
base_path, extension = os.path.splitext(file_path)
if extension != ".pdf":
raise Exception("File path must be a .pdf file")
proto_path = base_path + ".binaryproto"
ray.lib.dump_computation_graph(worker.handle, proto_path)
graph = ray.internal.graph_pb2.CompGraph()
graph.ParseFromString(open(proto_path).read())
ray.graph.graph_to_graphviz(graph).render(base_path, view=view)
raylib.dump_computation_graph(worker.handle, proto_path)
g = internal.graph_pb2.CompGraph()
g.ParseFromString(open(proto_path).read())
graph.graph_to_graphviz(g).render(base_path, view=view)
print "Wrote graph dot description to file {}".format(base_path)
print "Wrote graph protocol buffer description to file {}".format(proto_path)
@@ -535,7 +546,7 @@ def visualize_computation_graph(file_path=None, view=False, worker=global_worker
def task_info(worker=global_worker):
"""Return information about failed tasks."""
check_connected(worker)
return ray.lib.task_info(worker.handle)
return raylib.task_info(worker.handle)
def register_module(module, worker=global_worker):
"""Register each remote function in a module with the scheduler.
@@ -554,7 +565,7 @@ def register_module(module, worker=global_worker):
_logger().info("registering {}.".format(val.func_name))
worker.register_function(val)
def init(start_ray_local=False, num_workers=None, scheduler_address=None, objstore_address=None, driver_address=None, driver_mode=ray.SCRIPT_MODE):
def init(start_ray_local=False, num_workers=None, num_objstores=1, scheduler_address=None, objstore_address=None, driver_address=None, driver_mode=SCRIPT_MODE):
"""Either connect to an existing Ray cluster or start one and connect to it.
This method handles two cases. Either a Ray cluster already exists and we
@@ -567,6 +578,8 @@ def init(start_ray_local=False, num_workers=None, scheduler_address=None, objsto
existing Ray cluster.
num_workers (Optional[int]): The number of workers to start if
start_ray_local is True.
num_objstores (Optional[int]): The number of object stores to start if
start_ray_local is True.
scheduler_address (Optional[str]): The address of the scheduler to connect
to if start_ray_local is False.
objstore_address (Optional[str]): The address of the object store to connect
@@ -574,8 +587,7 @@ def init(start_ray_local=False, num_workers=None, scheduler_address=None, objsto
driver_address (Optional[str]): The address of this driver if
start_ray_local is False.
driver_mode (Optional[bool]): The mode in which to start the driver. This
should be one of ray.SCRIPT_MODE, ray.SHELL_MODE, ray.PYTHON_MODE, and
ray.SILENT_MODE.
should be one of SCRIPT_MODE, SHELL_MODE, PYTHON_MODE, and SILENT_MODE.
raises:
Exception: An exception is raised if an inappropriate combination of
@@ -586,17 +598,41 @@ def init(start_ray_local=False, num_workers=None, scheduler_address=None, objsto
# and we connect to them.
if (scheduler_address is not None) or (objstore_address is not None) or (driver_address is not None):
raise Exception("If start_ray_local=True, then you cannot pass in a scheduler_address, objstore_address, or worker_address.")
if driver_mode not in [ray.SCRIPT_MODE, ray.SHELL_MODE, ray.PYTHON_MODE, ray.SILENT_MODE]:
raise Exception("If start_ray_local=True, then driver_mode must be in [ray.SCRIPT_MODE, ray.SHELL_MODE, ray.PYTHON_MODE, ray.SILENT_MODE].")
if driver_mode not in [SCRIPT_MODE, SHELL_MODE, PYTHON_MODE, SILENT_MODE]:
raise Exception("If start_ray_local=True, then driver_mode must be in [SCRIPT_MODE, SHELL_MODE, PYTHON_MODE, SILENT_MODE].")
num_workers = 1 if num_workers is None else num_workers
ray.services.start_ray_local(num_objstores=1, num_workers_per_objstore=num_workers, worker_path=None, driver_mode=driver_mode)
# Start the scheduler, object store, and some workers. These will be killed
# by the call to cleanup(), which happens when the Python script exits.
scheduler_address, objstore_addresses, driver_addresses = services.start_ray_local(num_objstores=num_objstores, num_workers=num_workers, worker_path=None)
# It is possible for start_ray_local to return multiple object stores, but
# we will only connect the driver to one of them.
objstore_address = objstore_addresses[0]
driver_address = driver_addresses[0]
else:
# In this case, connect to an existing scheduler and object store.
if num_workers is not None:
raise Exception("The argument num_workers must not be provided unless start_ray_local=True.")
connect(scheduler_address, objstore_address, driver_address, is_driver=True, worker=global_worker, mode=driver_mode)
# In this case, there is an existing scheduler and object store, and we do
# not need to start any processes.
if (num_workers is not None) or (num_objstores is not None):
raise Exception("The arguments num_workers and num_objstores must not be provided unless start_ray_local=True.")
# Connect this driver to the scheduler and object store. The corresponing call
# to disconnect will happen in the call to cleanup() when the Python script
# exits.
connect(scheduler_address, objstore_address, driver_address, is_driver=True, worker=global_worker, mode=driver_mode)
def connect(scheduler_address, objstore_address, worker_address, is_driver=False, worker=global_worker, mode=ray.WORKER_MODE):
def cleanup(worker=global_worker):
"""Disconnect the driver, and terminate any processes started in init.
This will automatically run at the end when a Python process that uses Ray
exits. It is ok to run this twice in a row. Note that we manually call
services.cleanup() in the tests because we need to start and stop many
clusters in the tests, but the import and exit only happen once.
"""
disconnect()
worker.set_mode(None)
services.cleanup()
atexit.register(cleanup)
def connect(scheduler_address, objstore_address, worker_address, is_driver=False, worker=global_worker, mode=WORKER_MODE):
"""Connect this worker to the scheduler and an object store.
Args:
@@ -605,31 +641,31 @@ def connect(scheduler_address, objstore_address, worker_address, is_driver=False
worker_address (str): The ip address and port of this worker. The port can
be chosen arbitrarily.
is_driver (bool): True if this worker is a driver and false otherwise.
mode: The mode of the worker. One of ray.SCRIPT_MODE, ray.WORKER_MODE,
ray.SHELL_MODE, ray.PYTHON_MODE, and ray.SILENT_MODE.
mode: The mode of the worker. One of SCRIPT_MODE, WORKER_MODE, SHELL_MODE,
PYTHON_MODE, and SILENT_MODE.
"""
if hasattr(worker, "handle"):
del worker.handle
worker.scheduler_address = scheduler_address
worker.objstore_address = objstore_address
worker.worker_address = worker_address
worker.handle = ray.lib.create_worker(worker.scheduler_address, worker.objstore_address, worker.worker_address, is_driver)
worker.handle = raylib.create_worker(worker.scheduler_address, worker.objstore_address, worker.worker_address, is_driver)
worker.set_mode(mode)
FORMAT = "%(asctime)-15s %(message)s"
# Configure the Python logging module. Note that if we do not provide our own
# logger, then our logging will interfere with other Python modules that also
# use the logging module.
log_handler = logging.FileHandler(ray.config.get_log_file_path("-".join(["worker", worker_address]) + ".log"))
log_handler = logging.FileHandler(config.get_log_file_path("-".join(["worker", worker_address]) + ".log"))
log_handler.setLevel(logging.DEBUG)
log_handler.setFormatter(logging.Formatter(FORMAT))
_logger().addHandler(log_handler)
_logger().setLevel(logging.DEBUG)
_logger().propagate = False
# Configure the logging from the worker C++ code.
ray.lib.set_log_config(ray.config.get_log_file_path("-".join(["worker", worker_address, "c++"]) + ".log"))
if mode in [ray.SHELL_MODE, ray.SCRIPT_MODE, ray.SILENT_MODE]:
raylib.set_log_config(config.get_log_file_path("-".join(["worker", worker_address, "c++"]) + ".log"))
if mode in [SHELL_MODE, SCRIPT_MODE, SILENT_MODE]:
for function_to_export in worker.cached_remote_functions:
ray.lib.export_function(worker.handle, function_to_export)
raylib.export_function(worker.handle, function_to_export)
for name, reusable_variable in reusables._cached_reusables:
_export_reusable_variable(name, reusable_variable)
worker.cached_remote_functions = None
@@ -638,7 +674,7 @@ def connect(scheduler_address, objstore_address, worker_address, is_driver=False
def disconnect(worker=global_worker):
"""Disconnect this worker from the scheduler and object store."""
if worker.handle is not None:
ray.lib.disconnect(worker.handle)
raylib.disconnect(worker.handle)
# Reset the list of cached remote functions so that if more remote functions
# are defined and then connect is called again, the remote functions will be
# exported. This is mostly relevant for the tests.
@@ -655,17 +691,17 @@ def get(objectid, worker=global_worker):
created).
Args:
objectid (ray.ObjectID): Object ID to the object to get.
objectid (raylib.ObjectID): Object ID to the object to get.
Returns:
A Python object
"""
check_connected(worker)
if worker.mode == ray.PYTHON_MODE:
return objectid # In ray.PYTHON_MODE, ray.get is the identity operation (the input will actually be a value not an objectid)
ray.lib.request_object(worker.handle, objectid)
if worker.mode in [ray.SHELL_MODE, ray.SCRIPT_MODE]:
print_task_info(ray.lib.task_info(worker.handle), worker.mode)
if worker.mode == PYTHON_MODE:
return objectid # In PYTHON_MODE, ray.get is the identity operation (the input will actually be a value not an objectid)
raylib.request_object(worker.handle, objectid)
if worker.mode in [SHELL_MODE, SCRIPT_MODE]:
print_task_info(raylib.task_info(worker.handle), worker.mode)
value = worker.get_object(objectid)
if isinstance(value, RayFailedObject):
raise Exception("The task that created this object ID failed with error message:\n{}".format(value.error_message))
@@ -681,12 +717,12 @@ def put(value, worker=global_worker):
The object ID assigned to this value.
"""
check_connected(worker)
if worker.mode == ray.PYTHON_MODE:
return value # In ray.PYTHON_MODE, ray.put is the identity operation
objectid = ray.lib.get_objectid(worker.handle)
if worker.mode == PYTHON_MODE:
return value # In PYTHON_MODE, ray.put is the identity operation
objectid = raylib.get_objectid(worker.handle)
worker.put_object(objectid, value)
if worker.mode in [ray.SHELL_MODE, ray.SCRIPT_MODE]:
print_task_info(ray.lib.task_info(worker.handle), worker.mode)
if worker.mode in [SHELL_MODE, SCRIPT_MODE]:
print_task_info(raylib.task_info(worker.handle), worker.mode)
return objectid
def kill_workers(worker=global_worker):
@@ -700,7 +736,7 @@ def kill_workers(worker=global_worker):
Returns:
True if workers were successfully killed. False otherwise.
"""
success = ray.lib.kill_workers(worker.handle)
success = raylib.kill_workers(worker.handle)
if not success:
print "Could not kill all workers. We currently do not support killing workers when tasks are running."
return success
@@ -760,9 +796,9 @@ def main_loop(worker=global_worker):
error messages in the object store in place of the actual outputs. These
objects are used to propagate the error messages.
"""
if not ray.lib.connected(worker.handle):
if not raylib.connected(worker.handle):
raise Exception("Worker is attempting to enter main_loop but has not been connected yet.")
ray.lib.start_worker_service(worker.handle)
raylib.start_worker_service(worker.handle)
def process_task(task): # wrapping these lines in a function should cause the local variables to go out of scope more quickly, which is useful for inspecting reference counts
func_name, args, return_objectids = serialization.deserialize_task(worker.handle, task)
try:
@@ -776,17 +812,17 @@ def main_loop(worker=global_worker):
# failure (this is only interpreted by the worker).
failure_objects = [RayFailedObject(exception_message) for _ in range(len(return_objectids))]
store_outputs_in_objstore(return_objectids, failure_objects, worker)
ray.lib.notify_task_completed(worker.handle, False, exception_message) # notify the scheduler that the task threw an exception
raylib.notify_task_completed(worker.handle, False, exception_message) # notify the scheduler that the task threw an exception
_logger().info("Worker threw exception with message: \n\n{}\n, while running function {}.".format(exception_message, func_name))
else:
store_outputs_in_objstore(return_objectids, outputs, worker) # store output in local object store
ray.lib.notify_task_completed(worker.handle, True, "") # notify the scheduler that the task completed successfully
raylib.notify_task_completed(worker.handle, True, "") # notify the scheduler that the task completed successfully
finally:
# Reinitialize the values of reusable variables that were used in the task
# above so that changes made to their state do not affect other tasks.
ray.reusables._reinitialize()
reusables._reinitialize()
while True:
command, command_args = ray.lib.wait_for_next_message(worker.handle)
command, command_args = raylib.wait_for_next_message(worker.handle)
try:
if command == "die":
# We use this as a mechanism to allow the scheduler to kill workers.
@@ -846,9 +882,9 @@ def _export_reusable_variable(name, reusable, worker=global_worker):
reusable (Reusable): The reusable object containing code for initializing
and reinitializing the variable.
"""
if _mode(worker) not in [ray.SHELL_MODE, ray.SCRIPT_MODE, ray.SILENT_MODE]:
if _mode(worker) not in [SHELL_MODE, SCRIPT_MODE, SILENT_MODE]:
raise Exception("_export_reusable_variable can only be called on a driver.")
ray.lib.export_reusable_variable(worker.handle, name, pickling.dumps(reusable.initializer), pickling.dumps(reusable.reinitializer))
raylib.export_reusable_variable(worker.handle, name, pickling.dumps(reusable.initializer), pickling.dumps(reusable.reinitializer))
def remote(arg_types, return_types, worker=global_worker):
"""This decorator is used to create remote functions.
@@ -863,10 +899,10 @@ def remote(arg_types, return_types, worker=global_worker):
check_connected()
args = list(args)
args.extend([kwargs[keyword] if kwargs.has_key(keyword) else default for keyword, default in keyword_defaults[len(args):]]) # fill in the remaining arguments
if _mode() == ray.PYTHON_MODE:
# In ray.PYTHON_MODE, remote calls simply execute the function. We copy
# the arguments to prevent the function call from mutating them and to
# match the usual behavior of immutable remote objects.
if _mode() == PYTHON_MODE:
# In PYTHON_MODE, remote calls simply execute the function. We copy the
# arguments to prevent the function call from mutating them and to match
# the usual behavior of immutable remote objects.
return func(*copy.deepcopy(args))
check_arguments(arg_types, has_vararg_param, func_name, args) # throws an exception if args are invalid
objectids = _submit_task(func_name, args)
@@ -902,7 +938,7 @@ def remote(arg_types, return_types, worker=global_worker):
check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaults, func_name)
# Everything ready - export the function
if worker.mode in [None, ray.SHELL_MODE, ray.SCRIPT_MODE, ray.SILENT_MODE]:
if worker.mode in [None, SHELL_MODE, SCRIPT_MODE, SILENT_MODE]:
func_name_global_valid = func.__name__ in func.__globals__
func_name_global_value = func.__globals__.get(func.__name__)
# Set the function globally to make it refer to itself
@@ -913,8 +949,8 @@ def remote(arg_types, return_types, worker=global_worker):
# Undo our changes
if func_name_global_valid: func.__globals__[func.__name__] = func_name_global_value
else: del func.__globals__[func.__name__]
if worker.mode in [ray.SHELL_MODE, ray.SCRIPT_MODE, ray.SILENT_MODE]:
ray.lib.export_function(worker.handle, to_export)
if worker.mode in [SHELL_MODE, SCRIPT_MODE, SILENT_MODE]:
raylib.export_function(worker.handle, to_export)
elif worker.mode is None:
worker.cached_remote_functions.append(to_export)
return func_invoker
@@ -978,7 +1014,7 @@ def check_return_values(function, result):
# Here we do some limited type checking to make sure the return values have
# the right types.
for i in range(len(result)):
if (not issubclass(type(result[i]), function.return_types[i])) and (not isinstance(result[i], ray.lib.ObjectID)):
if (not issubclass(type(result[i]), function.return_types[i])) and (not isinstance(result[i], raylib.ObjectID)):
raise Exception("The {}th return value for function {} has type {}, but the @remote decorator expected a return value of type {} or an ObjectID.".format(i, function.__name__, type(result[i]), function.return_types[i]))
def typecheck_arg(arg, expected_type, i, name):
@@ -1034,7 +1070,7 @@ def check_arguments(arg_types, has_vararg_param, name, args):
else:
assert False, "This code should be unreachable."
if isinstance(arg, ray.ObjectID):
if isinstance(arg, raylib.ObjectID):
# TODO(rkn): When we have type information in the ObjectID, do type checking here.
pass
else:
@@ -1076,7 +1112,7 @@ def get_arguments_for_execution(function, args, worker=global_worker):
else:
assert False, "This code should be unreachable."
if isinstance(arg, ray.ObjectID):
if isinstance(arg, raylib.ObjectID):
# get the object from the local object store
_logger().info("Getting argument {} for function {}.".format(i, function.__name__))
argument = worker.get_object(arg)
@@ -1101,7 +1137,7 @@ def store_outputs_in_objstore(objectids, outputs, worker=global_worker):
The arguments objectids and outputs should have the same length.
Args:
objectids (List[ray.ObjectID]): The object IDs that were assigned to the
objectids (List[raylib.ObjectID]): The object IDs that were assigned to the
outputs of the remote function call.
outputs (Tuple): The value returned by the remote function. If the remote
function was supposed to only return one value, then its output was
@@ -1109,7 +1145,7 @@ def store_outputs_in_objstore(objectids, outputs, worker=global_worker):
function.
"""
for i in range(len(objectids)):
if isinstance(outputs[i], ray.ObjectID):
if isinstance(outputs[i], raylib.ObjectID):
# An ObjectID is being returned, so we must alias objectids[i] so that it refers to the same object that outputs[i] refers to
_logger().info("Aliasing objectids {} and {}".format(objectids[i].id, outputs[i].id))
worker.alias_objectids(objectids[i], outputs[i])