mirror of
https://github.com/wassname/ray.git
synced 2026-07-05 08:31:18 +08:00
API cleanups. Remove worker argument. Remove some deprecated arguments. (#4025)
* Remove worker argument from API methods. * Remove deprecated arguments and deprecate redirect_output and redirect_worker_output. * Fix
This commit is contained in:
committed by
Philipp Moritz
parent
042ad84573
commit
5f71751891
@@ -125,12 +125,6 @@ class ActorMethod(object):
|
||||
def remote(self, *args, **kwargs):
|
||||
return self._remote(args, kwargs)
|
||||
|
||||
def _submit(self, args, kwargs, num_return_vals=None):
|
||||
logger.warning(
|
||||
"WARNING: _submit() is being deprecated. Please use _remote().")
|
||||
return self._remote(
|
||||
args=args, kwargs=kwargs, num_return_vals=num_return_vals)
|
||||
|
||||
def _remote(self, args, kwargs, num_return_vals=None):
|
||||
if num_return_vals is None:
|
||||
num_return_vals = self._num_return_vals
|
||||
@@ -238,21 +232,6 @@ class ActorClass(object):
|
||||
"""
|
||||
return self._remote(args=args, kwargs=kwargs)
|
||||
|
||||
def _submit(self,
|
||||
args,
|
||||
kwargs,
|
||||
num_cpus=None,
|
||||
num_gpus=None,
|
||||
resources=None):
|
||||
logger.warning(
|
||||
"WARNING: _submit() is being deprecated. Please use _remote().")
|
||||
return self._remote(
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
num_cpus=num_cpus,
|
||||
num_gpus=num_gpus,
|
||||
resources=resources)
|
||||
|
||||
def _remote(self,
|
||||
args,
|
||||
kwargs,
|
||||
|
||||
@@ -6,7 +6,7 @@ import ray
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get(object_ids, worker=None):
|
||||
def get(object_ids):
|
||||
"""Get a single or a collection of remote objects from the object store.
|
||||
|
||||
This method is identical to `ray.get` except it adds support for tuples,
|
||||
@@ -19,11 +19,8 @@ def get(object_ids, worker=None):
|
||||
Returns:
|
||||
A Python object, a list of Python objects or a dict of {key: object}.
|
||||
"""
|
||||
# There is a dependency on ray.worker which prevents importing
|
||||
# global_worker at the top of this file
|
||||
worker = ray.worker.global_worker if worker is None else worker
|
||||
if isinstance(object_ids, (tuple, np.ndarray)):
|
||||
return ray.get(list(object_ids), worker)
|
||||
return ray.get(list(object_ids))
|
||||
elif isinstance(object_ids, dict):
|
||||
keys_to_get = [
|
||||
k for k, v in object_ids.items() if isinstance(v, ray.ObjectID)
|
||||
@@ -38,10 +35,10 @@ def get(object_ids, worker=None):
|
||||
result[key] = value
|
||||
return result
|
||||
else:
|
||||
return ray.get(object_ids, worker)
|
||||
return ray.get(object_ids)
|
||||
|
||||
|
||||
def wait(object_ids, num_returns=1, timeout=None, worker=None):
|
||||
def wait(object_ids, num_returns=1, timeout=None):
|
||||
"""Return a list of IDs that are ready and a list of IDs that are not.
|
||||
|
||||
This method is identical to `ray.wait` except it adds support for tuples
|
||||
@@ -59,13 +56,8 @@ def wait(object_ids, num_returns=1, timeout=None, worker=None):
|
||||
A list of object IDs that are ready and a list of the remaining object
|
||||
IDs.
|
||||
"""
|
||||
worker = ray.worker.global_worker if worker is None else worker
|
||||
if isinstance(object_ids, (tuple, np.ndarray)):
|
||||
return ray.wait(
|
||||
list(object_ids),
|
||||
num_returns=num_returns,
|
||||
timeout=timeout,
|
||||
worker=worker)
|
||||
list(object_ids), num_returns=num_returns, timeout=timeout)
|
||||
|
||||
return ray.wait(
|
||||
object_ids, num_returns=num_returns, timeout=timeout, worker=worker)
|
||||
return ray.wait(object_ids, num_returns=num_returns, timeout=timeout)
|
||||
|
||||
@@ -441,7 +441,7 @@ class FunctionActorManager(object):
|
||||
# we spend too long in this loop.
|
||||
# The driver function may not be found in sys.path. Try to load
|
||||
# the function from GCS.
|
||||
with profiling.profile("wait_for_function", worker=self._worker):
|
||||
with profiling.profile("wait_for_function"):
|
||||
self._wait_for_function(function_descriptor, driver_id)
|
||||
try:
|
||||
info = self._function_execution_info[driver_id][function_id]
|
||||
|
||||
@@ -91,21 +91,18 @@ class ImportThread(object):
|
||||
# Handle the driver case first.
|
||||
if self.mode != ray.WORKER_MODE:
|
||||
if key.startswith(b"FunctionsToRun"):
|
||||
with profiling.profile(
|
||||
"fetch_and_run_function", worker=self.worker):
|
||||
with profiling.profile("fetch_and_run_function"):
|
||||
self.fetch_and_execute_function_to_run(key)
|
||||
# Return because FunctionsToRun are the only things that
|
||||
# the driver should import.
|
||||
return
|
||||
|
||||
if key.startswith(b"RemoteFunction"):
|
||||
with profiling.profile(
|
||||
"register_remote_function", worker=self.worker):
|
||||
with profiling.profile("register_remote_function"):
|
||||
(self.worker.function_actor_manager.
|
||||
fetch_and_register_remote_function(key))
|
||||
elif key.startswith(b"FunctionsToRun"):
|
||||
with profiling.profile(
|
||||
"fetch_and_run_function", worker=self.worker):
|
||||
with profiling.profile("fetch_and_run_function"):
|
||||
self.fetch_and_execute_function_to_run(key)
|
||||
elif key.startswith(b"ActorClass"):
|
||||
# Keep track of the fact that this actor class has been
|
||||
|
||||
@@ -8,7 +8,7 @@ from ray import profiling
|
||||
__all__ = ["free"]
|
||||
|
||||
|
||||
def free(object_ids, local_only=False, worker=None):
|
||||
def free(object_ids, local_only=False):
|
||||
"""Free a list of IDs from object stores.
|
||||
|
||||
This function is a low-level API which should be used in restricted
|
||||
@@ -26,8 +26,7 @@ def free(object_ids, local_only=False, worker=None):
|
||||
local_only (bool): Whether only deleting the list of objects in local
|
||||
object store or all object stores.
|
||||
"""
|
||||
if worker is None:
|
||||
worker = ray.worker.get_global_worker()
|
||||
worker = ray.worker.get_global_worker()
|
||||
|
||||
if isinstance(object_ids, ray.ObjectID):
|
||||
object_ids = [object_ids]
|
||||
@@ -37,7 +36,7 @@ def free(object_ids, local_only=False, worker=None):
|
||||
type(object_ids)))
|
||||
|
||||
worker.check_connected()
|
||||
with profiling.profile("ray.free", worker=worker):
|
||||
with profiling.profile("ray.free"):
|
||||
if len(object_ids) == 0:
|
||||
return
|
||||
|
||||
|
||||
+4
-4
@@ -197,7 +197,7 @@ class Node(object):
|
||||
raise FileExistsError(errno.EEXIST,
|
||||
"No usable temporary filename found")
|
||||
|
||||
def new_log_files(self, name, redirect_output=None):
|
||||
def new_log_files(self, name, redirect_output=True):
|
||||
"""Generate partially randomized filenames for log files.
|
||||
|
||||
Args:
|
||||
@@ -262,7 +262,7 @@ class Node(object):
|
||||
redis_shard_ports=self._ray_params.redis_shard_ports,
|
||||
num_redis_shards=self._ray_params.num_redis_shards,
|
||||
redis_max_clients=self._ray_params.redis_max_clients,
|
||||
redirect_worker_output=self._ray_params.redirect_worker_output,
|
||||
redirect_worker_output=True,
|
||||
password=self._ray_params.redis_password,
|
||||
redis_max_memory=self._ray_params.redis_max_memory)
|
||||
assert (
|
||||
@@ -272,7 +272,7 @@ class Node(object):
|
||||
|
||||
def start_log_monitor(self):
|
||||
"""Start the log monitor."""
|
||||
stdout_file, stderr_file = self.new_log_files("log_monitor", True)
|
||||
stdout_file, stderr_file = self.new_log_files("log_monitor")
|
||||
process_info = ray.services.start_log_monitor(
|
||||
self.redis_address,
|
||||
self._logs_dir,
|
||||
@@ -286,7 +286,7 @@ class Node(object):
|
||||
|
||||
def start_ui(self):
|
||||
"""Start the web UI."""
|
||||
stdout_file, stderr_file = self.new_log_files("webui", True)
|
||||
stdout_file, stderr_file = self.new_log_files("webui")
|
||||
notebook_name = self._make_inc_temp(
|
||||
suffix=".ipynb", prefix="ray_ui", directory_name=self._temp_dir)
|
||||
self._webui_url, process_info = ray.services.start_ui(
|
||||
|
||||
+11
-8
@@ -90,11 +90,10 @@ class RayParams(object):
|
||||
node_manager_port=None,
|
||||
node_ip_address=None,
|
||||
object_id_seed=None,
|
||||
num_workers=None,
|
||||
local_mode=False,
|
||||
driver_mode=None,
|
||||
redirect_worker_output=True,
|
||||
redirect_output=True,
|
||||
redirect_worker_output=None,
|
||||
redirect_output=None,
|
||||
num_redis_shards=None,
|
||||
redis_max_clients=None,
|
||||
redis_password=None,
|
||||
@@ -124,7 +123,6 @@ class RayParams(object):
|
||||
self.object_manager_port = object_manager_port
|
||||
self.node_manager_port = node_manager_port
|
||||
self.node_ip_address = node_ip_address
|
||||
self.num_workers = num_workers
|
||||
self.local_mode = local_mode
|
||||
self.driver_mode = driver_mode
|
||||
self.redirect_worker_output = redirect_worker_output
|
||||
@@ -186,10 +184,15 @@ class RayParams(object):
|
||||
"'GPU' should not be included in the resource dictionary. Use "
|
||||
"num_gpus instead.")
|
||||
|
||||
if self.num_workers is not None:
|
||||
raise ValueError(
|
||||
"The 'num_workers' argument is deprecated. Please use "
|
||||
"'num_cpus' instead.")
|
||||
if self.redirect_worker_output is not None:
|
||||
raise DeprecationWarning(
|
||||
"The redirect_worker_output argument is deprecated. To "
|
||||
"control logging to the driver, use the 'log_to_driver' "
|
||||
"argument to 'ray.init()'")
|
||||
|
||||
if self.redirect_output is not None:
|
||||
raise DeprecationWarning(
|
||||
"The redirect_output argument is deprecated.")
|
||||
|
||||
if self.include_java is None and self.java_worker_options is not None:
|
||||
raise ValueError("Should not specify `java-worker-options` "
|
||||
|
||||
@@ -27,7 +27,7 @@ class _NullLogSpan(object):
|
||||
NULL_LOG_SPAN = _NullLogSpan()
|
||||
|
||||
|
||||
def profile(event_type, extra_data=None, worker=None):
|
||||
def profile(event_type, extra_data=None):
|
||||
"""Profile a span of time so that it appears in the timeline visualization.
|
||||
|
||||
Note that this only works in the raylet code path.
|
||||
@@ -57,8 +57,7 @@ def profile(event_type, extra_data=None, worker=None):
|
||||
Returns:
|
||||
An object that can profile a span of time via a "with" statement.
|
||||
"""
|
||||
if worker is None:
|
||||
worker = ray.worker.global_worker
|
||||
worker = ray.worker.global_worker
|
||||
return RayLogSpanRaylet(worker.profiler, event_type, extra_data=extra_data)
|
||||
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ env_config.update({
|
||||
})
|
||||
register_carla_model()
|
||||
|
||||
ray.init(redirect_output=True)
|
||||
ray.init()
|
||||
run_experiments({
|
||||
"carla": {
|
||||
"run": "PPO",
|
||||
|
||||
@@ -123,15 +123,6 @@ def cli(logging_level, logging_format):
|
||||
"limit is exceeded, redis will start LRU eviction of entries. This only "
|
||||
"applies to the sharded redis tables (task, object, and profile tables). "
|
||||
"By default this is capped at 10GB but can be set higher.")
|
||||
@click.option(
|
||||
"--num-workers",
|
||||
required=False,
|
||||
type=int,
|
||||
help=("The initial number of workers to start on this node, "
|
||||
"note that the local scheduler may start additional "
|
||||
"workers. If you wish to control the total number of "
|
||||
"concurent tasks, then use --resources instead and "
|
||||
"specify the CPU field."))
|
||||
@click.option(
|
||||
"--num-cpus",
|
||||
required=False,
|
||||
@@ -220,8 +211,8 @@ def cli(logging_level, logging_format):
|
||||
def start(node_ip_address, redis_address, redis_port, num_redis_shards,
|
||||
redis_max_clients, redis_password, redis_shard_ports,
|
||||
object_manager_port, node_manager_port, object_store_memory,
|
||||
redis_max_memory, num_workers, num_cpus, num_gpus, resources, head,
|
||||
no_ui, block, plasma_directory, huge_pages, autoscaling_config,
|
||||
redis_max_memory, num_cpus, num_gpus, resources, head, no_ui, block,
|
||||
plasma_directory, huge_pages, autoscaling_config,
|
||||
no_redirect_worker_output, no_redirect_output,
|
||||
plasma_store_socket_name, raylet_socket_name, temp_dir, include_java,
|
||||
java_worker_options, internal_config):
|
||||
@@ -239,15 +230,16 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards,
|
||||
" --resources='{\"CustomResource1\": 3, "
|
||||
"\"CustomReseource2\": 2}'")
|
||||
|
||||
redirect_worker_output = None if not no_redirect_worker_output else True
|
||||
redirect_output = None if not no_redirect_output else True
|
||||
ray_params = ray.parameter.RayParams(
|
||||
node_ip_address=node_ip_address,
|
||||
object_manager_port=object_manager_port,
|
||||
node_manager_port=node_manager_port,
|
||||
num_workers=num_workers,
|
||||
object_store_memory=object_store_memory,
|
||||
redis_password=redis_password,
|
||||
redirect_worker_output=not no_redirect_worker_output,
|
||||
redirect_output=not no_redirect_output,
|
||||
redirect_worker_output=redirect_worker_output,
|
||||
redirect_output=redirect_output,
|
||||
num_cpus=num_cpus,
|
||||
num_gpus=num_gpus,
|
||||
resources=resources,
|
||||
|
||||
@@ -30,7 +30,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument(
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||
args, _ = parser.parse_known_args()
|
||||
ray.init(redirect_output=True)
|
||||
ray.init()
|
||||
|
||||
register_trainable("exp", easy_objective)
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument(
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||
args, _ = parser.parse_known_args()
|
||||
ray.init(redirect_output=True)
|
||||
ray.init()
|
||||
|
||||
register_trainable("exp", michalewicz_function)
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument(
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||
args, _ = parser.parse_known_args()
|
||||
ray.init(redirect_output=True)
|
||||
ray.init()
|
||||
|
||||
register_trainable("exp", easy_objective)
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument(
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||
args, _ = parser.parse_known_args()
|
||||
ray.init(redirect_output=True)
|
||||
ray.init()
|
||||
|
||||
register_trainable("exp", easy_objective)
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument(
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||
args, _ = parser.parse_known_args()
|
||||
ray.init(redirect_output=True)
|
||||
ray.init()
|
||||
|
||||
register_trainable("exp", easy_objective)
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument(
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||
args, _ = parser.parse_known_args()
|
||||
ray.init(redirect_output=True)
|
||||
ray.init()
|
||||
|
||||
register_trainable("exp", easy_objective)
|
||||
|
||||
|
||||
+28
-52
@@ -579,7 +579,7 @@ class Worker(object):
|
||||
Returns:
|
||||
The return object IDs for this task.
|
||||
"""
|
||||
with profiling.profile("submit_task", worker=self):
|
||||
with profiling.profile("submit_task"):
|
||||
if actor_id is None:
|
||||
assert actor_handle_id is None
|
||||
actor_id = ActorID.nil()
|
||||
@@ -828,7 +828,7 @@ class Worker(object):
|
||||
if function_name != "__ray_terminate__":
|
||||
self.reraise_actor_init_error()
|
||||
self.memory_monitor.raise_if_low_memory()
|
||||
with profiling.profile("task:deserialize_arguments", worker=self):
|
||||
with profiling.profile("task:deserialize_arguments"):
|
||||
arguments = self._get_arguments_for_execution(
|
||||
function_name, args)
|
||||
except RayTaskError as e:
|
||||
@@ -844,7 +844,7 @@ class Worker(object):
|
||||
|
||||
# Execute the task.
|
||||
try:
|
||||
with profiling.profile("task:execute", worker=self):
|
||||
with profiling.profile("task:execute"):
|
||||
if (task.actor_id().is_nil()
|
||||
and task.actor_creation_id().is_nil()):
|
||||
outputs = function_executor(*arguments)
|
||||
@@ -867,7 +867,7 @@ class Worker(object):
|
||||
|
||||
# Store the outputs in the local object store.
|
||||
try:
|
||||
with profiling.profile("task:store_outputs", worker=self):
|
||||
with profiling.profile("task:store_outputs"):
|
||||
# If this is an actor task, then the last object ID returned by
|
||||
# the task is a dummy output, not returned by the function
|
||||
# itself. Decrement to get the correct number of return values.
|
||||
@@ -952,7 +952,7 @@ class Worker(object):
|
||||
title = "ray_{}:{}()".format(actor.__class__.__name__,
|
||||
function_name)
|
||||
next_title = "ray_{}".format(actor.__class__.__name__)
|
||||
with profiling.profile("task", extra_data=extra_data, worker=self):
|
||||
with profiling.profile("task", extra_data=extra_data):
|
||||
with _changeproctitle(title, next_title):
|
||||
self._process_task(task, execution_info)
|
||||
# Reset the state fields so the next task can run.
|
||||
@@ -981,7 +981,7 @@ class Worker(object):
|
||||
Returns:
|
||||
A task from the local scheduler.
|
||||
"""
|
||||
with profiling.profile("worker_idle", worker=self):
|
||||
with profiling.profile("worker_idle"):
|
||||
task = self.raylet_client.get_task()
|
||||
|
||||
# Automatically restrict the GPUs available to this task.
|
||||
@@ -993,7 +993,7 @@ class Worker(object):
|
||||
"""The main loop a worker runs to receive and execute tasks."""
|
||||
|
||||
def exit(signum, frame):
|
||||
shutdown(worker=self)
|
||||
shutdown()
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGTERM, exit)
|
||||
@@ -1109,8 +1109,9 @@ def print_failed_task(task_status):
|
||||
task_status["error_message"]))
|
||||
|
||||
|
||||
def error_info(worker=global_worker):
|
||||
def error_info():
|
||||
"""Return information about failed tasks."""
|
||||
worker = global_worker
|
||||
worker.check_connected()
|
||||
return (global_state.error_messages(job_id=worker.task_driver_id) +
|
||||
global_state.error_messages(job_id=DriverID.nil()))
|
||||
@@ -1251,11 +1252,9 @@ def init(redis_address=None,
|
||||
log_to_driver=True,
|
||||
node_ip_address=None,
|
||||
object_id_seed=None,
|
||||
num_workers=None,
|
||||
local_mode=False,
|
||||
driver_mode=None,
|
||||
redirect_worker_output=True,
|
||||
redirect_output=True,
|
||||
redirect_worker_output=None,
|
||||
redirect_output=None,
|
||||
ignore_reinit_error=False,
|
||||
num_redis_shards=None,
|
||||
redis_max_clients=None,
|
||||
@@ -1270,8 +1269,7 @@ def init(redis_address=None,
|
||||
plasma_store_socket_name=None,
|
||||
raylet_socket_name=None,
|
||||
temp_dir=None,
|
||||
_internal_config=None,
|
||||
use_raylet=None):
|
||||
_internal_config=None):
|
||||
"""Connect to an existing Ray cluster or start one and connect to it.
|
||||
|
||||
This method handles two cases. Either a Ray cluster already exists and we
|
||||
@@ -1320,10 +1318,6 @@ def init(redis_address=None,
|
||||
manner. However, the same ID should not be used for different jobs.
|
||||
local_mode (bool): True if the code should be executed serially
|
||||
without Ray. This is useful for debugging.
|
||||
redirect_worker_output: True if the stdout and stderr of worker
|
||||
processes should be redirected to files.
|
||||
redirect_output (bool): True if stdout and stderr for non-worker
|
||||
processes should be redirected to files and false otherwise.
|
||||
ignore_reinit_error: True if we should suppress errors from calling
|
||||
ray.init() a second time.
|
||||
num_redis_shards: The number of Redis shards to start in addition to
|
||||
@@ -1364,18 +1358,6 @@ def init(redis_address=None,
|
||||
if configure_logging:
|
||||
setup_logger(logging_level, logging_format)
|
||||
|
||||
# Add the use_raylet option for backwards compatibility.
|
||||
if use_raylet is not None:
|
||||
if use_raylet:
|
||||
logger.warning("WARNING: The use_raylet argument has been "
|
||||
"deprecated. Please remove it.")
|
||||
else:
|
||||
raise DeprecationWarning("The use_raylet argument is deprecated. "
|
||||
"Please remove it.")
|
||||
|
||||
if driver_mode is not None:
|
||||
raise Exception("The 'driver_mode' argument has been deprecated. "
|
||||
"To run Ray in local mode, pass in local_mode=True.")
|
||||
if local_mode:
|
||||
driver_mode = LOCAL_MODE
|
||||
else:
|
||||
@@ -1424,7 +1406,6 @@ def init(redis_address=None,
|
||||
ray_params = ray.parameter.RayParams(
|
||||
redis_address=redis_address,
|
||||
node_ip_address=node_ip_address,
|
||||
num_workers=num_workers,
|
||||
object_id_seed=object_id_seed,
|
||||
local_mode=local_mode,
|
||||
driver_mode=driver_mode,
|
||||
@@ -1458,9 +1439,6 @@ def init(redis_address=None,
|
||||
address_info["raylet_socket_name"] = _global_node.raylet_socket_name
|
||||
else:
|
||||
# In this case, we are connecting to an existing cluster.
|
||||
if num_workers is not None:
|
||||
raise Exception("When connecting to an existing cluster, "
|
||||
"num_workers must not be provided.")
|
||||
if num_cpus is not None or num_gpus is not None:
|
||||
raise Exception("When connecting to an existing cluster, num_cpus "
|
||||
"and num_gpus must not be provided.")
|
||||
@@ -1548,13 +1526,7 @@ def init(redis_address=None,
|
||||
_post_init_hooks = []
|
||||
|
||||
|
||||
def cleanup(worker=global_worker):
|
||||
raise DeprecationWarning(
|
||||
"The function ray.worker.cleanup() has been deprecated. Instead, "
|
||||
"please call ray.shutdown().")
|
||||
|
||||
|
||||
def shutdown(worker=global_worker):
|
||||
def shutdown():
|
||||
"""Disconnect the worker, and terminate processes started by ray.init().
|
||||
|
||||
This will automatically run at the end when a Python process that uses Ray
|
||||
@@ -1567,7 +1539,7 @@ def shutdown(worker=global_worker):
|
||||
need to redefine them. If they were defined in an imported module, then you
|
||||
will need to reload the module.
|
||||
"""
|
||||
disconnect(worker)
|
||||
disconnect()
|
||||
|
||||
# Shut down the Ray processes.
|
||||
global _global_node
|
||||
@@ -1575,7 +1547,7 @@ def shutdown(worker=global_worker):
|
||||
_global_node.kill_all_processes(check_alive=False, allow_graceful=True)
|
||||
_global_node = None
|
||||
|
||||
worker.set_mode(None)
|
||||
global_worker.set_mode(None)
|
||||
|
||||
|
||||
atexit.register(shutdown)
|
||||
@@ -2037,12 +2009,13 @@ def connect(info,
|
||||
worker.cached_functions_to_run = None
|
||||
|
||||
|
||||
def disconnect(worker=global_worker):
|
||||
def disconnect():
|
||||
"""Disconnect this worker from the scheduler and object store."""
|
||||
# Reset the list of cached remote functions and actors so that if more
|
||||
# remote functions or actors are defined and then connect is called again,
|
||||
# the remote functions will be exported. This is mostly relevant for the
|
||||
# tests.
|
||||
worker = global_worker
|
||||
if worker.connected:
|
||||
# Shutdown all of the threads that we've started. TODO(rkn): This
|
||||
# should be handled cleanly in the worker object's destructor and not
|
||||
@@ -2129,8 +2102,7 @@ def register_custom_serializer(cls,
|
||||
deserializer=None,
|
||||
local=False,
|
||||
driver_id=None,
|
||||
class_id=None,
|
||||
worker=global_worker):
|
||||
class_id=None):
|
||||
"""Enable serialization and deserialization for a particular class.
|
||||
|
||||
This method runs the register_class function defined below on every worker,
|
||||
@@ -2159,6 +2131,7 @@ def register_custom_serializer(cls,
|
||||
be efficiently serialized by Ray. This can also raise an exception
|
||||
if use_dict is true and cls is not pickleable.
|
||||
"""
|
||||
worker = global_worker
|
||||
assert (serializer is None) == (deserializer is None), (
|
||||
"The serializer/deserializer arguments must both be provided or "
|
||||
"both not be provided.")
|
||||
@@ -2225,7 +2198,7 @@ def register_custom_serializer(cls,
|
||||
register_class_for_serialization({"worker": worker})
|
||||
|
||||
|
||||
def get(object_ids, worker=global_worker):
|
||||
def get(object_ids):
|
||||
"""Get a remote object or a list of remote objects from the object store.
|
||||
|
||||
This method blocks until the object corresponding to the object ID is
|
||||
@@ -2245,8 +2218,9 @@ def get(object_ids, worker=global_worker):
|
||||
Exception: An exception is raised if the task that created the object
|
||||
or that created one of the objects raised an exception.
|
||||
"""
|
||||
worker = global_worker
|
||||
worker.check_connected()
|
||||
with profiling.profile("ray.get", worker=worker):
|
||||
with profiling.profile("ray.get"):
|
||||
if worker.mode == LOCAL_MODE:
|
||||
# In LOCAL_MODE, ray.get is the identity operation (the input will
|
||||
# actually be a value not an objectid).
|
||||
@@ -2270,7 +2244,7 @@ def get(object_ids, worker=global_worker):
|
||||
return value
|
||||
|
||||
|
||||
def put(value, worker=global_worker):
|
||||
def put(value):
|
||||
"""Store an object in the object store.
|
||||
|
||||
Args:
|
||||
@@ -2279,8 +2253,9 @@ def put(value, worker=global_worker):
|
||||
Returns:
|
||||
The object ID assigned to this value.
|
||||
"""
|
||||
worker = global_worker
|
||||
worker.check_connected()
|
||||
with profiling.profile("ray.put", worker=worker):
|
||||
with profiling.profile("ray.put"):
|
||||
if worker.mode == LOCAL_MODE:
|
||||
# In LOCAL_MODE, ray.put is the identity operation.
|
||||
return value
|
||||
@@ -2293,7 +2268,7 @@ def put(value, worker=global_worker):
|
||||
return object_id
|
||||
|
||||
|
||||
def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
|
||||
def wait(object_ids, num_returns=1, timeout=None):
|
||||
"""Return a list of IDs that are ready and a list of IDs that are not.
|
||||
|
||||
.. warning::
|
||||
@@ -2327,6 +2302,7 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
|
||||
A list of object IDs that are ready and a list of the remaining object
|
||||
IDs.
|
||||
"""
|
||||
worker = global_worker
|
||||
|
||||
if isinstance(object_ids, ObjectID):
|
||||
raise TypeError(
|
||||
@@ -2356,7 +2332,7 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
|
||||
|
||||
worker.check_connected()
|
||||
# TODO(swang): Check main thread.
|
||||
with profiling.profile("ray.wait", worker=worker):
|
||||
with profiling.profile("ray.wait"):
|
||||
# When Ray is run in LOCAL_MODE, all functions are run immediately,
|
||||
# so all objects in object_id are ready.
|
||||
if worker.mode == LOCAL_MODE:
|
||||
|
||||
Reference in New Issue
Block a user