mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 13:54:27 +08:00
Lint Python files with Yapf (#1872)
This commit is contained in:
committed by
Robert Nishihara
parent
a3ddde398c
commit
74162d1492
+21
-19
@@ -12,8 +12,8 @@ if "pyarrow" in sys.modules:
|
||||
|
||||
# Add the directory containing pyarrow to the Python path so that we find the
|
||||
# pyarrow version packaged with ray and not a pre-existing pyarrow.
|
||||
pyarrow_path = os.path.join(os.path.abspath(os.path.dirname(__file__)),
|
||||
"pyarrow_files")
|
||||
pyarrow_path = os.path.join(
|
||||
os.path.abspath(os.path.dirname(__file__)), "pyarrow_files")
|
||||
sys.path.insert(0, pyarrow_path)
|
||||
|
||||
# See https://github.com/ray-project/ray/issues/131.
|
||||
@@ -27,29 +27,29 @@ If you are using Anaconda, try fixing this problem by running:
|
||||
try:
|
||||
import pyarrow # noqa: F401
|
||||
except ImportError as e:
|
||||
if ((hasattr(e, "msg") and isinstance(e.msg, str) and
|
||||
("libstdc++" in e.msg or "CXX" in e.msg))):
|
||||
if ((hasattr(e, "msg") and isinstance(e.msg, str)
|
||||
and ("libstdc++" in e.msg or "CXX" in e.msg))):
|
||||
# This code path should be taken with Python 3.
|
||||
e.msg += helpful_message
|
||||
elif (hasattr(e, "message") and isinstance(e.message, str) and
|
||||
("libstdc++" in e.message or "CXX" in e.message)):
|
||||
elif (hasattr(e, "message") and isinstance(e.message, str)
|
||||
and ("libstdc++" in e.message or "CXX" in e.message)):
|
||||
# This code path should be taken with Python 2.
|
||||
condition = (hasattr(e, "args") and isinstance(e.args, tuple) and
|
||||
len(e.args) == 1 and isinstance(e.args[0], str))
|
||||
condition = (hasattr(e, "args") and isinstance(e.args, tuple)
|
||||
and len(e.args) == 1 and isinstance(e.args[0], str))
|
||||
if condition:
|
||||
e.args = (e.args[0] + helpful_message,)
|
||||
e.args = (e.args[0] + helpful_message, )
|
||||
else:
|
||||
if not hasattr(e, "args"):
|
||||
e.args = ()
|
||||
elif not isinstance(e.args, tuple):
|
||||
e.args = (e.args,)
|
||||
e.args += (helpful_message,)
|
||||
e.args = (e.args, )
|
||||
e.args += (helpful_message, )
|
||||
raise
|
||||
|
||||
from ray.local_scheduler import _config # noqa: E402
|
||||
from ray.worker import (error_info, init, connect, disconnect,
|
||||
get, put, wait, remote, log_event, log_span,
|
||||
flush_log, get_gpu_ids, get_webui_url,
|
||||
from ray.worker import (error_info, init, connect, disconnect, get, put, wait,
|
||||
remote, log_event, log_span, flush_log, get_gpu_ids,
|
||||
get_webui_url,
|
||||
register_custom_serializer) # noqa: E402
|
||||
from ray.worker import (SCRIPT_MODE, WORKER_MODE, PYTHON_MODE,
|
||||
SILENT_MODE) # noqa: E402
|
||||
@@ -63,11 +63,13 @@ from ray.actor import method # noqa: E402
|
||||
# Fix this.
|
||||
__version__ = "0.4.0"
|
||||
|
||||
__all__ = ["error_info", "init", "connect", "disconnect", "get", "put", "wait",
|
||||
"remote", "log_event", "log_span", "flush_log", "actor", "method",
|
||||
"get_gpu_ids", "get_webui_url", "register_custom_serializer",
|
||||
"SCRIPT_MODE", "WORKER_MODE", "PYTHON_MODE", "SILENT_MODE",
|
||||
"global_state", "_config", "__version__"]
|
||||
__all__ = [
|
||||
"error_info", "init", "connect", "disconnect", "get", "put", "wait",
|
||||
"remote", "log_event", "log_span", "flush_log", "actor", "method",
|
||||
"get_gpu_ids", "get_webui_url", "register_custom_serializer",
|
||||
"SCRIPT_MODE", "WORKER_MODE", "PYTHON_MODE", "SILENT_MODE", "global_state",
|
||||
"_config", "__version__"
|
||||
]
|
||||
|
||||
import ctypes # noqa: E402
|
||||
# Windows only
|
||||
|
||||
+114
-95
@@ -121,16 +121,17 @@ def save_and_log_checkpoint(worker, actor):
|
||||
try:
|
||||
actor.__ray_checkpoint__()
|
||||
except Exception:
|
||||
traceback_str = ray.utils.format_error_message(
|
||||
traceback.format_exc())
|
||||
traceback_str = ray.utils.format_error_message(traceback.format_exc())
|
||||
# Log the error message.
|
||||
ray.utils.push_error_to_driver(
|
||||
worker.redis_client,
|
||||
"checkpoint",
|
||||
traceback_str,
|
||||
driver_id=worker.task_driver_id.id(),
|
||||
data={"actor_class": actor.__class__.__name__,
|
||||
"function_name": actor.__ray_checkpoint__.__name__})
|
||||
data={
|
||||
"actor_class": actor.__class__.__name__,
|
||||
"function_name": actor.__ray_checkpoint__.__name__
|
||||
})
|
||||
|
||||
|
||||
def restore_and_log_checkpoint(worker, actor):
|
||||
@@ -144,8 +145,7 @@ def restore_and_log_checkpoint(worker, actor):
|
||||
try:
|
||||
checkpoint_resumed = actor.__ray_checkpoint_restore__()
|
||||
except Exception:
|
||||
traceback_str = ray.utils.format_error_message(
|
||||
traceback.format_exc())
|
||||
traceback_str = ray.utils.format_error_message(traceback.format_exc())
|
||||
# Log the error message.
|
||||
ray.utils.push_error_to_driver(
|
||||
worker.redis_client,
|
||||
@@ -154,8 +154,8 @@ def restore_and_log_checkpoint(worker, actor):
|
||||
driver_id=worker.task_driver_id.id(),
|
||||
data={
|
||||
"actor_class": actor.__class__.__name__,
|
||||
"function_name":
|
||||
actor.__ray_checkpoint_restore__.__name__})
|
||||
"function_name": actor.__ray_checkpoint_restore__.__name__
|
||||
})
|
||||
return checkpoint_resumed
|
||||
|
||||
|
||||
@@ -197,15 +197,15 @@ def make_actor_method_executor(worker, method_name, method, actor_imported):
|
||||
return
|
||||
|
||||
# Determine whether we should checkpoint the actor.
|
||||
checkpointing_on = (actor_imported and
|
||||
worker.actor_checkpoint_interval > 0)
|
||||
checkpointing_on = (actor_imported
|
||||
and worker.actor_checkpoint_interval > 0)
|
||||
# We should checkpoint the actor if user checkpointing is on, we've
|
||||
# executed checkpoint_interval tasks since the last checkpoint, and the
|
||||
# method we're about to execute is not a checkpoint.
|
||||
save_checkpoint = (checkpointing_on and
|
||||
(worker.actor_task_counter %
|
||||
worker.actor_checkpoint_interval == 0 and
|
||||
method_name != "__ray_checkpoint__"))
|
||||
save_checkpoint = (
|
||||
checkpointing_on and
|
||||
(worker.actor_task_counter % worker.actor_checkpoint_interval == 0
|
||||
and method_name != "__ray_checkpoint__"))
|
||||
|
||||
# Execute the assigned method and save a checkpoint if necessary.
|
||||
try:
|
||||
@@ -238,14 +238,14 @@ def fetch_and_register_actor(actor_class_key, resources, worker):
|
||||
worker: The worker to use.
|
||||
"""
|
||||
actor_id_str = worker.actor_id
|
||||
(driver_id, class_id, class_name,
|
||||
module, pickled_class, checkpoint_interval,
|
||||
actor_method_names,
|
||||
(driver_id, class_id, class_name, module, pickled_class,
|
||||
checkpoint_interval, actor_method_names,
|
||||
actor_method_num_return_vals) = worker.redis_client.hmget(
|
||||
actor_class_key, ["driver_id", "class_id", "class_name", "module",
|
||||
"class", "checkpoint_interval",
|
||||
"actor_method_names",
|
||||
"actor_method_num_return_vals"])
|
||||
actor_class_key, [
|
||||
"driver_id", "class_id", "class_name", "module", "class",
|
||||
"checkpoint_interval", "actor_method_names",
|
||||
"actor_method_num_return_vals"
|
||||
])
|
||||
|
||||
actor_name = class_name.decode("ascii")
|
||||
module = module.decode("ascii")
|
||||
@@ -259,12 +259,14 @@ def fetch_and_register_actor(actor_class_key, resources, worker):
|
||||
# error messages and to prevent the driver from hanging).
|
||||
class TemporaryActor(object):
|
||||
pass
|
||||
|
||||
worker.actors[actor_id_str] = TemporaryActor()
|
||||
worker.actor_checkpoint_interval = checkpoint_interval
|
||||
|
||||
def temporary_actor_method(*xs):
|
||||
raise Exception("The actor with name {} failed to be imported, and so "
|
||||
"cannot execute this method".format(actor_name))
|
||||
|
||||
# Register the actor method signatures.
|
||||
register_actor_signatures(worker, driver_id, class_id, class_name,
|
||||
actor_method_names, actor_method_num_return_vals)
|
||||
@@ -272,10 +274,11 @@ def fetch_and_register_actor(actor_class_key, resources, worker):
|
||||
for actor_method_name in actor_method_names:
|
||||
function_id = compute_actor_method_function_id(class_name,
|
||||
actor_method_name).id()
|
||||
temporary_executor = make_actor_method_executor(worker,
|
||||
actor_method_name,
|
||||
temporary_actor_method,
|
||||
actor_imported=False)
|
||||
temporary_executor = make_actor_method_executor(
|
||||
worker,
|
||||
actor_method_name,
|
||||
temporary_actor_method,
|
||||
actor_imported=False)
|
||||
worker.functions[driver_id][function_id] = (actor_method_name,
|
||||
temporary_executor)
|
||||
worker.num_task_executions[driver_id][function_id] = 0
|
||||
@@ -288,9 +291,12 @@ def fetch_and_register_actor(actor_class_key, resources, worker):
|
||||
# traceback and notify the scheduler of the failure.
|
||||
traceback_str = ray.utils.format_error_message(traceback.format_exc())
|
||||
# Log the error message.
|
||||
push_error_to_driver(worker.redis_client, "register_actor_signatures",
|
||||
traceback_str, driver_id,
|
||||
data={"actor_id": actor_id_str})
|
||||
push_error_to_driver(
|
||||
worker.redis_client,
|
||||
"register_actor_signatures",
|
||||
traceback_str,
|
||||
driver_id,
|
||||
data={"actor_id": actor_id_str})
|
||||
# TODO(rkn): In the future, it might make sense to have the worker exit
|
||||
# here. However, currently that would lead to hanging if someone calls
|
||||
# ray.get on a method invoked on the actor.
|
||||
@@ -298,16 +304,17 @@ def fetch_and_register_actor(actor_class_key, resources, worker):
|
||||
# TODO(pcm): Why is the below line necessary?
|
||||
unpickled_class.__module__ = module
|
||||
worker.actors[actor_id_str] = unpickled_class.__new__(unpickled_class)
|
||||
actor_methods = inspect.getmembers(
|
||||
unpickled_class, predicate=(lambda x: (inspect.isfunction(x) or
|
||||
inspect.ismethod(x) or
|
||||
is_cython(x))))
|
||||
|
||||
def pred(x):
|
||||
return (inspect.isfunction(x) or inspect.ismethod(x)
|
||||
or is_cython(x))
|
||||
|
||||
actor_methods = inspect.getmembers(unpickled_class, predicate=pred)
|
||||
for actor_method_name, actor_method in actor_methods:
|
||||
function_id = compute_actor_method_function_id(
|
||||
class_name, actor_method_name).id()
|
||||
executor = make_actor_method_executor(worker, actor_method_name,
|
||||
actor_method,
|
||||
actor_imported=True)
|
||||
executor = make_actor_method_executor(
|
||||
worker, actor_method_name, actor_method, actor_imported=True)
|
||||
worker.functions[driver_id][function_id] = (actor_method_name,
|
||||
executor)
|
||||
# We do not set worker.function_properties[driver_id][function_id]
|
||||
@@ -315,7 +322,10 @@ def fetch_and_register_actor(actor_class_key, resources, worker):
|
||||
# for the actor.
|
||||
|
||||
|
||||
def register_actor_signatures(worker, driver_id, class_id, class_name,
|
||||
def register_actor_signatures(worker,
|
||||
driver_id,
|
||||
class_id,
|
||||
class_name,
|
||||
actor_method_names,
|
||||
actor_method_num_return_vals,
|
||||
actor_creation_resources=None,
|
||||
@@ -346,18 +356,20 @@ def register_actor_signatures(worker, driver_id, class_id, class_name,
|
||||
# The extra return value is an actor dummy object.
|
||||
# In the cases where actor_method_cpus is None, that value should
|
||||
# never be used.
|
||||
FunctionProperties(num_return_vals=num_return_vals + 1,
|
||||
resources={"CPU": actor_method_cpus},
|
||||
max_calls=0))
|
||||
FunctionProperties(
|
||||
num_return_vals=num_return_vals + 1,
|
||||
resources={"CPU": actor_method_cpus},
|
||||
max_calls=0))
|
||||
|
||||
if actor_creation_resources is not None:
|
||||
# Also register the actor creation task.
|
||||
function_id = compute_actor_creation_function_id(class_id)
|
||||
worker.function_properties[driver_id][function_id.id()] = (
|
||||
# The extra return value is an actor dummy object.
|
||||
FunctionProperties(num_return_vals=0 + 1,
|
||||
resources=actor_creation_resources,
|
||||
max_calls=0))
|
||||
FunctionProperties(
|
||||
num_return_vals=0 + 1,
|
||||
resources=actor_creation_resources,
|
||||
max_calls=0))
|
||||
|
||||
|
||||
def publish_actor_class_to_key(key, actor_class_info, worker):
|
||||
@@ -380,8 +392,8 @@ def publish_actor_class_to_key(key, actor_class_info, worker):
|
||||
|
||||
|
||||
def export_actor_class(class_id, Class, actor_method_names,
|
||||
actor_method_num_return_vals,
|
||||
checkpoint_interval, worker):
|
||||
actor_method_num_return_vals, checkpoint_interval,
|
||||
worker):
|
||||
key = b"ActorClass:" + class_id
|
||||
actor_class_info = {
|
||||
"class_name": Class.__name__,
|
||||
@@ -389,8 +401,9 @@ def export_actor_class(class_id, Class, actor_method_names,
|
||||
"class": pickle.dumps(Class),
|
||||
"checkpoint_interval": checkpoint_interval,
|
||||
"actor_method_names": json.dumps(list(actor_method_names)),
|
||||
"actor_method_num_return_vals": json.dumps(
|
||||
actor_method_num_return_vals)}
|
||||
"actor_method_num_return_vals":
|
||||
json.dumps(actor_method_num_return_vals)
|
||||
}
|
||||
|
||||
if worker.mode is None:
|
||||
# This means that 'ray.init()' has not been called yet and so we must
|
||||
@@ -433,7 +446,11 @@ def export_actor(actor_id, class_id, class_name, actor_method_names,
|
||||
|
||||
driver_id = worker.task_driver_id.id()
|
||||
register_actor_signatures(
|
||||
worker, driver_id, class_id, class_name, actor_method_names,
|
||||
worker,
|
||||
driver_id,
|
||||
class_id,
|
||||
class_name,
|
||||
actor_method_names,
|
||||
actor_method_num_return_vals,
|
||||
actor_creation_resources=actor_creation_resources,
|
||||
actor_method_cpus=actor_method_cpus)
|
||||
@@ -466,12 +483,14 @@ class ActorMethod(object):
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise Exception("Actor methods cannot be called directly. Instead "
|
||||
"of running 'object.{}()', try "
|
||||
"'object.{}.remote()'."
|
||||
.format(self._method_name, self._method_name))
|
||||
"'object.{}.remote()'.".format(self._method_name,
|
||||
self._method_name))
|
||||
|
||||
def remote(self, *args, **kwargs):
|
||||
return self._actor._actor_method_call(
|
||||
self._method_name, args=args, kwargs=kwargs,
|
||||
self._method_name,
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
dependency=self._actor._ray_actor_cursor)
|
||||
|
||||
|
||||
@@ -481,12 +500,13 @@ class ActorHandleWrapper(object):
|
||||
This is essentially just a dictionary, but it is used so that the recipient
|
||||
can tell that an argument is an ActorHandle.
|
||||
"""
|
||||
|
||||
def __init__(self, actor_id, class_id, actor_handle_id, actor_cursor,
|
||||
actor_counter, actor_method_names,
|
||||
actor_method_num_return_vals, method_signatures,
|
||||
checkpoint_interval, class_name,
|
||||
actor_creation_dummy_object_id,
|
||||
actor_creation_resources, actor_method_cpus):
|
||||
actor_creation_dummy_object_id, actor_creation_resources,
|
||||
actor_method_cpus):
|
||||
# TODO(rkn): Some of these fields are probably not necessary. We should
|
||||
# strip out the unnecessary fields to keep actor handles lightweight.
|
||||
self.actor_id = actor_id
|
||||
@@ -545,27 +565,20 @@ def unwrap_actor_handle(worker, wrapper):
|
||||
The unwrapped ActorHandle instance.
|
||||
"""
|
||||
driver_id = worker.task_driver_id.id()
|
||||
register_actor_signatures(worker, driver_id, wrapper.class_id,
|
||||
wrapper.class_name, wrapper.actor_method_names,
|
||||
wrapper.actor_method_num_return_vals,
|
||||
wrapper.actor_creation_resources,
|
||||
wrapper.actor_method_cpus)
|
||||
register_actor_signatures(
|
||||
worker, driver_id, wrapper.class_id, wrapper.class_name,
|
||||
wrapper.actor_method_names, wrapper.actor_method_num_return_vals,
|
||||
wrapper.actor_creation_resources, wrapper.actor_method_cpus)
|
||||
|
||||
actor_handle_class = make_actor_handle_class(wrapper.class_name)
|
||||
actor_object = actor_handle_class.__new__(actor_handle_class)
|
||||
actor_object._manual_init(
|
||||
wrapper.actor_id,
|
||||
wrapper.class_id,
|
||||
wrapper.actor_handle_id,
|
||||
wrapper.actor_cursor,
|
||||
wrapper.actor_counter,
|
||||
wrapper.actor_method_names,
|
||||
wrapper.actor_method_num_return_vals,
|
||||
wrapper.method_signatures,
|
||||
wrapper.checkpoint_interval,
|
||||
wrapper.actor_id, wrapper.class_id, wrapper.actor_handle_id,
|
||||
wrapper.actor_cursor, wrapper.actor_counter,
|
||||
wrapper.actor_method_names, wrapper.actor_method_num_return_vals,
|
||||
wrapper.method_signatures, wrapper.checkpoint_interval,
|
||||
wrapper.actor_creation_dummy_object_id,
|
||||
wrapper.actor_creation_resources,
|
||||
wrapper.actor_method_cpus)
|
||||
wrapper.actor_creation_resources, wrapper.actor_method_cpus)
|
||||
return actor_object
|
||||
|
||||
|
||||
@@ -612,7 +625,10 @@ def make_actor_handle_class(class_name):
|
||||
self._ray_actor_creation_resources = actor_creation_resources
|
||||
self._ray_actor_method_cpus = actor_method_cpus
|
||||
|
||||
def _actor_method_call(self, method_name, args=None, kwargs=None,
|
||||
def _actor_method_call(self,
|
||||
method_name,
|
||||
args=None,
|
||||
kwargs=None,
|
||||
dependency=None):
|
||||
"""Method execution stub for an actor handle.
|
||||
|
||||
@@ -663,7 +679,9 @@ def make_actor_handle_class(class_name):
|
||||
function_id = compute_actor_method_function_id(
|
||||
self._ray_class_name, method_name)
|
||||
object_ids = ray.worker.global_worker.submit_task(
|
||||
function_id, args, actor_id=self._ray_actor_id,
|
||||
function_id,
|
||||
args,
|
||||
actor_id=self._ray_actor_id,
|
||||
actor_handle_id=self._ray_actor_handle_id,
|
||||
actor_counter=self._ray_actor_counter,
|
||||
is_actor_checkpoint_method=is_actor_checkpoint_method,
|
||||
@@ -722,8 +740,8 @@ def make_actor_handle_class(class_name):
|
||||
self._ray_actor_handle_id.id() == ray.worker.NIL_ACTOR_ID):
|
||||
# TODO(rkn): Should we be passing in the actor cursor as a
|
||||
# dependency here?
|
||||
self._actor_method_call("__ray_terminate__",
|
||||
args=[self._ray_actor_id.id()])
|
||||
self._actor_method_call(
|
||||
"__ray_terminate__", args=[self._ray_actor_id.id()])
|
||||
|
||||
return ActorHandle
|
||||
|
||||
@@ -735,7 +753,6 @@ def actor_handle_from_class(Class, class_id, actor_creation_resources,
|
||||
exported = []
|
||||
|
||||
class ActorHandle(actor_handle_class):
|
||||
|
||||
@classmethod
|
||||
def remote(cls, *args, **kwargs):
|
||||
if ray.worker.global_worker.mode is None:
|
||||
@@ -754,11 +771,13 @@ def actor_handle_from_class(Class, class_id, actor_creation_resources,
|
||||
actor_cursor = None
|
||||
# The number of actor method invocations that we've called so far.
|
||||
actor_counter = 0
|
||||
|
||||
# Get the actor methods of the given class.
|
||||
actor_methods = inspect.getmembers(
|
||||
Class, predicate=(lambda x: (inspect.isfunction(x) or
|
||||
inspect.ismethod(x) or
|
||||
is_cython(x))))
|
||||
def pred(x):
|
||||
return (inspect.isfunction(x) or inspect.ismethod(x)
|
||||
or is_cython(x))
|
||||
|
||||
actor_methods = inspect.getmembers(Class, predicate=pred)
|
||||
# Extract the signatures of each of the methods. This will be used
|
||||
# to catch some errors if the methods are called with inappropriate
|
||||
# arguments.
|
||||
@@ -773,8 +792,9 @@ def actor_handle_from_class(Class, class_id, actor_creation_resources,
|
||||
method_signatures[k] = signature.extract_signature(
|
||||
v, ignore_first=True)
|
||||
|
||||
actor_method_names = [method_name for method_name, _ in
|
||||
actor_methods]
|
||||
actor_method_names = [
|
||||
method_name for method_name, _ in actor_methods
|
||||
]
|
||||
actor_method_num_return_vals = []
|
||||
for _, method in actor_methods:
|
||||
if hasattr(method, "__ray_num_return_vals__"):
|
||||
@@ -796,30 +816,29 @@ def actor_handle_from_class(Class, class_id, actor_creation_resources,
|
||||
checkpoint_interval,
|
||||
ray.worker.global_worker)
|
||||
exported.append(0)
|
||||
actor_cursor = export_actor(actor_id, class_id, class_name,
|
||||
actor_method_names,
|
||||
actor_method_num_return_vals,
|
||||
actor_creation_resources,
|
||||
actor_method_cpus,
|
||||
ray.worker.global_worker)
|
||||
actor_cursor = export_actor(
|
||||
actor_id, class_id, class_name, actor_method_names,
|
||||
actor_method_num_return_vals, actor_creation_resources,
|
||||
actor_method_cpus, ray.worker.global_worker)
|
||||
# Increment the actor counter to account for the creation task.
|
||||
actor_counter += 1
|
||||
|
||||
# Instantiate the actor handle.
|
||||
actor_object = cls.__new__(cls)
|
||||
actor_object._manual_init(actor_id, class_id, actor_handle_id,
|
||||
actor_cursor, actor_counter,
|
||||
actor_method_names,
|
||||
actor_method_num_return_vals,
|
||||
method_signatures, checkpoint_interval,
|
||||
actor_cursor, actor_creation_resources,
|
||||
actor_method_cpus)
|
||||
actor_object._manual_init(
|
||||
actor_id, class_id, actor_handle_id, actor_cursor,
|
||||
actor_counter, actor_method_names,
|
||||
actor_method_num_return_vals, method_signatures,
|
||||
checkpoint_interval, actor_cursor, actor_creation_resources,
|
||||
actor_method_cpus)
|
||||
|
||||
# Call __init__ as a remote function.
|
||||
if "__init__" in actor_object._ray_actor_method_names:
|
||||
actor_object._actor_method_call("__init__", args=args,
|
||||
kwargs=kwargs,
|
||||
dependency=actor_cursor)
|
||||
actor_object._actor_method_call(
|
||||
"__init__",
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
dependency=actor_cursor)
|
||||
else:
|
||||
print("WARNING: this object has no __init__ method.")
|
||||
|
||||
|
||||
@@ -51,25 +51,32 @@ CLUSTER_CONFIG_SCHEMA = {
|
||||
"idle_timeout_minutes": (int, OPTIONAL),
|
||||
|
||||
# Cloud-provider specific configuration.
|
||||
"provider": ({
|
||||
"type": (str, REQUIRED), # e.g. aws
|
||||
"region": (str, OPTIONAL), # e.g. us-east-1
|
||||
"availability_zone": (str, OPTIONAL), # e.g. us-east-1a
|
||||
"module": (str, OPTIONAL), # module, if using external node provider
|
||||
}, REQUIRED),
|
||||
"provider": (
|
||||
{
|
||||
"type": (str, REQUIRED), # e.g. aws
|
||||
"region": (str, OPTIONAL), # e.g. us-east-1
|
||||
"availability_zone": (str, OPTIONAL), # e.g. us-east-1a
|
||||
"module": (str,
|
||||
OPTIONAL), # module, if using external node provider
|
||||
},
|
||||
REQUIRED),
|
||||
|
||||
# How Ray will authenticate with newly launched nodes.
|
||||
"auth": ({
|
||||
"ssh_user": (str, REQUIRED), # e.g. ubuntu
|
||||
"ssh_private_key": (str, OPTIONAL),
|
||||
}, REQUIRED),
|
||||
"auth": (
|
||||
{
|
||||
"ssh_user": (str, REQUIRED), # e.g. ubuntu
|
||||
"ssh_private_key": (str, OPTIONAL),
|
||||
},
|
||||
REQUIRED),
|
||||
|
||||
# Docker configuration. If this is specified, all setup and start commands
|
||||
# will be executed in the container.
|
||||
"docker": ({
|
||||
"image": (str, OPTIONAL), # e.g. tensorflow/tensorflow:1.5.0-py3
|
||||
"container_name": (str, OPTIONAL), # e.g., ray_docker
|
||||
}, OPTIONAL),
|
||||
"docker": (
|
||||
{
|
||||
"image": (str, OPTIONAL), # e.g. tensorflow/tensorflow:1.5.0-py3
|
||||
"container_name": (str, OPTIONAL), # e.g., ray_docker
|
||||
},
|
||||
OPTIONAL),
|
||||
|
||||
# Provider-specific config for the head node, e.g. instance type.
|
||||
"head_node": (dict, OPTIONAL),
|
||||
@@ -137,9 +144,9 @@ class LoadMetrics(object):
|
||||
for unwanted_key in unwanted:
|
||||
del mapping[unwanted_key]
|
||||
if unwanted:
|
||||
print(
|
||||
"Removed {} stale ip mappings: {} not in {}".format(
|
||||
len(unwanted), unwanted, active_ips))
|
||||
print("Removed {} stale ip mappings: {} not in {}".format(
|
||||
len(unwanted), unwanted, active_ips))
|
||||
|
||||
prune(self.last_used_time_by_ip)
|
||||
prune(self.static_resources_by_ip)
|
||||
prune(self.dynamic_resources_by_ip)
|
||||
@@ -148,10 +155,8 @@ class LoadMetrics(object):
|
||||
return self._info()["NumNodesUsed"]
|
||||
|
||||
def debug_string(self):
|
||||
return " - {}".format(
|
||||
"\n - ".join(
|
||||
["{}: {}".format(k, v)
|
||||
for k, v in sorted(self._info().items())]))
|
||||
return " - {}".format("\n - ".join(
|
||||
["{}: {}".format(k, v) for k, v in sorted(self._info().items())]))
|
||||
|
||||
def _info(self):
|
||||
nodes_used = 0.0
|
||||
@@ -176,14 +181,19 @@ class LoadMetrics(object):
|
||||
nodes_used += max_frac
|
||||
idle_times = [now - t for t in self.last_used_time_by_ip.values()]
|
||||
return {
|
||||
"ResourceUsage": ", ".join([
|
||||
"ResourceUsage":
|
||||
", ".join([
|
||||
"{}/{} {}".format(
|
||||
round(resources_used[rid], 2),
|
||||
round(resources_total[rid], 2), rid)
|
||||
for rid in sorted(resources_used)]),
|
||||
"NumNodesConnected": len(self.static_resources_by_ip),
|
||||
"NumNodesUsed": round(nodes_used, 2),
|
||||
"NodeIdleSeconds": "Min={} Mean={} Max={}".format(
|
||||
for rid in sorted(resources_used)
|
||||
]),
|
||||
"NumNodesConnected":
|
||||
len(self.static_resources_by_ip),
|
||||
"NumNodesUsed":
|
||||
round(nodes_used, 2),
|
||||
"NodeIdleSeconds":
|
||||
"Min={} Mean={} Max={}".format(
|
||||
int(np.min(idle_times)) if idle_times else -1,
|
||||
int(np.mean(idle_times)) if idle_times else -1,
|
||||
int(np.max(idle_times)) if idle_times else -1),
|
||||
@@ -208,18 +218,20 @@ class StandardAutoscaler(object):
|
||||
until the target cluster size is met).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, config_path, load_metrics,
|
||||
max_concurrent_launches=AUTOSCALER_MAX_CONCURRENT_LAUNCHES,
|
||||
max_failures=AUTOSCALER_MAX_NUM_FAILURES,
|
||||
process_runner=subprocess, verbose_updates=False,
|
||||
node_updater_cls=NodeUpdaterProcess,
|
||||
update_interval_s=AUTOSCALER_UPDATE_INTERVAL_S):
|
||||
def __init__(self,
|
||||
config_path,
|
||||
load_metrics,
|
||||
max_concurrent_launches=AUTOSCALER_MAX_CONCURRENT_LAUNCHES,
|
||||
max_failures=AUTOSCALER_MAX_NUM_FAILURES,
|
||||
process_runner=subprocess,
|
||||
verbose_updates=False,
|
||||
node_updater_cls=NodeUpdaterProcess,
|
||||
update_interval_s=AUTOSCALER_UPDATE_INTERVAL_S):
|
||||
self.config_path = config_path
|
||||
self.reload_config(errors_fatal=True)
|
||||
self.load_metrics = load_metrics
|
||||
self.provider = get_node_provider(
|
||||
self.config["provider"], self.config["cluster_name"])
|
||||
self.provider = get_node_provider(self.config["provider"],
|
||||
self.config["cluster_name"])
|
||||
|
||||
self.max_failures = max_failures
|
||||
self.max_concurrent_launches = max_concurrent_launches
|
||||
@@ -245,9 +257,8 @@ class StandardAutoscaler(object):
|
||||
self.reload_config(errors_fatal=False)
|
||||
self._update()
|
||||
except Exception as e:
|
||||
print(
|
||||
"StandardAutoscaler: Error during autoscaling: {}",
|
||||
traceback.format_exc())
|
||||
print("StandardAutoscaler: Error during autoscaling: {}",
|
||||
traceback.format_exc())
|
||||
self.num_failures += 1
|
||||
if self.num_failures > self.max_failures:
|
||||
print("*** StandardAutoscaler: Too many errors, abort. ***")
|
||||
@@ -274,15 +285,13 @@ class StandardAutoscaler(object):
|
||||
if node_ip in last_used and last_used[node_ip] < horizon and \
|
||||
len(nodes) - num_terminated > self.config["min_workers"]:
|
||||
num_terminated += 1
|
||||
print(
|
||||
"StandardAutoscaler: Terminating idle node: "
|
||||
"{}".format(node_id))
|
||||
print("StandardAutoscaler: Terminating idle node: "
|
||||
"{}".format(node_id))
|
||||
self.provider.terminate_node(node_id)
|
||||
elif not self.launch_config_ok(node_id):
|
||||
num_terminated += 1
|
||||
print(
|
||||
"StandardAutoscaler: Terminating outdated node: "
|
||||
"{}".format(node_id))
|
||||
print("StandardAutoscaler: Terminating outdated node: "
|
||||
"{}".format(node_id))
|
||||
self.provider.terminate_node(node_id)
|
||||
if num_terminated > 0:
|
||||
nodes = self.workers()
|
||||
@@ -292,9 +301,8 @@ class StandardAutoscaler(object):
|
||||
num_terminated = 0
|
||||
while len(nodes) > self.config["max_workers"]:
|
||||
num_terminated += 1
|
||||
print(
|
||||
"StandardAutoscaler: Terminating unneeded node: "
|
||||
"{}".format(nodes[-1]))
|
||||
print("StandardAutoscaler: Terminating unneeded node: "
|
||||
"{}".format(nodes[-1]))
|
||||
self.provider.terminate_node(nodes[-1])
|
||||
nodes = nodes[:-1]
|
||||
if num_terminated > 0:
|
||||
@@ -339,13 +347,13 @@ class StandardAutoscaler(object):
|
||||
with open(self.config_path) as f:
|
||||
new_config = yaml.load(f.read())
|
||||
validate_config(new_config)
|
||||
new_launch_hash = hash_launch_conf(
|
||||
new_config["worker_nodes"], new_config["auth"])
|
||||
new_runtime_hash = hash_runtime_conf(
|
||||
new_config["file_mounts"],
|
||||
[new_config["setup_commands"],
|
||||
new_config["worker_setup_commands"],
|
||||
new_config["worker_start_ray_commands"]])
|
||||
new_launch_hash = hash_launch_conf(new_config["worker_nodes"],
|
||||
new_config["auth"])
|
||||
new_runtime_hash = hash_runtime_conf(new_config["file_mounts"], [
|
||||
new_config["setup_commands"],
|
||||
new_config["worker_setup_commands"],
|
||||
new_config["worker_start_ray_commands"]
|
||||
])
|
||||
self.config = new_config
|
||||
self.launch_hash = new_launch_hash
|
||||
self.runtime_hash = new_runtime_hash
|
||||
@@ -353,17 +361,15 @@ class StandardAutoscaler(object):
|
||||
if errors_fatal:
|
||||
raise e
|
||||
else:
|
||||
print(
|
||||
"StandardAutoscaler: Error parsing config: {}",
|
||||
traceback.format_exc())
|
||||
print("StandardAutoscaler: Error parsing config: {}",
|
||||
traceback.format_exc())
|
||||
|
||||
def target_num_workers(self):
|
||||
target_frac = self.config["target_utilization_fraction"]
|
||||
cur_used = self.load_metrics.approx_workers_used()
|
||||
ideal_num_workers = int(np.ceil(cur_used / float(target_frac)))
|
||||
return min(
|
||||
self.config["max_workers"],
|
||||
max(self.config["min_workers"], ideal_num_workers))
|
||||
return min(self.config["max_workers"],
|
||||
max(self.config["min_workers"], ideal_num_workers))
|
||||
|
||||
def launch_config_ok(self, node_id):
|
||||
launch_conf = self.provider.node_tags(node_id).get(
|
||||
@@ -393,8 +399,7 @@ class StandardAutoscaler(object):
|
||||
node_id,
|
||||
self.config["provider"],
|
||||
self.config["auth"],
|
||||
self.config["cluster_name"],
|
||||
{},
|
||||
self.config["cluster_name"], {},
|
||||
with_head_node_ip(self.config["worker_start_ray_commands"]),
|
||||
self.runtime_hash,
|
||||
redirect_output=not self.verbose_updates,
|
||||
@@ -409,14 +414,12 @@ class StandardAutoscaler(object):
|
||||
return
|
||||
if self.config.get("no_restart", False) and \
|
||||
self.num_successful_updates.get(node_id, 0) > 0:
|
||||
init_commands = (
|
||||
self.config["setup_commands"] +
|
||||
self.config["worker_setup_commands"])
|
||||
init_commands = (self.config["setup_commands"] +
|
||||
self.config["worker_setup_commands"])
|
||||
else:
|
||||
init_commands = (
|
||||
self.config["setup_commands"] +
|
||||
self.config["worker_setup_commands"] +
|
||||
self.config["worker_start_ray_commands"])
|
||||
init_commands = (self.config["setup_commands"] +
|
||||
self.config["worker_setup_commands"] +
|
||||
self.config["worker_start_ray_commands"])
|
||||
updater = self.node_updater_cls(
|
||||
node_id,
|
||||
self.config["provider"],
|
||||
@@ -445,14 +448,12 @@ class StandardAutoscaler(object):
|
||||
print("StandardAutoscaler: Launching {} new nodes".format(count))
|
||||
num_before = len(self.workers())
|
||||
self.provider.create_node(
|
||||
self.config["worker_nodes"],
|
||||
{
|
||||
self.config["worker_nodes"], {
|
||||
TAG_NAME: "ray-{}-worker".format(self.config["cluster_name"]),
|
||||
TAG_RAY_NODE_TYPE: "Worker",
|
||||
TAG_RAY_NODE_STATUS: "Uninitialized",
|
||||
TAG_RAY_LAUNCH_CONFIG: self.launch_hash,
|
||||
},
|
||||
count)
|
||||
}, count)
|
||||
# TODO(ekl) be less conservative in this check
|
||||
assert len(self.workers()) > num_before, \
|
||||
"Num nodes failed to increase after creating a new node"
|
||||
@@ -472,8 +473,8 @@ class StandardAutoscaler(object):
|
||||
suffix += " ({} failed to update)".format(
|
||||
len(self.num_failed_updates))
|
||||
return "StandardAutoscaler [{}]: {}/{} target nodes{}\n{}".format(
|
||||
datetime.now(), len(nodes), self.target_num_workers(),
|
||||
suffix, self.load_metrics.debug_string())
|
||||
datetime.now(), len(nodes), self.target_num_workers(), suffix,
|
||||
self.load_metrics.debug_string())
|
||||
|
||||
|
||||
def typename(v):
|
||||
@@ -507,9 +508,8 @@ def check_extraneous(config, schema):
|
||||
raise ValueError("Config {} is not a dictionary".format(config))
|
||||
for k in config:
|
||||
if k not in schema:
|
||||
raise ValueError(
|
||||
"Unexpected config key `{}` not in {}".format(
|
||||
k, list(schema.keys())))
|
||||
raise ValueError("Unexpected config key `{}` not in {}".format(
|
||||
k, list(schema.keys())))
|
||||
v, kreq = schema[k]
|
||||
if v is None:
|
||||
continue
|
||||
@@ -517,7 +517,8 @@ def check_extraneous(config, schema):
|
||||
if not isinstance(config[k], v):
|
||||
raise ValueError(
|
||||
"Config key `{}` has wrong type {}, expected {}".format(
|
||||
k, type(config[k]).__name__, v.__name__))
|
||||
k,
|
||||
type(config[k]).__name__, v.__name__))
|
||||
else:
|
||||
check_extraneous(config[k], v)
|
||||
|
||||
|
||||
@@ -25,12 +25,10 @@ assert StrictVersion(boto3.__version__) >= StrictVersion("1.4.8"), \
|
||||
def key_pair(i, region):
|
||||
"""Returns the ith default (aws_key_pair_name, key_pair_path)."""
|
||||
if i == 0:
|
||||
return (
|
||||
"{}_{}".format(RAY, region),
|
||||
os.path.expanduser("~/.ssh/{}_{}.pem".format(RAY, region)))
|
||||
return (
|
||||
"{}_{}_{}".format(RAY, i, region),
|
||||
os.path.expanduser("~/.ssh/{}_{}_{}.pem".format(RAY, i, region)))
|
||||
return ("{}_{}".format(RAY, region),
|
||||
os.path.expanduser("~/.ssh/{}_{}.pem".format(RAY, region)))
|
||||
return ("{}_{}_{}".format(RAY, i, region),
|
||||
os.path.expanduser("~/.ssh/{}_{}_{}.pem".format(RAY, i, region)))
|
||||
|
||||
|
||||
# Suppress excessive connection dropped logs from boto
|
||||
@@ -83,7 +81,9 @@ def _configure_iam_role(config):
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Principal": {"Service": "ec2.amazonaws.com"},
|
||||
"Principal": {
|
||||
"Service": "ec2.amazonaws.com"
|
||||
},
|
||||
"Action": "sts:AssumeRole",
|
||||
},
|
||||
],
|
||||
@@ -97,8 +97,7 @@ def _configure_iam_role(config):
|
||||
profile.add_role(RoleName=role.name)
|
||||
time.sleep(15) # wait for propagation
|
||||
|
||||
print("Role not specified for head node, using {}".format(
|
||||
profile.arn))
|
||||
print("Role not specified for head node, using {}".format(profile.arn))
|
||||
config["head_node"]["IamInstanceProfile"] = {"Arn": profile.arn}
|
||||
|
||||
return config
|
||||
@@ -146,8 +145,10 @@ def _configure_key_pair(config):
|
||||
def _configure_subnet(config):
|
||||
ec2 = _resource("ec2", config)
|
||||
subnets = sorted(
|
||||
[s for s in ec2.subnets.all()
|
||||
if s.state == "available" and s.map_public_ip_on_launch],
|
||||
[
|
||||
s for s in ec2.subnets.all()
|
||||
if s.state == "available" and s.map_public_ip_on_launch
|
||||
],
|
||||
reverse=True, # sort from Z-A
|
||||
key=lambda subnet: subnet.availability_zone)
|
||||
if not subnets:
|
||||
@@ -157,9 +158,9 @@ def _configure_subnet(config):
|
||||
"and trying this again. Note that the subnet must map public IPs "
|
||||
"on instance launch.")
|
||||
if "availability_zone" in config["provider"]:
|
||||
default_subnet = next((s for s in subnets
|
||||
if s.availability_zone ==
|
||||
config["provider"]["availability_zone"]),
|
||||
default_subnet = next((
|
||||
s for s in subnets
|
||||
if s.availability_zone == config["provider"]["availability_zone"]),
|
||||
None)
|
||||
if not default_subnet:
|
||||
raise Exception(
|
||||
@@ -209,11 +210,21 @@ def _configure_security_group(config):
|
||||
|
||||
if not security_group.ip_permissions:
|
||||
security_group.authorize_ingress(
|
||||
IpPermissions=[
|
||||
{"FromPort": -1, "ToPort": -1, "IpProtocol": "-1",
|
||||
"UserIdGroupPairs": [{"GroupId": security_group.id}]},
|
||||
{"FromPort": 22, "ToPort": 22, "IpProtocol": "TCP",
|
||||
"IpRanges": [{"CidrIp": "0.0.0.0/0"}]}])
|
||||
IpPermissions=[{
|
||||
"FromPort": -1,
|
||||
"ToPort": -1,
|
||||
"IpProtocol": "-1",
|
||||
"UserIdGroupPairs": [{
|
||||
"GroupId": security_group.id
|
||||
}]
|
||||
}, {
|
||||
"FromPort": 22,
|
||||
"ToPort": 22,
|
||||
"IpProtocol": "TCP",
|
||||
"IpRanges": [{
|
||||
"CidrIp": "0.0.0.0/0"
|
||||
}]
|
||||
}])
|
||||
|
||||
if "SecurityGroupIds" not in config["head_node"]:
|
||||
print("SecurityGroupIds not specified for head node, using {}".format(
|
||||
@@ -231,8 +242,10 @@ def _configure_security_group(config):
|
||||
def _get_subnet_or_die(config, subnet_id):
|
||||
ec2 = _resource("ec2", config)
|
||||
subnet = list(
|
||||
ec2.subnets.filter(Filters=[
|
||||
{"Name": "subnet-id", "Values": [subnet_id]}]))
|
||||
ec2.subnets.filter(Filters=[{
|
||||
"Name": "subnet-id",
|
||||
"Values": [subnet_id]
|
||||
}]))
|
||||
assert len(subnet) == 1, "Subnet not found"
|
||||
subnet = subnet[0]
|
||||
return subnet
|
||||
@@ -241,8 +254,10 @@ def _get_subnet_or_die(config, subnet_id):
|
||||
def _get_security_group(config, vpc_id, group_name):
|
||||
ec2 = _resource("ec2", config)
|
||||
existing_groups = list(
|
||||
ec2.security_groups.filter(Filters=[
|
||||
{"Name": "vpc-id", "Values": [vpc_id]}]))
|
||||
ec2.security_groups.filter(Filters=[{
|
||||
"Name": "vpc-id",
|
||||
"Values": [vpc_id]
|
||||
}]))
|
||||
for sg in existing_groups:
|
||||
if sg.group_name == group_name:
|
||||
return sg
|
||||
@@ -270,8 +285,10 @@ def _get_instance_profile(profile_name, config):
|
||||
|
||||
def _get_key(key_name, config):
|
||||
ec2 = _resource("ec2", config)
|
||||
for key in ec2.key_pairs.filter(
|
||||
Filters=[{"Name": "key-name", "Values": [key_name]}]):
|
||||
for key in ec2.key_pairs.filter(Filters=[{
|
||||
"Name": "key-name",
|
||||
"Values": [key_name]
|
||||
}]):
|
||||
if key.name == key_name:
|
||||
return key
|
||||
|
||||
|
||||
@@ -84,7 +84,8 @@ class AWSNodeProvider(NodeProvider):
|
||||
tag_pairs = []
|
||||
for k, v in tags.items():
|
||||
tag_pairs.append({
|
||||
"Key": k, "Value": v,
|
||||
"Key": k,
|
||||
"Value": v,
|
||||
})
|
||||
node.create_tags(Tags=tag_pairs)
|
||||
|
||||
@@ -95,20 +96,20 @@ class AWSNodeProvider(NodeProvider):
|
||||
"Value": self.cluster_name,
|
||||
}]
|
||||
for k, v in tags.items():
|
||||
tag_pairs.append(
|
||||
{
|
||||
"Key": k,
|
||||
"Value": v,
|
||||
})
|
||||
tag_pairs.append({
|
||||
"Key": k,
|
||||
"Value": v,
|
||||
})
|
||||
conf.update({
|
||||
"MinCount": 1,
|
||||
"MaxCount": count,
|
||||
"TagSpecifications": conf.get("TagSpecifications", []) + [
|
||||
{
|
||||
"ResourceType": "instance",
|
||||
"Tags": tag_pairs,
|
||||
}
|
||||
]
|
||||
"MinCount":
|
||||
1,
|
||||
"MaxCount":
|
||||
count,
|
||||
"TagSpecifications":
|
||||
conf.get("TagSpecifications", []) + [{
|
||||
"ResourceType": "instance",
|
||||
"Tags": tag_pairs,
|
||||
}]
|
||||
})
|
||||
self.ec2.create_instances(**conf)
|
||||
|
||||
|
||||
@@ -23,9 +23,8 @@ from ray.autoscaler.tags import TAG_RAY_NODE_TYPE, TAG_RAY_LAUNCH_CONFIG, \
|
||||
from ray.autoscaler.updater import NodeUpdaterProcess
|
||||
|
||||
|
||||
def create_or_update_cluster(
|
||||
config_file, override_min_workers, override_max_workers,
|
||||
no_restart, yes):
|
||||
def create_or_update_cluster(config_file, override_min_workers,
|
||||
override_max_workers, no_restart, yes):
|
||||
"""Create or updates an autoscaling Ray cluster from a config json."""
|
||||
|
||||
config = yaml.load(open(config_file).read())
|
||||
@@ -39,8 +38,8 @@ def create_or_update_cluster(
|
||||
|
||||
importer = NODE_PROVIDERS.get(config["provider"]["type"])
|
||||
if not importer:
|
||||
raise NotImplementedError(
|
||||
"Unsupported provider {}".format(config["provider"]))
|
||||
raise NotImplementedError("Unsupported provider {}".format(
|
||||
config["provider"]))
|
||||
|
||||
bootstrap_config, _ = importer()
|
||||
config = bootstrap_config(config)
|
||||
@@ -129,8 +128,10 @@ def get_or_create_head_node(config, no_restart, yes):
|
||||
remote_config_file.write(json.dumps(remote_config))
|
||||
remote_config_file.flush()
|
||||
config["file_mounts"].update({
|
||||
remote_key_path: config["auth"]["ssh_private_key"],
|
||||
"~/ray_bootstrap_config.yaml": remote_config_file.name
|
||||
remote_key_path:
|
||||
config["auth"]["ssh_private_key"],
|
||||
"~/ray_bootstrap_config.yaml":
|
||||
remote_config_file.name
|
||||
})
|
||||
|
||||
if no_restart:
|
||||
@@ -160,30 +161,24 @@ def get_or_create_head_node(config, no_restart, yes):
|
||||
print("Error: updating {} failed".format(
|
||||
provider.external_ip(head_node)))
|
||||
sys.exit(1)
|
||||
print(
|
||||
"Head node up-to-date, IP address is: {}".format(
|
||||
provider.external_ip(head_node)))
|
||||
print("Head node up-to-date, IP address is: {}".format(
|
||||
provider.external_ip(head_node)))
|
||||
|
||||
monitor_str = "tail -f /tmp/raylogs/monitor-*"
|
||||
for s in init_commands:
|
||||
if ("ray start" in s and "docker exec" in s and
|
||||
"--autoscaling-config" in s):
|
||||
if ("ray start" in s and "docker exec" in s
|
||||
and "--autoscaling-config" in s):
|
||||
monitor_str = "docker exec {} /bin/sh -c {}".format(
|
||||
config["docker"]["container_name"],
|
||||
quote(monitor_str))
|
||||
print(
|
||||
"To monitor auto-scaling activity, you can run:\n\n"
|
||||
" ssh -i {} {}@{} {}\n".format(
|
||||
config["auth"]["ssh_private_key"],
|
||||
config["auth"]["ssh_user"],
|
||||
provider.external_ip(head_node),
|
||||
quote(monitor_str)))
|
||||
print(
|
||||
"To login to the cluster, run:\n\n"
|
||||
" ssh -i {} {}@{}\n".format(
|
||||
config["auth"]["ssh_private_key"],
|
||||
config["auth"]["ssh_user"],
|
||||
provider.external_ip(head_node)))
|
||||
config["docker"]["container_name"], quote(monitor_str))
|
||||
print("To monitor auto-scaling activity, you can run:\n\n"
|
||||
" ssh -i {} {}@{} {}\n".format(config["auth"]["ssh_private_key"],
|
||||
config["auth"]["ssh_user"],
|
||||
provider.external_ip(head_node),
|
||||
quote(monitor_str)))
|
||||
print("To login to the cluster, run:\n\n"
|
||||
" ssh -i {} {}@{}\n".format(config["auth"]["ssh_private_key"],
|
||||
config["auth"]["ssh_user"],
|
||||
provider.external_ip(head_node)))
|
||||
|
||||
|
||||
def get_head_node_ip(config_file):
|
||||
|
||||
@@ -22,24 +22,21 @@ def dockerize_if_needed(config):
|
||||
assert cname, "Must provide container name!"
|
||||
docker_mounts = {dst: dst for dst in config["file_mounts"]}
|
||||
config["setup_commands"] = (
|
||||
docker_install_cmds() +
|
||||
docker_start_cmds(
|
||||
config["auth"]["ssh_user"], docker_image,
|
||||
docker_mounts, cname) +
|
||||
with_docker_exec(
|
||||
config["setup_commands"], container_name=cname))
|
||||
docker_install_cmds() + docker_start_cmds(
|
||||
config["auth"]["ssh_user"], docker_image, docker_mounts, cname) +
|
||||
with_docker_exec(config["setup_commands"], container_name=cname))
|
||||
|
||||
config["head_setup_commands"] = with_docker_exec(
|
||||
config["head_setup_commands"], container_name=cname)
|
||||
config["head_start_ray_commands"] = (
|
||||
docker_autoscaler_setup(cname) +
|
||||
with_docker_exec(
|
||||
docker_autoscaler_setup(cname) + with_docker_exec(
|
||||
config["head_start_ray_commands"], container_name=cname))
|
||||
|
||||
config["worker_setup_commands"] = with_docker_exec(
|
||||
config["worker_setup_commands"], container_name=cname)
|
||||
config["worker_start_ray_commands"] = with_docker_exec(
|
||||
config["worker_start_ray_commands"], container_name=cname,
|
||||
config["worker_start_ray_commands"],
|
||||
container_name=cname,
|
||||
env_vars=["RAY_HEAD_IP"])
|
||||
|
||||
return config
|
||||
@@ -50,21 +47,24 @@ def with_docker_exec(cmds, container_name, env_vars=None):
|
||||
if env_vars:
|
||||
env_str = " ".join(
|
||||
["-e {env}=${env}".format(env=env) for env in env_vars])
|
||||
return ["docker exec {} {} /bin/sh -c {} ".format(
|
||||
env_str, container_name, quote(cmd)) for cmd in cmds]
|
||||
return [
|
||||
"docker exec {} {} /bin/sh -c {} ".format(env_str, container_name,
|
||||
quote(cmd)) for cmd in cmds
|
||||
]
|
||||
|
||||
|
||||
def docker_install_cmds():
|
||||
return [aptwait_cmd() + " && sudo apt-get update",
|
||||
aptwait_cmd() + " && sudo apt-get install -y docker.io"]
|
||||
return [
|
||||
aptwait_cmd() + " && sudo apt-get update",
|
||||
aptwait_cmd() + " && sudo apt-get install -y docker.io"
|
||||
]
|
||||
|
||||
|
||||
def aptwait_cmd():
|
||||
return (
|
||||
"while sudo fuser"
|
||||
" /var/{lib/{dpkg,apt/lists},cache/apt/archives}/lock"
|
||||
" >/dev/null 2>&1; "
|
||||
"do echo 'Waiting for release of dpkg/apt locks'; sleep 5; done")
|
||||
return ("while sudo fuser"
|
||||
" /var/{lib/{dpkg,apt/lists},cache/apt/archives}/lock"
|
||||
" >/dev/null 2>&1; "
|
||||
"do echo 'Waiting for release of dpkg/apt locks'; sleep 5; done")
|
||||
|
||||
|
||||
def docker_start_cmds(user, image, mount, cname):
|
||||
@@ -77,10 +77,12 @@ def docker_start_cmds(user, image, mount, cname):
|
||||
|
||||
# create flags
|
||||
# ports for the redis, object manager, and tune client
|
||||
port_flags = " ".join(["-p {port}:{port}".format(port=port)
|
||||
for port in ["6379", "8076", "4321"]])
|
||||
mount_flags = " ".join(["-v {src}:{dest}".format(src=k, dest=v)
|
||||
for k, v in mount.items()])
|
||||
port_flags = " ".join([
|
||||
"-p {port}:{port}".format(port=port)
|
||||
for port in ["6379", "8076", "4321"]
|
||||
])
|
||||
mount_flags = " ".join(
|
||||
["-v {src}:{dest}".format(src=k, dest=v) for k, v in mount.items()])
|
||||
|
||||
# for click, used in ray cli
|
||||
env_vars = {"LC_ALL": "C.UTF-8", "LANG": "C.UTF-8"}
|
||||
@@ -88,9 +90,10 @@ def docker_start_cmds(user, image, mount, cname):
|
||||
["-e {name}={val}".format(name=k, val=v) for k, v in env_vars.items()])
|
||||
|
||||
# docker run command
|
||||
docker_run = ["docker", "run", "--rm", "--name {}".format(cname),
|
||||
"-d", "-it", port_flags, mount_flags, env_flags,
|
||||
"--net=host", image, "bash"]
|
||||
docker_run = [
|
||||
"docker", "run", "--rm", "--name {}".format(cname), "-d", "-it",
|
||||
port_flags, mount_flags, env_flags, "--net=host", image, "bash"
|
||||
]
|
||||
cmds.append(" ".join(docker_run))
|
||||
docker_update = []
|
||||
docker_update.append("apt-get -y update")
|
||||
@@ -107,7 +110,8 @@ def docker_autoscaler_setup(cname):
|
||||
base_path = os.path.basename(path)
|
||||
cmds.append("docker cp {path} {cname}:{dpath}".format(
|
||||
path=path, dpath=base_path, cname=cname))
|
||||
cmds.extend(with_docker_exec(
|
||||
["cp {} {}".format("/" + base_path, path)],
|
||||
container_name=cname))
|
||||
cmds.extend(
|
||||
with_docker_exec(
|
||||
["cp {} {}".format("/" + base_path, path)],
|
||||
container_name=cname))
|
||||
return cmds
|
||||
|
||||
@@ -15,14 +15,15 @@ def import_aws():
|
||||
|
||||
def load_aws_config():
|
||||
import ray.autoscaler.aws as ray_aws
|
||||
return os.path.join(os.path.dirname(
|
||||
ray_aws.__file__), "example-full.yaml")
|
||||
return os.path.join(os.path.dirname(ray_aws.__file__), "example-full.yaml")
|
||||
|
||||
|
||||
def import_external():
|
||||
"""Mock a normal provider importer."""
|
||||
|
||||
def return_it_back(config):
|
||||
return config
|
||||
|
||||
return return_it_back, None
|
||||
|
||||
|
||||
@@ -55,8 +56,7 @@ def load_class(path):
|
||||
class_data = path.split(".")
|
||||
if len(class_data) < 2:
|
||||
raise ValueError(
|
||||
"You need to pass a valid path like mymodule.provider_class"
|
||||
)
|
||||
"You need to pass a valid path like mymodule.provider_class")
|
||||
module_path = ".".join(class_data[:-1])
|
||||
class_str = class_data[-1]
|
||||
module = importlib.import_module(module_path)
|
||||
@@ -71,8 +71,8 @@ def get_node_provider(provider_config, cluster_name):
|
||||
importer = NODE_PROVIDERS.get(provider_config["type"])
|
||||
|
||||
if importer is None:
|
||||
raise NotImplementedError(
|
||||
"Unsupported node provider: {}".format(provider_config["type"]))
|
||||
raise NotImplementedError("Unsupported node provider: {}".format(
|
||||
provider_config["type"]))
|
||||
_, provider_cls = importer()
|
||||
return provider_cls(provider_config, cluster_name)
|
||||
|
||||
@@ -82,8 +82,8 @@ def get_default_config(provider_config):
|
||||
return {}
|
||||
load_config = DEFAULT_CONFIGS.get(provider_config["type"])
|
||||
if load_config is None:
|
||||
raise NotImplementedError(
|
||||
"Unsupported node provider: {}".format(provider_config["type"]))
|
||||
raise NotImplementedError("Unsupported node provider: {}".format(
|
||||
provider_config["type"]))
|
||||
path_to_default = load_config()
|
||||
with open(path_to_default) as f:
|
||||
defaults = yaml.load(f)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
"""The Ray autoscaler uses tags to associate metadata with instances."""
|
||||
|
||||
# Tag for the name of the node
|
||||
|
||||
@@ -26,10 +26,16 @@ def pretty_cmd(cmd_str):
|
||||
class NodeUpdater(object):
|
||||
"""A process for syncing files and running init commands on a node."""
|
||||
|
||||
def __init__(
|
||||
self, node_id, provider_config, auth_config, cluster_name,
|
||||
file_mounts, setup_cmds, runtime_hash, redirect_output=True,
|
||||
process_runner=subprocess):
|
||||
def __init__(self,
|
||||
node_id,
|
||||
provider_config,
|
||||
auth_config,
|
||||
cluster_name,
|
||||
file_mounts,
|
||||
setup_cmds,
|
||||
runtime_hash,
|
||||
redirect_output=True,
|
||||
process_runner=subprocess):
|
||||
self.daemon = True
|
||||
self.process_runner = process_runner
|
||||
self.provider = get_node_provider(provider_config, cluster_name)
|
||||
@@ -66,13 +72,12 @@ class NodeUpdater(object):
|
||||
"NodeUpdater: Error updating {}"
|
||||
"See {} for remote logs.".format(error_str, self.output_name),
|
||||
file=self.stdout)
|
||||
self.provider.set_node_tags(
|
||||
self.node_id, {TAG_RAY_NODE_STATUS: "UpdateFailed"})
|
||||
self.provider.set_node_tags(self.node_id,
|
||||
{TAG_RAY_NODE_STATUS: "UpdateFailed"})
|
||||
if self.logfile is not None:
|
||||
print(
|
||||
"----- BEGIN REMOTE LOGS -----\n" +
|
||||
open(self.logfile.name).read() +
|
||||
"\n----- END REMOTE LOGS -----")
|
||||
print("----- BEGIN REMOTE LOGS -----\n" + open(
|
||||
self.logfile.name).read() + "\n----- END REMOTE LOGS -----"
|
||||
)
|
||||
raise e
|
||||
self.provider.set_node_tags(
|
||||
self.node_id, {
|
||||
@@ -85,8 +90,8 @@ class NodeUpdater(object):
|
||||
file=self.stdout)
|
||||
|
||||
def do_update(self):
|
||||
self.provider.set_node_tags(
|
||||
self.node_id, {TAG_RAY_NODE_STATUS: "WaitingForSSH"})
|
||||
self.provider.set_node_tags(self.node_id,
|
||||
{TAG_RAY_NODE_STATUS: "WaitingForSSH"})
|
||||
deadline = time.time() + NODE_START_WAIT_S
|
||||
|
||||
# Wait for external IP
|
||||
@@ -114,7 +119,8 @@ class NodeUpdater(object):
|
||||
raise Exception("Node not running yet...")
|
||||
self.ssh_cmd(
|
||||
"uptime",
|
||||
connect_timeout=5, redirect=open("/dev/null", "w"))
|
||||
connect_timeout=5,
|
||||
redirect=open("/dev/null", "w"))
|
||||
ssh_ok = True
|
||||
except Exception as e:
|
||||
retry_str = str(e)
|
||||
@@ -130,8 +136,8 @@ class NodeUpdater(object):
|
||||
assert ssh_ok, "Unable to SSH to node"
|
||||
|
||||
# Rsync file mounts
|
||||
self.provider.set_node_tags(
|
||||
self.node_id, {TAG_RAY_NODE_STATUS: "SyncingFiles"})
|
||||
self.provider.set_node_tags(self.node_id,
|
||||
{TAG_RAY_NODE_STATUS: "SyncingFiles"})
|
||||
for remote_path, local_path in self.file_mounts.items():
|
||||
print(
|
||||
"NodeUpdater: Syncing {} to {}...".format(
|
||||
@@ -143,18 +149,20 @@ class NodeUpdater(object):
|
||||
local_path += "/"
|
||||
if not remote_path.endswith("/"):
|
||||
remote_path += "/"
|
||||
self.ssh_cmd(
|
||||
"mkdir -p {}".format(os.path.dirname(remote_path)))
|
||||
self.process_runner.check_call([
|
||||
"rsync", "-e", "ssh -i {} ".format(self.ssh_private_key) +
|
||||
"-o ConnectTimeout=120s -o StrictHostKeyChecking=no",
|
||||
"--delete", "-avz", "{}".format(local_path),
|
||||
"{}@{}:{}".format(self.ssh_user, self.ssh_ip, remote_path)
|
||||
], stdout=self.stdout, stderr=self.stderr)
|
||||
self.ssh_cmd("mkdir -p {}".format(os.path.dirname(remote_path)))
|
||||
self.process_runner.check_call(
|
||||
[
|
||||
"rsync", "-e", "ssh -i {} ".format(self.ssh_private_key) +
|
||||
"-o ConnectTimeout=120s -o StrictHostKeyChecking=no",
|
||||
"--delete", "-avz", "{}".format(local_path),
|
||||
"{}@{}:{}".format(self.ssh_user, self.ssh_ip, remote_path)
|
||||
],
|
||||
stdout=self.stdout,
|
||||
stderr=self.stderr)
|
||||
|
||||
# Run init commands
|
||||
self.provider.set_node_tags(
|
||||
self.node_id, {TAG_RAY_NODE_STATUS: "SettingUp"})
|
||||
self.provider.set_node_tags(self.node_id,
|
||||
{TAG_RAY_NODE_STATUS: "SettingUp"})
|
||||
for cmd in self.setup_cmds:
|
||||
self.ssh_cmd(cmd, verbose=True)
|
||||
|
||||
@@ -165,13 +173,16 @@ class NodeUpdater(object):
|
||||
pretty_cmd(cmd), self.ssh_ip),
|
||||
file=self.stdout)
|
||||
force_interactive = "set -i && source ~/.bashrc && "
|
||||
self.process_runner.check_call([
|
||||
"ssh", "-o", "ConnectTimeout={}s".format(connect_timeout),
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-i", self.ssh_private_key,
|
||||
"{}@{}".format(self.ssh_user, self.ssh_ip),
|
||||
"bash --login -c {}".format(pipes.quote(force_interactive + cmd))
|
||||
], stdout=redirect or self.stdout, stderr=redirect or self.stderr)
|
||||
self.process_runner.check_call(
|
||||
[
|
||||
"ssh", "-o", "ConnectTimeout={}s".format(connect_timeout),
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-i", self.ssh_private_key, "{}@{}".format(
|
||||
self.ssh_user, self.ssh_ip), "bash --login -c {}".format(
|
||||
pipes.quote(force_interactive + cmd))
|
||||
],
|
||||
stdout=redirect or self.stdout,
|
||||
stderr=redirect or self.stderr)
|
||||
|
||||
|
||||
class NodeUpdaterProcess(NodeUpdater, Process):
|
||||
|
||||
@@ -25,7 +25,7 @@ OBJECT_CHANNEL_PREFIX = "OC:"
|
||||
def integerToAsciiHex(num, numbytes):
|
||||
retstr = b""
|
||||
# Support 32 and 64 bit architecture.
|
||||
assert(numbytes == 4 or numbytes == 8)
|
||||
assert (numbytes == 4 or numbytes == 8)
|
||||
for i in range(numbytes):
|
||||
curbyte = num & 0xff
|
||||
if sys.version_info >= (3, 0):
|
||||
@@ -50,7 +50,6 @@ def get_next_message(pubsub_client, timeout_seconds=10):
|
||||
|
||||
|
||||
class TestGlobalStateStore(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
redis_port, _ = ray.services.start_redis_instance()
|
||||
self.redis = redis.StrictRedis(host="localhost", port=redis_port, db=0)
|
||||
@@ -192,16 +191,16 @@ class TestGlobalStateStore(unittest.TestCase):
|
||||
# notifications.
|
||||
def check_object_notification(notification_message, object_id,
|
||||
object_size, manager_ids):
|
||||
notification_object = (SubscribeToNotificationsReply
|
||||
.GetRootAsSubscribeToNotificationsReply(
|
||||
notification_object = (SubscribeToNotificationsReply.
|
||||
GetRootAsSubscribeToNotificationsReply(
|
||||
notification_message, 0))
|
||||
self.assertEqual(notification_object.ObjectId(), object_id)
|
||||
self.assertEqual(notification_object.ObjectSize(), object_size)
|
||||
self.assertEqual(notification_object.ManagerIdsLength(),
|
||||
len(manager_ids))
|
||||
for i in range(len(manager_ids)):
|
||||
self.assertEqual(notification_object.ManagerIds(i),
|
||||
manager_ids[i])
|
||||
self.assertEqual(
|
||||
notification_object.ManagerIds(i), manager_ids[i])
|
||||
|
||||
data_size = 0xf1f0
|
||||
p = self.redis.pubsub()
|
||||
@@ -215,10 +214,9 @@ class TestGlobalStateStore(unittest.TestCase):
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS",
|
||||
"manager_id1", "object_id1")
|
||||
# Verify that the notification is correct.
|
||||
check_object_notification(get_next_message(p)["data"],
|
||||
b"object_id1",
|
||||
data_size,
|
||||
[b"manager_id2"])
|
||||
check_object_notification(
|
||||
get_next_message(p)["data"], b"object_id1", data_size,
|
||||
[b"manager_id2"])
|
||||
|
||||
# Request a notification for an object that isn't there. Then add the
|
||||
# object and receive the data. Only the first call to
|
||||
@@ -232,26 +230,22 @@ class TestGlobalStateStore(unittest.TestCase):
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3",
|
||||
data_size, "hash1", "manager_id3")
|
||||
# Verify that the notification is correct.
|
||||
check_object_notification(get_next_message(p)["data"],
|
||||
b"object_id3",
|
||||
data_size,
|
||||
[b"manager_id1"])
|
||||
check_object_notification(
|
||||
get_next_message(p)["data"], b"object_id3", data_size,
|
||||
[b"manager_id1"])
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2",
|
||||
data_size, "hash1", "manager_id3")
|
||||
# Verify that the notification is correct.
|
||||
check_object_notification(get_next_message(p)["data"],
|
||||
b"object_id2",
|
||||
data_size,
|
||||
[b"manager_id3"])
|
||||
check_object_notification(
|
||||
get_next_message(p)["data"], b"object_id2", data_size,
|
||||
[b"manager_id3"])
|
||||
# Request notifications for object_id3 again.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS",
|
||||
"manager_id1", "object_id3")
|
||||
# Verify that the notification is correct.
|
||||
check_object_notification(get_next_message(p)["data"],
|
||||
b"object_id3",
|
||||
data_size,
|
||||
[b"manager_id1", b"manager_id2",
|
||||
b"manager_id3"])
|
||||
check_object_notification(
|
||||
get_next_message(p)["data"], b"object_id3", data_size,
|
||||
[b"manager_id1", b"manager_id2", b"manager_id3"])
|
||||
|
||||
def testResultTableAddAndLookup(self):
|
||||
def check_result_table_entry(message, task_id, is_put):
|
||||
@@ -349,8 +343,7 @@ class TestGlobalStateStore(unittest.TestCase):
|
||||
# update happens, and the response is still the same task.
|
||||
task_args = [task_args[0]] + task_args
|
||||
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
|
||||
"task_id",
|
||||
*task_args[:3])
|
||||
"task_id", *task_args[:3])
|
||||
check_task_reply(response, task_args[1:], updated=True)
|
||||
# Check that the task entry is still the same.
|
||||
get_response = self.redis.execute_command("RAY.TASK_TABLE_GET",
|
||||
@@ -362,8 +355,7 @@ class TestGlobalStateStore(unittest.TestCase):
|
||||
# task.
|
||||
task_args[1] = TASK_STATUS_QUEUED
|
||||
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
|
||||
"task_id",
|
||||
*task_args[:3])
|
||||
"task_id", *task_args[:3])
|
||||
check_task_reply(response, task_args[1:], updated=True)
|
||||
# Check that the update happened.
|
||||
get_response = self.redis.execute_command("RAY.TASK_TABLE_GET",
|
||||
@@ -375,8 +367,7 @@ class TestGlobalStateStore(unittest.TestCase):
|
||||
new_task_args = task_args[:]
|
||||
new_task_args[1] = TASK_STATUS_WAITING
|
||||
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
|
||||
"task_id",
|
||||
*new_task_args[:3])
|
||||
"task_id", *new_task_args[:3])
|
||||
check_task_reply(response, task_args[1:], updated=False)
|
||||
# Check that the update did not happen.
|
||||
get_response2 = self.redis.execute_command("RAY.TASK_TABLE_GET",
|
||||
@@ -388,8 +379,7 @@ class TestGlobalStateStore(unittest.TestCase):
|
||||
task_args = new_task_args
|
||||
task_args[0] = TASK_STATUS_SCHEDULED | TASK_STATUS_QUEUED
|
||||
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
|
||||
"task_id",
|
||||
*task_args[:3])
|
||||
"task_id", *task_args[:3])
|
||||
check_task_reply(response, task_args[1:], updated=True)
|
||||
|
||||
# If the test value is a bitmask that does not match the current value,
|
||||
@@ -399,8 +389,7 @@ class TestGlobalStateStore(unittest.TestCase):
|
||||
new_task_args[0] = TASK_STATUS_SCHEDULED
|
||||
old_response = response
|
||||
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
|
||||
"task_id",
|
||||
*new_task_args[:3])
|
||||
"task_id", *new_task_args[:3])
|
||||
check_task_reply(response, task_args[1:], updated=False)
|
||||
# Check that the update did not happen.
|
||||
get_response = self.redis.execute_command("RAY.TASK_TABLE_GET",
|
||||
@@ -409,8 +398,10 @@ class TestGlobalStateStore(unittest.TestCase):
|
||||
check_task_reply(get_response, task_args[1:])
|
||||
|
||||
def check_task_subscription(self, p, scheduling_state, local_scheduler_id):
|
||||
task_args = [b"task_id", scheduling_state,
|
||||
local_scheduler_id.encode("ascii"), b"", 0, b"task_spec"]
|
||||
task_args = [
|
||||
b"task_id", scheduling_state,
|
||||
local_scheduler_id.encode("ascii"), b"", 0, b"task_spec"
|
||||
]
|
||||
self.redis.execute_command("RAY.TASK_TABLE_ADD", *task_args)
|
||||
# Receive the data.
|
||||
message = get_next_message(p)["data"]
|
||||
@@ -418,8 +409,7 @@ class TestGlobalStateStore(unittest.TestCase):
|
||||
notification_object = TaskReply.GetRootAsTaskReply(message, 0)
|
||||
self.assertEqual(notification_object.TaskId(), task_args[0])
|
||||
self.assertEqual(notification_object.State(), task_args[1])
|
||||
self.assertEqual(notification_object.LocalSchedulerId(),
|
||||
task_args[2])
|
||||
self.assertEqual(notification_object.LocalSchedulerId(), task_args[2])
|
||||
self.assertEqual(notification_object.ExecutionDependencies(),
|
||||
task_args[3])
|
||||
self.assertEqual(notification_object.TaskSpec(), task_args[-1])
|
||||
|
||||
@@ -30,19 +30,23 @@ def random_task_id():
|
||||
|
||||
BASE_SIMPLE_OBJECTS = [
|
||||
0, 1, 100000, 0.0, 0.5, 0.9, 100000.1, (), [], {}, "", 990 * "h", u"",
|
||||
990 * u"h"]
|
||||
990 * u"h"
|
||||
]
|
||||
|
||||
if sys.version_info < (3, 0):
|
||||
BASE_SIMPLE_OBJECTS += [long(0), long(1), long(100000), long(1 << 100)] # noqa: E501,F821
|
||||
BASE_SIMPLE_OBJECTS += [
|
||||
long(0), # noqa: E501,F821
|
||||
long(1), # noqa: E501,F821
|
||||
long(100000), # noqa: E501,F821
|
||||
long(1 << 100) # noqa: E501,F821
|
||||
]
|
||||
|
||||
LIST_SIMPLE_OBJECTS = [[obj] for obj in BASE_SIMPLE_OBJECTS]
|
||||
TUPLE_SIMPLE_OBJECTS = [(obj,) for obj in BASE_SIMPLE_OBJECTS]
|
||||
TUPLE_SIMPLE_OBJECTS = [(obj, ) for obj in BASE_SIMPLE_OBJECTS]
|
||||
DICT_SIMPLE_OBJECTS = [{(): obj} for obj in BASE_SIMPLE_OBJECTS]
|
||||
|
||||
SIMPLE_OBJECTS = (BASE_SIMPLE_OBJECTS +
|
||||
LIST_SIMPLE_OBJECTS +
|
||||
TUPLE_SIMPLE_OBJECTS +
|
||||
DICT_SIMPLE_OBJECTS)
|
||||
SIMPLE_OBJECTS = (BASE_SIMPLE_OBJECTS + LIST_SIMPLE_OBJECTS +
|
||||
TUPLE_SIMPLE_OBJECTS + DICT_SIMPLE_OBJECTS)
|
||||
|
||||
# Create some complex objects that cannot be serialized by value in tasks.
|
||||
|
||||
@@ -55,21 +59,20 @@ class Foo(object):
|
||||
pass
|
||||
|
||||
|
||||
BASE_COMPLEX_OBJECTS = [999 * "h", 999 * u"h", lst, Foo(),
|
||||
10 * [10 * [10 * [1]]]]
|
||||
BASE_COMPLEX_OBJECTS = [
|
||||
999 * "h", 999 * u"h", lst,
|
||||
Foo(), 10 * [10 * [10 * [1]]]
|
||||
]
|
||||
|
||||
LIST_COMPLEX_OBJECTS = [[obj] for obj in BASE_COMPLEX_OBJECTS]
|
||||
TUPLE_COMPLEX_OBJECTS = [(obj,) for obj in BASE_COMPLEX_OBJECTS]
|
||||
TUPLE_COMPLEX_OBJECTS = [(obj, ) for obj in BASE_COMPLEX_OBJECTS]
|
||||
DICT_COMPLEX_OBJECTS = [{(): obj} for obj in BASE_COMPLEX_OBJECTS]
|
||||
|
||||
COMPLEX_OBJECTS = (BASE_COMPLEX_OBJECTS +
|
||||
LIST_COMPLEX_OBJECTS +
|
||||
TUPLE_COMPLEX_OBJECTS +
|
||||
DICT_COMPLEX_OBJECTS)
|
||||
COMPLEX_OBJECTS = (BASE_COMPLEX_OBJECTS + LIST_COMPLEX_OBJECTS +
|
||||
TUPLE_COMPLEX_OBJECTS + DICT_COMPLEX_OBJECTS)
|
||||
|
||||
|
||||
class TestSerialization(unittest.TestCase):
|
||||
|
||||
def test_serialize_by_value(self):
|
||||
|
||||
for val in SIMPLE_OBJECTS:
|
||||
@@ -79,7 +82,6 @@ class TestSerialization(unittest.TestCase):
|
||||
|
||||
|
||||
class TestObjectID(unittest.TestCase):
|
||||
|
||||
def test_create_object_id(self):
|
||||
random_object_id()
|
||||
|
||||
@@ -95,6 +97,7 @@ class TestObjectID(unittest.TestCase):
|
||||
def h():
|
||||
object_ids[0]
|
||||
return 1
|
||||
|
||||
# Make sure that object IDs cannot be pickled (including functions that
|
||||
# close over object IDs).
|
||||
self.assertRaises(Exception, lambda: pickle.dumps(object_ids[0]))
|
||||
@@ -113,10 +116,12 @@ class TestObjectID(unittest.TestCase):
|
||||
self.assertNotEqual(x1, y1)
|
||||
|
||||
random_strings = [np.random.bytes(ID_SIZE) for _ in range(256)]
|
||||
object_ids1 = [local_scheduler.ObjectID(random_strings[i])
|
||||
for i in range(256)]
|
||||
object_ids2 = [local_scheduler.ObjectID(random_strings[i])
|
||||
for i in range(256)]
|
||||
object_ids1 = [
|
||||
local_scheduler.ObjectID(random_strings[i]) for i in range(256)
|
||||
]
|
||||
object_ids2 = [
|
||||
local_scheduler.ObjectID(random_strings[i]) for i in range(256)
|
||||
]
|
||||
self.assertEqual(len(set(object_ids1)), 256)
|
||||
self.assertEqual(len(set(object_ids1 + object_ids2)), 256)
|
||||
self.assertEqual(set(object_ids1), set(object_ids2))
|
||||
@@ -129,7 +134,6 @@ class TestObjectID(unittest.TestCase):
|
||||
|
||||
|
||||
class TestTask(unittest.TestCase):
|
||||
|
||||
def check_task(self, task, function_id, num_return_vals, args):
|
||||
self.assertEqual(function_id.id(), task.function_id().id())
|
||||
retrieved_args = task.arguments()
|
||||
@@ -148,31 +152,17 @@ class TestTask(unittest.TestCase):
|
||||
parent_id = random_task_id()
|
||||
function_id = random_function_id()
|
||||
object_ids = [random_object_id() for _ in range(256)]
|
||||
args_list = [
|
||||
[],
|
||||
1 * [1],
|
||||
10 * [1],
|
||||
100 * [1],
|
||||
1000 * [1],
|
||||
1 * ["a"],
|
||||
10 * ["a"],
|
||||
100 * ["a"],
|
||||
1000 * ["a"],
|
||||
[1, 1.3, 2, 1 << 100, "hi", u"hi", [1, 2]],
|
||||
object_ids[:1],
|
||||
object_ids[:2],
|
||||
object_ids[:3],
|
||||
object_ids[:4],
|
||||
object_ids[:5],
|
||||
object_ids[:10],
|
||||
object_ids[:100],
|
||||
object_ids[:256],
|
||||
[1, object_ids[0]],
|
||||
[object_ids[0], "a"],
|
||||
[1, object_ids[0], "a"],
|
||||
[object_ids[0], 1, object_ids[1], "a"],
|
||||
object_ids[:3] + [1, "hi", 2.3] + object_ids[:5],
|
||||
object_ids + 100 * ["a"] + object_ids]
|
||||
args_list = [[], 1 * [1], 10 * [1], 100 * [1], 1000 * [1], 1 * ["a"],
|
||||
10 * ["a"], 100 * ["a"], 1000 * ["a"], [
|
||||
1, 1.3, 2, 1 << 100, "hi", u"hi", [1, 2]
|
||||
], object_ids[:1], object_ids[:2], object_ids[:3],
|
||||
object_ids[:4], object_ids[:5], object_ids[:10],
|
||||
object_ids[:100], object_ids[:256], [1, object_ids[0]], [
|
||||
object_ids[0], "a"
|
||||
], [1, object_ids[0], "a"], [
|
||||
object_ids[0], 1, object_ids[1], "a"
|
||||
], object_ids[:3] + [1, "hi", 2.3] + object_ids[:5],
|
||||
object_ids + 100 * ["a"] + object_ids]
|
||||
for args in args_list:
|
||||
for num_return_vals in [0, 1, 2, 3, 5, 10, 100]:
|
||||
task = local_scheduler.Task(driver_id, function_id, args,
|
||||
|
||||
@@ -5,5 +5,7 @@ from __future__ import print_function
|
||||
from .tfutils import TensorFlowVariables
|
||||
from .features import flush_redis_unsafe, flush_task_and_object_metadata_unsafe
|
||||
|
||||
__all__ = ["TensorFlowVariables", "flush_redis_unsafe",
|
||||
"flush_task_and_object_metadata_unsafe"]
|
||||
__all__ = [
|
||||
"TensorFlowVariables", "flush_redis_unsafe",
|
||||
"flush_task_and_object_metadata_unsafe"
|
||||
]
|
||||
|
||||
@@ -8,6 +8,8 @@ from .core import (BLOCK_SIZE, DistArray, assemble, zeros, ones, copy, eye,
|
||||
triu, tril, blockwise_dot, dot, transpose, add, subtract,
|
||||
numpy_to_dist, subblocks)
|
||||
|
||||
__all__ = ["random", "linalg", "BLOCK_SIZE", "DistArray", "assemble", "zeros",
|
||||
"ones", "copy", "eye", "triu", "tril", "blockwise_dot", "dot",
|
||||
"transpose", "add", "subtract", "numpy_to_dist", "subblocks"]
|
||||
__all__ = [
|
||||
"random", "linalg", "BLOCK_SIZE", "DistArray", "assemble", "zeros", "ones",
|
||||
"copy", "eye", "triu", "tril", "blockwise_dot", "dot", "transpose", "add",
|
||||
"subtract", "numpy_to_dist", "subblocks"
|
||||
]
|
||||
|
||||
@@ -13,8 +13,9 @@ class DistArray(object):
|
||||
def __init__(self, shape, objectids=None):
|
||||
self.shape = shape
|
||||
self.ndim = len(shape)
|
||||
self.num_blocks = [int(np.ceil(1.0 * a / BLOCK_SIZE))
|
||||
for a in self.shape]
|
||||
self.num_blocks = [
|
||||
int(np.ceil(1.0 * a / BLOCK_SIZE)) for a in self.shape
|
||||
]
|
||||
if objectids is not None:
|
||||
self.objectids = objectids
|
||||
else:
|
||||
@@ -56,7 +57,7 @@ class DistArray(object):
|
||||
|
||||
def assemble(self):
|
||||
"""Assemble an array from a distributed array of object IDs."""
|
||||
first_block = ray.get(self.objectids[(0,) * self.ndim])
|
||||
first_block = ray.get(self.objectids[(0, ) * self.ndim])
|
||||
dtype = first_block.dtype
|
||||
result = np.zeros(self.shape, dtype=dtype)
|
||||
for index in np.ndindex(*self.num_blocks):
|
||||
@@ -85,8 +86,8 @@ def numpy_to_dist(a):
|
||||
for index in np.ndindex(*result.num_blocks):
|
||||
lower = DistArray.compute_block_lower(index, a.shape)
|
||||
upper = DistArray.compute_block_upper(index, a.shape)
|
||||
result.objectids[index] = ray.put(a[[slice(l, u) for (l, u)
|
||||
in zip(lower, upper)]])
|
||||
result.objectids[index] = ray.put(
|
||||
a[[slice(l, u) for (l, u) in zip(lower, upper)]])
|
||||
return result
|
||||
|
||||
|
||||
@@ -126,12 +127,11 @@ def eye(dim1, dim2=-1, dtype_name="float"):
|
||||
for (i, j) in np.ndindex(*result.num_blocks):
|
||||
block_shape = DistArray.compute_block_shape([i, j], shape)
|
||||
if i == j:
|
||||
result.objectids[i, j] = ra.eye.remote(block_shape[0],
|
||||
block_shape[1],
|
||||
dtype_name=dtype_name)
|
||||
result.objectids[i, j] = ra.eye.remote(
|
||||
block_shape[0], block_shape[1], dtype_name=dtype_name)
|
||||
else:
|
||||
result.objectids[i, j] = ra.zeros.remote(block_shape,
|
||||
dtype_name=dtype_name)
|
||||
result.objectids[i, j] = ra.zeros.remote(
|
||||
block_shape, dtype_name=dtype_name)
|
||||
return result
|
||||
|
||||
|
||||
@@ -190,8 +190,8 @@ def dot(a, b):
|
||||
"b.ndim = {}.".format(b.ndim))
|
||||
if a.shape[1] != b.shape[0]:
|
||||
raise Exception("dot expects a.shape[1] to equal b.shape[0], but "
|
||||
"a.shape = {} and b.shape = {}.".format(a.shape,
|
||||
b.shape))
|
||||
"a.shape = {} and b.shape = {}.".format(
|
||||
a.shape, b.shape))
|
||||
shape = [a.shape[0], b.shape[1]]
|
||||
result = DistArray(shape)
|
||||
for (i, j) in np.ndindex(*result.num_blocks):
|
||||
@@ -227,8 +227,8 @@ def subblocks(a, *ranges):
|
||||
"the {}th range is {}.".format(i, ranges[i]))
|
||||
if ranges[i][0] < 0:
|
||||
raise Exception("Values in the ranges passed to sub_blocks must "
|
||||
"be at least 0, but the {}th range is {}."
|
||||
.format(i, ranges[i]))
|
||||
"be at least 0, but the {}th range is {}.".format(
|
||||
i, ranges[i]))
|
||||
if ranges[i][-1] >= a.num_blocks[i]:
|
||||
raise Exception("Values in the ranges passed to sub_blocks must "
|
||||
"be less than the relevant number of blocks, but "
|
||||
@@ -240,8 +240,8 @@ def subblocks(a, *ranges):
|
||||
for i in range(a.ndim)]
|
||||
result = DistArray(shape)
|
||||
for index in np.ndindex(*result.num_blocks):
|
||||
result.objectids[index] = a.objectids[tuple([ranges[i][index[i]]
|
||||
for i in range(a.ndim)])]
|
||||
result.objectids[index] = a.objectids[tuple(
|
||||
[ranges[i][index[i]] for i in range(a.ndim)])]
|
||||
return result
|
||||
|
||||
|
||||
@@ -249,8 +249,8 @@ def subblocks(a, *ranges):
|
||||
def transpose(a):
|
||||
if a.ndim != 2:
|
||||
raise Exception("transpose expects its argument to be 2-dimensional, "
|
||||
"but a.ndim = {}, a.shape = {}.".format(a.ndim,
|
||||
a.shape))
|
||||
"but a.ndim = {}, a.shape = {}.".format(
|
||||
a.ndim, a.shape))
|
||||
result = DistArray([a.shape[1], a.shape[0]])
|
||||
for i in range(result.num_blocks[0]):
|
||||
for j in range(result.num_blocks[1]):
|
||||
@@ -263,8 +263,8 @@ def transpose(a):
|
||||
def add(x1, x2):
|
||||
if x1.shape != x2.shape:
|
||||
raise Exception("add expects arguments `x1` and `x2` to have the same "
|
||||
"shape, but x1.shape = {}, and x2.shape = {}."
|
||||
.format(x1.shape, x2.shape))
|
||||
"shape, but x1.shape = {}, and x2.shape = {}.".format(
|
||||
x1.shape, x2.shape))
|
||||
result = DistArray(x1.shape)
|
||||
for index in np.ndindex(*result.num_blocks):
|
||||
result.objectids[index] = ra.add.remote(x1.objectids[index],
|
||||
|
||||
@@ -76,9 +76,10 @@ def tsqr(a):
|
||||
lower = [a.shape[1], 0]
|
||||
upper = [2 * a.shape[1], core.BLOCK_SIZE]
|
||||
ith_index //= 2
|
||||
q_block_current = ra.dot.remote(
|
||||
q_block_current, ra.subarray.remote(q_tree[ith_index, j],
|
||||
lower, upper))
|
||||
q_block_current = ra.dot.remote(q_block_current,
|
||||
ra.subarray.remote(
|
||||
q_tree[ith_index, j], lower,
|
||||
upper))
|
||||
q_result.objectids[i] = q_block_current
|
||||
r = current_rs[0]
|
||||
return q_result, ray.get(r)
|
||||
@@ -196,8 +197,8 @@ def qr(a):
|
||||
if a.shape[0] > a.shape[1]:
|
||||
# in this case, R needs to be square
|
||||
R_shape = ray.get(ra.shape.remote(R))
|
||||
eye_temp = ra.eye.remote(R_shape[1], R_shape[0],
|
||||
dtype_name=result_dtype)
|
||||
eye_temp = ra.eye.remote(
|
||||
R_shape[1], R_shape[0], dtype_name=result_dtype)
|
||||
r_res.objectids[i, i] = ra.dot.remote(eye_temp, R)
|
||||
else:
|
||||
r_res.objectids[i, i] = R
|
||||
@@ -220,10 +221,11 @@ def qr(a):
|
||||
for i in range(len(Ts))[::-1]:
|
||||
y_col_block = core.subblocks.remote(y_res, [], [i])
|
||||
q = core.subtract.remote(
|
||||
q, core.dot.remote(
|
||||
y_col_block,
|
||||
core.dot.remote(
|
||||
Ts[i],
|
||||
core.dot.remote(core.transpose.remote(y_col_block), q))))
|
||||
q,
|
||||
core.dot.remote(y_col_block,
|
||||
core.dot.remote(
|
||||
Ts[i],
|
||||
core.dot.remote(
|
||||
core.transpose.remote(y_col_block), q))))
|
||||
|
||||
return ray.get(q), r_res
|
||||
|
||||
@@ -8,6 +8,8 @@ from .core import (zeros, zeros_like, ones, eye, dot, vstack, hstack, subarray,
|
||||
copy, tril, triu, diag, transpose, add, subtract, sum,
|
||||
shape, sum_list)
|
||||
|
||||
__all__ = ["random", "linalg", "zeros", "zeros_like", "ones", "eye", "dot",
|
||||
"vstack", "hstack", "subarray", "copy", "tril", "triu", "diag",
|
||||
"transpose", "add", "subtract", "sum", "shape", "sum_list"]
|
||||
__all__ = [
|
||||
"random", "linalg", "zeros", "zeros_like", "ones", "eye", "dot", "vstack",
|
||||
"hstack", "subarray", "copy", "tril", "triu", "diag", "transpose", "add",
|
||||
"subtract", "sum", "shape", "sum_list"
|
||||
]
|
||||
|
||||
@@ -5,10 +5,11 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
import ray
|
||||
|
||||
__all__ = ["matrix_power", "solve", "tensorsolve", "tensorinv", "inv",
|
||||
"cholesky", "eigvals", "eigvalsh", "pinv", "slogdet", "det",
|
||||
"svd", "eig", "eigh", "lstsq", "norm", "qr", "cond", "matrix_rank",
|
||||
"multi_dot"]
|
||||
__all__ = [
|
||||
"matrix_power", "solve", "tensorsolve", "tensorinv", "inv", "cholesky",
|
||||
"eigvals", "eigvalsh", "pinv", "slogdet", "det", "svd", "eig", "eigh",
|
||||
"lstsq", "norm", "qr", "cond", "matrix_rank", "multi_dot"
|
||||
]
|
||||
|
||||
|
||||
@ray.remote
|
||||
|
||||
@@ -69,14 +69,14 @@ def flush_task_and_object_metadata_unsafe():
|
||||
for key in redis_client.scan_iter(match=OBJECT_INFO_PREFIX + b"*"):
|
||||
num_object_keys_deleted += redis_client.delete(key)
|
||||
print("Deleted {} object info keys from Redis.".format(
|
||||
num_object_keys_deleted))
|
||||
num_object_keys_deleted))
|
||||
|
||||
# Flush the object locations.
|
||||
num_object_location_keys_deleted = 0
|
||||
for key in redis_client.scan_iter(match=OBJECT_LOCATION_PREFIX + b"*"):
|
||||
num_object_location_keys_deleted += redis_client.delete(key)
|
||||
print("Deleted {} object location keys from Redis.".format(
|
||||
num_object_location_keys_deleted))
|
||||
num_object_location_keys_deleted))
|
||||
|
||||
# Loop over the shards and flush all of them.
|
||||
for redis_client in ray.worker.global_state.redis_clients:
|
||||
|
||||
+279
-192
@@ -59,6 +59,7 @@ class GlobalState(object):
|
||||
Attributes:
|
||||
redis_client: The redis client used to query the redis server.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Create a GlobalState object."""
|
||||
# The redis server storing metadata, such as function table, client
|
||||
@@ -82,7 +83,9 @@ class GlobalState(object):
|
||||
raise Exception("The ray.global_state API cannot be used before "
|
||||
"ray.init has been called.")
|
||||
|
||||
def _initialize_global_state(self, redis_ip_address, redis_port,
|
||||
def _initialize_global_state(self,
|
||||
redis_ip_address,
|
||||
redis_port,
|
||||
timeout=20):
|
||||
"""Initialize the GlobalState object by connecting to Redis.
|
||||
|
||||
@@ -97,8 +100,8 @@ class GlobalState(object):
|
||||
timeout: The maximum amount of time (in seconds) that we should
|
||||
wait for the keys in Redis to be populated.
|
||||
"""
|
||||
self.redis_client = redis.StrictRedis(host=redis_ip_address,
|
||||
port=redis_port)
|
||||
self.redis_client = redis.StrictRedis(
|
||||
host=redis_ip_address, port=redis_port)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
@@ -118,8 +121,8 @@ class GlobalState(object):
|
||||
"{}.".format(num_redis_shards))
|
||||
|
||||
# Attempt to get all of the Redis shards.
|
||||
ip_address_ports = self.redis_client.lrange("RedisShards", start=0,
|
||||
end=-1)
|
||||
ip_address_ports = self.redis_client.lrange(
|
||||
"RedisShards", start=0, end=-1)
|
||||
if len(ip_address_ports) != num_redis_shards:
|
||||
print("Waiting longer for RedisShards to be populated.")
|
||||
time.sleep(1)
|
||||
@@ -132,15 +135,15 @@ class GlobalState(object):
|
||||
if time.time() - start_time >= timeout:
|
||||
raise Exception("Timed out while attempting to initialize the "
|
||||
"global state. num_redis_shards = {}, "
|
||||
"ip_address_ports = {}"
|
||||
.format(num_redis_shards, ip_address_ports))
|
||||
"ip_address_ports = {}".format(
|
||||
num_redis_shards, ip_address_ports))
|
||||
|
||||
# Get the rest of the information.
|
||||
self.redis_clients = []
|
||||
for ip_address_port in ip_address_ports:
|
||||
shard_address, shard_port = ip_address_port.split(b":")
|
||||
self.redis_clients.append(redis.StrictRedis(host=shard_address,
|
||||
port=shard_port))
|
||||
self.redis_clients.append(
|
||||
redis.StrictRedis(host=shard_address, port=shard_port))
|
||||
|
||||
def _execute_command(self, key, *args):
|
||||
"""Execute a Redis command on the appropriate Redis shard based on key.
|
||||
@@ -152,8 +155,8 @@ class GlobalState(object):
|
||||
Returns:
|
||||
The value returned by the Redis command.
|
||||
"""
|
||||
client = self.redis_clients[key.redis_shard_hash() %
|
||||
len(self.redis_clients)]
|
||||
client = self.redis_clients[key.redis_shard_hash() % len(
|
||||
self.redis_clients)]
|
||||
return client.execute_command(*args)
|
||||
|
||||
def _keys(self, pattern):
|
||||
@@ -189,8 +192,9 @@ class GlobalState(object):
|
||||
"RAY.OBJECT_TABLE_LOOKUP",
|
||||
object_id.id())
|
||||
if object_locations is not None:
|
||||
manager_ids = [binary_to_hex(manager_id)
|
||||
for manager_id in object_locations]
|
||||
manager_ids = [
|
||||
binary_to_hex(manager_id) for manager_id in object_locations
|
||||
]
|
||||
else:
|
||||
manager_ids = None
|
||||
|
||||
@@ -199,11 +203,13 @@ class GlobalState(object):
|
||||
result_table_message = ResultTableReply.GetRootAsResultTableReply(
|
||||
result_table_response, 0)
|
||||
|
||||
result = {"ManagerIDs": manager_ids,
|
||||
"TaskID": binary_to_hex(result_table_message.TaskId()),
|
||||
"IsPut": bool(result_table_message.IsPut()),
|
||||
"DataSize": result_table_message.DataSize(),
|
||||
"Hash": binary_to_hex(result_table_message.Hash())}
|
||||
result = {
|
||||
"ManagerIDs": manager_ids,
|
||||
"TaskID": binary_to_hex(result_table_message.TaskId()),
|
||||
"IsPut": bool(result_table_message.IsPut()),
|
||||
"DataSize": result_table_message.DataSize(),
|
||||
"Hash": binary_to_hex(result_table_message.Hash())
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
@@ -227,9 +233,10 @@ class GlobalState(object):
|
||||
object_info_keys = self._keys(OBJECT_INFO_PREFIX + "*")
|
||||
object_location_keys = self._keys(OBJECT_LOCATION_PREFIX + "*")
|
||||
object_ids_binary = set(
|
||||
[key[len(OBJECT_INFO_PREFIX):] for key in object_info_keys] +
|
||||
[key[len(OBJECT_LOCATION_PREFIX):]
|
||||
for key in object_location_keys])
|
||||
[key[len(OBJECT_INFO_PREFIX):] for key in object_info_keys] + [
|
||||
key[len(OBJECT_LOCATION_PREFIX):]
|
||||
for key in object_location_keys
|
||||
])
|
||||
results = {}
|
||||
for object_id_binary in object_ids_binary:
|
||||
results[binary_to_object_id(object_id_binary)] = (
|
||||
@@ -254,26 +261,37 @@ class GlobalState(object):
|
||||
if task_table_response is None:
|
||||
raise Exception("There is no entry for task ID {} in the task "
|
||||
"table.".format(binary_to_hex(task_id.id())))
|
||||
task_table_message = TaskReply.GetRootAsTaskReply(task_table_response,
|
||||
0)
|
||||
task_table_message = TaskReply.GetRootAsTaskReply(
|
||||
task_table_response, 0)
|
||||
task_spec = task_table_message.TaskSpec()
|
||||
task_spec = ray.local_scheduler.task_from_string(task_spec)
|
||||
|
||||
task_spec_info = {
|
||||
"DriverID": binary_to_hex(task_spec.driver_id().id()),
|
||||
"TaskID": binary_to_hex(task_spec.task_id().id()),
|
||||
"ParentTaskID": binary_to_hex(task_spec.parent_task_id().id()),
|
||||
"ParentCounter": task_spec.parent_counter(),
|
||||
"ActorID": binary_to_hex(task_spec.actor_id().id()),
|
||||
"DriverID":
|
||||
binary_to_hex(task_spec.driver_id().id()),
|
||||
"TaskID":
|
||||
binary_to_hex(task_spec.task_id().id()),
|
||||
"ParentTaskID":
|
||||
binary_to_hex(task_spec.parent_task_id().id()),
|
||||
"ParentCounter":
|
||||
task_spec.parent_counter(),
|
||||
"ActorID":
|
||||
binary_to_hex(task_spec.actor_id().id()),
|
||||
"ActorCreationID":
|
||||
binary_to_hex(task_spec.actor_creation_id().id()),
|
||||
binary_to_hex(task_spec.actor_creation_id().id()),
|
||||
"ActorCreationDummyObjectID":
|
||||
binary_to_hex(task_spec.actor_creation_dummy_object_id().id()),
|
||||
"ActorCounter": task_spec.actor_counter(),
|
||||
"FunctionID": binary_to_hex(task_spec.function_id().id()),
|
||||
"Args": task_spec.arguments(),
|
||||
"ReturnObjectIDs": task_spec.returns(),
|
||||
"RequiredResources": task_spec.required_resources()}
|
||||
binary_to_hex(task_spec.actor_creation_dummy_object_id().id()),
|
||||
"ActorCounter":
|
||||
task_spec.actor_counter(),
|
||||
"FunctionID":
|
||||
binary_to_hex(task_spec.function_id().id()),
|
||||
"Args":
|
||||
task_spec.arguments(),
|
||||
"ReturnObjectIDs":
|
||||
task_spec.returns(),
|
||||
"RequiredResources":
|
||||
task_spec.required_resources()
|
||||
}
|
||||
|
||||
execution_dependencies_message = (
|
||||
TaskExecutionDependencies.GetRootAsTaskExecutionDependencies(
|
||||
@@ -282,21 +300,27 @@ class GlobalState(object):
|
||||
ray.local_scheduler.ObjectID(
|
||||
execution_dependencies_message.ExecutionDependencies(i))
|
||||
for i in range(
|
||||
execution_dependencies_message.ExecutionDependenciesLength())]
|
||||
execution_dependencies_message.ExecutionDependenciesLength())
|
||||
]
|
||||
|
||||
# TODO(rkn): The return fields ExecutionDependenciesString and
|
||||
# ExecutionDependencies are redundant, so we should remove
|
||||
# ExecutionDependencies. However, it is currently used in monitor.py.
|
||||
|
||||
return {"State": task_table_message.State(),
|
||||
"LocalSchedulerID": binary_to_hex(
|
||||
task_table_message.LocalSchedulerId()),
|
||||
"ExecutionDependenciesString":
|
||||
task_table_message.ExecutionDependencies(),
|
||||
"ExecutionDependencies": execution_dependencies,
|
||||
"SpillbackCount":
|
||||
task_table_message.SpillbackCount(),
|
||||
"TaskSpec": task_spec_info}
|
||||
return {
|
||||
"State":
|
||||
task_table_message.State(),
|
||||
"LocalSchedulerID":
|
||||
binary_to_hex(task_table_message.LocalSchedulerId()),
|
||||
"ExecutionDependenciesString":
|
||||
task_table_message.ExecutionDependencies(),
|
||||
"ExecutionDependencies":
|
||||
execution_dependencies,
|
||||
"SpillbackCount":
|
||||
task_table_message.SpillbackCount(),
|
||||
"TaskSpec":
|
||||
task_spec_info
|
||||
}
|
||||
|
||||
def task_table(self, task_id=None):
|
||||
"""Fetch and parse the task table information for one or more task IDs.
|
||||
@@ -337,7 +361,8 @@ class GlobalState(object):
|
||||
function_info_parsed = {
|
||||
"DriverID": binary_to_hex(info[b"driver_id"]),
|
||||
"Module": decode(info[b"module"]),
|
||||
"Name": decode(info[b"name"])}
|
||||
"Name": decode(info[b"name"])
|
||||
}
|
||||
results[binary_to_hex(info[b"function_id"])] = function_info_parsed
|
||||
return results
|
||||
|
||||
@@ -469,21 +494,17 @@ class GlobalState(object):
|
||||
if start is None and end is None:
|
||||
if fwd:
|
||||
event_list = self.redis_client.zrange(
|
||||
event_log_set,
|
||||
**params)
|
||||
event_log_set, **params)
|
||||
else:
|
||||
event_list = self.redis_client.zrevrange(
|
||||
event_log_set,
|
||||
**params)
|
||||
event_log_set, **params)
|
||||
else:
|
||||
if fwd:
|
||||
event_list = self.redis_client.zrangebyscore(
|
||||
event_log_set,
|
||||
**params)
|
||||
event_log_set, **params)
|
||||
else:
|
||||
event_list = self.redis_client.zrevrangebyscore(
|
||||
event_log_set,
|
||||
**params)
|
||||
event_log_set, **params)
|
||||
|
||||
for (event, score) in event_list:
|
||||
event_dict = json.loads(event.decode())
|
||||
@@ -503,11 +524,11 @@ class GlobalState(object):
|
||||
task_info[task_id]["get_task_start"] = event[0]
|
||||
if event[1] == "ray:get_task" and event[2] == 2:
|
||||
task_info[task_id]["get_task_end"] = event[0]
|
||||
if (event[1] == "ray:import_remote_function" and
|
||||
event[2] == 1):
|
||||
if (event[1] == "ray:import_remote_function"
|
||||
and event[2] == 1):
|
||||
task_info[task_id]["import_remote_start"] = event[0]
|
||||
if (event[1] == "ray:import_remote_function" and
|
||||
event[2] == 2):
|
||||
if (event[1] == "ray:import_remote_function"
|
||||
and event[2] == 2):
|
||||
task_info[task_id]["import_remote_end"] = event[0]
|
||||
if event[1] == "ray:acquire_lock" and event[2] == 1:
|
||||
task_info[task_id]["acquire_lock_start"] = event[0]
|
||||
@@ -547,7 +568,6 @@ class GlobalState(object):
|
||||
breakdowns=True,
|
||||
task_dep=True,
|
||||
obj_dep=True):
|
||||
|
||||
"""Dump task profiling information to a file.
|
||||
|
||||
This information can be viewed as a timeline of profiling information
|
||||
@@ -604,72 +624,103 @@ class GlobalState(object):
|
||||
# modify it in place since we will use the original values later.
|
||||
total_info = copy.copy(task_table[task_id]["TaskSpec"])
|
||||
total_info["Args"] = [
|
||||
oid.hex() if isinstance(oid, ray.local_scheduler.ObjectID)
|
||||
else oid for oid in task_t_info["TaskSpec"]["Args"]]
|
||||
oid.hex()
|
||||
if isinstance(oid, ray.local_scheduler.ObjectID) else oid
|
||||
for oid in task_t_info["TaskSpec"]["Args"]
|
||||
]
|
||||
total_info["ReturnObjectIDs"] = [
|
||||
oid.hex() for oid
|
||||
in task_t_info["TaskSpec"]["ReturnObjectIDs"]]
|
||||
oid.hex() for oid in task_t_info["TaskSpec"]["ReturnObjectIDs"]
|
||||
]
|
||||
total_info["LocalSchedulerID"] = task_t_info["LocalSchedulerID"]
|
||||
total_info["get_arguments"] = (info["get_arguments_end"] -
|
||||
info["get_arguments_start"])
|
||||
total_info["execute"] = (info["execute_end"] -
|
||||
info["execute_start"])
|
||||
total_info["store_outputs"] = (info["store_outputs_end"] -
|
||||
info["store_outputs_start"])
|
||||
total_info["get_arguments"] = (
|
||||
info["get_arguments_end"] - info["get_arguments_start"])
|
||||
total_info["execute"] = (
|
||||
info["execute_end"] - info["execute_start"])
|
||||
total_info["store_outputs"] = (
|
||||
info["store_outputs_end"] - info["store_outputs_start"])
|
||||
total_info["function_name"] = info["function_name"]
|
||||
total_info["worker_id"] = info["worker_id"]
|
||||
|
||||
parent_info = task_info.get(
|
||||
task_table[task_id]["TaskSpec"]["ParentTaskID"])
|
||||
task_table[task_id]["TaskSpec"]["ParentTaskID"])
|
||||
worker = workers[info["worker_id"]]
|
||||
# The catapult trace format documentation can be found here:
|
||||
# https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview # noqa: E501
|
||||
if breakdowns:
|
||||
if "get_arguments_end" in info:
|
||||
get_args_trace = {
|
||||
"cat": "get_arguments",
|
||||
"pid": "Node " + worker["node_ip_address"],
|
||||
"tid": info["worker_id"],
|
||||
"id": task_id,
|
||||
"ts": micros_rel(info["get_arguments_start"]),
|
||||
"ph": "X",
|
||||
"name": info["function_name"] + ":get_arguments",
|
||||
"args": total_info,
|
||||
"dur": micros(info["get_arguments_end"] -
|
||||
info["get_arguments_start"]),
|
||||
"cname": "rail_idle"
|
||||
"cat":
|
||||
"get_arguments",
|
||||
"pid":
|
||||
"Node " + worker["node_ip_address"],
|
||||
"tid":
|
||||
info["worker_id"],
|
||||
"id":
|
||||
task_id,
|
||||
"ts":
|
||||
micros_rel(info["get_arguments_start"]),
|
||||
"ph":
|
||||
"X",
|
||||
"name":
|
||||
info["function_name"] + ":get_arguments",
|
||||
"args":
|
||||
total_info,
|
||||
"dur":
|
||||
micros(info["get_arguments_end"] -
|
||||
info["get_arguments_start"]),
|
||||
"cname":
|
||||
"rail_idle"
|
||||
}
|
||||
full_trace.append(get_args_trace)
|
||||
|
||||
if "store_outputs_end" in info:
|
||||
outputs_trace = {
|
||||
"cat": "store_outputs",
|
||||
"pid": "Node " + worker["node_ip_address"],
|
||||
"tid": info["worker_id"],
|
||||
"id": task_id,
|
||||
"ts": micros_rel(info["store_outputs_start"]),
|
||||
"ph": "X",
|
||||
"name": info["function_name"] + ":store_outputs",
|
||||
"args": total_info,
|
||||
"dur": micros(info["store_outputs_end"] -
|
||||
info["store_outputs_start"]),
|
||||
"cname": "thread_state_runnable"
|
||||
"cat":
|
||||
"store_outputs",
|
||||
"pid":
|
||||
"Node " + worker["node_ip_address"],
|
||||
"tid":
|
||||
info["worker_id"],
|
||||
"id":
|
||||
task_id,
|
||||
"ts":
|
||||
micros_rel(info["store_outputs_start"]),
|
||||
"ph":
|
||||
"X",
|
||||
"name":
|
||||
info["function_name"] + ":store_outputs",
|
||||
"args":
|
||||
total_info,
|
||||
"dur":
|
||||
micros(info["store_outputs_end"] -
|
||||
info["store_outputs_start"]),
|
||||
"cname":
|
||||
"thread_state_runnable"
|
||||
}
|
||||
full_trace.append(outputs_trace)
|
||||
|
||||
if "execute_end" in info:
|
||||
execute_trace = {
|
||||
"cat": "execute",
|
||||
"pid": "Node " + worker["node_ip_address"],
|
||||
"tid": info["worker_id"],
|
||||
"id": task_id,
|
||||
"ts": micros_rel(info["execute_start"]),
|
||||
"ph": "X",
|
||||
"name": info["function_name"] + ":execute",
|
||||
"args": total_info,
|
||||
"dur": micros(info["execute_end"] -
|
||||
info["execute_start"]),
|
||||
"cname": "rail_animation"
|
||||
"cat":
|
||||
"execute",
|
||||
"pid":
|
||||
"Node " + worker["node_ip_address"],
|
||||
"tid":
|
||||
info["worker_id"],
|
||||
"id":
|
||||
task_id,
|
||||
"ts":
|
||||
micros_rel(info["execute_start"]),
|
||||
"ph":
|
||||
"X",
|
||||
"name":
|
||||
info["function_name"] + ":execute",
|
||||
"args":
|
||||
total_info,
|
||||
"dur":
|
||||
micros(info["execute_end"] - info["execute_start"]),
|
||||
"cname":
|
||||
"rail_animation"
|
||||
}
|
||||
full_trace.append(execute_trace)
|
||||
|
||||
@@ -680,15 +731,20 @@ class GlobalState(object):
|
||||
parent_profile = task_info.get(
|
||||
task_table[task_id]["TaskSpec"]["ParentTaskID"])
|
||||
parent = {
|
||||
"cat": "submit_task",
|
||||
"pid": "Node " + parent_worker["node_ip_address"],
|
||||
"tid": parent_info["worker_id"],
|
||||
"ts": micros_rel(
|
||||
parent_profile and
|
||||
parent_profile["get_arguments_start"] or
|
||||
start_time),
|
||||
"ph": "s",
|
||||
"name": "SubmitTask",
|
||||
"cat":
|
||||
"submit_task",
|
||||
"pid":
|
||||
"Node " + parent_worker["node_ip_address"],
|
||||
"tid":
|
||||
parent_info["worker_id"],
|
||||
"ts":
|
||||
micros_rel(parent_profile
|
||||
and parent_profile["get_arguments_start"]
|
||||
or start_time),
|
||||
"ph":
|
||||
"s",
|
||||
"name":
|
||||
"SubmitTask",
|
||||
"args": {},
|
||||
"id": (parent_info["worker_id"] +
|
||||
str(micros(min(parent_times))))
|
||||
@@ -696,32 +752,50 @@ class GlobalState(object):
|
||||
full_trace.append(parent)
|
||||
|
||||
task_trace = {
|
||||
"cat": "submit_task",
|
||||
"pid": "Node " + worker["node_ip_address"],
|
||||
"tid": info["worker_id"],
|
||||
"ts": micros_rel(info["get_arguments_start"]),
|
||||
"ph": "f",
|
||||
"name": "SubmitTask",
|
||||
"cat":
|
||||
"submit_task",
|
||||
"pid":
|
||||
"Node " + worker["node_ip_address"],
|
||||
"tid":
|
||||
info["worker_id"],
|
||||
"ts":
|
||||
micros_rel(info["get_arguments_start"]),
|
||||
"ph":
|
||||
"f",
|
||||
"name":
|
||||
"SubmitTask",
|
||||
"args": {},
|
||||
"id": (info["worker_id"] +
|
||||
str(micros(min(parent_times)))),
|
||||
"bp": "e",
|
||||
"cname": "olive"
|
||||
"id":
|
||||
(info["worker_id"] + str(micros(min(parent_times)))),
|
||||
"bp":
|
||||
"e",
|
||||
"cname":
|
||||
"olive"
|
||||
}
|
||||
full_trace.append(task_trace)
|
||||
|
||||
task = {
|
||||
"cat": "task",
|
||||
"pid": "Node " + worker["node_ip_address"],
|
||||
"tid": info["worker_id"],
|
||||
"id": task_id,
|
||||
"ts": micros_rel(info["get_arguments_start"]),
|
||||
"ph": "X",
|
||||
"name": info["function_name"],
|
||||
"args": total_info,
|
||||
"dur": micros(info["store_outputs_end"] -
|
||||
info["get_arguments_start"]),
|
||||
"cname": "thread_state_runnable"
|
||||
"cat":
|
||||
"task",
|
||||
"pid":
|
||||
"Node " + worker["node_ip_address"],
|
||||
"tid":
|
||||
info["worker_id"],
|
||||
"id":
|
||||
task_id,
|
||||
"ts":
|
||||
micros_rel(info["get_arguments_start"]),
|
||||
"ph":
|
||||
"X",
|
||||
"name":
|
||||
info["function_name"],
|
||||
"args":
|
||||
total_info,
|
||||
"dur":
|
||||
micros(info["store_outputs_end"] -
|
||||
info["get_arguments_start"]),
|
||||
"cname":
|
||||
"thread_state_runnable"
|
||||
}
|
||||
full_trace.append(task)
|
||||
|
||||
@@ -732,15 +806,20 @@ class GlobalState(object):
|
||||
parent_profile = task_info.get(
|
||||
task_table[task_id]["TaskSpec"]["ParentTaskID"])
|
||||
parent = {
|
||||
"cat": "submit_task",
|
||||
"pid": "Node " + parent_worker["node_ip_address"],
|
||||
"tid": parent_info["worker_id"],
|
||||
"ts": micros_rel(
|
||||
parent_profile and
|
||||
parent_profile["get_arguments_start"] or
|
||||
start_time),
|
||||
"ph": "s",
|
||||
"name": "SubmitTask",
|
||||
"cat":
|
||||
"submit_task",
|
||||
"pid":
|
||||
"Node " + parent_worker["node_ip_address"],
|
||||
"tid":
|
||||
parent_info["worker_id"],
|
||||
"ts":
|
||||
micros_rel(parent_profile
|
||||
and parent_profile["get_arguments_start"]
|
||||
or start_time),
|
||||
"ph":
|
||||
"s",
|
||||
"name":
|
||||
"SubmitTask",
|
||||
"args": {},
|
||||
"id": (parent_info["worker_id"] +
|
||||
str(micros(min(parent_times))))
|
||||
@@ -748,16 +827,23 @@ class GlobalState(object):
|
||||
full_trace.append(parent)
|
||||
|
||||
task_trace = {
|
||||
"cat": "submit_task",
|
||||
"pid": "Node " + worker["node_ip_address"],
|
||||
"tid": info["worker_id"],
|
||||
"ts": micros_rel(info["get_arguments_start"]),
|
||||
"ph": "f",
|
||||
"name": "SubmitTask",
|
||||
"cat":
|
||||
"submit_task",
|
||||
"pid":
|
||||
"Node " + worker["node_ip_address"],
|
||||
"tid":
|
||||
info["worker_id"],
|
||||
"ts":
|
||||
micros_rel(info["get_arguments_start"]),
|
||||
"ph":
|
||||
"f",
|
||||
"name":
|
||||
"SubmitTask",
|
||||
"args": {},
|
||||
"id": (info["worker_id"] +
|
||||
str(micros(min(parent_times)))),
|
||||
"bp": "e"
|
||||
"id":
|
||||
(info["worker_id"] + str(micros(min(parent_times)))),
|
||||
"bp":
|
||||
"e"
|
||||
}
|
||||
full_trace.append(task_trace)
|
||||
|
||||
@@ -775,8 +861,8 @@ class GlobalState(object):
|
||||
seen_obj[arg] += 1
|
||||
owner_task = self._object_table(arg)["TaskID"]
|
||||
if owner_task in task_info:
|
||||
owner_worker = (workers[
|
||||
task_info[owner_task]["worker_id"]])
|
||||
owner_worker = (workers[task_info[owner_task][
|
||||
"worker_id"]])
|
||||
# Adding/subtracting 2 to the time associated
|
||||
# with the beginning/ending of the flow event
|
||||
# is necessary to make the flow events show up
|
||||
@@ -790,27 +876,35 @@ class GlobalState(object):
|
||||
# duration event that it's associated with, and
|
||||
# the flow event therefore always gets drawn.
|
||||
owner = {
|
||||
"cat": "obj_dependency",
|
||||
"cat":
|
||||
"obj_dependency",
|
||||
"pid": ("Node " +
|
||||
owner_worker["node_ip_address"]),
|
||||
"tid": task_info[owner_task]["worker_id"],
|
||||
"ts": micros_rel(task_info[
|
||||
owner_task]["store_outputs_end"]) - 2,
|
||||
"ph": "s",
|
||||
"name": "ObjectDependency",
|
||||
"tid":
|
||||
task_info[owner_task]["worker_id"],
|
||||
"ts":
|
||||
micros_rel(task_info[owner_task]
|
||||
["store_outputs_end"]) - 2,
|
||||
"ph":
|
||||
"s",
|
||||
"name":
|
||||
"ObjectDependency",
|
||||
"args": {},
|
||||
"bp": "e",
|
||||
"cname": "cq_build_attempt_failed",
|
||||
"id": "obj" + str(arg) + str(seen_obj[arg])
|
||||
"bp":
|
||||
"e",
|
||||
"cname":
|
||||
"cq_build_attempt_failed",
|
||||
"id":
|
||||
"obj" + str(arg) + str(seen_obj[arg])
|
||||
}
|
||||
full_trace.append(owner)
|
||||
|
||||
dependent = {
|
||||
"cat": "obj_dependency",
|
||||
"pid": "Node " + worker["node_ip_address"],
|
||||
"pid": "Node " + worker["node_ip_address"],
|
||||
"tid": info["worker_id"],
|
||||
"ts": micros_rel(
|
||||
info["get_arguments_start"]) + 2,
|
||||
"ts":
|
||||
micros_rel(info["get_arguments_start"]) + 2,
|
||||
"ph": "f",
|
||||
"name": "ObjectDependency",
|
||||
"args": {},
|
||||
@@ -852,14 +946,10 @@ class GlobalState(object):
|
||||
"""
|
||||
|
||||
keys = [
|
||||
"acquire_lock_start",
|
||||
"acquire_lock_end",
|
||||
"get_arguments_start",
|
||||
"get_arguments_end",
|
||||
"execute_start",
|
||||
"execute_end",
|
||||
"store_outputs_start",
|
||||
"store_outputs_end"]
|
||||
"acquire_lock_start", "acquire_lock_end", "get_arguments_start",
|
||||
"get_arguments_end", "execute_start", "execute_end",
|
||||
"store_outputs_start", "store_outputs_end"
|
||||
]
|
||||
|
||||
latest_timestamp = 0
|
||||
for key in keys:
|
||||
@@ -877,8 +967,8 @@ class GlobalState(object):
|
||||
local_schedulers = []
|
||||
for ip_address, client_list in clients.items():
|
||||
for client in client_list:
|
||||
if (client["ClientType"] == "local_scheduler" and
|
||||
not client["Deleted"]):
|
||||
if (client["ClientType"] == "local_scheduler"
|
||||
and not client["Deleted"]):
|
||||
local_schedulers.append(client)
|
||||
return local_schedulers
|
||||
|
||||
@@ -893,8 +983,7 @@ class GlobalState(object):
|
||||
|
||||
workers_data[worker_id] = {
|
||||
"local_scheduler_socket":
|
||||
(worker_info[b"local_scheduler_socket"]
|
||||
.decode("ascii")),
|
||||
(worker_info[b"local_scheduler_socket"].decode("ascii")),
|
||||
"node_ip_address": (worker_info[b"node_ip_address"]
|
||||
.decode("ascii")),
|
||||
"plasma_manager_socket": (worker_info[b"plasma_manager_socket"]
|
||||
@@ -921,9 +1010,10 @@ class GlobalState(object):
|
||||
"class_id": binary_to_hex(info[b"class_id"]),
|
||||
"driver_id": binary_to_hex(info[b"driver_id"]),
|
||||
"local_scheduler_id":
|
||||
binary_to_hex(info[b"local_scheduler_id"]),
|
||||
binary_to_hex(info[b"local_scheduler_id"]),
|
||||
"num_gpus": int(info[b"num_gpus"]),
|
||||
"removed": decode(info[b"removed"]) == "True"}
|
||||
"removed": decode(info[b"removed"]) == "True"
|
||||
}
|
||||
return actor_info
|
||||
|
||||
def _job_length(self):
|
||||
@@ -932,21 +1022,16 @@ class GlobalState(object):
|
||||
overall_largest = 0
|
||||
num_tasks = 0
|
||||
for event_log_set in event_log_sets:
|
||||
fwd_range = self.redis_client.zrange(event_log_set,
|
||||
start=0,
|
||||
end=0,
|
||||
withscores=True)
|
||||
fwd_range = self.redis_client.zrange(
|
||||
event_log_set, start=0, end=0, withscores=True)
|
||||
overall_smallest = min(overall_smallest, fwd_range[0][1])
|
||||
|
||||
rev_range = self.redis_client.zrevrange(event_log_set,
|
||||
start=0,
|
||||
end=0,
|
||||
withscores=True)
|
||||
rev_range = self.redis_client.zrevrange(
|
||||
event_log_set, start=0, end=0, withscores=True)
|
||||
overall_largest = max(overall_largest, rev_range[0][1])
|
||||
|
||||
num_tasks += self.redis_client.zcount(event_log_set,
|
||||
min=0,
|
||||
max=time.time())
|
||||
num_tasks += self.redis_client.zcount(
|
||||
event_log_set, min=0, max=time.time())
|
||||
if num_tasks is 0:
|
||||
return 0, 0, 0
|
||||
return overall_smallest, overall_largest, num_tasks
|
||||
@@ -966,8 +1051,10 @@ class GlobalState(object):
|
||||
|
||||
for local_scheduler in local_schedulers:
|
||||
for key, value in local_scheduler.items():
|
||||
if key not in ["ClientType", "Deleted", "DBClientID",
|
||||
"AuxAddress", "LocalSchedulerSocketName"]:
|
||||
if key not in [
|
||||
"ClientType", "Deleted", "DBClientID", "AuxAddress",
|
||||
"LocalSchedulerSocketName"
|
||||
]:
|
||||
resources[key] += value
|
||||
|
||||
return dict(resources)
|
||||
|
||||
@@ -27,6 +27,7 @@ class TensorFlowVariables(object):
|
||||
placeholders (Dict[str, tf.placeholders]): Placeholders for weights.
|
||||
assignment_nodes (Dict[str, tf.Tensor]): Nodes that assign weights.
|
||||
"""
|
||||
|
||||
def __init__(self, loss, sess=None, input_variables=None):
|
||||
"""Creates TensorFlowVariables containing extracted variables.
|
||||
|
||||
@@ -74,8 +75,10 @@ class TensorFlowVariables(object):
|
||||
if "Variable" in tf_obj.node_def.op:
|
||||
variable_names.append(tf_obj.node_def.name)
|
||||
self.variables = OrderedDict()
|
||||
variable_list = [v for v in tf.global_variables()
|
||||
if v.op.node_def.name in variable_names]
|
||||
variable_list = [
|
||||
v for v in tf.global_variables()
|
||||
if v.op.node_def.name in variable_names
|
||||
]
|
||||
if input_variables is not None:
|
||||
variable_list += input_variables
|
||||
for v in variable_list:
|
||||
@@ -86,9 +89,10 @@ class TensorFlowVariables(object):
|
||||
|
||||
# Create new placeholders to put in custom weights.
|
||||
for k, var in self.variables.items():
|
||||
self.placeholders[k] = tf.placeholder(var.value().dtype,
|
||||
var.get_shape().as_list(),
|
||||
name="Placeholder_" + k)
|
||||
self.placeholders[k] = tf.placeholder(
|
||||
var.value().dtype,
|
||||
var.get_shape().as_list(),
|
||||
name="Placeholder_" + k)
|
||||
self.assignment_nodes[k] = var.assign(self.placeholders[k])
|
||||
|
||||
def set_session(self, sess):
|
||||
@@ -105,8 +109,9 @@ class TensorFlowVariables(object):
|
||||
Returns:
|
||||
The length of all flattened variables concatenated.
|
||||
"""
|
||||
return sum([np.prod(v.get_shape().as_list())
|
||||
for v in self.variables.values()])
|
||||
return sum([
|
||||
np.prod(v.get_shape().as_list()) for v in self.variables.values()
|
||||
])
|
||||
|
||||
def _check_sess(self):
|
||||
"""Checks if the session is set, and if not throw an error message."""
|
||||
@@ -122,8 +127,10 @@ class TensorFlowVariables(object):
|
||||
1D Array containing the flattened weights.
|
||||
"""
|
||||
self._check_sess()
|
||||
return np.concatenate([v.eval(session=self.sess).flatten()
|
||||
for v in self.variables.values()])
|
||||
return np.concatenate([
|
||||
v.eval(session=self.sess).flatten()
|
||||
for v in self.variables.values()
|
||||
])
|
||||
|
||||
def set_flat(self, new_weights):
|
||||
"""Sets the weights to new_weights, converting from a flat array.
|
||||
@@ -138,10 +145,12 @@ class TensorFlowVariables(object):
|
||||
self._check_sess()
|
||||
shapes = [v.get_shape().as_list() for v in self.variables.values()]
|
||||
arrays = unflatten(new_weights, shapes)
|
||||
placeholders = [self.placeholders[k] for k, v
|
||||
in self.variables.items()]
|
||||
self.sess.run(list(self.assignment_nodes.values()),
|
||||
feed_dict=dict(zip(placeholders, arrays)))
|
||||
placeholders = [
|
||||
self.placeholders[k] for k, v in self.variables.items()
|
||||
]
|
||||
self.sess.run(
|
||||
list(self.assignment_nodes.values()),
|
||||
feed_dict=dict(zip(placeholders, arrays)))
|
||||
|
||||
def get_weights(self):
|
||||
"""Returns a dictionary containing the weights of the network.
|
||||
@@ -150,8 +159,10 @@ class TensorFlowVariables(object):
|
||||
Dictionary mapping variable names to their weights.
|
||||
"""
|
||||
self._check_sess()
|
||||
return {k: v.eval(session=self.sess) for k, v
|
||||
in self.variables.items()}
|
||||
return {
|
||||
k: v.eval(session=self.sess)
|
||||
for k, v in self.variables.items()
|
||||
}
|
||||
|
||||
def set_weights(self, new_weights):
|
||||
"""Sets the weights to new_weights.
|
||||
@@ -165,15 +176,19 @@ class TensorFlowVariables(object):
|
||||
weights.
|
||||
"""
|
||||
self._check_sess()
|
||||
assign_list = [self.assignment_nodes[name]
|
||||
for name in new_weights.keys()
|
||||
if name in self.assignment_nodes]
|
||||
assign_list = [
|
||||
self.assignment_nodes[name] for name in new_weights.keys()
|
||||
if name in self.assignment_nodes
|
||||
]
|
||||
assert assign_list, ("No variables in the input matched those in the "
|
||||
"network. Possible cause: Two networks were "
|
||||
"defined in the same TensorFlow graph. To fix "
|
||||
"this, place each network definition in its own "
|
||||
"tf.Graph.")
|
||||
self.sess.run(assign_list,
|
||||
feed_dict={self.placeholders[name]: value
|
||||
for (name, value) in new_weights.items()
|
||||
if name in self.placeholders})
|
||||
self.sess.run(
|
||||
assign_list,
|
||||
feed_dict={
|
||||
self.placeholders[name]: value
|
||||
for (name, value) in new_weights.items()
|
||||
if name in self.placeholders
|
||||
})
|
||||
|
||||
+200
-215
@@ -29,9 +29,9 @@ class _EventRecursionContextManager(object):
|
||||
total_time_value = "% total time"
|
||||
total_tasks_value = "% total tasks"
|
||||
|
||||
|
||||
# Function that returns instances of sliders and handles associated events.
|
||||
|
||||
|
||||
def get_sliders(update):
|
||||
# Start_box value indicates the desired start point of queried window.
|
||||
start_box = widgets.FloatText(
|
||||
@@ -60,18 +60,14 @@ def get_sliders(update):
|
||||
|
||||
# Indicates the number of tasks that the user wants to be returned. Is
|
||||
# disabled when the breakdown_opt value is set to total_time_value.
|
||||
num_tasks_box = widgets.IntText(
|
||||
description="Num Tasks:",
|
||||
disabled=False
|
||||
)
|
||||
num_tasks_box = widgets.IntText(description="Num Tasks:", disabled=False)
|
||||
|
||||
# Dropdown bar that lets the user choose between modifying % of total
|
||||
# time or total number of tasks.
|
||||
breakdown_opt = widgets.Dropdown(
|
||||
options=[total_time_value, total_tasks_value],
|
||||
value=total_tasks_value,
|
||||
description="Selection Options:"
|
||||
)
|
||||
description="Selection Options:")
|
||||
|
||||
# Display box for layout.
|
||||
total_time_box = widgets.VBox([start_box, end_box])
|
||||
@@ -105,9 +101,9 @@ def get_sliders(update):
|
||||
if event == INIT_EVENT:
|
||||
if breakdown_opt.value == total_tasks_value:
|
||||
num_tasks_box.value = -min(10000, num_tasks)
|
||||
range_slider.value = (int(100 -
|
||||
(100. * -num_tasks_box.value) /
|
||||
num_tasks), 100)
|
||||
range_slider.value = (int(
|
||||
100 - (100. * -num_tasks_box.value) / num_tasks),
|
||||
100)
|
||||
else:
|
||||
low, high = map(lambda x: x / 100., range_slider.value)
|
||||
start_box.value = round(diff * low, 2)
|
||||
@@ -120,8 +116,8 @@ def get_sliders(update):
|
||||
elif start_box.value < 0:
|
||||
start_box.value = 0
|
||||
low, high = range_slider.value
|
||||
range_slider.value = (int((start_box.value * 100.) /
|
||||
diff), high)
|
||||
range_slider.value = (int((start_box.value * 100.) / diff),
|
||||
high)
|
||||
|
||||
# Event was triggered by a change in the end_box value.
|
||||
elif event["owner"] == end_box:
|
||||
@@ -130,8 +126,8 @@ def get_sliders(update):
|
||||
elif end_box.value > diff:
|
||||
end_box.value = diff
|
||||
low, high = range_slider.value
|
||||
range_slider.value = (low, int((end_box.value * 100.) /
|
||||
diff))
|
||||
range_slider.value = (low,
|
||||
int((end_box.value * 100.) / diff))
|
||||
|
||||
# Event was triggered by a change in the breakdown options
|
||||
# toggle.
|
||||
@@ -145,9 +141,9 @@ def get_sliders(update):
|
||||
# Make CSS display go back to the default settings.
|
||||
num_tasks_box.layout.display = None
|
||||
num_tasks_box.value = min(10000, num_tasks)
|
||||
range_slider.value = (int(100 -
|
||||
(100. * num_tasks_box.value) /
|
||||
num_tasks), 100)
|
||||
range_slider.value = (int(
|
||||
100 - (100. * num_tasks_box.value) / num_tasks),
|
||||
100)
|
||||
else:
|
||||
start_box.disabled = False
|
||||
end_box.disabled = False
|
||||
@@ -156,10 +152,9 @@ def get_sliders(update):
|
||||
# Make CSS display go back to the default settings.
|
||||
total_time_box.layout.display = None
|
||||
num_tasks_box.layout.display = 'none'
|
||||
range_slider.value = (int((start_box.value * 100.) /
|
||||
diff),
|
||||
int((end_box.value * 100.) /
|
||||
diff))
|
||||
range_slider.value = (
|
||||
int((start_box.value * 100.) / diff),
|
||||
int((end_box.value * 100.) / diff))
|
||||
|
||||
# Event was triggered by a change in the range_slider
|
||||
# value.
|
||||
@@ -170,8 +165,8 @@ def get_sliders(update):
|
||||
new_low, new_high = event["new"]
|
||||
if old_low != new_low:
|
||||
range_slider.value = (new_low, 100)
|
||||
num_tasks_box.value = (-(100. - new_low) /
|
||||
100. * num_tasks)
|
||||
num_tasks_box.value = (
|
||||
-(100. - new_low) / 100. * num_tasks)
|
||||
else:
|
||||
range_slider.value = (0, new_high)
|
||||
num_tasks_box.value = new_high / 100. * num_tasks
|
||||
@@ -183,14 +178,12 @@ def get_sliders(update):
|
||||
# value.
|
||||
elif event["owner"] == num_tasks_box:
|
||||
if num_tasks_box.value > 0:
|
||||
range_slider.value = (0, int(100 *
|
||||
float(num_tasks_box.value) /
|
||||
num_tasks))
|
||||
range_slider.value = (
|
||||
0, int(
|
||||
100 * float(num_tasks_box.value) / num_tasks))
|
||||
elif num_tasks_box.value < 0:
|
||||
range_slider.value = (100 +
|
||||
int(100 *
|
||||
float(num_tasks_box.value) /
|
||||
num_tasks), 100)
|
||||
range_slider.value = (100 + int(
|
||||
100 * float(num_tasks_box.value) / num_tasks), 100)
|
||||
|
||||
if not update:
|
||||
return
|
||||
@@ -205,23 +198,20 @@ def get_sliders(update):
|
||||
# box values.
|
||||
# (Querying based on the % total amount of time.)
|
||||
if breakdown_opt.value == total_time_value:
|
||||
tasks = _truncated_task_profiles(start=(smallest +
|
||||
diff * low),
|
||||
end=(smallest +
|
||||
diff * high))
|
||||
tasks = _truncated_task_profiles(
|
||||
start=(smallest + diff * low),
|
||||
end=(smallest + diff * high))
|
||||
|
||||
# (Querying based on % of total number of tasks that were
|
||||
# run.)
|
||||
elif breakdown_opt.value == total_tasks_value:
|
||||
if range_slider.value[0] == 0:
|
||||
tasks = _truncated_task_profiles(num_tasks=(int(
|
||||
num_tasks * high)),
|
||||
fwd=True)
|
||||
tasks = _truncated_task_profiles(
|
||||
num_tasks=(int(num_tasks * high)), fwd=True)
|
||||
else:
|
||||
tasks = _truncated_task_profiles(num_tasks=(int(
|
||||
num_tasks *
|
||||
(high - low))),
|
||||
fwd=False)
|
||||
tasks = _truncated_task_profiles(
|
||||
num_tasks=(int(num_tasks * (high - low))),
|
||||
fwd=False)
|
||||
|
||||
update(smallest, largest, num_tasks, tasks)
|
||||
|
||||
@@ -237,8 +227,8 @@ def get_sliders(update):
|
||||
update_wrapper(INIT_EVENT)
|
||||
|
||||
# Display sliders and search boxes
|
||||
display(breakdown_opt, widgets.HBox([range_slider, total_time_box,
|
||||
num_tasks_box]))
|
||||
display(breakdown_opt,
|
||||
widgets.HBox([range_slider, total_time_box, num_tasks_box]))
|
||||
|
||||
# Return the sliders and text boxes
|
||||
return start_box, end_box, range_slider, breakdown_opt
|
||||
@@ -249,8 +239,7 @@ def object_search_bar():
|
||||
value="",
|
||||
placeholder="Object ID",
|
||||
description="Search for an object:",
|
||||
disabled=False
|
||||
)
|
||||
disabled=False)
|
||||
display(object_search)
|
||||
|
||||
def handle_submit(sender):
|
||||
@@ -265,8 +254,7 @@ def task_search_bar():
|
||||
value="",
|
||||
placeholder="Task ID",
|
||||
description="Search for a task:",
|
||||
disabled=False
|
||||
)
|
||||
disabled=False)
|
||||
display(task_search)
|
||||
|
||||
def handle_submit(sender):
|
||||
@@ -284,14 +272,12 @@ MAX_TASKS_TO_VISUALIZE = 10000
|
||||
def _truncated_task_profiles(start=None, end=None, num_tasks=None, fwd=True):
|
||||
if num_tasks is None:
|
||||
num_tasks = MAX_TASKS_TO_VISUALIZE
|
||||
print(
|
||||
"Warning: at most {} tasks will be fetched within this "
|
||||
"time range.".format(MAX_TASKS_TO_VISUALIZE))
|
||||
print("Warning: at most {} tasks will be fetched within this "
|
||||
"time range.".format(MAX_TASKS_TO_VISUALIZE))
|
||||
elif num_tasks > MAX_TASKS_TO_VISUALIZE:
|
||||
print(
|
||||
"Warning: too many tasks to visualize, "
|
||||
"fetching only the first {} of {}.".format(
|
||||
MAX_TASKS_TO_VISUALIZE, num_tasks))
|
||||
print("Warning: too many tasks to visualize, "
|
||||
"fetching only the first {} of {}.".format(
|
||||
MAX_TASKS_TO_VISUALIZE, num_tasks))
|
||||
num_tasks = MAX_TASKS_TO_VISUALIZE
|
||||
return ray.global_state.task_profiles(num_tasks, start, end, fwd)
|
||||
|
||||
@@ -299,9 +285,8 @@ def _truncated_task_profiles(start=None, end=None, num_tasks=None, fwd=True):
|
||||
# Helper function that guarantees unique and writeable temp files.
|
||||
# Prevents clashes in task trace files when multiple notebooks are running.
|
||||
def _get_temp_file_path(**kwargs):
|
||||
temp_file = tempfile.NamedTemporaryFile(delete=False,
|
||||
dir=os.getcwd(),
|
||||
**kwargs)
|
||||
temp_file = tempfile.NamedTemporaryFile(
|
||||
delete=False, dir=os.getcwd(), **kwargs)
|
||||
temp_file_path = temp_file.name
|
||||
temp_file.close()
|
||||
return os.path.relpath(temp_file_path)
|
||||
@@ -319,22 +304,16 @@ def task_timeline():
|
||||
disabled=False,
|
||||
)
|
||||
obj_dep = widgets.Checkbox(
|
||||
value=True,
|
||||
disabled=False,
|
||||
layout=widgets.Layout(width='20px')
|
||||
)
|
||||
value=True, disabled=False, layout=widgets.Layout(width='20px'))
|
||||
task_dep = widgets.Checkbox(
|
||||
value=True,
|
||||
disabled=False,
|
||||
layout=widgets.Layout(width='20px')
|
||||
)
|
||||
value=True, disabled=False, layout=widgets.Layout(width='20px'))
|
||||
# Labels to bypass width limitation for descriptions.
|
||||
label_tasks = widgets.Label(value='Task submissions',
|
||||
layout=widgets.Layout(width='110px'))
|
||||
label_objects = widgets.Label(value='Object dependencies',
|
||||
layout=widgets.Layout(width='130px'))
|
||||
label_options = widgets.Label(value='View options:',
|
||||
layout=widgets.Layout(width='100px'))
|
||||
label_tasks = widgets.Label(
|
||||
value='Task submissions', layout=widgets.Layout(width='110px'))
|
||||
label_objects = widgets.Label(
|
||||
value='Object dependencies', layout=widgets.Layout(width='130px'))
|
||||
label_options = widgets.Label(
|
||||
value='View options:', layout=widgets.Layout(width='100px'))
|
||||
start_box, end_box, range_slider, time_opt = get_sliders(False)
|
||||
display(widgets.HBox([task_dep, label_tasks, obj_dep, label_objects]))
|
||||
display(widgets.HBox([label_options, breakdown_opt]))
|
||||
@@ -344,8 +323,9 @@ def task_timeline():
|
||||
# current working directory if it is not present.
|
||||
if not os.path.exists("trace_viewer_full.html"):
|
||||
shutil.copy(
|
||||
os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
||||
"../core/src/catapult_files/trace_viewer_full.html"),
|
||||
os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
"../core/src/catapult_files/trace_viewer_full.html"),
|
||||
"trace_viewer_full.html")
|
||||
|
||||
def handle_submit(sender):
|
||||
@@ -357,8 +337,8 @@ def task_timeline():
|
||||
elif breakdown_opt.value == breakdown_task:
|
||||
breakdown = True
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unexpected breakdown value '{}'".format(breakdown_opt.value))
|
||||
raise ValueError("Unexpected breakdown value '{}'".format(
|
||||
breakdown_opt.value))
|
||||
|
||||
low, high = map(lambda x: x / 100., range_slider.value)
|
||||
|
||||
@@ -366,30 +346,28 @@ def task_timeline():
|
||||
diff = largest - smallest
|
||||
|
||||
if time_opt.value == total_time_value:
|
||||
tasks = _truncated_task_profiles(start=smallest + diff * low,
|
||||
end=smallest + diff * high)
|
||||
tasks = _truncated_task_profiles(
|
||||
start=smallest + diff * low, end=smallest + diff * high)
|
||||
elif time_opt.value == total_tasks_value:
|
||||
if range_slider.value[0] == 0:
|
||||
tasks = _truncated_task_profiles(num_tasks=int(
|
||||
num_tasks * high),
|
||||
fwd=True)
|
||||
tasks = _truncated_task_profiles(
|
||||
num_tasks=int(num_tasks * high), fwd=True)
|
||||
else:
|
||||
tasks = _truncated_task_profiles(num_tasks=int(
|
||||
num_tasks * (high - low)),
|
||||
fwd=False)
|
||||
tasks = _truncated_task_profiles(
|
||||
num_tasks=int(num_tasks * (high - low)), fwd=False)
|
||||
else:
|
||||
raise ValueError("Unexpected time value '{}'".format(
|
||||
time_opt.value))
|
||||
time_opt.value))
|
||||
# Write trace to a JSON file
|
||||
print("Collected profiles for {} tasks.".format(len(tasks)))
|
||||
print(
|
||||
"Dumping task profile data to {}, "
|
||||
"this might take a while...".format(json_tmp))
|
||||
ray.global_state.dump_catapult_trace(json_tmp,
|
||||
tasks,
|
||||
breakdowns=breakdown,
|
||||
obj_dep=obj_dep.value,
|
||||
task_dep=task_dep.value)
|
||||
print("Dumping task profile data to {}, "
|
||||
"this might take a while...".format(json_tmp))
|
||||
ray.global_state.dump_catapult_trace(
|
||||
json_tmp,
|
||||
tasks,
|
||||
breakdowns=breakdown,
|
||||
obj_dep=obj_dep.value,
|
||||
task_dep=task_dep.value)
|
||||
print("Opening html file in browser...")
|
||||
|
||||
trace_viewer_path = os.path.join(
|
||||
@@ -415,9 +393,8 @@ def task_timeline():
|
||||
|
||||
# Display the task trace within the Jupyter notebook
|
||||
clear_output(wait=True)
|
||||
print(
|
||||
"To view fullscreen, open chrome://tracing in Google Chrome "
|
||||
"and load `{}`".format(json_tmp))
|
||||
print("To view fullscreen, open chrome://tracing in Google Chrome "
|
||||
"and load `{}`".format(json_tmp))
|
||||
display(IFrame(html_file_path, 900, 800))
|
||||
|
||||
path_input.on_click(handle_submit)
|
||||
@@ -432,36 +409,41 @@ def task_completion_time_distribution():
|
||||
output_notebook(resources=CDN)
|
||||
|
||||
# Create the Bokeh plot
|
||||
p = figure(title="Task Completion Time Distribution",
|
||||
tools=["save", "hover", "wheel_zoom", "box_zoom", "pan"],
|
||||
background_fill_color="#FFFFFF",
|
||||
x_range=(0, 1),
|
||||
y_range=(0, 1))
|
||||
p = figure(
|
||||
title="Task Completion Time Distribution",
|
||||
tools=["save", "hover", "wheel_zoom", "box_zoom", "pan"],
|
||||
background_fill_color="#FFFFFF",
|
||||
x_range=(0, 1),
|
||||
y_range=(0, 1))
|
||||
|
||||
# Create the data source that the plot pulls from
|
||||
source = ColumnDataSource(data={
|
||||
"top": [],
|
||||
"left": [],
|
||||
"right": []
|
||||
})
|
||||
source = ColumnDataSource(data={"top": [], "left": [], "right": []})
|
||||
|
||||
# Plot the histogram rectangles
|
||||
p.quad(top="top", bottom=0, left="left", right="right", source=source,
|
||||
fill_color="#B3B3B3", line_color="#033649")
|
||||
p.quad(
|
||||
top="top",
|
||||
bottom=0,
|
||||
left="left",
|
||||
right="right",
|
||||
source=source,
|
||||
fill_color="#B3B3B3",
|
||||
line_color="#033649")
|
||||
|
||||
# Label the plot axes
|
||||
p.xaxis.axis_label = "Duration in seconds"
|
||||
p.yaxis.axis_label = "Number of tasks"
|
||||
|
||||
handle = show(gridplot(p, ncols=1,
|
||||
plot_width=500,
|
||||
plot_height=500,
|
||||
toolbar_location="below"), notebook_handle=True)
|
||||
handle = show(
|
||||
gridplot(
|
||||
p,
|
||||
ncols=1,
|
||||
plot_width=500,
|
||||
plot_height=500,
|
||||
toolbar_location="below"),
|
||||
notebook_handle=True)
|
||||
|
||||
# Function to update the plot
|
||||
def task_completion_time_update(abs_earliest,
|
||||
abs_latest,
|
||||
abs_num_tasks,
|
||||
def task_completion_time_update(abs_earliest, abs_latest, abs_num_tasks,
|
||||
tasks):
|
||||
if len(tasks) == 0:
|
||||
return
|
||||
@@ -469,8 +451,8 @@ def task_completion_time_distribution():
|
||||
# Create the distribution to plot
|
||||
distr = []
|
||||
for task_id, data in tasks.items():
|
||||
distr.append(data["store_outputs_end"] -
|
||||
data["get_arguments_start"])
|
||||
distr.append(
|
||||
data["store_outputs_end"] - data["get_arguments_start"])
|
||||
|
||||
# Create a histogram from the distribution
|
||||
top, bin_edges = np.histogram(distr, bins="auto")
|
||||
@@ -480,8 +462,8 @@ def task_completion_time_distribution():
|
||||
source.data = {"top": top, "left": left, "right": right}
|
||||
|
||||
# Set the x and y ranges
|
||||
x_range = (min(left) if len(left) else 0,
|
||||
max(right) if len(right) else 1)
|
||||
x_range = (min(left) if len(left) else 0, max(right)
|
||||
if len(right) else 1)
|
||||
y_range = (0, max(top) + 1 if len(top) else 1)
|
||||
|
||||
x_range = helpers._get_range(x_range)
|
||||
@@ -517,8 +499,7 @@ def compute_utilizations(abs_earliest,
|
||||
latest_time = 0
|
||||
for task_id, data in tasks.items():
|
||||
latest_time = max((latest_time, data["store_outputs_end"]))
|
||||
earliest_time = min((earliest_time,
|
||||
data["get_arguments_start"]))
|
||||
earliest_time = min((earliest_time, data["get_arguments_start"]))
|
||||
|
||||
# Add some epsilon to latest_time to ensure that the end time of the
|
||||
# last task falls __within__ a bucket, and not on the edge
|
||||
@@ -533,37 +514,37 @@ def compute_utilizations(abs_earliest,
|
||||
task_start_time = data["get_arguments_start"]
|
||||
task_end_time = data["store_outputs_end"]
|
||||
|
||||
start_bucket = int((task_start_time - earliest_time) /
|
||||
bucket_time_length)
|
||||
end_bucket = int((task_end_time - earliest_time) /
|
||||
bucket_time_length)
|
||||
start_bucket = int(
|
||||
(task_start_time - earliest_time) / bucket_time_length)
|
||||
end_bucket = int((task_end_time - earliest_time) / bucket_time_length)
|
||||
# Walk over each time bucket that this task intersects, adding the
|
||||
# amount of time that the task intersects within each bucket
|
||||
for bucket_idx in range(start_bucket, end_bucket + 1):
|
||||
bucket_start_time = ((earliest_time + bucket_idx) *
|
||||
bucket_time_length)
|
||||
bucket_end_time = ((earliest_time + (bucket_idx + 1)) *
|
||||
bucket_time_length)
|
||||
bucket_start_time = ((
|
||||
earliest_time + bucket_idx) * bucket_time_length)
|
||||
bucket_end_time = ((earliest_time +
|
||||
(bucket_idx + 1)) * bucket_time_length)
|
||||
|
||||
task_start_time_within_bucket = max(task_start_time,
|
||||
bucket_start_time)
|
||||
task_end_time_within_bucket = min(task_end_time,
|
||||
bucket_end_time)
|
||||
task_cpu_time_within_bucket = (task_end_time_within_bucket -
|
||||
task_start_time_within_bucket)
|
||||
task_end_time_within_bucket = min(task_end_time, bucket_end_time)
|
||||
task_cpu_time_within_bucket = (
|
||||
task_end_time_within_bucket - task_start_time_within_bucket)
|
||||
|
||||
if bucket_idx > -1 and bucket_idx < num_buckets:
|
||||
cpu_time[bucket_idx] += task_cpu_time_within_bucket
|
||||
|
||||
# Cpu_utilization is the average cpu utilization of the bucket, which
|
||||
# is just cpu_time divided by bucket_time_length.
|
||||
cpu_utilization = list(map(lambda x: x / float(bucket_time_length),
|
||||
cpu_time))
|
||||
cpu_utilization = list(
|
||||
map(lambda x: x / float(bucket_time_length), cpu_time))
|
||||
|
||||
# Generate histogram bucket edges. Subtract out abs_earliest to get
|
||||
# relative time.
|
||||
all_edges = [earliest_time - abs_earliest + i * bucket_time_length
|
||||
for i in range(num_buckets + 1)]
|
||||
all_edges = [
|
||||
earliest_time - abs_earliest + i * bucket_time_length
|
||||
for i in range(num_buckets + 1)
|
||||
]
|
||||
# Left edges are all but the rightmost edge, right edges are all but
|
||||
# the leftmost edge.
|
||||
left_edges = all_edges[:-1]
|
||||
@@ -591,54 +572,53 @@ def cpu_usage():
|
||||
# Update the plot based on the sliders
|
||||
def plot_utilization():
|
||||
# Create the Bokeh plot
|
||||
time_series_fig = figure(title="CPU Utilization",
|
||||
tools=["save", "hover", "wheel_zoom",
|
||||
"box_zoom", "pan"],
|
||||
background_fill_color="#FFFFFF",
|
||||
x_range=[0, 1],
|
||||
y_range=[0, 1])
|
||||
time_series_fig = figure(
|
||||
title="CPU Utilization",
|
||||
tools=["save", "hover", "wheel_zoom", "box_zoom", "pan"],
|
||||
background_fill_color="#FFFFFF",
|
||||
x_range=[0, 1],
|
||||
y_range=[0, 1])
|
||||
|
||||
# Create the data source that the plot will pull from
|
||||
time_series_source = ColumnDataSource(data=dict(
|
||||
left=[],
|
||||
right=[],
|
||||
top=[]
|
||||
))
|
||||
time_series_source = ColumnDataSource(
|
||||
data=dict(left=[], right=[], top=[]))
|
||||
|
||||
# Plot the rectangles representing the distribution
|
||||
time_series_fig.quad(left="left",
|
||||
right="right",
|
||||
top="top",
|
||||
bottom=0,
|
||||
source=time_series_source,
|
||||
fill_color="#B3B3B3",
|
||||
line_color="#033649")
|
||||
time_series_fig.quad(
|
||||
left="left",
|
||||
right="right",
|
||||
top="top",
|
||||
bottom=0,
|
||||
source=time_series_source,
|
||||
fill_color="#B3B3B3",
|
||||
line_color="#033649")
|
||||
|
||||
# Label the plot axes
|
||||
time_series_fig.xaxis.axis_label = "Time in seconds"
|
||||
time_series_fig.yaxis.axis_label = "Number of CPUs used"
|
||||
|
||||
handle = show(gridplot(time_series_fig,
|
||||
ncols=1,
|
||||
plot_width=500,
|
||||
plot_height=500,
|
||||
toolbar_location="below"), notebook_handle=True)
|
||||
handle = show(
|
||||
gridplot(
|
||||
time_series_fig,
|
||||
ncols=1,
|
||||
plot_width=500,
|
||||
plot_height=500,
|
||||
toolbar_location="below"),
|
||||
notebook_handle=True)
|
||||
|
||||
def update_plot(abs_earliest, abs_latest, abs_num_tasks, tasks):
|
||||
num_buckets = 100
|
||||
left, right, top = compute_utilizations(abs_earliest,
|
||||
abs_latest,
|
||||
abs_num_tasks,
|
||||
tasks,
|
||||
num_buckets)
|
||||
left, right, top = compute_utilizations(
|
||||
abs_earliest, abs_latest, abs_num_tasks, tasks, num_buckets)
|
||||
|
||||
time_series_source.data = {"left": left,
|
||||
"right": right,
|
||||
"top": top}
|
||||
time_series_source.data = {
|
||||
"left": left,
|
||||
"right": right,
|
||||
"top": top
|
||||
}
|
||||
|
||||
x_range = (max(0, min(left))
|
||||
if len(left) else 0,
|
||||
max(right) if len(right) else 1)
|
||||
x_range = (max(0, min(left)) if len(left) else 0, max(right)
|
||||
if len(right) else 1)
|
||||
y_range = (0, max(top) + 1 if len(top) else 1)
|
||||
|
||||
# Define the axis ranges
|
||||
@@ -654,6 +634,7 @@ def cpu_usage():
|
||||
push_notebook(handle=handle)
|
||||
|
||||
get_sliders(update_plot)
|
||||
|
||||
plot_utilization()
|
||||
|
||||
|
||||
@@ -672,33 +653,32 @@ def cluster_usage():
|
||||
output_notebook(resources=CDN)
|
||||
|
||||
# Initial values
|
||||
source = ColumnDataSource(data={"node_ip_address": ['127.0.0.1'],
|
||||
"time": ['0.5'],
|
||||
"num_tasks": ['1'],
|
||||
"length": [1]})
|
||||
source = ColumnDataSource(
|
||||
data={
|
||||
"node_ip_address": ['127.0.0.1'],
|
||||
"time": ['0.5'],
|
||||
"num_tasks": ['1'],
|
||||
"length": [1]
|
||||
})
|
||||
|
||||
# Define the color schema
|
||||
colors = ["#75968f",
|
||||
"#a5bab7",
|
||||
"#c9d9d3",
|
||||
"#e2e2e2",
|
||||
"#dfccce",
|
||||
"#ddb7b1",
|
||||
"#cc7878",
|
||||
"#933b41",
|
||||
"#550b1d"]
|
||||
colors = [
|
||||
"#75968f", "#a5bab7", "#c9d9d3", "#e2e2e2", "#dfccce", "#ddb7b1",
|
||||
"#cc7878", "#933b41", "#550b1d"
|
||||
]
|
||||
mapper = LinearColorMapper(palette=colors, low=0, high=2)
|
||||
|
||||
TOOLS = "hover, save, xpan, box_zoom, reset, xwheel_zoom"
|
||||
|
||||
# Create the plot
|
||||
p = figure(title="Cluster Usage",
|
||||
y_range=list(set(source.data['node_ip_address'])),
|
||||
x_axis_location="above",
|
||||
plot_width=900,
|
||||
plot_height=500,
|
||||
tools=TOOLS,
|
||||
toolbar_location='below')
|
||||
p = figure(
|
||||
title="Cluster Usage",
|
||||
y_range=list(set(source.data['node_ip_address'])),
|
||||
x_axis_location="above",
|
||||
plot_width=900,
|
||||
plot_height=500,
|
||||
tools=TOOLS,
|
||||
toolbar_location='below')
|
||||
|
||||
# Format the plot axes
|
||||
p.grid.grid_line_color = None
|
||||
@@ -709,26 +689,33 @@ def cluster_usage():
|
||||
p.xaxis.major_label_orientation = np.pi / 3
|
||||
|
||||
# Plot rectangles
|
||||
p.rect(x="time", y="node_ip_address", width="length", height=1,
|
||||
source=source,
|
||||
fill_color={"field": "num_tasks", "transform": mapper},
|
||||
line_color=None)
|
||||
p.rect(
|
||||
x="time",
|
||||
y="node_ip_address",
|
||||
width="length",
|
||||
height=1,
|
||||
source=source,
|
||||
fill_color={
|
||||
"field": "num_tasks",
|
||||
"transform": mapper
|
||||
},
|
||||
line_color=None)
|
||||
|
||||
# Add legend to the side of the plot
|
||||
color_bar = ColorBar(color_mapper=mapper,
|
||||
major_label_text_font_size="8pt",
|
||||
ticker=BasicTicker(desired_num_ticks=len(colors)),
|
||||
label_standoff=6,
|
||||
border_line_color=None,
|
||||
location=(0, 0))
|
||||
color_bar = ColorBar(
|
||||
color_mapper=mapper,
|
||||
major_label_text_font_size="8pt",
|
||||
ticker=BasicTicker(desired_num_ticks=len(colors)),
|
||||
label_standoff=6,
|
||||
border_line_color=None,
|
||||
location=(0, 0))
|
||||
p.add_layout(color_bar, "right")
|
||||
|
||||
# Define hover tool
|
||||
p.select_one(HoverTool).tooltips = [
|
||||
("Node IP Address", "@node_ip_address"),
|
||||
("Number of tasks running", "@num_tasks"),
|
||||
("Time", "@time")
|
||||
]
|
||||
p.select_one(HoverTool).tooltips = [("Node IP Address",
|
||||
"@node_ip_address"),
|
||||
("Number of tasks running",
|
||||
"@num_tasks"), ("Time", "@time")]
|
||||
|
||||
# Define the axis labels
|
||||
p.xaxis.axis_label = "Time in seconds"
|
||||
@@ -764,12 +751,8 @@ def cluster_usage():
|
||||
num_tasks = []
|
||||
|
||||
for node_ip, task_dict in node_to_tasks.items():
|
||||
left, right, top = compute_utilizations(earliest,
|
||||
latest,
|
||||
abs_num_tasks,
|
||||
task_dict,
|
||||
100,
|
||||
True)
|
||||
left, right, top = compute_utilizations(
|
||||
earliest, latest, abs_num_tasks, task_dict, 100, True)
|
||||
for (l, r, t) in zip(left, right, top):
|
||||
nodes.append(node_ip)
|
||||
times.append((l + r) / 2)
|
||||
@@ -783,10 +766,12 @@ def cluster_usage():
|
||||
mapper.high = max(max(num_tasks), 1)
|
||||
|
||||
# Update plot with new data based on slider and text box values
|
||||
source.data = {"node_ip_address": nodes,
|
||||
"time": times,
|
||||
"num_tasks": num_tasks,
|
||||
"length": lengths}
|
||||
source.data = {
|
||||
"node_ip_address": nodes,
|
||||
"time": times,
|
||||
"num_tasks": num_tasks,
|
||||
"length": lengths
|
||||
}
|
||||
|
||||
push_notebook(handle=handle)
|
||||
|
||||
|
||||
@@ -7,9 +7,12 @@ import subprocess
|
||||
import time
|
||||
|
||||
|
||||
def start_global_scheduler(redis_address, node_ip_address,
|
||||
use_valgrind=False, use_profiler=False,
|
||||
stdout_file=None, stderr_file=None):
|
||||
def start_global_scheduler(redis_address,
|
||||
node_ip_address,
|
||||
use_valgrind=False,
|
||||
use_profiler=False,
|
||||
stdout_file=None,
|
||||
stderr_file=None):
|
||||
"""Start a global scheduler process.
|
||||
|
||||
Args:
|
||||
@@ -33,21 +36,24 @@ def start_global_scheduler(redis_address, node_ip_address,
|
||||
global_scheduler_executable = os.path.join(
|
||||
os.path.abspath(os.path.dirname(__file__)),
|
||||
"../core/src/global_scheduler/global_scheduler")
|
||||
command = [global_scheduler_executable,
|
||||
"-r", redis_address,
|
||||
"-h", node_ip_address]
|
||||
command = [
|
||||
global_scheduler_executable, "-r", redis_address, "-h", node_ip_address
|
||||
]
|
||||
if use_valgrind:
|
||||
pid = subprocess.Popen(["valgrind",
|
||||
"--track-origins=yes",
|
||||
"--leak-check=full",
|
||||
"--show-leak-kinds=all",
|
||||
"--leak-check-heuristics=stdstring",
|
||||
"--error-exitcode=1"] + command,
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
pid = subprocess.Popen(
|
||||
[
|
||||
"valgrind", "--track-origins=yes", "--leak-check=full",
|
||||
"--show-leak-kinds=all", "--leak-check-heuristics=stdstring",
|
||||
"--error-exitcode=1"
|
||||
] + command,
|
||||
stdout=stdout_file,
|
||||
stderr=stderr_file)
|
||||
time.sleep(1.0)
|
||||
elif use_profiler:
|
||||
pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command,
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
pid = subprocess.Popen(
|
||||
["valgrind", "--tool=callgrind"] + command,
|
||||
stdout=stdout_file,
|
||||
stderr=stderr_file)
|
||||
time.sleep(1.0)
|
||||
else:
|
||||
pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
|
||||
|
||||
@@ -56,7 +56,6 @@ def new_port():
|
||||
|
||||
|
||||
class TestGlobalScheduler(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Start one Redis server and N pairs of (plasma, local_scheduler)
|
||||
self.node_ip_address = "127.0.0.1"
|
||||
@@ -164,17 +163,19 @@ class TestGlobalScheduler(unittest.TestCase):
|
||||
return db_client_id
|
||||
|
||||
def test_task_default_resources(self):
|
||||
task1 = local_scheduler.Task(random_driver_id(), random_function_id(),
|
||||
[random_object_id()], 0, random_task_id(),
|
||||
0)
|
||||
task1 = local_scheduler.Task(
|
||||
random_driver_id(), random_function_id(), [random_object_id()], 0,
|
||||
random_task_id(), 0)
|
||||
self.assertEqual(task1.required_resources(), {"CPU": 1})
|
||||
task2 = local_scheduler.Task(random_driver_id(), random_function_id(),
|
||||
[random_object_id()], 0, random_task_id(),
|
||||
0, local_scheduler.ObjectID(NIL_ACTOR_ID),
|
||||
local_scheduler.ObjectID(NIL_OBJECT_ID),
|
||||
local_scheduler.ObjectID(NIL_ACTOR_ID),
|
||||
local_scheduler.ObjectID(NIL_ACTOR_ID),
|
||||
0, 0, [], {"CPU": 1, "GPU": 2})
|
||||
task2 = local_scheduler.Task(
|
||||
random_driver_id(), random_function_id(), [random_object_id()], 0,
|
||||
random_task_id(), 0, local_scheduler.ObjectID(NIL_ACTOR_ID),
|
||||
local_scheduler.ObjectID(NIL_OBJECT_ID),
|
||||
local_scheduler.ObjectID(NIL_ACTOR_ID),
|
||||
local_scheduler.ObjectID(NIL_ACTOR_ID), 0, 0, [], {
|
||||
"CPU": 1,
|
||||
"GPU": 2
|
||||
})
|
||||
self.assertEqual(task2.required_resources(), {"CPU": 1, "GPU": 2})
|
||||
|
||||
def test_redis_only_single_task(self):
|
||||
@@ -189,7 +190,7 @@ class TestGlobalScheduler(unittest.TestCase):
|
||||
len(self.state.client_table()[self.node_ip_address]),
|
||||
2 * NUM_CLUSTER_NODES + 1)
|
||||
db_client_id = self.get_plasma_manager_id()
|
||||
assert(db_client_id is not None)
|
||||
assert (db_client_id is not None)
|
||||
|
||||
@unittest.skipIf(
|
||||
os.environ.get('RAY_USE_NEW_GCS', False),
|
||||
@@ -227,9 +228,10 @@ class TestGlobalScheduler(unittest.TestCase):
|
||||
if len(task_entries) == 1:
|
||||
task_id, task = task_entries.popitem()
|
||||
task_status = task["State"]
|
||||
self.assertTrue(task_status in [state.TASK_STATUS_WAITING,
|
||||
state.TASK_STATUS_SCHEDULED,
|
||||
state.TASK_STATUS_QUEUED])
|
||||
self.assertTrue(task_status in [
|
||||
state.TASK_STATUS_WAITING, state.TASK_STATUS_SCHEDULED,
|
||||
state.TASK_STATUS_QUEUED
|
||||
])
|
||||
if task_status == state.TASK_STATUS_QUEUED:
|
||||
break
|
||||
else:
|
||||
@@ -258,17 +260,14 @@ class TestGlobalScheduler(unittest.TestCase):
|
||||
data_size = np.random.randint(1 << 12)
|
||||
metadata_size = np.random.randint(1 << 9)
|
||||
plasma_client = self.plasma_clients[0]
|
||||
object_dep, memory_buffer, metadata = create_object(plasma_client,
|
||||
data_size,
|
||||
metadata_size,
|
||||
seal=True)
|
||||
object_dep, memory_buffer, metadata = create_object(
|
||||
plasma_client, data_size, metadata_size, seal=True)
|
||||
if timesync:
|
||||
# Give 10ms for object info handler to fire (long enough to
|
||||
# yield CPU).
|
||||
time.sleep(0.010)
|
||||
task = local_scheduler.Task(
|
||||
random_driver_id(),
|
||||
random_function_id(),
|
||||
random_driver_id(), random_function_id(),
|
||||
[local_scheduler.ObjectID(object_dep.binary())],
|
||||
num_return_vals[0], random_task_id(), 0)
|
||||
self.local_scheduler_clients[0].submit(task)
|
||||
@@ -281,12 +280,18 @@ class TestGlobalScheduler(unittest.TestCase):
|
||||
self.assertLessEqual(len(task_entries), num_tasks)
|
||||
# First, check if all tasks made it to Redis.
|
||||
if len(task_entries) == num_tasks:
|
||||
task_statuses = [task_entry["State"] for task_entry in
|
||||
task_entries.values()]
|
||||
self.assertTrue(all([status in [state.TASK_STATUS_WAITING,
|
||||
state.TASK_STATUS_SCHEDULED,
|
||||
state.TASK_STATUS_QUEUED]
|
||||
for status in task_statuses]))
|
||||
task_statuses = [
|
||||
task_entry["State"]
|
||||
for task_entry in task_entries.values()
|
||||
]
|
||||
self.assertTrue(
|
||||
all([
|
||||
status in [
|
||||
state.TASK_STATUS_WAITING,
|
||||
state.TASK_STATUS_SCHEDULED,
|
||||
state.TASK_STATUS_QUEUED
|
||||
] for status in task_statuses
|
||||
]))
|
||||
num_tasks_done = task_statuses.count(state.TASK_STATUS_QUEUED)
|
||||
num_tasks_scheduled = task_statuses.count(
|
||||
state.TASK_STATUS_SCHEDULED)
|
||||
@@ -294,12 +299,13 @@ class TestGlobalScheduler(unittest.TestCase):
|
||||
state.TASK_STATUS_WAITING)
|
||||
print("tasks in Redis = {}, tasks waiting = {}, "
|
||||
"tasks scheduled = {}, "
|
||||
"tasks queued = {}, retries left = {}"
|
||||
.format(len(task_entries), num_tasks_waiting,
|
||||
num_tasks_scheduled, num_tasks_done,
|
||||
num_retries))
|
||||
if all([status == state.TASK_STATUS_QUEUED for status in
|
||||
task_statuses]):
|
||||
"tasks queued = {}, retries left = {}".format(
|
||||
len(task_entries), num_tasks_waiting,
|
||||
num_tasks_scheduled, num_tasks_done, num_retries))
|
||||
if all([
|
||||
status == state.TASK_STATUS_QUEUED
|
||||
for status in task_statuses
|
||||
]):
|
||||
# We're done, so pass.
|
||||
break
|
||||
num_retries -= 1
|
||||
|
||||
@@ -7,6 +7,8 @@ from ray.core.src.local_scheduler.liblocal_scheduler_library import (
|
||||
task_to_string, _config, common_error)
|
||||
from .local_scheduler_services import start_local_scheduler
|
||||
|
||||
__all__ = ["Task", "LocalSchedulerClient", "ObjectID", "check_simple_value",
|
||||
"task_from_string", "task_to_string", "start_local_scheduler",
|
||||
"_config", "common_error"]
|
||||
__all__ = [
|
||||
"Task", "LocalSchedulerClient", "ObjectID", "check_simple_value",
|
||||
"task_from_string", "task_to_string", "start_local_scheduler", "_config",
|
||||
"common_error"
|
||||
]
|
||||
|
||||
@@ -68,15 +68,15 @@ def start_local_scheduler(plasma_store_name,
|
||||
"provided.")
|
||||
if use_valgrind and use_profiler:
|
||||
raise Exception("Cannot use valgrind and profiler at the same time.")
|
||||
local_scheduler_executable = os.path.join(os.path.dirname(
|
||||
os.path.abspath(__file__)),
|
||||
local_scheduler_executable = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
"../core/src/local_scheduler/local_scheduler")
|
||||
local_scheduler_name = "/tmp/scheduler{}".format(random_name())
|
||||
command = [local_scheduler_executable,
|
||||
"-s", local_scheduler_name,
|
||||
"-p", plasma_store_name,
|
||||
"-h", node_ip_address,
|
||||
"-n", str(num_workers)]
|
||||
command = [
|
||||
local_scheduler_executable, "-s", local_scheduler_name, "-p",
|
||||
plasma_store_name, "-h", node_ip_address, "-n",
|
||||
str(num_workers)
|
||||
]
|
||||
if plasma_manager_name is not None:
|
||||
command += ["-m", plasma_manager_name]
|
||||
if worker_path is not None:
|
||||
@@ -88,14 +88,11 @@ def start_local_scheduler(plasma_store_name,
|
||||
"--object-store-name={} "
|
||||
"--object-store-manager-name={} "
|
||||
"--local-scheduler-name={} "
|
||||
"--redis-address={}"
|
||||
.format(sys.executable,
|
||||
worker_path,
|
||||
node_ip_address,
|
||||
plasma_store_name,
|
||||
plasma_manager_name,
|
||||
local_scheduler_name,
|
||||
redis_address))
|
||||
"--redis-address={}".format(
|
||||
sys.executable, worker_path,
|
||||
node_ip_address, plasma_store_name,
|
||||
plasma_manager_name, local_scheduler_name,
|
||||
redis_address))
|
||||
command += ["-w", start_worker_command]
|
||||
if redis_address is not None:
|
||||
command += ["-r", redis_address]
|
||||
@@ -104,27 +101,31 @@ def start_local_scheduler(plasma_store_name,
|
||||
if static_resources is not None:
|
||||
resource_argument = ""
|
||||
for resource_name, resource_quantity in static_resources.items():
|
||||
assert (isinstance(resource_quantity, int) or
|
||||
isinstance(resource_quantity, float))
|
||||
resource_argument = ",".join(
|
||||
[resource_name + "," + str(resource_quantity)
|
||||
for resource_name, resource_quantity in static_resources.items()])
|
||||
assert (isinstance(resource_quantity, int)
|
||||
or isinstance(resource_quantity, float))
|
||||
resource_argument = ",".join([
|
||||
resource_name + "," + str(resource_quantity)
|
||||
for resource_name, resource_quantity in static_resources.items()
|
||||
])
|
||||
else:
|
||||
resource_argument = "CPU,{}".format(psutil.cpu_count())
|
||||
command += ["-c", resource_argument]
|
||||
|
||||
if use_valgrind:
|
||||
pid = subprocess.Popen(["valgrind",
|
||||
"--track-origins=yes",
|
||||
"--leak-check=full",
|
||||
"--show-leak-kinds=all",
|
||||
"--leak-check-heuristics=stdstring",
|
||||
"--error-exitcode=1"] + command,
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
pid = subprocess.Popen(
|
||||
[
|
||||
"valgrind", "--track-origins=yes", "--leak-check=full",
|
||||
"--show-leak-kinds=all", "--leak-check-heuristics=stdstring",
|
||||
"--error-exitcode=1"
|
||||
] + command,
|
||||
stdout=stdout_file,
|
||||
stderr=stderr_file)
|
||||
time.sleep(1.0)
|
||||
elif use_profiler:
|
||||
pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command,
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
pid = subprocess.Popen(
|
||||
["valgrind", "--tool=callgrind"] + command,
|
||||
stdout=stdout_file,
|
||||
stderr=stderr_file)
|
||||
time.sleep(1.0)
|
||||
else:
|
||||
pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
|
||||
|
||||
@@ -37,7 +37,6 @@ def random_function_id():
|
||||
|
||||
|
||||
class TestLocalSchedulerClient(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Start Plasma store.
|
||||
plasma_store_name, self.p1 = plasma.start_plasma_store()
|
||||
@@ -74,34 +73,17 @@ class TestLocalSchedulerClient(unittest.TestCase):
|
||||
self.plasma_client.create(pa.plasma.ObjectID(object_id.id()), 0)
|
||||
self.plasma_client.seal(pa.plasma.ObjectID(object_id.id()))
|
||||
# Define some arguments to use for the tasks.
|
||||
args_list = [
|
||||
[],
|
||||
[{}],
|
||||
[()],
|
||||
1 * [1],
|
||||
10 * [1],
|
||||
100 * [1],
|
||||
1000 * [1],
|
||||
1 * ["a"],
|
||||
10 * ["a"],
|
||||
100 * ["a"],
|
||||
1000 * ["a"],
|
||||
[1, 1.3, 1 << 100, "hi", u"hi", [1, 2]],
|
||||
object_ids[:1],
|
||||
object_ids[:2],
|
||||
object_ids[:3],
|
||||
object_ids[:4],
|
||||
object_ids[:5],
|
||||
object_ids[:10],
|
||||
object_ids[:100],
|
||||
object_ids[:256],
|
||||
[1, object_ids[0]],
|
||||
[object_ids[0], "a"],
|
||||
[1, object_ids[0], "a"],
|
||||
[object_ids[0], 1, object_ids[1], "a"],
|
||||
object_ids[:3] + [1, "hi", 2.3] + object_ids[:5],
|
||||
object_ids + 100 * ["a"] + object_ids
|
||||
]
|
||||
args_list = [[], [{}], [()], 1 * [1], 10 * [1], 100 * [1], 1000 * [1],
|
||||
1 * ["a"], 10 * ["a"], 100 * ["a"], 1000 * ["a"], [
|
||||
1, 1.3, 1 << 100, "hi", u"hi", [1, 2]
|
||||
], object_ids[:1], object_ids[:2], object_ids[:3],
|
||||
object_ids[:4], object_ids[:5], object_ids[:10],
|
||||
object_ids[:100], object_ids[:256], [1, object_ids[0]], [
|
||||
object_ids[0], "a"
|
||||
], [1, object_ids[0], "a"], [
|
||||
object_ids[0], 1, object_ids[1], "a"
|
||||
], object_ids[:3] + [1, "hi", 2.3] + object_ids[:5],
|
||||
object_ids + 100 * ["a"] + object_ids]
|
||||
|
||||
for args in args_list:
|
||||
for num_return_vals in [0, 1, 2, 3, 5, 10, 100]:
|
||||
@@ -146,6 +128,7 @@ class TestLocalSchedulerClient(unittest.TestCase):
|
||||
# Launch a thread to get the task.
|
||||
def get_task():
|
||||
self.local_scheduler_client.get_task()
|
||||
|
||||
t = threading.Thread(target=get_task)
|
||||
t.start()
|
||||
# Sleep to give the thread time to call get_task.
|
||||
@@ -170,6 +153,7 @@ class TestLocalSchedulerClient(unittest.TestCase):
|
||||
# Launch a thread to get the task.
|
||||
def get_task():
|
||||
self.local_scheduler_client.get_task()
|
||||
|
||||
t = threading.Thread(target=get_task)
|
||||
t.start()
|
||||
|
||||
|
||||
+21
-14
@@ -26,11 +26,12 @@ class LogMonitor(object):
|
||||
log_file_handles: A dictionary mapping the name of a log file to a file
|
||||
handle for that file.
|
||||
"""
|
||||
|
||||
def __init__(self, redis_ip_address, redis_port, node_ip_address):
|
||||
"""Initialize the log monitor object."""
|
||||
self.node_ip_address = node_ip_address
|
||||
self.redis_client = redis.StrictRedis(host=redis_ip_address,
|
||||
port=redis_port)
|
||||
self.redis_client = redis.StrictRedis(
|
||||
host=redis_ip_address, port=redis_port)
|
||||
self.log_files = {}
|
||||
self.log_file_handles = {}
|
||||
self.files_to_ignore = set()
|
||||
@@ -38,9 +39,8 @@ class LogMonitor(object):
|
||||
def update_log_filenames(self):
|
||||
"""Get the most up-to-date list of log files to monitor from Redis."""
|
||||
num_current_log_files = len(self.log_files)
|
||||
new_log_filenames = self.redis_client.lrange(
|
||||
"LOG_FILENAMES:{}".format(self.node_ip_address),
|
||||
num_current_log_files, -1)
|
||||
new_log_filenames = self.redis_client.lrange("LOG_FILENAMES:{}".format(
|
||||
self.node_ip_address), num_current_log_files, -1)
|
||||
for log_filename in new_log_filenames:
|
||||
print("Beginning to track file {}".format(log_filename))
|
||||
assert log_filename not in self.log_files
|
||||
@@ -78,8 +78,8 @@ class LogMonitor(object):
|
||||
# Try to open this file for the first time.
|
||||
else:
|
||||
try:
|
||||
self.log_file_handles[log_filename] = open(log_filename,
|
||||
"r")
|
||||
self.log_file_handles[log_filename] = open(
|
||||
log_filename, "r")
|
||||
except IOError as e:
|
||||
if e.errno == os.errno.EMFILE:
|
||||
print("Warning: Ignoring {} because there are too "
|
||||
@@ -106,13 +106,20 @@ class LogMonitor(object):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=("Parse Redis server for the "
|
||||
"log monitor to connect "
|
||||
"to."))
|
||||
parser.add_argument("--redis-address", required=True, type=str,
|
||||
help="The address to use for Redis.")
|
||||
parser.add_argument("--node-ip-address", required=True, type=str,
|
||||
help="The IP address of the node this process is on.")
|
||||
parser = argparse.ArgumentParser(
|
||||
description=("Parse Redis server for the "
|
||||
"log monitor to connect "
|
||||
"to."))
|
||||
parser.add_argument(
|
||||
"--redis-address",
|
||||
required=True,
|
||||
type=str,
|
||||
help="The address to use for Redis.")
|
||||
parser.add_argument(
|
||||
"--node-ip-address",
|
||||
required=True,
|
||||
type=str,
|
||||
help="The IP address of the node this process is on.")
|
||||
args = parser.parse_args()
|
||||
|
||||
redis_ip_address = get_ip_address(args.redis_address)
|
||||
|
||||
+9
-10
@@ -100,8 +100,8 @@ class Monitor(object):
|
||||
self.local_scheduler_id_to_ip_map = dict()
|
||||
self.load_metrics = LoadMetrics()
|
||||
if autoscaling_config:
|
||||
self.autoscaler = StandardAutoscaler(
|
||||
autoscaling_config, self.load_metrics)
|
||||
self.autoscaler = StandardAutoscaler(autoscaling_config,
|
||||
self.load_metrics)
|
||||
else:
|
||||
self.autoscaler = None
|
||||
|
||||
@@ -160,11 +160,9 @@ class Monitor(object):
|
||||
# task as lost.
|
||||
key = binary_to_object_id(hex_to_binary(task_id))
|
||||
ok = self.state._execute_command(
|
||||
key, "RAY.TASK_TABLE_UPDATE",
|
||||
hex_to_binary(task_id),
|
||||
key, "RAY.TASK_TABLE_UPDATE", hex_to_binary(task_id),
|
||||
ray.experimental.state.TASK_STATUS_LOST, NIL_ID,
|
||||
task["ExecutionDependenciesString"],
|
||||
task["SpillbackCount"])
|
||||
task["ExecutionDependenciesString"], task["SpillbackCount"])
|
||||
if ok != b"OK":
|
||||
log.warn("Failed to update lost task for dead scheduler.")
|
||||
num_tasks_updated += 1
|
||||
@@ -428,8 +426,8 @@ class Monitor(object):
|
||||
"""
|
||||
message = DriverTableMessage.GetRootAsDriverTableMessage(data, 0)
|
||||
driver_id = message.DriverId()
|
||||
log.info(
|
||||
"Driver {} has been removed.".format(binary_to_hex(driver_id)))
|
||||
log.info("Driver {} has been removed.".format(
|
||||
binary_to_hex(driver_id)))
|
||||
|
||||
self._clean_up_entries_for_driver(driver_id)
|
||||
|
||||
@@ -571,8 +569,9 @@ class Monitor(object):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=("Parse Redis server for the "
|
||||
"monitor to connect to."))
|
||||
parser = argparse.ArgumentParser(
|
||||
description=("Parse Redis server for the "
|
||||
"monitor to connect to."))
|
||||
parser.add_argument(
|
||||
"--redis-address",
|
||||
required=True,
|
||||
|
||||
@@ -5,5 +5,6 @@ from __future__ import print_function
|
||||
from ray.plasma.plasma import (start_plasma_store, start_plasma_manager,
|
||||
DEFAULT_PLASMA_STORE_MEMORY)
|
||||
|
||||
__all__ = ["start_plasma_store", "start_plasma_manager",
|
||||
"DEFAULT_PLASMA_STORE_MEMORY"]
|
||||
__all__ = [
|
||||
"start_plasma_store", "start_plasma_manager", "DEFAULT_PLASMA_STORE_MEMORY"
|
||||
]
|
||||
|
||||
+65
-46
@@ -8,13 +8,13 @@ import subprocess
|
||||
import sys
|
||||
import time
|
||||
|
||||
__all__ = [
|
||||
"start_plasma_store", "start_plasma_manager", "DEFAULT_PLASMA_STORE_MEMORY"
|
||||
]
|
||||
|
||||
__all__ = ["start_plasma_store", "start_plasma_manager",
|
||||
"DEFAULT_PLASMA_STORE_MEMORY"]
|
||||
PLASMA_WAIT_TIMEOUT = 2**30
|
||||
|
||||
PLASMA_WAIT_TIMEOUT = 2 ** 30
|
||||
|
||||
DEFAULT_PLASMA_STORE_MEMORY = 10 ** 9
|
||||
DEFAULT_PLASMA_STORE_MEMORY = 10**9
|
||||
|
||||
|
||||
def random_name():
|
||||
@@ -22,9 +22,12 @@ def random_name():
|
||||
|
||||
|
||||
def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY,
|
||||
use_valgrind=False, use_profiler=False,
|
||||
stdout_file=None, stderr_file=None,
|
||||
plasma_directory=None, huge_pages=False):
|
||||
use_valgrind=False,
|
||||
use_profiler=False,
|
||||
stdout_file=None,
|
||||
stderr_file=None,
|
||||
plasma_directory=None,
|
||||
huge_pages=False):
|
||||
"""Start a plasma store process.
|
||||
|
||||
Args:
|
||||
@@ -48,8 +51,8 @@ def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY,
|
||||
if use_valgrind and use_profiler:
|
||||
raise Exception("Cannot use valgrind and profiler at the same time.")
|
||||
|
||||
if huge_pages and not (sys.platform == "linux" or
|
||||
sys.platform == "linux2"):
|
||||
if huge_pages and not (sys.platform == "linux"
|
||||
or sys.platform == "linux2"):
|
||||
raise Exception("The huge_pages argument is only supported on "
|
||||
"Linux.")
|
||||
|
||||
@@ -57,29 +60,33 @@ def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY,
|
||||
raise Exception("If huge_pages is True, then the "
|
||||
"plasma_directory argument must be provided.")
|
||||
|
||||
plasma_store_executable = os.path.join(os.path.abspath(
|
||||
os.path.dirname(__file__)),
|
||||
plasma_store_executable = os.path.join(
|
||||
os.path.abspath(os.path.dirname(__file__)),
|
||||
"../core/src/plasma/plasma_store")
|
||||
plasma_store_name = "/tmp/plasma_store{}".format(random_name())
|
||||
command = [plasma_store_executable,
|
||||
"-s", plasma_store_name,
|
||||
"-m", str(plasma_store_memory)]
|
||||
command = [
|
||||
plasma_store_executable, "-s", plasma_store_name, "-m",
|
||||
str(plasma_store_memory)
|
||||
]
|
||||
if plasma_directory is not None:
|
||||
command += ["-d", plasma_directory]
|
||||
if huge_pages:
|
||||
command += ["-h"]
|
||||
if use_valgrind:
|
||||
pid = subprocess.Popen(["valgrind",
|
||||
"--track-origins=yes",
|
||||
"--leak-check=full",
|
||||
"--show-leak-kinds=all",
|
||||
"--leak-check-heuristics=stdstring",
|
||||
"--error-exitcode=1"] + command,
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
pid = subprocess.Popen(
|
||||
[
|
||||
"valgrind", "--track-origins=yes", "--leak-check=full",
|
||||
"--show-leak-kinds=all", "--leak-check-heuristics=stdstring",
|
||||
"--error-exitcode=1"
|
||||
] + command,
|
||||
stdout=stdout_file,
|
||||
stderr=stderr_file)
|
||||
time.sleep(1.0)
|
||||
elif use_profiler:
|
||||
pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command,
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
pid = subprocess.Popen(
|
||||
["valgrind", "--tool=callgrind"] + command,
|
||||
stdout=stdout_file,
|
||||
stderr=stderr_file)
|
||||
time.sleep(1.0)
|
||||
else:
|
||||
pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
|
||||
@@ -91,10 +98,14 @@ def new_port():
|
||||
return random.randint(10000, 65535)
|
||||
|
||||
|
||||
def start_plasma_manager(store_name, redis_address,
|
||||
node_ip_address="127.0.0.1", plasma_manager_port=None,
|
||||
num_retries=20, use_valgrind=False,
|
||||
run_profiler=False, stdout_file=None,
|
||||
def start_plasma_manager(store_name,
|
||||
redis_address,
|
||||
node_ip_address="127.0.0.1",
|
||||
plasma_manager_port=None,
|
||||
num_retries=20,
|
||||
use_valgrind=False,
|
||||
run_profiler=False,
|
||||
stdout_file=None,
|
||||
stderr_file=None):
|
||||
"""Start a plasma manager and return the ports it listens on.
|
||||
|
||||
@@ -133,27 +144,35 @@ def start_plasma_manager(store_name, redis_address,
|
||||
while counter < num_retries:
|
||||
if counter > 0:
|
||||
print("Plasma manager failed to start, retrying now.")
|
||||
command = [plasma_manager_executable,
|
||||
"-s", store_name,
|
||||
"-m", plasma_manager_name,
|
||||
"-h", node_ip_address,
|
||||
"-p", str(plasma_manager_port),
|
||||
"-r", redis_address,
|
||||
]
|
||||
command = [
|
||||
plasma_manager_executable,
|
||||
"-s",
|
||||
store_name,
|
||||
"-m",
|
||||
plasma_manager_name,
|
||||
"-h",
|
||||
node_ip_address,
|
||||
"-p",
|
||||
str(plasma_manager_port),
|
||||
"-r",
|
||||
redis_address,
|
||||
]
|
||||
if use_valgrind:
|
||||
process = subprocess.Popen(["valgrind",
|
||||
"--track-origins=yes",
|
||||
"--leak-check=full",
|
||||
"--show-leak-kinds=all",
|
||||
"--error-exitcode=1"] + command,
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
process = subprocess.Popen(
|
||||
[
|
||||
"valgrind", "--track-origins=yes", "--leak-check=full",
|
||||
"--show-leak-kinds=all", "--error-exitcode=1"
|
||||
] + command,
|
||||
stdout=stdout_file,
|
||||
stderr=stderr_file)
|
||||
elif run_profiler:
|
||||
process = subprocess.Popen((["valgrind", "--tool=callgrind"] +
|
||||
command),
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
process = subprocess.Popen(
|
||||
(["valgrind", "--tool=callgrind"] + command),
|
||||
stdout=stdout_file,
|
||||
stderr=stderr_file)
|
||||
else:
|
||||
process = subprocess.Popen(command, stdout=stdout_file,
|
||||
stderr=stderr_file)
|
||||
process = subprocess.Popen(
|
||||
command, stdout=stdout_file, stderr=stderr_file)
|
||||
# This sleep is critical. If the plasma_manager fails to start because
|
||||
# the port is already in use, then we need it to fail within 0.1
|
||||
# seconds.
|
||||
|
||||
+143
-107
@@ -16,8 +16,8 @@ import unittest
|
||||
# The ray import must come before the pyarrow import because ray modifies the
|
||||
# python path so that the right version of pyarrow is found.
|
||||
import ray
|
||||
from ray.plasma.utils import (random_object_id,
|
||||
create_object_with_id, create_object)
|
||||
from ray.plasma.utils import (random_object_id, create_object_with_id,
|
||||
create_object)
|
||||
from ray import services
|
||||
import pyarrow as pa
|
||||
import pyarrow.plasma as plasma
|
||||
@@ -30,8 +30,12 @@ def random_name():
|
||||
return str(random.randint(0, 99999999))
|
||||
|
||||
|
||||
def assert_get_object_equal(unit_test, client1, client2, object_id,
|
||||
memory_buffer=None, metadata=None):
|
||||
def assert_get_object_equal(unit_test,
|
||||
client1,
|
||||
client2,
|
||||
object_id,
|
||||
memory_buffer=None,
|
||||
metadata=None):
|
||||
client1_buff = client1.get_buffers([object_id])[0]
|
||||
client2_buff = client2.get_buffers([object_id])[0]
|
||||
client1_metadata = client1.get_metadata([object_id])[0]
|
||||
@@ -39,27 +43,33 @@ def assert_get_object_equal(unit_test, client1, client2, object_id,
|
||||
unit_test.assertEqual(len(client1_buff), len(client2_buff))
|
||||
unit_test.assertEqual(len(client1_metadata), len(client2_metadata))
|
||||
# Check that the buffers from the two clients are the same.
|
||||
assert_equal(np.frombuffer(client1_buff, dtype="uint8"),
|
||||
np.frombuffer(client2_buff, dtype="uint8"))
|
||||
assert_equal(
|
||||
np.frombuffer(client1_buff, dtype="uint8"),
|
||||
np.frombuffer(client2_buff, dtype="uint8"))
|
||||
# Check that the metadata buffers from the two clients are the same.
|
||||
assert_equal(np.frombuffer(client1_metadata, dtype="uint8"),
|
||||
np.frombuffer(client2_metadata, dtype="uint8"))
|
||||
assert_equal(
|
||||
np.frombuffer(client1_metadata, dtype="uint8"),
|
||||
np.frombuffer(client2_metadata, dtype="uint8"))
|
||||
# If a reference buffer was provided, check that it is the same as well.
|
||||
if memory_buffer is not None:
|
||||
assert_equal(np.frombuffer(memory_buffer, dtype="uint8"),
|
||||
np.frombuffer(client1_buff, dtype="uint8"))
|
||||
assert_equal(
|
||||
np.frombuffer(memory_buffer, dtype="uint8"),
|
||||
np.frombuffer(client1_buff, dtype="uint8"))
|
||||
# If reference metadata was provided, check that it is the same as well.
|
||||
if metadata is not None:
|
||||
assert_equal(np.frombuffer(metadata, dtype="uint8"),
|
||||
np.frombuffer(client1_metadata, dtype="uint8"))
|
||||
assert_equal(
|
||||
np.frombuffer(metadata, dtype="uint8"),
|
||||
np.frombuffer(client1_metadata, dtype="uint8"))
|
||||
|
||||
|
||||
DEFAULT_PLASMA_STORE_MEMORY = 10 ** 9
|
||||
DEFAULT_PLASMA_STORE_MEMORY = 10**9
|
||||
|
||||
|
||||
def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY,
|
||||
use_valgrind=False, use_profiler=False,
|
||||
stdout_file=None, stderr_file=None):
|
||||
use_valgrind=False,
|
||||
use_profiler=False,
|
||||
stdout_file=None,
|
||||
stderr_file=None):
|
||||
"""Start a plasma store process.
|
||||
Args:
|
||||
use_valgrind (bool): True if the plasma store should be started inside
|
||||
@@ -78,21 +88,25 @@ def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY,
|
||||
raise Exception("Cannot use valgrind and profiler at the same time.")
|
||||
plasma_store_executable = os.path.join(pa.__path__[0], "plasma_store")
|
||||
plasma_store_name = "/tmp/plasma_store{}".format(random_name())
|
||||
command = [plasma_store_executable,
|
||||
"-s", plasma_store_name,
|
||||
"-m", str(plasma_store_memory)]
|
||||
command = [
|
||||
plasma_store_executable, "-s", plasma_store_name, "-m",
|
||||
str(plasma_store_memory)
|
||||
]
|
||||
if use_valgrind:
|
||||
pid = subprocess.Popen(["valgrind",
|
||||
"--track-origins=yes",
|
||||
"--leak-check=full",
|
||||
"--show-leak-kinds=all",
|
||||
"--leak-check-heuristics=stdstring",
|
||||
"--error-exitcode=1"] + command,
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
pid = subprocess.Popen(
|
||||
[
|
||||
"valgrind", "--track-origins=yes", "--leak-check=full",
|
||||
"--show-leak-kinds=all", "--leak-check-heuristics=stdstring",
|
||||
"--error-exitcode=1"
|
||||
] + command,
|
||||
stdout=stdout_file,
|
||||
stderr=stderr_file)
|
||||
time.sleep(1.0)
|
||||
elif use_profiler:
|
||||
pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command,
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
pid = subprocess.Popen(
|
||||
["valgrind", "--tool=callgrind"] + command,
|
||||
stdout=stdout_file,
|
||||
stderr=stderr_file)
|
||||
time.sleep(1.0)
|
||||
else:
|
||||
pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
|
||||
@@ -104,13 +118,10 @@ def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY,
|
||||
|
||||
|
||||
class TestPlasmaManager(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Start two PlasmaStores.
|
||||
store_name1, self.p2 = start_plasma_store(
|
||||
use_valgrind=USE_VALGRIND)
|
||||
store_name2, self.p3 = start_plasma_store(
|
||||
use_valgrind=USE_VALGRIND)
|
||||
store_name1, self.p2 = start_plasma_store(use_valgrind=USE_VALGRIND)
|
||||
store_name2, self.p3 = start_plasma_store(use_valgrind=USE_VALGRIND)
|
||||
# Start a Redis server.
|
||||
redis_address, _ = services.start_redis("127.0.0.1")
|
||||
# Start two PlasmaManagers.
|
||||
@@ -152,9 +163,8 @@ class TestPlasmaManager(unittest.TestCase):
|
||||
def test_fetch(self):
|
||||
for _ in range(10):
|
||||
# Create an object.
|
||||
object_id1, memory_buffer1, metadata1 = create_object(self.client1,
|
||||
2000,
|
||||
2000)
|
||||
object_id1, memory_buffer1, metadata1 = create_object(
|
||||
self.client1, 2000, 2000)
|
||||
self.client1.fetch([object_id1])
|
||||
self.assertEqual(self.client1.contains(object_id1), True)
|
||||
self.assertEqual(self.client2.contains(object_id1), False)
|
||||
@@ -164,18 +174,20 @@ class TestPlasmaManager(unittest.TestCase):
|
||||
while not self.client2.contains(object_id1):
|
||||
self.client2.fetch([object_id1])
|
||||
# Compare the two buffers.
|
||||
assert_get_object_equal(self, self.client1, self.client2,
|
||||
object_id1,
|
||||
memory_buffer=memory_buffer1,
|
||||
metadata=metadata1)
|
||||
assert_get_object_equal(
|
||||
self,
|
||||
self.client1,
|
||||
self.client2,
|
||||
object_id1,
|
||||
memory_buffer=memory_buffer1,
|
||||
metadata=metadata1)
|
||||
|
||||
# Test that we can call fetch on object IDs that don't exist yet.
|
||||
object_id2 = random_object_id()
|
||||
self.client1.fetch([object_id2])
|
||||
self.assertEqual(self.client1.contains(object_id2), False)
|
||||
memory_buffer2, metadata2 = create_object_with_id(self.client2,
|
||||
object_id2,
|
||||
2000, 2000)
|
||||
memory_buffer2, metadata2 = create_object_with_id(
|
||||
self.client2, object_id2, 2000, 2000)
|
||||
# # Check that the object has been fetched.
|
||||
# self.assertEqual(self.client1.contains(object_id2), True)
|
||||
# Compare the two buffers.
|
||||
@@ -190,68 +202,88 @@ class TestPlasmaManager(unittest.TestCase):
|
||||
for _ in range(10):
|
||||
self.client1.fetch([object_id3])
|
||||
self.client2.fetch([object_id3])
|
||||
memory_buffer3, metadata3 = create_object_with_id(self.client1,
|
||||
object_id3,
|
||||
2000, 2000)
|
||||
memory_buffer3, metadata3 = create_object_with_id(
|
||||
self.client1, object_id3, 2000, 2000)
|
||||
for _ in range(10):
|
||||
self.client1.fetch([object_id3])
|
||||
self.client2.fetch([object_id3])
|
||||
# TODO(rkn): Right now we must wait for the object table to be updated.
|
||||
while not self.client2.contains(object_id3):
|
||||
self.client2.fetch([object_id3])
|
||||
assert_get_object_equal(self, self.client1, self.client2, object_id3,
|
||||
memory_buffer=memory_buffer3,
|
||||
metadata=metadata3)
|
||||
assert_get_object_equal(
|
||||
self,
|
||||
self.client1,
|
||||
self.client2,
|
||||
object_id3,
|
||||
memory_buffer=memory_buffer3,
|
||||
metadata=metadata3)
|
||||
|
||||
def test_fetch_multiple(self):
|
||||
for _ in range(20):
|
||||
# Create two objects and a third fake one that doesn't exist.
|
||||
object_id1, memory_buffer1, metadata1 = create_object(self.client1,
|
||||
2000,
|
||||
2000)
|
||||
object_id1, memory_buffer1, metadata1 = create_object(
|
||||
self.client1, 2000, 2000)
|
||||
missing_object_id = random_object_id()
|
||||
object_id2, memory_buffer2, metadata2 = create_object(self.client1,
|
||||
2000,
|
||||
2000)
|
||||
object_id2, memory_buffer2, metadata2 = create_object(
|
||||
self.client1, 2000, 2000)
|
||||
object_ids = [object_id1, missing_object_id, object_id2]
|
||||
# Fetch the objects from the other plasma store. The second object
|
||||
# ID should timeout since it does not exist.
|
||||
# TODO(rkn): Right now we must wait for the object table to be
|
||||
# updated.
|
||||
while ((not self.client2.contains(object_id1)) or
|
||||
(not self.client2.contains(object_id2))):
|
||||
while ((not self.client2.contains(object_id1))
|
||||
or (not self.client2.contains(object_id2))):
|
||||
self.client2.fetch(object_ids)
|
||||
# Compare the buffers of the objects that do exist.
|
||||
assert_get_object_equal(self, self.client1, self.client2,
|
||||
object_id1, memory_buffer=memory_buffer1,
|
||||
metadata=metadata1)
|
||||
assert_get_object_equal(self, self.client1, self.client2,
|
||||
object_id2, memory_buffer=memory_buffer2,
|
||||
metadata=metadata2)
|
||||
assert_get_object_equal(
|
||||
self,
|
||||
self.client1,
|
||||
self.client2,
|
||||
object_id1,
|
||||
memory_buffer=memory_buffer1,
|
||||
metadata=metadata1)
|
||||
assert_get_object_equal(
|
||||
self,
|
||||
self.client1,
|
||||
self.client2,
|
||||
object_id2,
|
||||
memory_buffer=memory_buffer2,
|
||||
metadata=metadata2)
|
||||
# Fetch in the other direction. The fake object still does not
|
||||
# exist.
|
||||
self.client1.fetch(object_ids)
|
||||
assert_get_object_equal(self, self.client2, self.client1,
|
||||
object_id1, memory_buffer=memory_buffer1,
|
||||
metadata=metadata1)
|
||||
assert_get_object_equal(self, self.client2, self.client1,
|
||||
object_id2, memory_buffer=memory_buffer2,
|
||||
metadata=metadata2)
|
||||
assert_get_object_equal(
|
||||
self,
|
||||
self.client2,
|
||||
self.client1,
|
||||
object_id1,
|
||||
memory_buffer=memory_buffer1,
|
||||
metadata=metadata1)
|
||||
assert_get_object_equal(
|
||||
self,
|
||||
self.client2,
|
||||
self.client1,
|
||||
object_id2,
|
||||
memory_buffer=memory_buffer2,
|
||||
metadata=metadata2)
|
||||
|
||||
# Check that we can call fetch with duplicated object IDs.
|
||||
object_id3 = random_object_id()
|
||||
self.client1.fetch([object_id3, object_id3])
|
||||
object_id4, memory_buffer4, metadata4 = create_object(self.client1,
|
||||
2000,
|
||||
2000)
|
||||
object_id4, memory_buffer4, metadata4 = create_object(
|
||||
self.client1, 2000, 2000)
|
||||
time.sleep(0.1)
|
||||
# TODO(rkn): Right now we must wait for the object table to be updated.
|
||||
while not self.client2.contains(object_id4):
|
||||
self.client2.fetch([object_id3, object_id3, object_id4,
|
||||
object_id4])
|
||||
assert_get_object_equal(self, self.client2, self.client1, object_id4,
|
||||
memory_buffer=memory_buffer4,
|
||||
metadata=metadata4)
|
||||
self.client2.fetch(
|
||||
[object_id3, object_id3, object_id4, object_id4])
|
||||
assert_get_object_equal(
|
||||
self,
|
||||
self.client2,
|
||||
self.client1,
|
||||
object_id4,
|
||||
memory_buffer=memory_buffer4,
|
||||
metadata=metadata4)
|
||||
|
||||
def test_wait(self):
|
||||
# Test timeout.
|
||||
@@ -263,8 +295,8 @@ class TestPlasmaManager(unittest.TestCase):
|
||||
obj_id1 = random_object_id()
|
||||
self.client1.create(obj_id1, 1000)
|
||||
self.client1.seal(obj_id1)
|
||||
ready, waiting = self.client1.wait([obj_id1], timeout=100,
|
||||
num_returns=1)
|
||||
ready, waiting = self.client1.wait(
|
||||
[obj_id1], timeout=100, num_returns=1)
|
||||
self.assertEqual(set(ready), set([obj_id1]))
|
||||
self.assertEqual(waiting, [])
|
||||
|
||||
@@ -273,8 +305,8 @@ class TestPlasmaManager(unittest.TestCase):
|
||||
obj_id2 = random_object_id()
|
||||
self.client1.create(obj_id2, 1000)
|
||||
# Don't seal.
|
||||
ready, waiting = self.client1.wait([obj_id2, obj_id1], timeout=100,
|
||||
num_returns=1)
|
||||
ready, waiting = self.client1.wait(
|
||||
[obj_id2, obj_id1], timeout=100, num_returns=1)
|
||||
self.assertEqual(set(ready), set([obj_id1]))
|
||||
self.assertEqual(set(waiting), set([obj_id2]))
|
||||
|
||||
@@ -287,8 +319,8 @@ class TestPlasmaManager(unittest.TestCase):
|
||||
|
||||
t = threading.Timer(0.1, finish)
|
||||
t.start()
|
||||
ready, waiting = self.client1.wait([obj_id3, obj_id2, obj_id1],
|
||||
timeout=1000, num_returns=2)
|
||||
ready, waiting = self.client1.wait(
|
||||
[obj_id3, obj_id2, obj_id1], timeout=1000, num_returns=2)
|
||||
self.assertEqual(set(ready), set([obj_id1, obj_id3]))
|
||||
self.assertEqual(set(waiting), set([obj_id2]))
|
||||
|
||||
@@ -319,26 +351,26 @@ class TestPlasmaManager(unittest.TestCase):
|
||||
waiting = object_ids
|
||||
retrieved = []
|
||||
for i in range(1, n + 1):
|
||||
ready, waiting = self.client1.wait(waiting, timeout=1000,
|
||||
num_returns=i)
|
||||
ready, waiting = self.client1.wait(
|
||||
waiting, timeout=1000, num_returns=i)
|
||||
self.assertEqual(len(ready), i)
|
||||
retrieved += ready
|
||||
self.assertEqual(set(retrieved), set(object_ids))
|
||||
ready, waiting = self.client1.wait(object_ids, timeout=1000,
|
||||
num_returns=len(object_ids))
|
||||
ready, waiting = self.client1.wait(
|
||||
object_ids, timeout=1000, num_returns=len(object_ids))
|
||||
self.assertEqual(set(ready), set(object_ids))
|
||||
self.assertEqual(waiting, [])
|
||||
# Try waiting for all of the object IDs on the second client.
|
||||
waiting = object_ids
|
||||
retrieved = []
|
||||
for i in range(1, n + 1):
|
||||
ready, waiting = self.client2.wait(waiting, timeout=1000,
|
||||
num_returns=i)
|
||||
ready, waiting = self.client2.wait(
|
||||
waiting, timeout=1000, num_returns=i)
|
||||
self.assertEqual(len(ready), i)
|
||||
retrieved += ready
|
||||
self.assertEqual(set(retrieved), set(object_ids))
|
||||
ready, waiting = self.client2.wait(object_ids, timeout=1000,
|
||||
num_returns=len(object_ids))
|
||||
ready, waiting = self.client2.wait(
|
||||
object_ids, timeout=1000, num_returns=len(object_ids))
|
||||
self.assertEqual(set(ready), set(object_ids))
|
||||
self.assertEqual(waiting, [])
|
||||
|
||||
@@ -363,9 +395,8 @@ class TestPlasmaManager(unittest.TestCase):
|
||||
num_attempts = 100
|
||||
for _ in range(100):
|
||||
# Create an object.
|
||||
object_id1, memory_buffer1, metadata1 = create_object(self.client1,
|
||||
2000,
|
||||
2000)
|
||||
object_id1, memory_buffer1, metadata1 = create_object(
|
||||
self.client1, 2000, 2000)
|
||||
# Transfer the buffer to the the other Plasma store. There is a
|
||||
# race condition on the create and transfer of the object, so keep
|
||||
# trying until the object appears on the second Plasma store.
|
||||
@@ -379,9 +410,13 @@ class TestPlasmaManager(unittest.TestCase):
|
||||
del buff
|
||||
|
||||
# Compare the two buffers.
|
||||
assert_get_object_equal(self, self.client1, self.client2,
|
||||
object_id1, memory_buffer=memory_buffer1,
|
||||
metadata=metadata1)
|
||||
assert_get_object_equal(
|
||||
self,
|
||||
self.client1,
|
||||
self.client2,
|
||||
object_id1,
|
||||
memory_buffer=memory_buffer1,
|
||||
metadata=metadata1)
|
||||
# # Transfer the buffer again.
|
||||
# self.client1.transfer("127.0.0.1", self.port2, object_id1)
|
||||
# # Compare the two buffers.
|
||||
@@ -391,8 +426,8 @@ class TestPlasmaManager(unittest.TestCase):
|
||||
# metadata=metadata1)
|
||||
|
||||
# Create an object.
|
||||
object_id2, memory_buffer2, metadata2 = create_object(self.client2,
|
||||
20000, 20000)
|
||||
object_id2, memory_buffer2, metadata2 = create_object(
|
||||
self.client2, 20000, 20000)
|
||||
# Transfer the buffer to the the other Plasma store. There is a
|
||||
# race condition on the create and transfer of the object, so keep
|
||||
# trying until the object appears on the second Plasma store.
|
||||
@@ -406,9 +441,13 @@ class TestPlasmaManager(unittest.TestCase):
|
||||
del buff
|
||||
|
||||
# Compare the two buffers.
|
||||
assert_get_object_equal(self, self.client1, self.client2,
|
||||
object_id2, memory_buffer=memory_buffer2,
|
||||
metadata=metadata2)
|
||||
assert_get_object_equal(
|
||||
self,
|
||||
self.client1,
|
||||
self.client2,
|
||||
object_id2,
|
||||
memory_buffer=memory_buffer2,
|
||||
metadata=metadata2)
|
||||
|
||||
def test_illegal_functionality(self):
|
||||
# Create an object id string.
|
||||
@@ -437,7 +476,6 @@ class TestPlasmaManager(unittest.TestCase):
|
||||
|
||||
|
||||
class TestPlasmaManagerRecovery(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Start a Plasma store.
|
||||
self.store_name, self.p2 = start_plasma_store(
|
||||
@@ -446,9 +484,7 @@ class TestPlasmaManagerRecovery(unittest.TestCase):
|
||||
self.redis_address, _ = services.start_redis("127.0.0.1")
|
||||
# Start a PlasmaManagers.
|
||||
manager_name, self.p3, self.port1 = ray.plasma.start_plasma_manager(
|
||||
self.store_name,
|
||||
self.redis_address,
|
||||
use_valgrind=USE_VALGRIND)
|
||||
self.store_name, self.redis_address, use_valgrind=USE_VALGRIND)
|
||||
# Connect a PlasmaClient.
|
||||
self.client = plasma.connect(self.store_name, manager_name, 64)
|
||||
|
||||
@@ -501,8 +537,8 @@ class TestPlasmaManagerRecovery(unittest.TestCase):
|
||||
client2 = plasma.connect(self.store_name, manager_name, 64)
|
||||
ready, waiting = [], object_ids
|
||||
while True:
|
||||
ready, waiting = client2.wait(object_ids, num_returns=num_objects,
|
||||
timeout=0)
|
||||
ready, waiting = client2.wait(
|
||||
object_ids, num_returns=num_objects, timeout=0)
|
||||
if len(ready) == len(object_ids):
|
||||
break
|
||||
|
||||
|
||||
@@ -18,8 +18,8 @@ def generate_metadata(length):
|
||||
metadata_buffer[0] = random.randint(0, 255)
|
||||
metadata_buffer[-1] = random.randint(0, 255)
|
||||
for _ in range(100):
|
||||
metadata_buffer[random.randint(0, length - 1)] = (
|
||||
random.randint(0, 255))
|
||||
metadata_buffer[random.randint(0, length - 1)] = (random.randint(
|
||||
0, 255))
|
||||
return metadata_buffer
|
||||
|
||||
|
||||
@@ -32,7 +32,10 @@ def write_to_data_buffer(buff, length):
|
||||
array[random.randint(0, length - 1)] = random.randint(0, 255)
|
||||
|
||||
|
||||
def create_object_with_id(client, object_id, data_size, metadata_size,
|
||||
def create_object_with_id(client,
|
||||
object_id,
|
||||
data_size,
|
||||
metadata_size,
|
||||
seal=True):
|
||||
metadata = generate_metadata(metadata_size)
|
||||
memory_buffer = client.create(object_id, data_size, metadata)
|
||||
@@ -44,7 +47,6 @@ def create_object_with_id(client, object_id, data_size, metadata_size,
|
||||
|
||||
def create_object(client, data_size, metadata_size, seal=True):
|
||||
object_id = random_object_id()
|
||||
memory_buffer, metadata = create_object_with_id(client, object_id,
|
||||
data_size, metadata_size,
|
||||
seal=seal)
|
||||
memory_buffer, metadata = create_object_with_id(
|
||||
client, object_id, data_size, metadata_size, seal=seal)
|
||||
return object_id, memory_buffer, metadata
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
"""Ray constants used in the Python code."""
|
||||
|
||||
|
||||
# Abort autoscaling if more than this number of errors are encountered. This
|
||||
# is a safety feature to prevent e.g. runaway node launches.
|
||||
AUTOSCALER_MAX_NUM_FAILURES = 5
|
||||
|
||||
+172
-81
@@ -7,8 +7,8 @@ import json
|
||||
import subprocess
|
||||
|
||||
import ray.services as services
|
||||
from ray.autoscaler.commands import (
|
||||
create_or_update_cluster, teardown_cluster, get_head_node_ip)
|
||||
from ray.autoscaler.commands import (create_or_update_cluster,
|
||||
teardown_cluster, get_head_node_ip)
|
||||
|
||||
|
||||
def check_no_existing_redis_clients(node_ip_address, redis_client):
|
||||
@@ -41,58 +41,116 @@ def cli():
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("--node-ip-address", required=False, type=str,
|
||||
help="the IP address of this node")
|
||||
@click.option("--redis-address", required=False, type=str,
|
||||
help="the address to use for connecting to Redis")
|
||||
@click.option("--redis-port", required=False, type=str,
|
||||
help="the port to use for starting Redis")
|
||||
@click.option("--num-redis-shards", required=False, type=int,
|
||||
help=("the number of additional Redis shards to use in "
|
||||
"addition to the primary Redis shard"))
|
||||
@click.option("--redis-max-clients", required=False, type=int,
|
||||
help=("If provided, attempt to configure Redis with this "
|
||||
"maximum number of clients."))
|
||||
@click.option("--redis-shard-ports", required=False, type=str,
|
||||
help="the port to use for the Redis shards other than the "
|
||||
"primary Redis shard")
|
||||
@click.option("--object-manager-port", required=False, type=int,
|
||||
help="the port to use for starting the object manager")
|
||||
@click.option("--object-store-memory", required=False, type=int,
|
||||
help="the maximum amount of memory (in bytes) to allow the "
|
||||
"object store to use")
|
||||
@click.option("--num-workers", required=False, type=int,
|
||||
help=("The initial number of workers to start on this node, "
|
||||
"note that the local scheduler may start additional "
|
||||
"workers. If you wish to control the total number of "
|
||||
"concurent tasks, then use --resources instead and "
|
||||
"specify the CPU field."))
|
||||
@click.option("--num-cpus", required=False, type=int,
|
||||
help="the number of CPUs on this node")
|
||||
@click.option("--num-gpus", required=False, type=int,
|
||||
help="the number of GPUs on this node")
|
||||
@click.option("--resources", required=False, default="{}", type=str,
|
||||
help="a JSON serialized dictionary mapping resource name to "
|
||||
"resource quantity")
|
||||
@click.option("--head", is_flag=True, default=False,
|
||||
help="provide this argument for the head node")
|
||||
@click.option("--no-ui", is_flag=True, default=False,
|
||||
help="provide this argument if the UI should not be started")
|
||||
@click.option("--block", is_flag=True, default=False,
|
||||
help="provide this argument to block forever in this command")
|
||||
@click.option("--plasma-directory", required=False, type=str,
|
||||
help="object store directory for memory mapped files")
|
||||
@click.option("--huge-pages", is_flag=True, default=False,
|
||||
help="enable support for huge pages in the object store")
|
||||
@click.option("--autoscaling-config", required=False, type=str,
|
||||
help="the file that contains the autoscaling config")
|
||||
@click.option("--use-raylet", is_flag=True, default=False,
|
||||
help="use the raylet code path, this is not supported yet")
|
||||
@click.option(
|
||||
"--node-ip-address",
|
||||
required=False,
|
||||
type=str,
|
||||
help="the IP address of this node")
|
||||
@click.option(
|
||||
"--redis-address",
|
||||
required=False,
|
||||
type=str,
|
||||
help="the address to use for connecting to Redis")
|
||||
@click.option(
|
||||
"--redis-port",
|
||||
required=False,
|
||||
type=str,
|
||||
help="the port to use for starting Redis")
|
||||
@click.option(
|
||||
"--num-redis-shards",
|
||||
required=False,
|
||||
type=int,
|
||||
help=("the number of additional Redis shards to use in "
|
||||
"addition to the primary Redis shard"))
|
||||
@click.option(
|
||||
"--redis-max-clients",
|
||||
required=False,
|
||||
type=int,
|
||||
help=("If provided, attempt to configure Redis with this "
|
||||
"maximum number of clients."))
|
||||
@click.option(
|
||||
"--redis-shard-ports",
|
||||
required=False,
|
||||
type=str,
|
||||
help="the port to use for the Redis shards other than the "
|
||||
"primary Redis shard")
|
||||
@click.option(
|
||||
"--object-manager-port",
|
||||
required=False,
|
||||
type=int,
|
||||
help="the port to use for starting the object manager")
|
||||
@click.option(
|
||||
"--object-store-memory",
|
||||
required=False,
|
||||
type=int,
|
||||
help="the maximum amount of memory (in bytes) to allow the "
|
||||
"object store to use")
|
||||
@click.option(
|
||||
"--num-workers",
|
||||
required=False,
|
||||
type=int,
|
||||
help=("The initial number of workers to start on this node, "
|
||||
"note that the local scheduler may start additional "
|
||||
"workers. If you wish to control the total number of "
|
||||
"concurent tasks, then use --resources instead and "
|
||||
"specify the CPU field."))
|
||||
@click.option(
|
||||
"--num-cpus",
|
||||
required=False,
|
||||
type=int,
|
||||
help="the number of CPUs on this node")
|
||||
@click.option(
|
||||
"--num-gpus",
|
||||
required=False,
|
||||
type=int,
|
||||
help="the number of GPUs on this node")
|
||||
@click.option(
|
||||
"--resources",
|
||||
required=False,
|
||||
default="{}",
|
||||
type=str,
|
||||
help="a JSON serialized dictionary mapping resource name to "
|
||||
"resource quantity")
|
||||
@click.option(
|
||||
"--head",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="provide this argument for the head node")
|
||||
@click.option(
|
||||
"--no-ui",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="provide this argument if the UI should not be started")
|
||||
@click.option(
|
||||
"--block",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="provide this argument to block forever in this command")
|
||||
@click.option(
|
||||
"--plasma-directory",
|
||||
required=False,
|
||||
type=str,
|
||||
help="object store directory for memory mapped files")
|
||||
@click.option(
|
||||
"--huge-pages",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="enable support for huge pages in the object store")
|
||||
@click.option(
|
||||
"--autoscaling-config",
|
||||
required=False,
|
||||
type=str,
|
||||
help="the file that contains the autoscaling config")
|
||||
@click.option(
|
||||
"--use-raylet",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="use the raylet code path, this is not supported yet")
|
||||
def start(node_ip_address, redis_address, redis_port, num_redis_shards,
|
||||
redis_max_clients, redis_shard_ports, object_manager_port,
|
||||
object_store_memory, num_workers, num_cpus, num_gpus, resources,
|
||||
head, no_ui, block, plasma_directory, huge_pages,
|
||||
autoscaling_config, use_raylet):
|
||||
head, no_ui, block, plasma_directory, huge_pages, autoscaling_config,
|
||||
use_raylet):
|
||||
# Convert hostnames to numerical IP address.
|
||||
if node_ip_address is not None:
|
||||
node_ip_address = services.address_to_ip(node_ip_address)
|
||||
@@ -245,33 +303,54 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards,
|
||||
|
||||
@click.command()
|
||||
def stop():
|
||||
subprocess.call(["killall global_scheduler plasma_store plasma_manager "
|
||||
"local_scheduler raylet"], shell=True)
|
||||
subprocess.call(
|
||||
[
|
||||
"killall global_scheduler plasma_store plasma_manager "
|
||||
"local_scheduler raylet"
|
||||
],
|
||||
shell=True)
|
||||
|
||||
# Find the PID of the monitor process and kill it.
|
||||
subprocess.call(["kill $(ps aux | grep monitor.py | grep -v grep | "
|
||||
"awk '{ print $2 }') 2> /dev/null"], shell=True)
|
||||
subprocess.call(
|
||||
[
|
||||
"kill $(ps aux | grep monitor.py | grep -v grep | "
|
||||
"awk '{ print $2 }') 2> /dev/null"
|
||||
],
|
||||
shell=True)
|
||||
|
||||
# Find the PID of the Redis process and kill it.
|
||||
subprocess.call(["kill $(ps aux | grep redis-server | grep -v grep | "
|
||||
"awk '{ print $2 }') 2> /dev/null"], shell=True)
|
||||
subprocess.call(
|
||||
[
|
||||
"kill $(ps aux | grep redis-server | grep -v grep | "
|
||||
"awk '{ print $2 }') 2> /dev/null"
|
||||
],
|
||||
shell=True)
|
||||
|
||||
# Find the PIDs of the worker processes and kill them.
|
||||
subprocess.call(["kill -9 $(ps aux | grep default_worker.py | "
|
||||
"grep -v grep | awk '{ print $2 }') 2> /dev/null"],
|
||||
shell=True)
|
||||
subprocess.call(
|
||||
[
|
||||
"kill -9 $(ps aux | grep default_worker.py | "
|
||||
"grep -v grep | awk '{ print $2 }') 2> /dev/null"
|
||||
],
|
||||
shell=True)
|
||||
|
||||
# Find the PID of the Ray log monitor process and kill it.
|
||||
subprocess.call(["kill $(ps aux | grep log_monitor.py | grep -v grep | "
|
||||
"awk '{ print $2 }') 2> /dev/null"], shell=True)
|
||||
subprocess.call(
|
||||
[
|
||||
"kill $(ps aux | grep log_monitor.py | grep -v grep | "
|
||||
"awk '{ print $2 }') 2> /dev/null"
|
||||
],
|
||||
shell=True)
|
||||
|
||||
# Find the PID of the jupyter process and kill it.
|
||||
try:
|
||||
from notebook.notebookapp import list_running_servers
|
||||
pids = [str(server["pid"]) for server in list_running_servers()
|
||||
if "/tmp/raylogs" in server["notebook_dir"]]
|
||||
subprocess.call(["kill {} 2> /dev/null".format(
|
||||
" ".join(pids))], shell=True)
|
||||
pids = [
|
||||
str(server["pid"]) for server in list_running_servers()
|
||||
if "/tmp/raylogs" in server["notebook_dir"]
|
||||
]
|
||||
subprocess.call(
|
||||
["kill {} 2> /dev/null".format(" ".join(pids))], shell=True)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
@@ -279,29 +358,41 @@ def stop():
|
||||
@click.command()
|
||||
@click.argument("cluster_config_file", required=True, type=str)
|
||||
@click.option(
|
||||
"--no-restart", is_flag=True, default=False, help=(
|
||||
"Whether to skip restarting Ray services during the update. "
|
||||
"This avoids interrupting running jobs."))
|
||||
"--no-restart",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help=("Whether to skip restarting Ray services during the update. "
|
||||
"This avoids interrupting running jobs."))
|
||||
@click.option(
|
||||
"--min-workers", required=False, type=int, help=(
|
||||
"Override the configured min worker node count for the cluster."))
|
||||
"--min-workers",
|
||||
required=False,
|
||||
type=int,
|
||||
help=("Override the configured min worker node count for the cluster."))
|
||||
@click.option(
|
||||
"--max-workers", required=False, type=int, help=(
|
||||
"Override the configured max worker node count for the cluster."))
|
||||
"--max-workers",
|
||||
required=False,
|
||||
type=int,
|
||||
help=("Override the configured max worker node count for the cluster."))
|
||||
@click.option(
|
||||
"--yes", "-y", is_flag=True, default=False, help=(
|
||||
"Don't ask for confirmation."))
|
||||
def create_or_update(
|
||||
cluster_config_file, min_workers, max_workers, no_restart, yes):
|
||||
create_or_update_cluster(
|
||||
cluster_config_file, min_workers, max_workers, no_restart, yes)
|
||||
"--yes",
|
||||
"-y",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help=("Don't ask for confirmation."))
|
||||
def create_or_update(cluster_config_file, min_workers, max_workers, no_restart,
|
||||
yes):
|
||||
create_or_update_cluster(cluster_config_file, min_workers, max_workers,
|
||||
no_restart, yes)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("cluster_config_file", required=True, type=str)
|
||||
@click.option(
|
||||
"--yes", "-y", is_flag=True, default=False, help=(
|
||||
"Don't ask for confirmation."))
|
||||
"--yes",
|
||||
"-y",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help=("Don't ask for confirmation."))
|
||||
def teardown(cluster_config_file, yes):
|
||||
teardown_cluster(cluster_config_file, yes)
|
||||
|
||||
|
||||
+237
-200
@@ -41,16 +41,12 @@ PROCESS_TYPE_WEB_UI = "web_ui"
|
||||
# important because it determines the order in which these processes will be
|
||||
# terminated when Ray exits, and certain orders will cause errors to be logged
|
||||
# to the screen.
|
||||
all_processes = OrderedDict([(PROCESS_TYPE_MONITOR, []),
|
||||
(PROCESS_TYPE_LOG_MONITOR, []),
|
||||
(PROCESS_TYPE_WORKER, []),
|
||||
(PROCESS_TYPE_RAYLET, []),
|
||||
(PROCESS_TYPE_LOCAL_SCHEDULER, []),
|
||||
(PROCESS_TYPE_PLASMA_MANAGER, []),
|
||||
(PROCESS_TYPE_PLASMA_STORE, []),
|
||||
(PROCESS_TYPE_GLOBAL_SCHEDULER, []),
|
||||
(PROCESS_TYPE_REDIS_SERVER, []),
|
||||
(PROCESS_TYPE_WEB_UI, [])],)
|
||||
all_processes = OrderedDict(
|
||||
[(PROCESS_TYPE_MONITOR, []), (PROCESS_TYPE_LOG_MONITOR, []),
|
||||
(PROCESS_TYPE_WORKER, []), (PROCESS_TYPE_RAYLET, []),
|
||||
(PROCESS_TYPE_LOCAL_SCHEDULER, []), (PROCESS_TYPE_PLASMA_MANAGER, []),
|
||||
(PROCESS_TYPE_PLASMA_STORE, []), (PROCESS_TYPE_GLOBAL_SCHEDULER, []),
|
||||
(PROCESS_TYPE_REDIS_SERVER, []), (PROCESS_TYPE_WEB_UI, [])], )
|
||||
|
||||
# True if processes are run in the valgrind profiler.
|
||||
RUN_RAYLET_PROFILER = False
|
||||
@@ -82,17 +78,15 @@ RAYLET_MONITOR_EXECUTABLE = os.path.join(
|
||||
os.path.abspath(os.path.dirname(__file__)),
|
||||
"core/src/ray/raylet/raylet_monitor")
|
||||
RAYLET_EXECUTABLE = os.path.join(
|
||||
os.path.abspath(os.path.dirname(__file__)),
|
||||
"core/src/ray/raylet/raylet")
|
||||
os.path.abspath(os.path.dirname(__file__)), "core/src/ray/raylet/raylet")
|
||||
|
||||
# ObjectStoreAddress tuples contain all information necessary to connect to an
|
||||
# object store. The fields are:
|
||||
# - name: The socket name for the object store
|
||||
# - manager_name: The socket name for the object store manager
|
||||
# - manager_port: The Internet port that the object store manager listens on
|
||||
ObjectStoreAddress = namedtuple("ObjectStoreAddress", ["name",
|
||||
"manager_name",
|
||||
"manager_port"])
|
||||
ObjectStoreAddress = namedtuple("ObjectStoreAddress",
|
||||
["name", "manager_name", "manager_port"])
|
||||
|
||||
|
||||
def address(ip_address, port):
|
||||
@@ -133,8 +127,10 @@ def kill_process(p):
|
||||
if p.poll() is not None:
|
||||
# The process has already terminated.
|
||||
return True
|
||||
if any([RUN_RAYLET_PROFILER, RUN_LOCAL_SCHEDULER_PROFILER,
|
||||
RUN_PLASMA_MANAGER_PROFILER, RUN_PLASMA_STORE_PROFILER]):
|
||||
if any([
|
||||
RUN_RAYLET_PROFILER, RUN_LOCAL_SCHEDULER_PROFILER,
|
||||
RUN_PLASMA_MANAGER_PROFILER, RUN_PLASMA_STORE_PROFILER
|
||||
]):
|
||||
# Give process signal to write profiler data.
|
||||
os.kill(p.pid, signal.SIGINT)
|
||||
# Wait for profiling data to be written.
|
||||
@@ -260,8 +256,8 @@ def record_log_files_in_redis(redis_address, node_ip_address, log_files):
|
||||
for log_file in log_files:
|
||||
if log_file is not None:
|
||||
redis_ip_address, redis_port = redis_address.split(":")
|
||||
redis_client = redis.StrictRedis(host=redis_ip_address,
|
||||
port=redis_port)
|
||||
redis_client = redis.StrictRedis(
|
||||
host=redis_ip_address, port=redis_port)
|
||||
# The name of the key storing the list of log filenames for this IP
|
||||
# address.
|
||||
log_file_list_key = "LOG_FILENAMES:{}".format(node_ip_address)
|
||||
@@ -304,8 +300,8 @@ def wait_for_redis_to_start(redis_ip_address, redis_port, num_retries=5):
|
||||
while counter < num_retries:
|
||||
try:
|
||||
# Run some random command and see if it worked.
|
||||
print("Waiting for redis server at {}:{} to respond..."
|
||||
.format(redis_ip_address, redis_port))
|
||||
print("Waiting for redis server at {}:{} to respond...".format(
|
||||
redis_ip_address, redis_port))
|
||||
redis_client.client_list()
|
||||
except redis.ConnectionError as e:
|
||||
# Wait a little bit.
|
||||
@@ -427,17 +423,19 @@ def start_credis(node_ip_address,
|
||||
"""
|
||||
|
||||
components = ["credis_master", "credis_head", "credis_tail"]
|
||||
modules = [CREDIS_MASTER_MODULE, CREDIS_MEMBER_MODULE,
|
||||
CREDIS_MEMBER_MODULE]
|
||||
modules = [
|
||||
CREDIS_MASTER_MODULE, CREDIS_MEMBER_MODULE, CREDIS_MEMBER_MODULE
|
||||
]
|
||||
ports = []
|
||||
|
||||
for i, component in enumerate(components):
|
||||
stdout_file, stderr_file = new_log_files(
|
||||
component, redirect_output)
|
||||
stdout_file, stderr_file = new_log_files(component, redirect_output)
|
||||
|
||||
new_port, _ = start_redis_instance(
|
||||
node_ip_address=node_ip_address, port=port,
|
||||
stdout_file=stdout_file, stderr_file=stderr_file,
|
||||
node_ip_address=node_ip_address,
|
||||
port=port,
|
||||
stdout_file=stdout_file,
|
||||
stderr_file=stderr_file,
|
||||
cleanup=cleanup,
|
||||
module=modules[i],
|
||||
executable=CREDIS_EXECUTABLE)
|
||||
@@ -456,8 +454,7 @@ def start_credis(node_ip_address,
|
||||
|
||||
# Register credis master in redis
|
||||
redis_ip_address, redis_port = redis_address.split(":")
|
||||
redis_client = redis.StrictRedis(host=redis_ip_address,
|
||||
port=redis_port)
|
||||
redis_client = redis.StrictRedis(host=redis_ip_address, port=redis_port)
|
||||
redis_client.set("credis_address", credis_address)
|
||||
|
||||
return credis_address
|
||||
@@ -509,9 +506,11 @@ def start_redis(node_ip_address,
|
||||
"number of Redis shards.")
|
||||
|
||||
assigned_port, _ = start_redis_instance(
|
||||
node_ip_address=node_ip_address, port=port,
|
||||
node_ip_address=node_ip_address,
|
||||
port=port,
|
||||
redis_max_clients=redis_max_clients,
|
||||
stdout_file=redis_stdout_file, stderr_file=redis_stderr_file,
|
||||
stdout_file=redis_stdout_file,
|
||||
stderr_file=redis_stderr_file,
|
||||
cleanup=cleanup)
|
||||
if port is not None:
|
||||
assert assigned_port == port
|
||||
@@ -540,7 +539,8 @@ def start_redis(node_ip_address,
|
||||
node_ip_address=node_ip_address,
|
||||
port=redis_shard_ports[i],
|
||||
redis_max_clients=redis_max_clients,
|
||||
stdout_file=redis_stdout_file, stderr_file=redis_stderr_file,
|
||||
stdout_file=redis_stdout_file,
|
||||
stderr_file=redis_stderr_file,
|
||||
cleanup=cleanup)
|
||||
if redis_shard_ports[i] is not None:
|
||||
assert redis_shard_port == redis_shard_ports[i]
|
||||
@@ -601,11 +601,13 @@ def start_redis_instance(node_ip_address="127.0.0.1",
|
||||
while counter < num_retries:
|
||||
if counter > 0:
|
||||
print("Redis failed to start, retrying now.")
|
||||
p = subprocess.Popen([executable,
|
||||
"--port", str(port),
|
||||
"--loglevel", "warning",
|
||||
"--loadmodule", module],
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
p = subprocess.Popen(
|
||||
[
|
||||
executable, "--port",
|
||||
str(port), "--loglevel", "warning", "--loadmodule", module
|
||||
],
|
||||
stdout=stdout_file,
|
||||
stderr=stderr_file)
|
||||
time.sleep(0.1)
|
||||
# Check if Redis successfully started (or at least if it the executable
|
||||
# did not exit within 0.1 seconds).
|
||||
@@ -652,8 +654,8 @@ def start_redis_instance(node_ip_address="127.0.0.1",
|
||||
# Increase the hard and soft limits for the redis client pubsub buffer to
|
||||
# 128MB. This is a hack to make it less likely for pubsub messages to be
|
||||
# dropped and for pubsub connections to therefore be killed.
|
||||
cur_config = (redis_client.config_get("client-output-buffer-limit")
|
||||
["client-output-buffer-limit"])
|
||||
cur_config = (redis_client.config_get("client-output-buffer-limit")[
|
||||
"client-output-buffer-limit"])
|
||||
cur_config_list = cur_config.split()
|
||||
assert len(cur_config_list) == 12
|
||||
cur_config_list[8:] = ["pubsub", "134217728", "134217728", "60"]
|
||||
@@ -662,13 +664,17 @@ def start_redis_instance(node_ip_address="127.0.0.1",
|
||||
# Put a time stamp in Redis to indicate when it was started.
|
||||
redis_client.set("redis_start_time", time.time())
|
||||
# Record the log files in Redis.
|
||||
record_log_files_in_redis(address(node_ip_address, port), node_ip_address,
|
||||
[stdout_file, stderr_file])
|
||||
record_log_files_in_redis(
|
||||
address(node_ip_address, port), node_ip_address,
|
||||
[stdout_file, stderr_file])
|
||||
return port, p
|
||||
|
||||
|
||||
def start_log_monitor(redis_address, node_ip_address, stdout_file=None,
|
||||
stderr_file=None, cleanup=cleanup):
|
||||
def start_log_monitor(redis_address,
|
||||
node_ip_address,
|
||||
stdout_file=None,
|
||||
stderr_file=None,
|
||||
cleanup=cleanup):
|
||||
"""Start a log monitor process.
|
||||
|
||||
Args:
|
||||
@@ -684,20 +690,25 @@ def start_log_monitor(redis_address, node_ip_address, stdout_file=None,
|
||||
Python process that imported services exits.
|
||||
"""
|
||||
log_monitor_filepath = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
"log_monitor.py")
|
||||
p = subprocess.Popen([sys.executable, "-u", log_monitor_filepath,
|
||||
"--redis-address", redis_address,
|
||||
"--node-ip-address", node_ip_address],
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
os.path.dirname(os.path.abspath(__file__)), "log_monitor.py")
|
||||
p = subprocess.Popen(
|
||||
[
|
||||
sys.executable, "-u", log_monitor_filepath, "--redis-address",
|
||||
redis_address, "--node-ip-address", node_ip_address
|
||||
],
|
||||
stdout=stdout_file,
|
||||
stderr=stderr_file)
|
||||
if cleanup:
|
||||
all_processes[PROCESS_TYPE_LOG_MONITOR].append(p)
|
||||
record_log_files_in_redis(redis_address, node_ip_address,
|
||||
[stdout_file, stderr_file])
|
||||
|
||||
|
||||
def start_global_scheduler(redis_address, node_ip_address,
|
||||
stdout_file=None, stderr_file=None, cleanup=True):
|
||||
def start_global_scheduler(redis_address,
|
||||
node_ip_address,
|
||||
stdout_file=None,
|
||||
stderr_file=None,
|
||||
cleanup=True):
|
||||
"""Start a global scheduler process.
|
||||
|
||||
Args:
|
||||
@@ -712,10 +723,11 @@ def start_global_scheduler(redis_address, node_ip_address,
|
||||
then this process will be killed by services.cleanup() when the
|
||||
Python process that imported services exits.
|
||||
"""
|
||||
p = global_scheduler.start_global_scheduler(redis_address,
|
||||
node_ip_address,
|
||||
stdout_file=stdout_file,
|
||||
stderr_file=stderr_file)
|
||||
p = global_scheduler.start_global_scheduler(
|
||||
redis_address,
|
||||
node_ip_address,
|
||||
stdout_file=stdout_file,
|
||||
stderr_file=stderr_file)
|
||||
if cleanup:
|
||||
all_processes[PROCESS_TYPE_GLOBAL_SCHEDULER].append(p)
|
||||
record_log_files_in_redis(redis_address, node_ip_address,
|
||||
@@ -737,8 +749,7 @@ def start_ui(redis_address, stdout_file=None, stderr_file=None, cleanup=True):
|
||||
"""
|
||||
new_env = os.environ.copy()
|
||||
notebook_filepath = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
"WebUI.ipynb")
|
||||
os.path.dirname(os.path.abspath(__file__)), "WebUI.ipynb")
|
||||
# We copy the notebook file so that the original doesn't get modified by
|
||||
# the user.
|
||||
random_ui_id = random.randint(0, 100000)
|
||||
@@ -759,19 +770,23 @@ def start_ui(redis_address, stdout_file=None, stderr_file=None, cleanup=True):
|
||||
# We generate the token used for authentication ourselves to avoid
|
||||
# querying the jupyter server.
|
||||
token = binascii.hexlify(os.urandom(24)).decode("ascii")
|
||||
command = ["jupyter", "notebook", "--no-browser",
|
||||
"--port={}".format(port),
|
||||
"--NotebookApp.iopub_data_rate_limit=10000000000",
|
||||
"--NotebookApp.open_browser=False",
|
||||
"--NotebookApp.token={}".format(token)]
|
||||
command = [
|
||||
"jupyter", "notebook", "--no-browser", "--port={}".format(port),
|
||||
"--NotebookApp.iopub_data_rate_limit=10000000000",
|
||||
"--NotebookApp.open_browser=False",
|
||||
"--NotebookApp.token={}".format(token)
|
||||
]
|
||||
# If the user is root, add the --allow-root flag.
|
||||
if os.geteuid() == 0:
|
||||
command.append("--allow-root")
|
||||
|
||||
try:
|
||||
ui_process = subprocess.Popen(command, env=new_env,
|
||||
cwd=new_notebook_directory,
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
ui_process = subprocess.Popen(
|
||||
command,
|
||||
env=new_env,
|
||||
cwd=new_notebook_directory,
|
||||
stdout=stdout_file,
|
||||
stderr=stderr_file)
|
||||
except Exception:
|
||||
print("Failed to start the UI, you may need to run "
|
||||
"'pip install jupyter'.")
|
||||
@@ -836,8 +851,8 @@ def start_local_scheduler(redis_address,
|
||||
|
||||
# Check that the number of GPUs that the local scheduler wants doesn't
|
||||
# excede the amount allowed by CUDA_VISIBLE_DEVICES.
|
||||
if ("GPU" in resources and gpu_ids is not None and
|
||||
resources["GPU"] > len(gpu_ids)):
|
||||
if ("GPU" in resources and gpu_ids is not None
|
||||
and resources["GPU"] > len(gpu_ids)):
|
||||
raise Exception("Attempting to start local scheduler with {} GPUs, "
|
||||
"but CUDA_VISIBLE_DEVICES contains {}.".format(
|
||||
resources["GPU"], gpu_ids))
|
||||
@@ -906,21 +921,14 @@ def start_raylet(redis_address,
|
||||
"--node-ip-address={} "
|
||||
"--object-store-name={} "
|
||||
"--raylet-name={} "
|
||||
"--redis-address={}"
|
||||
.format(sys.executable,
|
||||
worker_path,
|
||||
node_ip_address,
|
||||
plasma_store_name,
|
||||
raylet_name,
|
||||
redis_address))
|
||||
"--redis-address={}".format(
|
||||
sys.executable, worker_path, node_ip_address,
|
||||
plasma_store_name, raylet_name, redis_address))
|
||||
|
||||
command = [RAYLET_EXECUTABLE,
|
||||
raylet_name,
|
||||
plasma_store_name,
|
||||
node_ip_address,
|
||||
gcs_ip_address,
|
||||
gcs_port,
|
||||
start_worker_command]
|
||||
command = [
|
||||
RAYLET_EXECUTABLE, raylet_name, plasma_store_name, node_ip_address,
|
||||
gcs_ip_address, gcs_port, start_worker_command
|
||||
]
|
||||
pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
|
||||
|
||||
if cleanup:
|
||||
@@ -931,12 +939,18 @@ def start_raylet(redis_address,
|
||||
return raylet_name
|
||||
|
||||
|
||||
def start_objstore(node_ip_address, redis_address,
|
||||
object_manager_port=None, store_stdout_file=None,
|
||||
store_stderr_file=None, manager_stdout_file=None,
|
||||
manager_stderr_file=None, objstore_memory=None,
|
||||
cleanup=True, plasma_directory=None,
|
||||
huge_pages=False, use_raylet=False):
|
||||
def start_objstore(node_ip_address,
|
||||
redis_address,
|
||||
object_manager_port=None,
|
||||
store_stdout_file=None,
|
||||
store_stderr_file=None,
|
||||
manager_stdout_file=None,
|
||||
manager_stderr_file=None,
|
||||
objstore_memory=None,
|
||||
cleanup=True,
|
||||
plasma_directory=None,
|
||||
huge_pages=False,
|
||||
use_raylet=False):
|
||||
"""This method starts an object store process.
|
||||
|
||||
Args:
|
||||
@@ -1013,24 +1027,24 @@ def start_objstore(node_ip_address, redis_address,
|
||||
if object_manager_port is not None:
|
||||
(plasma_manager_name, p2,
|
||||
plasma_manager_port) = ray.plasma.start_plasma_manager(
|
||||
plasma_store_name,
|
||||
redis_address,
|
||||
plasma_manager_port=object_manager_port,
|
||||
node_ip_address=node_ip_address,
|
||||
num_retries=1,
|
||||
run_profiler=RUN_PLASMA_MANAGER_PROFILER,
|
||||
stdout_file=manager_stdout_file,
|
||||
stderr_file=manager_stderr_file)
|
||||
plasma_store_name,
|
||||
redis_address,
|
||||
plasma_manager_port=object_manager_port,
|
||||
node_ip_address=node_ip_address,
|
||||
num_retries=1,
|
||||
run_profiler=RUN_PLASMA_MANAGER_PROFILER,
|
||||
stdout_file=manager_stdout_file,
|
||||
stderr_file=manager_stderr_file)
|
||||
assert plasma_manager_port == object_manager_port
|
||||
else:
|
||||
(plasma_manager_name, p2,
|
||||
plasma_manager_port) = ray.plasma.start_plasma_manager(
|
||||
plasma_store_name,
|
||||
redis_address,
|
||||
node_ip_address=node_ip_address,
|
||||
run_profiler=RUN_PLASMA_MANAGER_PROFILER,
|
||||
stdout_file=manager_stdout_file,
|
||||
stderr_file=manager_stderr_file)
|
||||
plasma_store_name,
|
||||
redis_address,
|
||||
node_ip_address=node_ip_address,
|
||||
run_profiler=RUN_PLASMA_MANAGER_PROFILER,
|
||||
stdout_file=manager_stdout_file,
|
||||
stderr_file=manager_stderr_file)
|
||||
else:
|
||||
plasma_manager_port = None
|
||||
plasma_manager_name = None
|
||||
@@ -1049,9 +1063,15 @@ def start_objstore(node_ip_address, redis_address,
|
||||
plasma_manager_port)
|
||||
|
||||
|
||||
def start_worker(node_ip_address, object_store_name, object_store_manager_name,
|
||||
local_scheduler_name, redis_address, worker_path,
|
||||
stdout_file=None, stderr_file=None, cleanup=True):
|
||||
def start_worker(node_ip_address,
|
||||
object_store_name,
|
||||
object_store_manager_name,
|
||||
local_scheduler_name,
|
||||
redis_address,
|
||||
worker_path,
|
||||
stdout_file=None,
|
||||
stderr_file=None,
|
||||
cleanup=True):
|
||||
"""This method starts a worker process.
|
||||
|
||||
Args:
|
||||
@@ -1072,14 +1092,14 @@ def start_worker(node_ip_address, object_store_name, object_store_manager_name,
|
||||
Python process that imported services exits. This is True by
|
||||
default.
|
||||
"""
|
||||
command = [sys.executable,
|
||||
"-u",
|
||||
worker_path,
|
||||
"--node-ip-address=" + node_ip_address,
|
||||
"--object-store-name=" + object_store_name,
|
||||
"--object-store-manager-name=" + object_store_manager_name,
|
||||
"--local-scheduler-name=" + local_scheduler_name,
|
||||
"--redis-address=" + str(redis_address)]
|
||||
command = [
|
||||
sys.executable, "-u", worker_path,
|
||||
"--node-ip-address=" + node_ip_address,
|
||||
"--object-store-name=" + object_store_name,
|
||||
"--object-store-manager-name=" + object_store_manager_name,
|
||||
"--local-scheduler-name=" + local_scheduler_name,
|
||||
"--redis-address=" + str(redis_address)
|
||||
]
|
||||
p = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
|
||||
if cleanup:
|
||||
all_processes[PROCESS_TYPE_WORKER].append(p)
|
||||
@@ -1087,8 +1107,12 @@ def start_worker(node_ip_address, object_store_name, object_store_manager_name,
|
||||
[stdout_file, stderr_file])
|
||||
|
||||
|
||||
def start_monitor(redis_address, node_ip_address, stdout_file=None,
|
||||
stderr_file=None, cleanup=True, autoscaling_config=None):
|
||||
def start_monitor(redis_address,
|
||||
node_ip_address,
|
||||
stdout_file=None,
|
||||
stderr_file=None,
|
||||
cleanup=True,
|
||||
autoscaling_config=None):
|
||||
"""Run a process to monitor the other processes.
|
||||
|
||||
Args:
|
||||
@@ -1105,12 +1129,12 @@ def start_monitor(redis_address, node_ip_address, stdout_file=None,
|
||||
default.
|
||||
autoscaling_config: path to autoscaling config file.
|
||||
"""
|
||||
monitor_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
||||
"monitor.py")
|
||||
command = [sys.executable,
|
||||
"-u",
|
||||
monitor_path,
|
||||
"--redis-address=" + str(redis_address)]
|
||||
monitor_path = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), "monitor.py")
|
||||
command = [
|
||||
sys.executable, "-u", monitor_path,
|
||||
"--redis-address=" + str(redis_address)
|
||||
]
|
||||
if autoscaling_config:
|
||||
command.append("--autoscaling-config=" + str(autoscaling_config))
|
||||
p = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
|
||||
@@ -1120,8 +1144,10 @@ def start_monitor(redis_address, node_ip_address, stdout_file=None,
|
||||
[stdout_file, stderr_file])
|
||||
|
||||
|
||||
def start_raylet_monitor(redis_address, stdout_file=None,
|
||||
stderr_file=None, cleanup=True):
|
||||
def start_raylet_monitor(redis_address,
|
||||
stdout_file=None,
|
||||
stderr_file=None,
|
||||
cleanup=True):
|
||||
"""Run a process to monitor the other processes.
|
||||
|
||||
Args:
|
||||
@@ -1136,9 +1162,7 @@ def start_raylet_monitor(redis_address, stdout_file=None,
|
||||
default.
|
||||
"""
|
||||
gcs_ip_address, gcs_port = redis_address.split(":")
|
||||
command = [RAYLET_MONITOR_EXECUTABLE,
|
||||
gcs_ip_address,
|
||||
gcs_port]
|
||||
command = [RAYLET_MONITOR_EXECUTABLE, gcs_ip_address, gcs_port]
|
||||
p = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
|
||||
if cleanup:
|
||||
all_processes[PROCESS_TYPE_MONITOR].append(p)
|
||||
@@ -1238,16 +1262,17 @@ def start_ray_processes(address_info=None,
|
||||
workers_per_local_scheduler = []
|
||||
for resource_dict in resources:
|
||||
cpus = resource_dict.get("CPU")
|
||||
workers_per_local_scheduler.append(cpus if cpus is not None
|
||||
else psutil.cpu_count())
|
||||
workers_per_local_scheduler.append(cpus if cpus is not None else
|
||||
psutil.cpu_count())
|
||||
|
||||
if address_info is None:
|
||||
address_info = {}
|
||||
address_info["node_ip_address"] = node_ip_address
|
||||
|
||||
if worker_path is None:
|
||||
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
||||
"workers/default_worker.py")
|
||||
worker_path = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
"workers/default_worker.py")
|
||||
|
||||
# Start Redis if there isn't already an instance running. TODO(rkn): We are
|
||||
# suppressing the output of Redis because on Linux it prints a bunch of
|
||||
@@ -1257,7 +1282,8 @@ def start_ray_processes(address_info=None,
|
||||
redis_shards = address_info.get("redis_shards", [])
|
||||
if redis_address is None:
|
||||
redis_address, redis_shards = start_redis(
|
||||
node_ip_address, port=redis_port,
|
||||
node_ip_address,
|
||||
port=redis_port,
|
||||
redis_shard_ports=redis_shard_ports,
|
||||
num_redis_shards=num_redis_shards,
|
||||
redis_max_clients=redis_max_clients,
|
||||
@@ -1274,23 +1300,25 @@ def start_ray_processes(address_info=None,
|
||||
# Start monitoring the processes.
|
||||
monitor_stdout_file, monitor_stderr_file = new_log_files(
|
||||
"monitor", redirect_output)
|
||||
start_monitor(redis_address,
|
||||
node_ip_address,
|
||||
stdout_file=monitor_stdout_file,
|
||||
stderr_file=monitor_stderr_file,
|
||||
cleanup=cleanup,
|
||||
autoscaling_config=autoscaling_config)
|
||||
start_monitor(
|
||||
redis_address,
|
||||
node_ip_address,
|
||||
stdout_file=monitor_stdout_file,
|
||||
stderr_file=monitor_stderr_file,
|
||||
cleanup=cleanup,
|
||||
autoscaling_config=autoscaling_config)
|
||||
if use_raylet:
|
||||
start_raylet_monitor(redis_address,
|
||||
stdout_file=monitor_stdout_file,
|
||||
stderr_file=monitor_stderr_file,
|
||||
cleanup=cleanup)
|
||||
start_raylet_monitor(
|
||||
redis_address,
|
||||
stdout_file=monitor_stdout_file,
|
||||
stderr_file=monitor_stderr_file,
|
||||
cleanup=cleanup)
|
||||
|
||||
if redis_shards == []:
|
||||
# Get redis shards from primary redis instance.
|
||||
redis_ip_address, redis_port = redis_address.split(":")
|
||||
redis_client = redis.StrictRedis(host=redis_ip_address,
|
||||
port=redis_port)
|
||||
redis_client = redis.StrictRedis(
|
||||
host=redis_ip_address, port=redis_port)
|
||||
redis_shards = redis_client.lrange("RedisShards", start=0, end=-1)
|
||||
redis_shards = [shard.decode("ascii") for shard in redis_shards]
|
||||
address_info["redis_shards"] = redis_shards
|
||||
@@ -1299,21 +1327,23 @@ def start_ray_processes(address_info=None,
|
||||
if include_log_monitor:
|
||||
log_monitor_stdout_file, log_monitor_stderr_file = new_log_files(
|
||||
"log_monitor", redirect_output=True)
|
||||
start_log_monitor(redis_address,
|
||||
node_ip_address,
|
||||
stdout_file=log_monitor_stdout_file,
|
||||
stderr_file=log_monitor_stderr_file,
|
||||
cleanup=cleanup)
|
||||
start_log_monitor(
|
||||
redis_address,
|
||||
node_ip_address,
|
||||
stdout_file=log_monitor_stdout_file,
|
||||
stderr_file=log_monitor_stderr_file,
|
||||
cleanup=cleanup)
|
||||
|
||||
# Start the global scheduler, if necessary.
|
||||
if include_global_scheduler and not use_raylet:
|
||||
global_scheduler_stdout_file, global_scheduler_stderr_file = (
|
||||
new_log_files("global_scheduler", redirect_output))
|
||||
start_global_scheduler(redis_address,
|
||||
node_ip_address,
|
||||
stdout_file=global_scheduler_stdout_file,
|
||||
stderr_file=global_scheduler_stderr_file,
|
||||
cleanup=cleanup)
|
||||
start_global_scheduler(
|
||||
redis_address,
|
||||
node_ip_address,
|
||||
stdout_file=global_scheduler_stdout_file,
|
||||
stderr_file=global_scheduler_stderr_file,
|
||||
cleanup=cleanup)
|
||||
|
||||
# Initialize with existing services.
|
||||
if "object_store_addresses" not in address_info:
|
||||
@@ -1324,9 +1354,8 @@ def start_ray_processes(address_info=None,
|
||||
local_scheduler_socket_names = address_info["local_scheduler_socket_names"]
|
||||
|
||||
# Get the ports to use for the object managers if any are provided.
|
||||
object_manager_ports = (address_info["object_manager_ports"]
|
||||
if "object_manager_ports" in address_info
|
||||
else None)
|
||||
object_manager_ports = (address_info["object_manager_ports"] if
|
||||
"object_manager_ports" in address_info else None)
|
||||
if not isinstance(object_manager_ports, list):
|
||||
object_manager_ports = num_local_schedulers * [object_manager_ports]
|
||||
assert len(object_manager_ports) == num_local_schedulers
|
||||
@@ -1347,7 +1376,8 @@ def start_ray_processes(address_info=None,
|
||||
manager_stdout_file=plasma_manager_stdout_file,
|
||||
manager_stderr_file=plasma_manager_stderr_file,
|
||||
objstore_memory=object_store_memory,
|
||||
cleanup=cleanup, plasma_directory=plasma_directory,
|
||||
cleanup=cleanup,
|
||||
plasma_directory=plasma_directory,
|
||||
huge_pages=huge_pages,
|
||||
use_raylet=use_raylet)
|
||||
object_store_addresses.append(object_store_address)
|
||||
@@ -1355,8 +1385,8 @@ def start_ray_processes(address_info=None,
|
||||
|
||||
# Start any local schedulers that do not yet exist.
|
||||
if not use_raylet:
|
||||
for i in range(len(local_scheduler_socket_names),
|
||||
num_local_schedulers):
|
||||
for i in range(
|
||||
len(local_scheduler_socket_names), num_local_schedulers):
|
||||
# Connect the local scheduler to the object store at the same
|
||||
# index.
|
||||
object_store_address = object_store_addresses[i]
|
||||
@@ -1374,8 +1404,9 @@ def start_ray_processes(address_info=None,
|
||||
# redirect the worker output, then we cannot redirect the local
|
||||
# scheduler output.
|
||||
local_scheduler_stdout_file, local_scheduler_stderr_file = (
|
||||
new_log_files("local_scheduler_{}".format(i),
|
||||
redirect_output=redirect_worker_output))
|
||||
new_log_files(
|
||||
"local_scheduler_{}".format(i),
|
||||
redirect_output=redirect_worker_output))
|
||||
local_scheduler_name = start_local_scheduler(
|
||||
redis_address,
|
||||
node_ip_address,
|
||||
@@ -1398,17 +1429,18 @@ def start_ray_processes(address_info=None,
|
||||
else:
|
||||
# Start the raylet. TODO(rkn): Modify this to allow starting
|
||||
# multiple raylets on the same machine.
|
||||
raylet_stdout_file, raylet_stderr_file = (
|
||||
new_log_files("raylet_{}".format(i),
|
||||
redirect_output=redirect_output))
|
||||
address_info["raylet_socket_names"] = [start_raylet(
|
||||
redis_address,
|
||||
node_ip_address,
|
||||
object_store_addresses[i].name,
|
||||
worker_path,
|
||||
stdout_file=None,
|
||||
stderr_file=None,
|
||||
cleanup=cleanup)]
|
||||
raylet_stdout_file, raylet_stderr_file = (new_log_files(
|
||||
"raylet_{}".format(i), redirect_output=redirect_output))
|
||||
address_info["raylet_socket_names"] = [
|
||||
start_raylet(
|
||||
redis_address,
|
||||
node_ip_address,
|
||||
object_store_addresses[i].name,
|
||||
worker_path,
|
||||
stdout_file=None,
|
||||
stderr_file=None,
|
||||
cleanup=cleanup)
|
||||
]
|
||||
|
||||
if not use_raylet:
|
||||
# Start any workers that the local scheduler has not already started.
|
||||
@@ -1419,28 +1451,30 @@ def start_ray_processes(address_info=None,
|
||||
for j in range(num_local_scheduler_workers):
|
||||
worker_stdout_file, worker_stderr_file = new_log_files(
|
||||
"worker_{}_{}".format(i, j), redirect_output)
|
||||
start_worker(node_ip_address,
|
||||
object_store_address.name,
|
||||
object_store_address.manager_name,
|
||||
local_scheduler_name,
|
||||
redis_address,
|
||||
worker_path,
|
||||
stdout_file=worker_stdout_file,
|
||||
stderr_file=worker_stderr_file,
|
||||
cleanup=cleanup)
|
||||
start_worker(
|
||||
node_ip_address,
|
||||
object_store_address.name,
|
||||
object_store_address.manager_name,
|
||||
local_scheduler_name,
|
||||
redis_address,
|
||||
worker_path,
|
||||
stdout_file=worker_stdout_file,
|
||||
stderr_file=worker_stderr_file,
|
||||
cleanup=cleanup)
|
||||
workers_per_local_scheduler[i] -= 1
|
||||
|
||||
# Make sure that we've started all the workers.
|
||||
assert(sum(workers_per_local_scheduler) == 0)
|
||||
assert (sum(workers_per_local_scheduler) == 0)
|
||||
|
||||
# Try to start the web UI.
|
||||
if include_webui:
|
||||
ui_stdout_file, ui_stderr_file = new_log_files(
|
||||
"webui", redirect_output=True)
|
||||
address_info["webui_url"] = start_ui(redis_address,
|
||||
stdout_file=ui_stdout_file,
|
||||
stderr_file=ui_stderr_file,
|
||||
cleanup=cleanup)
|
||||
address_info["webui_url"] = start_ui(
|
||||
redis_address,
|
||||
stdout_file=ui_stdout_file,
|
||||
stderr_file=ui_stderr_file,
|
||||
cleanup=cleanup)
|
||||
else:
|
||||
address_info["webui_url"] = ""
|
||||
# Return the addresses of the relevant processes.
|
||||
@@ -1500,21 +1534,24 @@ def start_ray_node(node_ip_address,
|
||||
A dictionary of the address information for the processes that were
|
||||
started.
|
||||
"""
|
||||
address_info = {"redis_address": redis_address,
|
||||
"object_manager_ports": object_manager_ports}
|
||||
return start_ray_processes(address_info=address_info,
|
||||
node_ip_address=node_ip_address,
|
||||
num_workers=num_workers,
|
||||
num_local_schedulers=num_local_schedulers,
|
||||
object_store_memory=object_store_memory,
|
||||
worker_path=worker_path,
|
||||
include_log_monitor=True,
|
||||
cleanup=cleanup,
|
||||
redirect_worker_output=redirect_worker_output,
|
||||
redirect_output=redirect_output,
|
||||
resources=resources,
|
||||
plasma_directory=plasma_directory,
|
||||
huge_pages=huge_pages)
|
||||
address_info = {
|
||||
"redis_address": redis_address,
|
||||
"object_manager_ports": object_manager_ports
|
||||
}
|
||||
return start_ray_processes(
|
||||
address_info=address_info,
|
||||
node_ip_address=node_ip_address,
|
||||
num_workers=num_workers,
|
||||
num_local_schedulers=num_local_schedulers,
|
||||
object_store_memory=object_store_memory,
|
||||
worker_path=worker_path,
|
||||
include_log_monitor=True,
|
||||
cleanup=cleanup,
|
||||
redirect_worker_output=redirect_worker_output,
|
||||
redirect_output=redirect_output,
|
||||
resources=resources,
|
||||
plasma_directory=plasma_directory,
|
||||
huge_pages=huge_pages)
|
||||
|
||||
|
||||
def start_ray_head(address_info=None,
|
||||
|
||||
+19
-17
@@ -7,11 +7,10 @@ import funcsigs
|
||||
|
||||
from ray.utils import is_cython
|
||||
|
||||
FunctionSignature = namedtuple("FunctionSignature", ["arg_names",
|
||||
"arg_defaults",
|
||||
"arg_is_positionals",
|
||||
"keyword_names",
|
||||
"function_name"])
|
||||
FunctionSignature = namedtuple("FunctionSignature", [
|
||||
"arg_names", "arg_defaults", "arg_is_positionals", "keyword_names",
|
||||
"function_name"
|
||||
])
|
||||
"""This class is used to represent a function signature.
|
||||
|
||||
Attributes:
|
||||
@@ -49,13 +48,16 @@ def get_signature_params(func):
|
||||
# The first condition for Cython functions, the latter for Cython instance
|
||||
# methods
|
||||
if is_cython(func):
|
||||
attrs = ["__code__", "__annotations__",
|
||||
"__defaults__", "__kwdefaults__"]
|
||||
attrs = [
|
||||
"__code__", "__annotations__", "__defaults__", "__kwdefaults__"
|
||||
]
|
||||
|
||||
if all([hasattr(func, attr) for attr in attrs]):
|
||||
original_func = func
|
||||
|
||||
def func(): return
|
||||
def func():
|
||||
return
|
||||
|
||||
for attr in attrs:
|
||||
setattr(func, attr, getattr(original_func, attr))
|
||||
else:
|
||||
@@ -130,8 +132,8 @@ def extract_signature(func, ignore_first=False):
|
||||
if ignore_first:
|
||||
if len(sig_params) == 0:
|
||||
raise Exception("Methods must take a 'self' argument, but the "
|
||||
"method '{}' does not have one."
|
||||
.format(func.__name__))
|
||||
"method '{}' does not have one.".format(
|
||||
func.__name__))
|
||||
sig_params = sig_params[1:]
|
||||
|
||||
# Extract the names of the keyword arguments.
|
||||
@@ -183,8 +185,8 @@ def extend_args(function_signature, args, kwargs):
|
||||
for keyword_name in kwargs:
|
||||
if keyword_name not in keyword_names:
|
||||
raise Exception("The name '{}' is not a valid keyword argument "
|
||||
"for the function '{}'."
|
||||
.format(keyword_name, function_name))
|
||||
"for the function '{}'.".format(
|
||||
keyword_name, function_name))
|
||||
|
||||
# Fill in the remaining arguments.
|
||||
zipped_info = list(zip(arg_names, arg_defaults,
|
||||
@@ -201,12 +203,12 @@ def extend_args(function_signature, args, kwargs):
|
||||
# can be omitted.
|
||||
if not is_positional:
|
||||
raise Exception("No value was provided for the argument "
|
||||
"'{}' for the function '{}'."
|
||||
.format(keyword_name, function_name))
|
||||
"'{}' for the function '{}'.".format(
|
||||
keyword_name, function_name))
|
||||
|
||||
too_many_arguments = (len(args) > len(arg_names) and
|
||||
(len(arg_is_positionals) == 0 or
|
||||
not arg_is_positionals[-1]))
|
||||
too_many_arguments = (len(args) > len(arg_names)
|
||||
and (len(arg_is_positionals) == 0
|
||||
or not arg_is_positionals[-1]))
|
||||
if too_many_arguments:
|
||||
raise Exception("Too many arguments were passed to the function '{}'"
|
||||
.format(function_name))
|
||||
|
||||
@@ -13,6 +13,7 @@ import numpy as np
|
||||
def handle_int(a, b):
|
||||
return a + 1, b + 1
|
||||
|
||||
|
||||
# Test timing
|
||||
|
||||
|
||||
@@ -25,6 +26,7 @@ def empty_function():
|
||||
def trivial_function():
|
||||
return 1
|
||||
|
||||
|
||||
# Test keyword arguments
|
||||
|
||||
|
||||
@@ -42,6 +44,7 @@ def keyword_fct2(a="hello", b="world"):
|
||||
def keyword_fct3(a, b, c="hello", d="world"):
|
||||
return "{} {} {} {}".format(a, b, c, d)
|
||||
|
||||
|
||||
# Test variable numbers of arguments
|
||||
|
||||
|
||||
@@ -56,17 +59,21 @@ def varargs_fct2(a, *b):
|
||||
|
||||
|
||||
try:
|
||||
|
||||
@ray.remote
|
||||
def kwargs_throw_exception(**c):
|
||||
return ()
|
||||
|
||||
kwargs_exception_thrown = False
|
||||
except Exception:
|
||||
kwargs_exception_thrown = True
|
||||
|
||||
try:
|
||||
|
||||
@ray.remote
|
||||
def varargs_and_kwargs_throw_exception(a, b="hi", *c):
|
||||
return "{} {} {}".format(a, b, c)
|
||||
|
||||
varargs_and_kwargs_exception_thrown = False
|
||||
except Exception:
|
||||
varargs_and_kwargs_exception_thrown = True
|
||||
@@ -88,6 +95,7 @@ def throw_exception_fct2():
|
||||
def throw_exception_fct3(x):
|
||||
raise Exception("Test function 3 intentionally failed.")
|
||||
|
||||
|
||||
# test Python mode
|
||||
|
||||
|
||||
@@ -101,6 +109,7 @@ def python_mode_g(x):
|
||||
x[0] = 1
|
||||
return x
|
||||
|
||||
|
||||
# test no return values
|
||||
|
||||
|
||||
|
||||
@@ -48,8 +48,8 @@ def _wait_for_nodes_to_join(num_nodes, timeout=20):
|
||||
if num_ready_nodes > num_nodes:
|
||||
# Too many nodes have joined. Something must be wrong.
|
||||
raise Exception("{} nodes have joined the cluster, but we were "
|
||||
"expecting {} nodes.".format(num_ready_nodes,
|
||||
num_nodes))
|
||||
"expecting {} nodes.".format(
|
||||
num_ready_nodes, num_nodes))
|
||||
time.sleep(0.1)
|
||||
|
||||
# If we get here then we timed out.
|
||||
|
||||
@@ -9,14 +9,7 @@ from ray.tune.result import TrainingResult
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.variant_generator import grid_search
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Trainable",
|
||||
"TrainingResult",
|
||||
"TuneError",
|
||||
"grid_search",
|
||||
"register_env",
|
||||
"register_trainable",
|
||||
"run_experiments",
|
||||
"Experiment"
|
||||
"Trainable", "TrainingResult", "TuneError", "grid_search", "register_env",
|
||||
"register_trainable", "run_experiments", "Experiment"
|
||||
]
|
||||
|
||||
@@ -35,10 +35,13 @@ class AsyncHyperBandScheduler(FIFOScheduler):
|
||||
halving rate, specified by the reduction factor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, time_attr='training_iteration',
|
||||
reward_attr='episode_reward_mean', max_t=100,
|
||||
grace_period=10, reduction_factor=3, brackets=3):
|
||||
def __init__(self,
|
||||
time_attr='training_iteration',
|
||||
reward_attr='episode_reward_mean',
|
||||
max_t=100,
|
||||
grace_period=10,
|
||||
reduction_factor=3,
|
||||
brackets=3):
|
||||
assert max_t > 0, "Max (time_attr) not valid!"
|
||||
assert max_t >= grace_period, "grace_period must be <= max_t!"
|
||||
assert grace_period > 0, "grace_period must be positive!"
|
||||
@@ -51,8 +54,10 @@ class AsyncHyperBandScheduler(FIFOScheduler):
|
||||
self._trial_info = {} # Stores Trial -> Bracket
|
||||
|
||||
# Tracks state for new trial add
|
||||
self._brackets = [_Bracket(
|
||||
grace_period, max_t, reduction_factor, s) for s in range(brackets)]
|
||||
self._brackets = [
|
||||
_Bracket(grace_period, max_t, reduction_factor, s)
|
||||
for s in range(brackets)
|
||||
]
|
||||
self._counter = 0 # for
|
||||
self._num_stopped = 0
|
||||
self._reward_attr = reward_attr
|
||||
@@ -60,7 +65,7 @@ class AsyncHyperBandScheduler(FIFOScheduler):
|
||||
|
||||
def on_trial_add(self, trial_runner, trial):
|
||||
sizes = np.array([len(b._rungs) for b in self._brackets])
|
||||
probs = np.e ** (sizes - sizes.max())
|
||||
probs = np.e**(sizes - sizes.max())
|
||||
normalized = probs / probs.sum()
|
||||
idx = np.random.choice(len(self._brackets), p=normalized)
|
||||
self._trial_info[trial.trial_id] = self._brackets[idx]
|
||||
@@ -71,28 +76,23 @@ class AsyncHyperBandScheduler(FIFOScheduler):
|
||||
action = TrialScheduler.STOP
|
||||
else:
|
||||
bracket = self._trial_info[trial.trial_id]
|
||||
action = bracket.on_result(
|
||||
trial,
|
||||
getattr(result, self._time_attr),
|
||||
getattr(result, self._reward_attr))
|
||||
action = bracket.on_result(trial, getattr(result, self._time_attr),
|
||||
getattr(result, self._reward_attr))
|
||||
if action == TrialScheduler.STOP:
|
||||
self._num_stopped += 1
|
||||
return action
|
||||
|
||||
def on_trial_complete(self, trial_runner, trial, result):
|
||||
bracket = self._trial_info[trial.trial_id]
|
||||
bracket.on_result(
|
||||
trial,
|
||||
getattr(result, self._time_attr),
|
||||
getattr(result, self._reward_attr))
|
||||
bracket.on_result(trial, getattr(result, self._time_attr),
|
||||
getattr(result, self._reward_attr))
|
||||
del self._trial_info[trial.trial_id]
|
||||
|
||||
def on_trial_remove(self, trial_runner, trial):
|
||||
del self._trial_info[trial.trial_id]
|
||||
|
||||
def debug_string(self):
|
||||
out = "Using AsyncHyperBand: num_stopped={}".format(
|
||||
self._num_stopped)
|
||||
out = "Using AsyncHyperBand: num_stopped={}".format(self._num_stopped)
|
||||
out += "\n" + "\n".join([b.debug_str() for b in self._brackets])
|
||||
return out
|
||||
|
||||
@@ -111,6 +111,7 @@ class _Bracket():
|
||||
>>> b.on_result(trial3, 1, 1) # STOP
|
||||
>>> b.cutoff(b._rungs[0][1]) == 2.0
|
||||
"""
|
||||
|
||||
def __init__(self, min_t, max_t, reduction_factor, s):
|
||||
self.rf = reduction_factor
|
||||
MAX_RUNGS = int(np.log(max_t / min_t) / np.log(self.rf) - s + 1)
|
||||
@@ -140,9 +141,10 @@ class _Bracket():
|
||||
return action
|
||||
|
||||
def debug_str(self):
|
||||
iters = " | ".join(
|
||||
["Iter {:.3f}: {}".format(milestone, self.cutoff(recorded))
|
||||
for milestone, recorded in self._rungs])
|
||||
iters = " | ".join([
|
||||
"Iter {:.3f}: {}".format(milestone, self.cutoff(recorded))
|
||||
for milestone, recorded in self._rungs
|
||||
])
|
||||
return "Bracket: " + iters
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
@@ -24,8 +23,8 @@ def json_to_resources(data):
|
||||
"Unknown resource type {}, must be one of {}".format(
|
||||
k, Resources._fields))
|
||||
return Resources(
|
||||
data.get("cpu", 1), data.get("gpu", 0),
|
||||
data.get("extra_cpu", 0), data.get("extra_gpu", 0))
|
||||
data.get("cpu", 1), data.get("gpu", 0), data.get("extra_cpu", 0),
|
||||
data.get("extra_gpu", 0))
|
||||
|
||||
|
||||
def resources_to_json(resources):
|
||||
@@ -50,59 +49,85 @@ def make_parser(**kwargs):
|
||||
|
||||
# Note: keep this in sync with rllib/train.py
|
||||
parser.add_argument(
|
||||
"--run", default=None, type=str,
|
||||
"--run",
|
||||
default=None,
|
||||
type=str,
|
||||
help="The algorithm or model to train. This may refer to the name "
|
||||
"of a built-on algorithm (e.g. RLLib's DQN or PPO), or a "
|
||||
"user-defined trainable function or class registered in the "
|
||||
"tune registry.")
|
||||
parser.add_argument(
|
||||
"--stop", default="{}", type=json.loads,
|
||||
"--stop",
|
||||
default="{}",
|
||||
type=json.loads,
|
||||
help="The stopping criteria, specified in JSON. The keys may be any "
|
||||
"field in TrainingResult, e.g. "
|
||||
"'{\"time_total_s\": 600, \"timesteps_total\": 100000}' to stop "
|
||||
"after 600 seconds or 100k timesteps, whichever is reached first.")
|
||||
parser.add_argument(
|
||||
"--config", default="{}", type=json.loads,
|
||||
"--config",
|
||||
default="{}",
|
||||
type=json.loads,
|
||||
help="Algorithm-specific configuration (e.g. env, hyperparams), "
|
||||
"specified in JSON.")
|
||||
parser.add_argument(
|
||||
"--resources", help="Deprecated, use --trial-resources.",
|
||||
type=lambda v: _tune_error(
|
||||
"The `resources` argument is no longer supported. "
|
||||
"Use `trial_resources` or --trial-resources instead."))
|
||||
"--resources",
|
||||
help="Deprecated, use --trial-resources.",
|
||||
type=lambda v: _tune_error("The `resources` argument is no longer "
|
||||
"supported. Use `trial_resources` or "
|
||||
"--trial-resources instead."))
|
||||
parser.add_argument(
|
||||
"--trial-resources", default='{"cpu": 1}', type=json_to_resources,
|
||||
"--trial-resources",
|
||||
default='{"cpu": 1}',
|
||||
type=json_to_resources,
|
||||
help="Machine resources to allocate per trial, e.g. "
|
||||
"'{\"cpu\": 64, \"gpu\": 8}'. Note that GPUs will not be assigned "
|
||||
"unless you specify them here.")
|
||||
parser.add_argument(
|
||||
"--repeat", default=1, type=int,
|
||||
"--repeat",
|
||||
default=1,
|
||||
type=int,
|
||||
help="Number of times to repeat each trial.")
|
||||
parser.add_argument(
|
||||
"--local-dir", default=DEFAULT_RESULTS_DIR, type=str,
|
||||
"--local-dir",
|
||||
default=DEFAULT_RESULTS_DIR,
|
||||
type=str,
|
||||
help="Local dir to save training results to. Defaults to '{}'.".format(
|
||||
DEFAULT_RESULTS_DIR))
|
||||
parser.add_argument(
|
||||
"--upload-dir", default="", type=str,
|
||||
"--upload-dir",
|
||||
default="",
|
||||
type=str,
|
||||
help="Optional URI to sync training results to (e.g. s3://bucket).")
|
||||
parser.add_argument(
|
||||
"--checkpoint-freq", default=0, type=int,
|
||||
"--checkpoint-freq",
|
||||
default=0,
|
||||
type=int,
|
||||
help="How many training iterations between checkpoints. "
|
||||
"A value of 0 (default) disables checkpointing.")
|
||||
parser.add_argument(
|
||||
"--max-failures", default=3, type=int,
|
||||
"--max-failures",
|
||||
default=3,
|
||||
type=int,
|
||||
help="Try to recover a trial from its last checkpoint at least this "
|
||||
"many times. Only applies if checkpointing is enabled.")
|
||||
parser.add_argument(
|
||||
"--scheduler", default="FIFO", type=str,
|
||||
"--scheduler",
|
||||
default="FIFO",
|
||||
type=str,
|
||||
help="FIFO (default), MedianStopping, AsyncHyperBand,"
|
||||
"HyperBand, or HyperOpt.")
|
||||
"HyperBand, or HyperOpt.")
|
||||
parser.add_argument(
|
||||
"--scheduler-config", default="{}", type=json.loads,
|
||||
"--scheduler-config",
|
||||
default="{}",
|
||||
type=json.loads,
|
||||
help="Config options to pass to the scheduler.")
|
||||
|
||||
# Note: this currently only makes sense when running a single trial
|
||||
parser.add_argument("--restore", default=None, type=str,
|
||||
help="If specified, restore from this checkpoint.")
|
||||
parser.add_argument(
|
||||
"--restore",
|
||||
default=None,
|
||||
type=str,
|
||||
help="If specified, restore from this checkpoint.")
|
||||
|
||||
return parser
|
||||
|
||||
@@ -60,18 +60,27 @@ if __name__ == "__main__":
|
||||
# `episode_reward_mean` as the
|
||||
# objective and `timesteps_total` as the time unit.
|
||||
ahb = AsyncHyperBandScheduler(
|
||||
time_attr="timesteps_total", reward_attr="episode_reward_mean",
|
||||
grace_period=5, max_t=100)
|
||||
time_attr="timesteps_total",
|
||||
reward_attr="episode_reward_mean",
|
||||
grace_period=5,
|
||||
max_t=100)
|
||||
|
||||
run_experiments({
|
||||
"asynchyperband_test": {
|
||||
"run": "my_class",
|
||||
"stop": {"training_iteration": 1 if args.smoke_test else 99999},
|
||||
"repeat": 20,
|
||||
"trial_resources": {"cpu": 1, "gpu": 0},
|
||||
"config": {
|
||||
"width": lambda spec: 10 + int(90 * random.random()),
|
||||
"height": lambda spec: int(100 * random.random()),
|
||||
},
|
||||
}
|
||||
}, scheduler=ahb)
|
||||
run_experiments(
|
||||
{
|
||||
"asynchyperband_test": {
|
||||
"run": "my_class",
|
||||
"stop": {
|
||||
"training_iteration": 1 if args.smoke_test else 99999
|
||||
},
|
||||
"repeat": 20,
|
||||
"trial_resources": {
|
||||
"cpu": 1,
|
||||
"gpu": 0
|
||||
},
|
||||
"config": {
|
||||
"width": lambda spec: 10 + int(90 * random.random()),
|
||||
"height": lambda spec: int(100 * random.random()),
|
||||
},
|
||||
}
|
||||
},
|
||||
scheduler=ahb)
|
||||
|
||||
@@ -59,7 +59,8 @@ if __name__ == "__main__":
|
||||
# Hyperband early stopping, configured with `episode_reward_mean` as the
|
||||
# objective and `timesteps_total` as the time unit.
|
||||
hyperband = HyperBandScheduler(
|
||||
time_attr="timesteps_total", reward_attr="episode_reward_mean",
|
||||
time_attr="timesteps_total",
|
||||
reward_attr="episode_reward_mean",
|
||||
max_t=100)
|
||||
|
||||
exp = Experiment(
|
||||
|
||||
@@ -12,8 +12,8 @@ def easy_objective(config, reporter):
|
||||
time.sleep(0.2)
|
||||
reporter(
|
||||
timesteps_total=1,
|
||||
episode_reward_mean=-((config["height"]-14) ** 2
|
||||
+ abs(config["width"]-3)))
|
||||
episode_reward_mean=-(
|
||||
(config["height"] - 14)**2 + abs(config["width"] - 3)))
|
||||
time.sleep(0.2)
|
||||
|
||||
|
||||
@@ -34,12 +34,18 @@ if __name__ == '__main__':
|
||||
'height': hp.uniform('height', -100, 100),
|
||||
}
|
||||
|
||||
config = {"my_exp": {
|
||||
config = {
|
||||
"my_exp": {
|
||||
"run": "exp",
|
||||
"repeat": 5 if args.smoke_test else 1000,
|
||||
"stop": {"training_iteration": 1},
|
||||
"stop": {
|
||||
"training_iteration": 1
|
||||
},
|
||||
"config": {
|
||||
"space": space}}}
|
||||
"space": space
|
||||
}
|
||||
}
|
||||
}
|
||||
hpo_sched = HyperOptScheduler()
|
||||
|
||||
run_experiments(config, verbose=False, scheduler=hpo_sched)
|
||||
|
||||
@@ -42,8 +42,11 @@ class MyTrainableClass(Trainable):
|
||||
def _save(self, checkpoint_dir):
|
||||
path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
with open(path, "w") as f:
|
||||
f.write(json.dumps(
|
||||
{"timestep": self.timestep, "value": self.current_value}))
|
||||
f.write(
|
||||
json.dumps({
|
||||
"timestep": self.timestep,
|
||||
"value": self.current_value
|
||||
}))
|
||||
return path
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
@@ -63,7 +66,8 @@ if __name__ == "__main__":
|
||||
ray.init()
|
||||
|
||||
pbt = PopulationBasedTraining(
|
||||
time_attr="training_iteration", reward_attr="episode_reward_mean",
|
||||
time_attr="training_iteration",
|
||||
reward_attr="episode_reward_mean",
|
||||
perturbation_interval=10,
|
||||
hyperparam_mutations={
|
||||
# Allow for scaling-based perturbations, with a uniform backing
|
||||
@@ -74,15 +78,23 @@ if __name__ == "__main__":
|
||||
})
|
||||
|
||||
# Try to find the best factor 1 and factor 2
|
||||
run_experiments({
|
||||
"pbt_test": {
|
||||
"run": "my_class",
|
||||
"stop": {"training_iteration": 2 if args.smoke_test else 99999},
|
||||
"repeat": 10,
|
||||
"trial_resources": {"cpu": 1, "gpu": 0},
|
||||
"config": {
|
||||
"factor_1": 4.0,
|
||||
"factor_2": 1.0,
|
||||
},
|
||||
}
|
||||
}, scheduler=pbt, verbose=False)
|
||||
run_experiments(
|
||||
{
|
||||
"pbt_test": {
|
||||
"run": "my_class",
|
||||
"stop": {
|
||||
"training_iteration": 2 if args.smoke_test else 99999
|
||||
},
|
||||
"repeat": 10,
|
||||
"trial_resources": {
|
||||
"cpu": 1,
|
||||
"gpu": 0
|
||||
},
|
||||
"config": {
|
||||
"factor_1": 4.0,
|
||||
"factor_2": 1.0,
|
||||
},
|
||||
}
|
||||
},
|
||||
scheduler=pbt,
|
||||
verbose=False)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""Example of using PBT with RLlib.
|
||||
|
||||
Note that this requires a cluster with at least 8 GPUs in order for all trials
|
||||
@@ -30,7 +29,8 @@ if __name__ == "__main__":
|
||||
return config
|
||||
|
||||
pbt = PopulationBasedTraining(
|
||||
time_attr="time_total_s", reward_attr="episode_reward_mean",
|
||||
time_attr="time_total_s",
|
||||
reward_attr="episode_reward_mean",
|
||||
perturbation_interval=120,
|
||||
resample_probability=0.25,
|
||||
# Specifies the mutations of these hyperparams
|
||||
@@ -45,26 +45,40 @@ if __name__ == "__main__":
|
||||
custom_explore_fn=explore)
|
||||
|
||||
ray.init()
|
||||
run_experiments({
|
||||
"pbt_humanoid_test": {
|
||||
"run": "PPO",
|
||||
"env": "Humanoid-v1",
|
||||
"repeat": 8,
|
||||
"trial_resources": {"cpu": 4, "gpu": 1},
|
||||
"config": {
|
||||
"kl_coeff": 1.0,
|
||||
"num_workers": 8,
|
||||
"devices": ["/gpu:0"],
|
||||
"model": {"free_log_std": True},
|
||||
# These params are tuned from a fixed starting value.
|
||||
"lambda": 0.95,
|
||||
"clip_param": 0.2,
|
||||
"sgd_stepsize": 1e-4,
|
||||
# These params start off randomly drawn from a set.
|
||||
"num_sgd_iter": lambda spec: random.choice([10, 20, 30]),
|
||||
"sgd_batchsize": lambda spec: random.choice([128, 512, 2048]),
|
||||
"timesteps_per_batch":
|
||||
run_experiments(
|
||||
{
|
||||
"pbt_humanoid_test": {
|
||||
"run": "PPO",
|
||||
"env": "Humanoid-v1",
|
||||
"repeat": 8,
|
||||
"trial_resources": {
|
||||
"cpu": 4,
|
||||
"gpu": 1
|
||||
},
|
||||
"config": {
|
||||
"kl_coeff":
|
||||
1.0,
|
||||
"num_workers":
|
||||
8,
|
||||
"devices": ["/gpu:0"],
|
||||
"model": {
|
||||
"free_log_std": True
|
||||
},
|
||||
# These params are tuned from a fixed starting value.
|
||||
"lambda":
|
||||
0.95,
|
||||
"clip_param":
|
||||
0.2,
|
||||
"sgd_stepsize":
|
||||
1e-4,
|
||||
# These params start off randomly drawn from a set.
|
||||
"num_sgd_iter":
|
||||
lambda spec: random.choice([10, 20, 30]),
|
||||
"sgd_batchsize":
|
||||
lambda spec: random.choice([128, 512, 2048]),
|
||||
"timesteps_per_batch":
|
||||
lambda spec: random.choice([10000, 20000, 40000])
|
||||
},
|
||||
},
|
||||
},
|
||||
}, scheduler=pbt)
|
||||
scheduler=pbt)
|
||||
|
||||
@@ -29,12 +29,10 @@ from ray.tune import Trainable
|
||||
from ray.tune import TrainingResult
|
||||
from ray.tune.pbt import PopulationBasedTraining
|
||||
|
||||
|
||||
num_classes = 10
|
||||
|
||||
|
||||
class Cifar10Model(Trainable):
|
||||
|
||||
def _read_data(self):
|
||||
# The data, split between train and test sets:
|
||||
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
|
||||
@@ -54,27 +52,51 @@ class Cifar10Model(Trainable):
|
||||
x = Input(shape=(32, 32, 3))
|
||||
y = x
|
||||
y = Convolution2D(
|
||||
filters=64, kernel_size=3, strides=1, padding="same",
|
||||
activation="relu", kernel_initializer="he_normal")(y)
|
||||
filters=64,
|
||||
kernel_size=3,
|
||||
strides=1,
|
||||
padding="same",
|
||||
activation="relu",
|
||||
kernel_initializer="he_normal")(y)
|
||||
y = Convolution2D(
|
||||
filters=64, kernel_size=3, strides=1, padding="same",
|
||||
activation="relu", kernel_initializer="he_normal")(y)
|
||||
filters=64,
|
||||
kernel_size=3,
|
||||
strides=1,
|
||||
padding="same",
|
||||
activation="relu",
|
||||
kernel_initializer="he_normal")(y)
|
||||
y = MaxPooling2D(pool_size=2, strides=2, padding="same")(y)
|
||||
|
||||
y = Convolution2D(
|
||||
filters=128, kernel_size=3, strides=1, padding="same",
|
||||
activation="relu", kernel_initializer="he_normal")(y)
|
||||
filters=128,
|
||||
kernel_size=3,
|
||||
strides=1,
|
||||
padding="same",
|
||||
activation="relu",
|
||||
kernel_initializer="he_normal")(y)
|
||||
y = Convolution2D(
|
||||
filters=128, kernel_size=3, strides=1, padding="same",
|
||||
activation="relu", kernel_initializer="he_normal")(y)
|
||||
filters=128,
|
||||
kernel_size=3,
|
||||
strides=1,
|
||||
padding="same",
|
||||
activation="relu",
|
||||
kernel_initializer="he_normal")(y)
|
||||
y = MaxPooling2D(pool_size=2, strides=2, padding="same")(y)
|
||||
|
||||
y = Convolution2D(
|
||||
filters=256, kernel_size=3, strides=1, padding="same",
|
||||
activation="relu", kernel_initializer="he_normal")(y)
|
||||
filters=256,
|
||||
kernel_size=3,
|
||||
strides=1,
|
||||
padding="same",
|
||||
activation="relu",
|
||||
kernel_initializer="he_normal")(y)
|
||||
y = Convolution2D(
|
||||
filters=256, kernel_size=3, strides=1, padding="same",
|
||||
activation="relu", kernel_initializer="he_normal")(y)
|
||||
filters=256,
|
||||
kernel_size=3,
|
||||
strides=1,
|
||||
padding="same",
|
||||
activation="relu",
|
||||
kernel_initializer="he_normal")(y)
|
||||
y = MaxPooling2D(pool_size=2, strides=2, padding="same")(y)
|
||||
|
||||
y = Flatten()(y)
|
||||
@@ -91,9 +113,10 @@ class Cifar10Model(Trainable):
|
||||
model = self._build_model(x_train.shape[1:])
|
||||
|
||||
opt = tf.keras.optimizers.Adadelta()
|
||||
model.compile(loss="categorical_crossentropy",
|
||||
optimizer=opt,
|
||||
metrics=["accuracy"])
|
||||
model.compile(
|
||||
loss="categorical_crossentropy",
|
||||
optimizer=opt,
|
||||
metrics=["accuracy"])
|
||||
self.model = model
|
||||
|
||||
def _train(self):
|
||||
@@ -134,8 +157,7 @@ class Cifar10Model(Trainable):
|
||||
|
||||
# loss, accuracy
|
||||
_, accuracy = self.model.evaluate(x_test, y_test, verbose=0)
|
||||
return TrainingResult(timesteps_this_iter=10,
|
||||
mean_accuracy=accuracy)
|
||||
return TrainingResult(timesteps_this_iter=10, mean_accuracy=accuracy)
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
file_path = checkpoint_dir + "/model"
|
||||
@@ -154,15 +176,17 @@ class Cifar10Model(Trainable):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--smoke-test",
|
||||
action="store_true",
|
||||
help="Finish quickly for testing")
|
||||
parser.add_argument(
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
register_trainable("train_cifar10", Cifar10Model)
|
||||
train_spec = {
|
||||
"run": "train_cifar10",
|
||||
"trial_resources": {"cpu": 1, "gpu": 1},
|
||||
"trial_resources": {
|
||||
"cpu": 1,
|
||||
"gpu": 1
|
||||
},
|
||||
"stop": {
|
||||
"mean_accuracy": 0.80,
|
||||
"timesteps_total": 300,
|
||||
@@ -170,7 +194,7 @@ if __name__ == "__main__":
|
||||
"config": {
|
||||
"epochs": 1,
|
||||
"batch_size": 64,
|
||||
"lr": grid_search([10 ** -4, 10 ** -5]),
|
||||
"lr": grid_search([10**-4, 10**-5]),
|
||||
"decay": lambda spec: spec.config.lr / 100.0,
|
||||
"dropout": grid_search([0.25, 0.5]),
|
||||
},
|
||||
@@ -178,17 +202,17 @@ if __name__ == "__main__":
|
||||
}
|
||||
|
||||
if args.smoke_test:
|
||||
train_spec["config"]["lr"] = 10 ** -4
|
||||
train_spec["config"]["lr"] = 10**-4
|
||||
train_spec["config"]["dropout"] = 0.5
|
||||
|
||||
ray.init()
|
||||
|
||||
pbt = PopulationBasedTraining(
|
||||
time_attr="timesteps_total", reward_attr="mean_accuracy",
|
||||
time_attr="timesteps_total",
|
||||
reward_attr="mean_accuracy",
|
||||
perturbation_interval=10,
|
||||
hyperparam_mutations={
|
||||
"dropout": lambda _: np.random.uniform(0, 1),
|
||||
})
|
||||
|
||||
run_experiments({"pbt_cifar10": train_spec},
|
||||
scheduler=pbt)
|
||||
run_experiments({"pbt_cifar10": train_spec}, scheduler=pbt)
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""A deep MNIST classifier using convolutional layers.
|
||||
|
||||
See extensive documentation at
|
||||
@@ -42,7 +41,7 @@ import tensorflow as tf
|
||||
|
||||
FLAGS = None
|
||||
status_reporter = None # used to report training status back to Ray
|
||||
activation_fn = None # e.g. tf.nn.relu
|
||||
activation_fn = None # e.g. tf.nn.relu
|
||||
|
||||
|
||||
def deepnn(x):
|
||||
@@ -90,7 +89,7 @@ def deepnn(x):
|
||||
W_fc1 = weight_variable([7 * 7 * 64, 1024])
|
||||
b_fc1 = bias_variable([1024])
|
||||
|
||||
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
|
||||
h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
|
||||
h_fc1 = activation_fn(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
|
||||
|
||||
# Dropout - controls the complexity of the model, prevents co-adaptation of
|
||||
@@ -173,7 +172,10 @@ def main(_):
|
||||
batch = mnist.train.next_batch(50)
|
||||
if i % 10 == 0:
|
||||
train_accuracy = accuracy.eval(feed_dict={
|
||||
x: batch[0], y_: batch[1], keep_prob: 1.0})
|
||||
x: batch[0],
|
||||
y_: batch[1],
|
||||
keep_prob: 1.0
|
||||
})
|
||||
|
||||
# !!! Report status to ray.tune !!!
|
||||
if status_reporter:
|
||||
@@ -181,11 +183,17 @@ def main(_):
|
||||
timesteps_total=i, mean_accuracy=train_accuracy)
|
||||
|
||||
print('step %d, training accuracy %g' % (i, train_accuracy))
|
||||
train_step.run(
|
||||
feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
|
||||
train_step.run(feed_dict={
|
||||
x: batch[0],
|
||||
y_: batch[1],
|
||||
keep_prob: 0.5
|
||||
})
|
||||
|
||||
print('test accuracy %g' % accuracy.eval(feed_dict={
|
||||
x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))
|
||||
x: mnist.test.images,
|
||||
y_: mnist.test.labels,
|
||||
keep_prob: 1.0
|
||||
}))
|
||||
|
||||
|
||||
# !!! Entrypoint for ray.tune !!!
|
||||
@@ -195,7 +203,9 @@ def train(config={'activation': 'relu'}, reporter=None):
|
||||
activation_fn = getattr(tf.nn, config['activation'])
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data',
|
||||
'--data_dir',
|
||||
type=str,
|
||||
default='/tmp/tensorflow/mnist/input_data',
|
||||
help='Directory for storing input data')
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||
@@ -213,8 +223,8 @@ if __name__ == '__main__':
|
||||
'run': 'train_mnist',
|
||||
'repeat': 10,
|
||||
'stop': {
|
||||
'mean_accuracy': 0.99,
|
||||
'timesteps_total': 600,
|
||||
'mean_accuracy': 0.99,
|
||||
'timesteps_total': 600,
|
||||
},
|
||||
'config': {
|
||||
'activation': grid_search(['relu', 'elu', 'tanh']),
|
||||
@@ -228,8 +238,12 @@ if __name__ == '__main__':
|
||||
ray.init()
|
||||
|
||||
from ray.tune.async_hyperband import AsyncHyperBandScheduler
|
||||
run_experiments({'tune_mnist_test': mnist_spec},
|
||||
scheduler=AsyncHyperBandScheduler(
|
||||
time_attr="timesteps_total",
|
||||
reward_attr="mean_accuracy",
|
||||
max_t=600,))
|
||||
run_experiments(
|
||||
{
|
||||
'tune_mnist_test': mnist_spec
|
||||
},
|
||||
scheduler=AsyncHyperBandScheduler(
|
||||
time_attr="timesteps_total",
|
||||
reward_attr="mean_accuracy",
|
||||
max_t=600,
|
||||
))
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""A deep MNIST classifier using convolutional layers.
|
||||
|
||||
See extensive documentation at
|
||||
@@ -42,7 +41,7 @@ import tensorflow as tf
|
||||
|
||||
FLAGS = None
|
||||
status_reporter = None # used to report training status back to Ray
|
||||
activation_fn = None # e.g. tf.nn.relu
|
||||
activation_fn = None # e.g. tf.nn.relu
|
||||
|
||||
|
||||
def deepnn(x):
|
||||
@@ -90,7 +89,7 @@ def deepnn(x):
|
||||
W_fc1 = weight_variable([7 * 7 * 64, 1024])
|
||||
b_fc1 = bias_variable([1024])
|
||||
|
||||
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
|
||||
h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
|
||||
h_fc1 = activation_fn(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
|
||||
|
||||
# Dropout - controls the complexity of the model, prevents co-adaptation of
|
||||
@@ -173,7 +172,10 @@ def main(_):
|
||||
batch = mnist.train.next_batch(50)
|
||||
if i % 10 == 0:
|
||||
train_accuracy = accuracy.eval(feed_dict={
|
||||
x: batch[0], y_: batch[1], keep_prob: 1.0})
|
||||
x: batch[0],
|
||||
y_: batch[1],
|
||||
keep_prob: 1.0
|
||||
})
|
||||
|
||||
# !!! Report status to ray.tune !!!
|
||||
if status_reporter:
|
||||
@@ -181,11 +183,17 @@ def main(_):
|
||||
timesteps_total=i, mean_accuracy=train_accuracy)
|
||||
|
||||
print('step %d, training accuracy %g' % (i, train_accuracy))
|
||||
train_step.run(
|
||||
feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
|
||||
train_step.run(feed_dict={
|
||||
x: batch[0],
|
||||
y_: batch[1],
|
||||
keep_prob: 0.5
|
||||
})
|
||||
|
||||
print('test accuracy %g' % accuracy.eval(feed_dict={
|
||||
x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))
|
||||
x: mnist.test.images,
|
||||
y_: mnist.test.labels,
|
||||
keep_prob: 1.0
|
||||
}))
|
||||
|
||||
|
||||
# !!! Entrypoint for ray.tune !!!
|
||||
@@ -195,7 +203,9 @@ def train(config={'activation': 'relu'}, reporter=None):
|
||||
activation_fn = getattr(tf.nn, config['activation'])
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data',
|
||||
'--data_dir',
|
||||
type=str,
|
||||
default='/tmp/tensorflow/mnist/input_data',
|
||||
help='Directory for storing input data')
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||
@@ -212,8 +222,8 @@ if __name__ == '__main__':
|
||||
mnist_spec = {
|
||||
'run': 'train_mnist',
|
||||
'stop': {
|
||||
'mean_accuracy': 0.99,
|
||||
'time_total_s': 600,
|
||||
'mean_accuracy': 0.99,
|
||||
'time_total_s': 600,
|
||||
},
|
||||
'config': {
|
||||
'activation': grid_search(['relu', 'elu', 'tanh']),
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""A deep MNIST classifier using convolutional layers.
|
||||
See extensive documentation at
|
||||
https://www.tensorflow.org/get_started/mnist/pros
|
||||
@@ -39,7 +38,7 @@ from tensorflow.examples.tutorials.mnist import input_data
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
activation_fn = None # e.g. tf.nn.relu
|
||||
activation_fn = None # e.g. tf.nn.relu
|
||||
|
||||
|
||||
def setupCNN(x):
|
||||
@@ -85,7 +84,7 @@ def setupCNN(x):
|
||||
W_fc1 = weight_variable([7 * 7 * 64, 1024])
|
||||
b_fc1 = bias_variable([1024])
|
||||
|
||||
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
|
||||
h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
|
||||
h_fc1 = activation_fn(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
|
||||
|
||||
# Dropout - controls the complexity of the model, prevents co-adaptation of
|
||||
@@ -182,14 +181,18 @@ class TrainMNIST(Trainable):
|
||||
self.sess.run(
|
||||
self.train_step,
|
||||
feed_dict={
|
||||
self.x: batch[0], self.y_: batch[1], self.keep_prob: 0.5
|
||||
self.x: batch[0],
|
||||
self.y_: batch[1],
|
||||
self.keep_prob: 0.5
|
||||
})
|
||||
|
||||
batch = self.mnist.train.next_batch(50)
|
||||
train_accuracy = self.sess.run(
|
||||
self.accuracy,
|
||||
feed_dict={
|
||||
self.x: batch[0], self.y_: batch[1], self.keep_prob: 1.0
|
||||
self.x: batch[0],
|
||||
self.y_: batch[1],
|
||||
self.keep_prob: 1.0
|
||||
})
|
||||
|
||||
self.iterations += 1
|
||||
@@ -215,11 +218,11 @@ if __name__ == '__main__':
|
||||
mnist_spec = {
|
||||
'run': 'my_class',
|
||||
'stop': {
|
||||
'mean_accuracy': 0.99,
|
||||
'time_total_s': 600,
|
||||
'mean_accuracy': 0.99,
|
||||
'time_total_s': 600,
|
||||
},
|
||||
'config': {
|
||||
'learning_rate': lambda spec: 10 ** np.random.uniform(-5, -3),
|
||||
'learning_rate': lambda spec: 10**np.random.uniform(-5, -3),
|
||||
'activation': grid_search(['relu', 'elu', 'tanh']),
|
||||
},
|
||||
"repeat": 10,
|
||||
@@ -231,8 +234,6 @@ if __name__ == '__main__':
|
||||
|
||||
ray.init()
|
||||
hyperband = HyperBandScheduler(
|
||||
time_attr="timesteps_total", reward_attr="mean_accuracy",
|
||||
max_t=100)
|
||||
time_attr="timesteps_total", reward_attr="mean_accuracy", max_t=100)
|
||||
|
||||
run_experiments(
|
||||
{'mnist_hyperband_test': mnist_spec}, scheduler=hyperband)
|
||||
run_experiments({'mnist_hyperband_test': mnist_spec}, scheduler=hyperband)
|
||||
|
||||
@@ -35,14 +35,26 @@ class Experiment(object):
|
||||
checkpoint at least this many times. Only applies if
|
||||
checkpointing is enabled. Defaults to 3.
|
||||
"""
|
||||
def __init__(self, name, run, stop=None, config=None,
|
||||
trial_resources=None, repeat=1, local_dir=None,
|
||||
upload_dir="", checkpoint_freq=0, max_failures=3):
|
||||
|
||||
def __init__(self,
|
||||
name,
|
||||
run,
|
||||
stop=None,
|
||||
config=None,
|
||||
trial_resources=None,
|
||||
repeat=1,
|
||||
local_dir=None,
|
||||
upload_dir="",
|
||||
checkpoint_freq=0,
|
||||
max_failures=3):
|
||||
spec = {
|
||||
"run": run,
|
||||
"stop": stop or {},
|
||||
"config": config or {},
|
||||
"trial_resources": trial_resources or {"cpu": 1, "gpu": 0},
|
||||
"trial_resources": trial_resources or {
|
||||
"cpu": 1,
|
||||
"gpu": 0
|
||||
},
|
||||
"repeat": repeat,
|
||||
"local_dir": local_dir or DEFAULT_RESULTS_DIR,
|
||||
"upload_dir": upload_dir,
|
||||
|
||||
@@ -91,8 +91,8 @@ class FunctionRunner(Trainable):
|
||||
for k in self._default_config:
|
||||
if k in scrubbed_config:
|
||||
del scrubbed_config[k]
|
||||
self._runner = _RunnerThread(
|
||||
entrypoint, scrubbed_config, self._status_reporter)
|
||||
self._runner = _RunnerThread(entrypoint, scrubbed_config,
|
||||
self._status_reporter)
|
||||
self._start_time = time.time()
|
||||
self._last_reported_timestep = 0
|
||||
self._runner.start()
|
||||
@@ -104,9 +104,8 @@ class FunctionRunner(Trainable):
|
||||
|
||||
def _train(self):
|
||||
time.sleep(
|
||||
self.config.get(
|
||||
"script_min_iter_time_s",
|
||||
self._default_config["script_min_iter_time_s"]))
|
||||
self.config.get("script_min_iter_time_s",
|
||||
self._default_config["script_min_iter_time_s"]))
|
||||
result = self._status_reporter._get_and_clear_status()
|
||||
while result is None:
|
||||
time.sleep(1)
|
||||
|
||||
@@ -102,9 +102,8 @@ class HyperOptScheduler(FIFOScheduler):
|
||||
self._hpopt_trials.refresh()
|
||||
|
||||
# Get new suggestion from
|
||||
new_trials = self.algo(
|
||||
new_ids, self.domain, self._hpopt_trials,
|
||||
self.rstate.randint(2 ** 31 - 1))
|
||||
new_trials = self.algo(new_ids, self.domain, self._hpopt_trials,
|
||||
self.rstate.randint(2**31 - 1))
|
||||
self._hpopt_trials.insert_trial_docs(new_trials)
|
||||
self._hpopt_trials.refresh()
|
||||
new_trial = new_trials[0]
|
||||
@@ -112,8 +111,11 @@ class HyperOptScheduler(FIFOScheduler):
|
||||
suggested_config = hpo.base.spec_from_misc(new_trial["misc"])
|
||||
new_cfg.update(suggested_config)
|
||||
|
||||
kv_str = "_".join(["{}={}".format(k, str(v)[:5])
|
||||
for k, v in sorted(suggested_config.items())])
|
||||
kv_str = "_".join([
|
||||
"{}={}".format(k,
|
||||
str(v)[:5])
|
||||
for k, v in sorted(suggested_config.items())
|
||||
])
|
||||
experiment_tag = "{}_{}".format(new_trial_id, kv_str)
|
||||
|
||||
# Keep this consistent with tune.variant_generator
|
||||
@@ -166,8 +168,7 @@ class HyperOptScheduler(FIFOScheduler):
|
||||
del self._tune_to_hp[trial]
|
||||
|
||||
def _to_hyperopt_result(self, result):
|
||||
return {"loss": -getattr(result, self._reward_attr),
|
||||
"status": "ok"}
|
||||
return {"loss": -getattr(result, self._reward_attr), "status": "ok"}
|
||||
|
||||
def _get_hyperopt_trial(self, tid):
|
||||
return [t for t in self._hpopt_trials.trials if t["tid"] == tid][0]
|
||||
@@ -183,8 +184,9 @@ class HyperOptScheduler(FIFOScheduler):
|
||||
experiments and trials left to run. If self._max_concurrent is None,
|
||||
scheduler will add new trial if there is none that are pending.
|
||||
"""
|
||||
pending = [t for t in trial_runner.get_trials()
|
||||
if t.status == Trial.PENDING]
|
||||
pending = [
|
||||
t for t in trial_runner.get_trials() if t.status == Trial.PENDING
|
||||
]
|
||||
if self._num_trials_left <= 0:
|
||||
return
|
||||
if self._max_concurrent is None:
|
||||
|
||||
@@ -66,9 +66,10 @@ class HyperBandScheduler(FIFOScheduler):
|
||||
mentioned in the original HyperBand paper.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, time_attr='training_iteration',
|
||||
reward_attr='episode_reward_mean', max_t=81):
|
||||
def __init__(self,
|
||||
time_attr='training_iteration',
|
||||
reward_attr='episode_reward_mean',
|
||||
max_t=81):
|
||||
assert max_t > 0, "Max (time_attr) not valid!"
|
||||
FIFOScheduler.__init__(self)
|
||||
self._eta = 3
|
||||
@@ -78,13 +79,12 @@ class HyperBandScheduler(FIFOScheduler):
|
||||
self._get_n0 = lambda s: int(
|
||||
np.ceil(self._s_max_1/(s+1) * self._eta**s))
|
||||
# bracket initial iterations
|
||||
self._get_r0 = lambda s: int((max_t*self._eta**(-s)))
|
||||
self._get_r0 = lambda s: int((max_t * self._eta**(-s)))
|
||||
self._hyperbands = [[]] # list of hyperband iterations
|
||||
self._trial_info = {} # Stores Trial -> Bracket, Band Iteration
|
||||
|
||||
# Tracks state for new trial add
|
||||
self._state = {"bracket": None,
|
||||
"band_idx": 0}
|
||||
self._state = {"bracket": None, "band_idx": 0}
|
||||
self._num_stopped = 0
|
||||
self._reward_attr = reward_attr
|
||||
self._time_attr = time_attr
|
||||
@@ -116,9 +116,9 @@ class HyperBandScheduler(FIFOScheduler):
|
||||
cur_bracket = None
|
||||
else:
|
||||
retry = False
|
||||
cur_bracket = Bracket(
|
||||
self._time_attr, self._get_n0(s), self._get_r0(s),
|
||||
self._max_t_attr, self._eta, s)
|
||||
cur_bracket = Bracket(self._time_attr, self._get_n0(s),
|
||||
self._get_r0(s), self._max_t_attr,
|
||||
self._eta, s)
|
||||
cur_band.append(cur_bracket)
|
||||
self._state["bracket"] = cur_bracket
|
||||
|
||||
@@ -217,11 +217,11 @@ class HyperBandScheduler(FIFOScheduler):
|
||||
"""
|
||||
|
||||
for hyperband in self._hyperbands:
|
||||
for bracket in sorted(hyperband,
|
||||
key=lambda b: b.completion_percentage()):
|
||||
for bracket in sorted(
|
||||
hyperband, key=lambda b: b.completion_percentage()):
|
||||
for trial in bracket.current_trials():
|
||||
if (trial.status == Trial.PENDING and
|
||||
trial_runner.has_resources(trial.resources)):
|
||||
if (trial.status == Trial.PENDING
|
||||
and trial_runner.has_resources(trial.resources)):
|
||||
return trial
|
||||
return None
|
||||
|
||||
@@ -258,6 +258,7 @@ class Bracket():
|
||||
|
||||
Also keeps track of progress to ensure good scheduling.
|
||||
"""
|
||||
|
||||
def __init__(self, time_attr, max_trials, init_t_attr, max_t_attr, eta, s):
|
||||
self._live_trials = {} # maps trial -> current result
|
||||
self._all_trials = []
|
||||
@@ -287,8 +288,9 @@ class Bracket():
|
||||
"""Checks if all iterations have completed.
|
||||
|
||||
TODO(rliaw): also check that `t.iterations == self._r`"""
|
||||
return all(self._get_result_time(result) >= self._cumul_r
|
||||
for result in self._live_trials.values())
|
||||
return all(
|
||||
self._get_result_time(result) >= self._cumul_r
|
||||
for result in self._live_trials.values())
|
||||
|
||||
def finished(self):
|
||||
return self._halves == 0 and self.cur_iter_done()
|
||||
@@ -379,7 +381,7 @@ class Bracket():
|
||||
def _calculate_total_work(self, n, r, s):
|
||||
work = 0
|
||||
cumulative_r = r
|
||||
for i in range(s+1):
|
||||
for i in range(s + 1):
|
||||
work += int(n) * int(r)
|
||||
n /= self._eta
|
||||
n = int(np.ceil(n))
|
||||
@@ -389,11 +391,11 @@ class Bracket():
|
||||
|
||||
def __repr__(self):
|
||||
status = ", ".join([
|
||||
"Max Size (n)={}".format(self._n),
|
||||
"Milestone (r)={}".format(self._cumul_r),
|
||||
"completed={:.1%}".format(self.completion_percentage())
|
||||
])
|
||||
"Max Size (n)={}".format(self._n), "Milestone (r)={}".format(
|
||||
self._cumul_r), "completed={:.1%}".format(
|
||||
self.completion_percentage())
|
||||
])
|
||||
counts = collections.Counter([t.status for t in self._all_trials])
|
||||
trial_statuses = ", ".join(sorted(
|
||||
["{}: {}".format(k, v) for k, v in counts.items()]))
|
||||
trial_statuses = ", ".join(
|
||||
sorted(["{}: {}".format(k, v) for k, v in counts.items()]))
|
||||
return "Bracket({}): {{{}}} ".format(status, trial_statuses)
|
||||
|
||||
@@ -13,7 +13,6 @@ from ray.tune.cluster_info import get_ssh_key, get_ssh_user
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
|
||||
|
||||
# Map from (logdir, remote_dir) -> syncer
|
||||
_syncers = {}
|
||||
|
||||
@@ -69,9 +68,8 @@ class _LogSyncer(object):
|
||||
def sync_now(self, force=False):
|
||||
self.last_sync_time = time.time()
|
||||
if not self.worker_ip:
|
||||
print(
|
||||
"Worker ip unknown, skipping log sync for {}".format(
|
||||
self.local_dir))
|
||||
print("Worker ip unknown, skipping log sync for {}".format(
|
||||
self.local_dir))
|
||||
return
|
||||
|
||||
if self.worker_ip == self.local_ip:
|
||||
@@ -80,23 +78,21 @@ class _LogSyncer(object):
|
||||
ssh_key = get_ssh_key()
|
||||
ssh_user = get_ssh_user()
|
||||
if ssh_key is None or ssh_user is None:
|
||||
print(
|
||||
"Error: log sync requires cluster to be setup with "
|
||||
"`ray create_or_update`.")
|
||||
print("Error: log sync requires cluster to be setup with "
|
||||
"`ray create_or_update`.")
|
||||
return
|
||||
if not distutils.spawn.find_executable("rsync"):
|
||||
print("Error: log sync requires rsync to be installed.")
|
||||
return
|
||||
worker_to_local_sync_cmd = (
|
||||
("""rsync -avz -e "ssh -i '{}' -o ConnectTimeout=120s """
|
||||
"""-o StrictHostKeyChecking=no" '{}@{}:{}/' '{}/'""").format(
|
||||
worker_to_local_sync_cmd = ((
|
||||
"""rsync -avz -e "ssh -i '{}' -o ConnectTimeout=120s """
|
||||
"""-o StrictHostKeyChecking=no" '{}@{}:{}/' '{}/'""").format(
|
||||
ssh_key, ssh_user, self.worker_ip,
|
||||
pipes.quote(self.local_dir), pipes.quote(self.local_dir)))
|
||||
|
||||
if self.remote_dir:
|
||||
local_to_remote_sync_cmd = (
|
||||
"aws s3 sync '{}' '{}'".format(
|
||||
pipes.quote(self.local_dir), pipes.quote(self.remote_dir)))
|
||||
local_to_remote_sync_cmd = ("aws s3 sync '{}' '{}'".format(
|
||||
pipes.quote(self.local_dir), pipes.quote(self.remote_dir)))
|
||||
else:
|
||||
local_to_remote_sync_cmd = None
|
||||
|
||||
|
||||
@@ -110,9 +110,9 @@ def to_tf_values(result, path):
|
||||
for attr, value in result.items():
|
||||
if value is not None:
|
||||
if type(value) in [int, float]:
|
||||
values.append(tf.Summary.Value(
|
||||
tag="/".join(path + [attr]),
|
||||
simple_value=value))
|
||||
values.append(
|
||||
tf.Summary.Value(
|
||||
tag="/".join(path + [attr]), simple_value=value))
|
||||
elif type(value) is dict:
|
||||
values.extend(to_tf_values(value, path + [attr]))
|
||||
return values
|
||||
@@ -125,8 +125,8 @@ class _TFLogger(Logger):
|
||||
def on_result(self, result):
|
||||
tmp = result._asdict()
|
||||
for k in [
|
||||
"config", "pid", "timestamp", "time_total_s",
|
||||
"timesteps_total"]:
|
||||
"config", "pid", "timestamp", "time_total_s", "timesteps_total"
|
||||
]:
|
||||
del tmp[k] # not useful to tf log these
|
||||
values = to_tf_values(tmp, ["ray", "tune"])
|
||||
train_stats = tf.Summary(value=values)
|
||||
@@ -165,9 +165,9 @@ class _CustomEncoder(json.JSONEncoder):
|
||||
return repr(o) if not np.isnan(o) else nan_str
|
||||
|
||||
_iterencode = json.encoder._make_iterencode(
|
||||
None, self.default, _encoder, self.indent, floatstr,
|
||||
self.key_separator, self.item_separator, self.sort_keys,
|
||||
self.skipkeys, _one_shot)
|
||||
None, self.default, _encoder, self.indent, floatstr,
|
||||
self.key_separator, self.item_separator, self.sort_keys,
|
||||
self.skipkeys, _one_shot)
|
||||
return _iterencode(o, 0)
|
||||
|
||||
def default(self, value):
|
||||
|
||||
@@ -32,10 +32,13 @@ class MedianStoppingRule(FIFOScheduler):
|
||||
time a trial reports. Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, time_attr="time_total_s", reward_attr="episode_reward_mean",
|
||||
grace_period=60.0, min_samples_required=3,
|
||||
hard_stop=True, verbose=True):
|
||||
def __init__(self,
|
||||
time_attr="time_total_s",
|
||||
reward_attr="episode_reward_mean",
|
||||
grace_period=60.0,
|
||||
min_samples_required=3,
|
||||
hard_stop=True,
|
||||
verbose=True):
|
||||
FIFOScheduler.__init__(self)
|
||||
self._stopped_trials = set()
|
||||
self._completed_trials = set()
|
||||
@@ -103,9 +106,10 @@ class MedianStoppingRule(FIFOScheduler):
|
||||
results = self._results[trial]
|
||||
# TODO(ekl) we could do interpolation to be more precise, but for now
|
||||
# assume len(results) is large and the time diffs are roughly equal
|
||||
return np.mean(
|
||||
[getattr(r, self._reward_attr)
|
||||
for r in results if getattr(r, self._time_attr) <= t_max])
|
||||
return np.mean([
|
||||
getattr(r, self._reward_attr) for r in results
|
||||
if getattr(r, self._time_attr) <= t_max
|
||||
])
|
||||
|
||||
def _best_result(self, trial):
|
||||
results = self._results[trial]
|
||||
|
||||
+26
-26
@@ -11,7 +11,6 @@ from ray.tune.trial import Trial
|
||||
from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler
|
||||
from ray.tune.variant_generator import _format_vars
|
||||
|
||||
|
||||
# Parameters are transferred from the top PBT_QUANTILE fraction of trials to
|
||||
# the bottom PBT_QUANTILE fraction.
|
||||
PBT_QUANTILE = 0.25
|
||||
@@ -27,9 +26,8 @@ class PBTTrialState(object):
|
||||
self.last_perturbation_time = 0
|
||||
|
||||
def __repr__(self):
|
||||
return str((
|
||||
self.last_score, self.last_checkpoint,
|
||||
self.last_perturbation_time))
|
||||
return str((self.last_score, self.last_checkpoint,
|
||||
self.last_perturbation_time))
|
||||
|
||||
|
||||
def explore(config, mutations, resample_probability, custom_explore_fn):
|
||||
@@ -51,12 +49,13 @@ def explore(config, mutations, resample_probability, custom_explore_fn):
|
||||
config[key] not in distribution:
|
||||
new_config[key] = random.choice(distribution)
|
||||
elif random.random() > 0.5:
|
||||
new_config[key] = distribution[
|
||||
max(0, distribution.index(config[key]) - 1)]
|
||||
new_config[key] = distribution[max(
|
||||
0,
|
||||
distribution.index(config[key]) - 1)]
|
||||
else:
|
||||
new_config[key] = distribution[
|
||||
min(len(distribution) - 1,
|
||||
distribution.index(config[key]) + 1)]
|
||||
new_config[key] = distribution[min(
|
||||
len(distribution) - 1,
|
||||
distribution.index(config[key]) + 1)]
|
||||
else:
|
||||
if random.random() < resample_probability:
|
||||
new_config[key] = distribution()
|
||||
@@ -70,8 +69,8 @@ def explore(config, mutations, resample_probability, custom_explore_fn):
|
||||
new_config = custom_explore_fn(new_config)
|
||||
assert new_config is not None, \
|
||||
"Custom explore fn failed to return new config"
|
||||
print(
|
||||
"[explore] perturbed config from {} -> {}".format(config, new_config))
|
||||
print("[explore] perturbed config from {} -> {}".format(
|
||||
config, new_config))
|
||||
return new_config
|
||||
|
||||
|
||||
@@ -148,10 +147,13 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
>>> run_experiments({...}, scheduler=pbt)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, time_attr="time_total_s", reward_attr="episode_reward_mean",
|
||||
perturbation_interval=60.0, hyperparam_mutations={},
|
||||
resample_probability=0.25, custom_explore_fn=None):
|
||||
def __init__(self,
|
||||
time_attr="time_total_s",
|
||||
reward_attr="episode_reward_mean",
|
||||
perturbation_interval=60.0,
|
||||
hyperparam_mutations={},
|
||||
resample_probability=0.25,
|
||||
custom_explore_fn=None):
|
||||
if not hyperparam_mutations and not custom_explore_fn:
|
||||
raise TuneError(
|
||||
"You must specify at least one of `hyperparam_mutations` or "
|
||||
@@ -209,14 +211,13 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
if not new_state.last_checkpoint:
|
||||
print("[pbt] warn: no checkpoint for trial, skip exploit", trial)
|
||||
return
|
||||
new_config = explore(
|
||||
trial_to_clone.config, self._hyperparam_mutations,
|
||||
self._resample_probability, self._custom_explore_fn)
|
||||
print(
|
||||
"[exploit] transferring weights from trial "
|
||||
"{} (score {}) -> {} (score {})".format(
|
||||
trial_to_clone, new_state.last_score, trial,
|
||||
trial_state.last_score))
|
||||
new_config = explore(trial_to_clone.config, self._hyperparam_mutations,
|
||||
self._resample_probability,
|
||||
self._custom_explore_fn)
|
||||
print("[exploit] transferring weights from trial "
|
||||
"{} (score {}) -> {} (score {})".format(
|
||||
trial_to_clone, new_state.last_score, trial,
|
||||
trial_state.last_score))
|
||||
# TODO(ekl) restarting the trial is expensive. We should implement a
|
||||
# lighter way reset() method that can alter the trial config.
|
||||
trial.stop(stop_logger=False)
|
||||
@@ -242,9 +243,8 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
if len(trials) <= 1:
|
||||
return [], []
|
||||
else:
|
||||
return (
|
||||
trials[:int(math.ceil(len(trials)*PBT_QUANTILE))],
|
||||
trials[int(math.floor(-len(trials)*PBT_QUANTILE)):])
|
||||
return (trials[:int(math.ceil(len(trials) * PBT_QUANTILE))],
|
||||
trials[int(math.floor(-len(trials) * PBT_QUANTILE)):])
|
||||
|
||||
def choose_trial_to_run(self, trial_runner):
|
||||
"""Ensures all trials get fair share of time (as defined by time_attr).
|
||||
|
||||
@@ -14,7 +14,8 @@ ENV_CREATOR = "env_creator"
|
||||
RLLIB_MODEL = "rllib_model"
|
||||
RLLIB_PREPROCESSOR = "rllib_preprocessor"
|
||||
KNOWN_CATEGORIES = [
|
||||
TRAINABLE_CLASS, ENV_CREATOR, RLLIB_MODEL, RLLIB_PREPROCESSOR]
|
||||
TRAINABLE_CLASS, ENV_CREATOR, RLLIB_MODEL, RLLIB_PREPROCESSOR
|
||||
]
|
||||
|
||||
|
||||
def register_trainable(name, trainable):
|
||||
@@ -32,8 +33,8 @@ def register_trainable(name, trainable):
|
||||
if isinstance(trainable, FunctionType):
|
||||
trainable = wrap_function(trainable)
|
||||
if not issubclass(trainable, Trainable):
|
||||
raise TypeError(
|
||||
"Second argument must be convertable to Trainable", trainable)
|
||||
raise TypeError("Second argument must be convertable to Trainable",
|
||||
trainable)
|
||||
_default_registry.register(TRAINABLE_CLASS, name, trainable)
|
||||
|
||||
|
||||
@@ -46,8 +47,7 @@ def register_env(name, env_creator):
|
||||
"""
|
||||
|
||||
if not isinstance(env_creator, FunctionType):
|
||||
raise TypeError(
|
||||
"Second argument must be a function.", env_creator)
|
||||
raise TypeError("Second argument must be a function.", env_creator)
|
||||
_default_registry.register(ENV_CREATOR, name, env_creator)
|
||||
|
||||
|
||||
|
||||
+54
-51
@@ -4,8 +4,6 @@ from __future__ import print_function
|
||||
|
||||
from collections import namedtuple
|
||||
import os
|
||||
|
||||
|
||||
"""
|
||||
When using ray.tune with custom training scripts, you must periodically report
|
||||
training status back to Ray by calling reporter(result).
|
||||
@@ -18,73 +16,78 @@ In RLlib, the supplied algorithms fill in TrainingResult for you.
|
||||
# Where ray.tune writes result files by default
|
||||
DEFAULT_RESULTS_DIR = os.path.expanduser("~/ray_results")
|
||||
|
||||
TrainingResult = namedtuple(
|
||||
"TrainingResult",
|
||||
[
|
||||
# (Required) Accumulated timesteps for this entire experiment.
|
||||
"timesteps_total",
|
||||
|
||||
TrainingResult = namedtuple("TrainingResult", [
|
||||
# (Required) Accumulated timesteps for this entire experiment.
|
||||
"timesteps_total",
|
||||
# (Optional) If training is terminated.
|
||||
"done",
|
||||
|
||||
# (Optional) If training is terminated.
|
||||
"done",
|
||||
# (Optional) Custom metadata to report for this iteration.
|
||||
"info",
|
||||
|
||||
# (Optional) Custom metadata to report for this iteration.
|
||||
"info",
|
||||
# (Optional) The mean episode reward if applicable.
|
||||
"episode_reward_mean",
|
||||
|
||||
# (Optional) The mean episode reward if applicable.
|
||||
"episode_reward_mean",
|
||||
# (Optional) The mean episode length if applicable.
|
||||
"episode_len_mean",
|
||||
|
||||
# (Optional) The mean episode length if applicable.
|
||||
"episode_len_mean",
|
||||
# (Optional) The number of episodes total.
|
||||
"episodes_total",
|
||||
|
||||
# (Optional) The number of episodes total.
|
||||
"episodes_total",
|
||||
# (Optional) The current training accuracy if applicable.
|
||||
"mean_accuracy",
|
||||
|
||||
# (Optional) The current training accuracy if applicable.
|
||||
"mean_accuracy",
|
||||
# (Optional) The current validation accuracy if applicable.
|
||||
"mean_validation_accuracy",
|
||||
|
||||
# (Optional) The current validation accuracy if applicable.
|
||||
"mean_validation_accuracy",
|
||||
# (Optional) The current training loss if applicable.
|
||||
"mean_loss",
|
||||
|
||||
# (Optional) The current training loss if applicable.
|
||||
"mean_loss",
|
||||
# (Auto-filled) The negated current training loss.
|
||||
"neg_mean_loss",
|
||||
|
||||
# (Auto-filled) The negated current training loss.
|
||||
"neg_mean_loss",
|
||||
# (Auto-filled) Unique string identifier for this experiment.
|
||||
# This id is preserved across checkpoint / restore calls.
|
||||
"experiment_id",
|
||||
|
||||
# (Auto-filled) Unique string identifier for this experiment. This id is
|
||||
# preserved across checkpoint / restore calls.
|
||||
"experiment_id",
|
||||
# (Auto-filled) The index of this training iteration,
|
||||
# e.g. call to train().
|
||||
"training_iteration",
|
||||
|
||||
# (Auto-filled) The index of this training iteration, e.g. call to train().
|
||||
"training_iteration",
|
||||
# (Auto-filled) Number of timesteps in the simulator
|
||||
# in this iteration.
|
||||
"timesteps_this_iter",
|
||||
|
||||
# (Auto-filled) Number of timesteps in the simulator in this iteration.
|
||||
"timesteps_this_iter",
|
||||
# (Auto-filled) Time in seconds this iteration took to run. This may
|
||||
# be overriden in order to override the system-computed
|
||||
# time difference.
|
||||
"time_this_iter_s",
|
||||
|
||||
# (Auto-filled) Time in seconds this iteration took to run. This may be
|
||||
# overriden in order to override the system-computed time difference.
|
||||
"time_this_iter_s",
|
||||
# (Auto-filled) Accumulated time in seconds for this entire experiment.
|
||||
"time_total_s",
|
||||
|
||||
# (Auto-filled) Accumulated time in seconds for this entire experiment.
|
||||
"time_total_s",
|
||||
# (Auto-filled) The pid of the training process.
|
||||
"pid",
|
||||
|
||||
# (Auto-filled) The pid of the training process.
|
||||
"pid",
|
||||
# (Auto-filled) A formatted date of when the result was processed.
|
||||
"date",
|
||||
|
||||
# (Auto-filled) A formatted date of when the result was processed.
|
||||
"date",
|
||||
# (Auto-filled) A UNIX timestamp of when the result was processed.
|
||||
"timestamp",
|
||||
|
||||
# (Auto-filled) A UNIX timestamp of when the result was processed.
|
||||
"timestamp",
|
||||
# (Auto-filled) The hostname of the machine hosting the
|
||||
# training process.
|
||||
"hostname",
|
||||
|
||||
# (Auto-filled) The hostname of the machine hosting the training process.
|
||||
"hostname",
|
||||
# (Auto-filled) The node ip of the machine hosting the
|
||||
# training process.
|
||||
"node_ip",
|
||||
|
||||
# (Auto-filled) The node ip of the machine hosting the training process.
|
||||
"node_ip",
|
||||
# (Auto=filled) The current hyperparameter configuration.
|
||||
"config",
|
||||
])
|
||||
|
||||
# (Auto=filled) The current hyperparameter configuration.
|
||||
"config",
|
||||
])
|
||||
|
||||
|
||||
TrainingResult.__new__.__defaults__ = (None,) * len(TrainingResult._fields)
|
||||
TrainingResult.__new__.__defaults__ = (None, ) * len(TrainingResult._fields)
|
||||
|
||||
@@ -20,7 +20,9 @@ if __name__ == "__main__":
|
||||
run_experiments({
|
||||
"test": {
|
||||
"run": "my_class",
|
||||
"stop": {"training_iteration": 1}
|
||||
"stop": {
|
||||
"training_iteration": 1
|
||||
}
|
||||
}
|
||||
})
|
||||
assert 'ray.rllib' not in sys.modules, "RLlib should not be imported"
|
||||
|
||||
@@ -60,163 +60,209 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||
def testRewriteEnv(self):
|
||||
def train(config, reporter):
|
||||
reporter(timesteps_total=1)
|
||||
|
||||
register_trainable("f1", train)
|
||||
|
||||
[trial] = run_experiments({"foo": {
|
||||
"run": "f1",
|
||||
"env": "CartPole-v0",
|
||||
}})
|
||||
[trial] = run_experiments({
|
||||
"foo": {
|
||||
"run": "f1",
|
||||
"env": "CartPole-v0",
|
||||
}
|
||||
})
|
||||
self.assertEqual(trial.config["env"], "CartPole-v0")
|
||||
|
||||
def testConfigPurity(self):
|
||||
def train(config, reporter):
|
||||
assert config == {"a": "b"}, config
|
||||
reporter(timesteps_total=1)
|
||||
|
||||
register_trainable("f1", train)
|
||||
run_experiments({"foo": {
|
||||
"run": "f1",
|
||||
"config": {"a": "b"},
|
||||
}})
|
||||
run_experiments({
|
||||
"foo": {
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"a": "b"
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
def testLogdir(self):
|
||||
def train(config, reporter):
|
||||
assert "/tmp/logdir/foo" in os.getcwd(), os.getcwd()
|
||||
reporter(timesteps_total=1)
|
||||
|
||||
register_trainable("f1", train)
|
||||
run_experiments({"foo": {
|
||||
"run": "f1",
|
||||
"local_dir": "/tmp/logdir",
|
||||
"config": {"a": "b"},
|
||||
}})
|
||||
run_experiments({
|
||||
"foo": {
|
||||
"run": "f1",
|
||||
"local_dir": "/tmp/logdir",
|
||||
"config": {
|
||||
"a": "b"
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
def testLongFilename(self):
|
||||
def train(config, reporter):
|
||||
assert "/tmp/logdir/foo" in os.getcwd(), os.getcwd()
|
||||
reporter(timesteps_total=1)
|
||||
|
||||
register_trainable("f1", train)
|
||||
run_experiments({"foo": {
|
||||
"run": "f1",
|
||||
"local_dir": "/tmp/logdir",
|
||||
"config": {
|
||||
"a" * 50: lambda spec: 5.0 / 7,
|
||||
"b" * 50: lambda spec: "long" * 40},
|
||||
}})
|
||||
run_experiments({
|
||||
"foo": {
|
||||
"run": "f1",
|
||||
"local_dir": "/tmp/logdir",
|
||||
"config": {
|
||||
"a" * 50: lambda spec: 5.0 / 7,
|
||||
"b" * 50: lambda spec: "long" * 40
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
def testBadParams(self):
|
||||
def f():
|
||||
run_experiments({"foo": {}})
|
||||
|
||||
self.assertRaises(TuneError, f)
|
||||
|
||||
def testBadParams2(self):
|
||||
def f():
|
||||
run_experiments({"foo": {
|
||||
"run": "asdf",
|
||||
"bah": "this param is not allowed",
|
||||
}})
|
||||
run_experiments({
|
||||
"foo": {
|
||||
"run": "asdf",
|
||||
"bah": "this param is not allowed",
|
||||
}
|
||||
})
|
||||
|
||||
self.assertRaises(TuneError, f)
|
||||
|
||||
def testBadParams3(self):
|
||||
def f():
|
||||
run_experiments({"foo": {
|
||||
"run": grid_search("invalid grid search"),
|
||||
}})
|
||||
run_experiments({
|
||||
"foo": {
|
||||
"run": grid_search("invalid grid search"),
|
||||
}
|
||||
})
|
||||
|
||||
self.assertRaises(TuneError, f)
|
||||
|
||||
def testBadParams4(self):
|
||||
def f():
|
||||
run_experiments({"foo": {
|
||||
"run": "asdf",
|
||||
}})
|
||||
run_experiments({
|
||||
"foo": {
|
||||
"run": "asdf",
|
||||
}
|
||||
})
|
||||
|
||||
self.assertRaises(TuneError, f)
|
||||
|
||||
def testBadParams5(self):
|
||||
def f():
|
||||
run_experiments({"foo": {
|
||||
"run": "PPO",
|
||||
"stop": {"asdf": 1}
|
||||
}})
|
||||
run_experiments({"foo": {"run": "PPO", "stop": {"asdf": 1}}})
|
||||
|
||||
self.assertRaises(TuneError, f)
|
||||
|
||||
def testBadParams6(self):
|
||||
def f():
|
||||
run_experiments({"foo": {
|
||||
"run": "PPO",
|
||||
"trial_resources": {"asdf": 1}
|
||||
}})
|
||||
run_experiments({
|
||||
"foo": {
|
||||
"run": "PPO",
|
||||
"trial_resources": {
|
||||
"asdf": 1
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
self.assertRaises(TuneError, f)
|
||||
|
||||
def testBadReturn(self):
|
||||
def train(config, reporter):
|
||||
reporter()
|
||||
|
||||
register_trainable("f1", train)
|
||||
|
||||
def f():
|
||||
run_experiments({"foo": {
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0,
|
||||
},
|
||||
}})
|
||||
run_experiments({
|
||||
"foo": {
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0,
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
self.assertRaises(TuneError, f)
|
||||
|
||||
def testEarlyReturn(self):
|
||||
def train(config, reporter):
|
||||
reporter(timesteps_total=100, done=True)
|
||||
time.sleep(99999)
|
||||
|
||||
register_trainable("f1", train)
|
||||
[trial] = run_experiments({"foo": {
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0,
|
||||
},
|
||||
}})
|
||||
[trial] = run_experiments({
|
||||
"foo": {
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0,
|
||||
},
|
||||
}
|
||||
})
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertEqual(trial.last_result.timesteps_total, 100)
|
||||
|
||||
def testAbruptReturn(self):
|
||||
def train(config, reporter):
|
||||
reporter(timesteps_total=100)
|
||||
|
||||
register_trainable("f1", train)
|
||||
[trial] = run_experiments({"foo": {
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0,
|
||||
},
|
||||
}})
|
||||
[trial] = run_experiments({
|
||||
"foo": {
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0,
|
||||
},
|
||||
}
|
||||
})
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertEqual(trial.last_result.timesteps_total, 100)
|
||||
|
||||
def testErrorReturn(self):
|
||||
def train(config, reporter):
|
||||
raise Exception("uh oh")
|
||||
|
||||
register_trainable("f1", train)
|
||||
|
||||
def f():
|
||||
run_experiments({"foo": {
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0,
|
||||
},
|
||||
}})
|
||||
run_experiments({
|
||||
"foo": {
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0,
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
self.assertRaises(TuneError, f)
|
||||
|
||||
def testSuccess(self):
|
||||
def train(config, reporter):
|
||||
for i in range(100):
|
||||
reporter(timesteps_total=i)
|
||||
|
||||
register_trainable("f1", train)
|
||||
[trial] = run_experiments({"foo": {
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0,
|
||||
},
|
||||
}})
|
||||
[trial] = run_experiments({
|
||||
"foo": {
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0,
|
||||
},
|
||||
}
|
||||
})
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertEqual(trial.last_result.timesteps_total, 99)
|
||||
|
||||
|
||||
class RunExperimentTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
ray.init()
|
||||
|
||||
@@ -228,6 +274,7 @@ class RunExperimentTest(unittest.TestCase):
|
||||
def train(config, reporter):
|
||||
for i in range(100):
|
||||
reporter(timesteps_total=i)
|
||||
|
||||
register_trainable("f1", train)
|
||||
trials = run_experiments({
|
||||
"foo": {
|
||||
@@ -251,13 +298,14 @@ class RunExperimentTest(unittest.TestCase):
|
||||
def train(config, reporter):
|
||||
for i in range(100):
|
||||
reporter(timesteps_total=i)
|
||||
|
||||
register_trainable("f1", train)
|
||||
exp1 = Experiment(**{
|
||||
"name": "foo",
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0
|
||||
}
|
||||
"name": "foo",
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0
|
||||
}
|
||||
})
|
||||
[trial] = run_experiments(exp1)
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
@@ -267,20 +315,21 @@ class RunExperimentTest(unittest.TestCase):
|
||||
def train(config, reporter):
|
||||
for i in range(100):
|
||||
reporter(timesteps_total=i)
|
||||
|
||||
register_trainable("f1", train)
|
||||
exp1 = Experiment(**{
|
||||
"name": "foo",
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0
|
||||
}
|
||||
"name": "foo",
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0
|
||||
}
|
||||
})
|
||||
exp2 = Experiment(**{
|
||||
"name": "bar",
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0
|
||||
}
|
||||
"name": "bar",
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0
|
||||
}
|
||||
})
|
||||
trials = run_experiments([exp1, exp2])
|
||||
for trial in trials:
|
||||
@@ -306,9 +355,8 @@ class VariantGeneratorTest(unittest.TestCase):
|
||||
self.assertEqual(trials[0].trainable_name, "PPO")
|
||||
self.assertEqual(trials[0].experiment_tag, "0")
|
||||
self.assertEqual(trials[0].max_failures, 5)
|
||||
self.assertEqual(
|
||||
trials[0].local_dir,
|
||||
os.path.join(DEFAULT_RESULTS_DIR, "tune-pong"))
|
||||
self.assertEqual(trials[0].local_dir,
|
||||
os.path.join(DEFAULT_RESULTS_DIR, "tune-pong"))
|
||||
self.assertEqual(trials[1].experiment_tag, "1")
|
||||
|
||||
def testEval(self):
|
||||
@@ -392,11 +440,13 @@ class VariantGeneratorTest(unittest.TestCase):
|
||||
trials = generate_trials({
|
||||
"run": "PPO",
|
||||
"config": {
|
||||
"x": grid_search([
|
||||
"x":
|
||||
grid_search([
|
||||
lambda spec: spec.config.y * 100,
|
||||
lambda spec: spec.config.y * 200
|
||||
]),
|
||||
"y": lambda spec: 1,
|
||||
"y":
|
||||
lambda spec: 1,
|
||||
},
|
||||
})
|
||||
trials = list(trials)
|
||||
@@ -406,12 +456,13 @@ class VariantGeneratorTest(unittest.TestCase):
|
||||
|
||||
def testRecursiveDep(self):
|
||||
try:
|
||||
list(generate_trials({
|
||||
"run": "PPO",
|
||||
"config": {
|
||||
"foo": lambda spec: spec.config.foo,
|
||||
},
|
||||
}))
|
||||
list(
|
||||
generate_trials({
|
||||
"run": "PPO",
|
||||
"config": {
|
||||
"foo": lambda spec: spec.config.foo,
|
||||
},
|
||||
}))
|
||||
except RecursiveDependencyError as e:
|
||||
assert "`foo` recursively depends on" in str(e), e
|
||||
else:
|
||||
@@ -442,12 +493,15 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
|
||||
register_trainable("f1", train)
|
||||
|
||||
experiments = {"foo": {
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"a" * 50: lambda spec: 5.0 / 7,
|
||||
"b" * 50: lambda spec: "long" * 40},
|
||||
}}
|
||||
experiments = {
|
||||
"foo": {
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"a" * 50: lambda spec: 5.0 / 7,
|
||||
"b" * 50: lambda spec: "long" * 40
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
for name, spec in experiments.items():
|
||||
for trial in generate_trials(spec, name):
|
||||
@@ -468,12 +522,12 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
ray.init(num_cpus=4, num_gpus=2)
|
||||
runner = TrialRunner()
|
||||
kwargs = {
|
||||
"stopping_criterion": {"training_iteration": 1},
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 1
|
||||
},
|
||||
"resources": Resources(cpu=1, gpu=0, extra_cpu=3, extra_gpu=1),
|
||||
}
|
||||
trials = [
|
||||
Trial("__fake", **kwargs),
|
||||
Trial("__fake", **kwargs)]
|
||||
trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)]
|
||||
for t in trials:
|
||||
runner.add_trial(t)
|
||||
|
||||
@@ -489,12 +543,12 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
ray.init(num_cpus=4, num_gpus=1)
|
||||
runner = TrialRunner()
|
||||
kwargs = {
|
||||
"stopping_criterion": {"training_iteration": 1},
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 1
|
||||
},
|
||||
"resources": Resources(cpu=1, gpu=1),
|
||||
}
|
||||
trials = [
|
||||
Trial("__fake", **kwargs),
|
||||
Trial("__fake", **kwargs)]
|
||||
trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)]
|
||||
for t in trials:
|
||||
runner.add_trial(t)
|
||||
|
||||
@@ -518,12 +572,12 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
ray.init(num_cpus=4, num_gpus=2)
|
||||
runner = TrialRunner()
|
||||
kwargs = {
|
||||
"stopping_criterion": {"training_iteration": 5},
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 5
|
||||
},
|
||||
"resources": Resources(cpu=1, gpu=1),
|
||||
}
|
||||
trials = [
|
||||
Trial("__fake", **kwargs),
|
||||
Trial("__fake", **kwargs)]
|
||||
trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)]
|
||||
for t in trials:
|
||||
runner.add_trial(t)
|
||||
|
||||
@@ -547,13 +601,13 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
ray.init(num_cpus=4, num_gpus=2)
|
||||
runner = TrialRunner()
|
||||
kwargs = {
|
||||
"stopping_criterion": {"training_iteration": 1},
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 1
|
||||
},
|
||||
"resources": Resources(cpu=1, gpu=1),
|
||||
}
|
||||
_default_registry.register(TRAINABLE_CLASS, "asdf", None)
|
||||
trials = [
|
||||
Trial("asdf", **kwargs),
|
||||
Trial("__fake", **kwargs)]
|
||||
trials = [Trial("asdf", **kwargs), Trial("__fake", **kwargs)]
|
||||
for t in trials:
|
||||
runner.add_trial(t)
|
||||
|
||||
@@ -644,7 +698,9 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
ray.init(num_cpus=1, num_gpus=1)
|
||||
runner = TrialRunner()
|
||||
kwargs = {
|
||||
"stopping_criterion": {"training_iteration": 1},
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 1
|
||||
},
|
||||
"resources": Resources(cpu=1, gpu=1),
|
||||
}
|
||||
runner.add_trial(Trial("__fake", **kwargs))
|
||||
@@ -675,7 +731,9 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
ray.init(num_cpus=1, num_gpus=1)
|
||||
runner = TrialRunner()
|
||||
kwargs = {
|
||||
"stopping_criterion": {"training_iteration": 2},
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 2
|
||||
},
|
||||
"resources": Resources(cpu=1, gpu=1),
|
||||
}
|
||||
runner.add_trial(Trial("__fake", **kwargs))
|
||||
@@ -692,7 +750,9 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
ray.init(num_cpus=1, num_gpus=1)
|
||||
runner = TrialRunner()
|
||||
kwargs = {
|
||||
"stopping_criterion": {"training_iteration": 2},
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 2
|
||||
},
|
||||
"resources": Resources(cpu=1, gpu=1),
|
||||
}
|
||||
runner.add_trial(Trial("__fake", **kwargs))
|
||||
@@ -721,14 +781,17 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
ray.init(num_cpus=4, num_gpus=2)
|
||||
runner = TrialRunner()
|
||||
kwargs = {
|
||||
"stopping_criterion": {"training_iteration": 5},
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 5
|
||||
},
|
||||
"resources": Resources(cpu=1, gpu=1),
|
||||
}
|
||||
trials = [
|
||||
Trial("__fake", **kwargs),
|
||||
Trial("__fake", **kwargs),
|
||||
Trial("__fake", **kwargs),
|
||||
Trial("__fake", **kwargs)]
|
||||
Trial("__fake", **kwargs)
|
||||
]
|
||||
for t in trials:
|
||||
runner.add_trial(t)
|
||||
runner.step()
|
||||
|
||||
@@ -19,9 +19,8 @@ _register_all()
|
||||
|
||||
|
||||
def result(t, rew):
|
||||
return TrainingResult(time_total_s=t,
|
||||
episode_reward_mean=rew,
|
||||
training_iteration=int(t))
|
||||
return TrainingResult(
|
||||
time_total_s=t, episode_reward_mean=rew, training_iteration=int(t))
|
||||
|
||||
|
||||
class EarlyStoppingSuite(unittest.TestCase):
|
||||
@@ -76,8 +75,7 @@ class EarlyStoppingSuite(unittest.TestCase):
|
||||
rule.on_trial_result(None, t3, result(2, 10)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t3, result(3, 10)),
|
||||
TrialScheduler.STOP)
|
||||
rule.on_trial_result(None, t3, result(3, 10)), TrialScheduler.STOP)
|
||||
|
||||
def testMedianStoppingMinSamples(self):
|
||||
rule = MedianStoppingRule(grace_period=0, min_samples_required=2)
|
||||
@@ -89,8 +87,7 @@ class EarlyStoppingSuite(unittest.TestCase):
|
||||
TrialScheduler.CONTINUE)
|
||||
rule.on_trial_complete(None, t2, result(10, 1000))
|
||||
self.assertEqual(
|
||||
rule.on_trial_result(None, t3, result(3, 10)),
|
||||
TrialScheduler.STOP)
|
||||
rule.on_trial_result(None, t3, result(3, 10)), TrialScheduler.STOP)
|
||||
|
||||
def testMedianStoppingUsesMedian(self):
|
||||
rule = MedianStoppingRule(grace_period=0, min_samples_required=1)
|
||||
@@ -124,8 +121,10 @@ class EarlyStoppingSuite(unittest.TestCase):
|
||||
return TrainingResult(training_iteration=t, neg_mean_loss=rew)
|
||||
|
||||
rule = MedianStoppingRule(
|
||||
grace_period=0, min_samples_required=1,
|
||||
time_attr='training_iteration', reward_attr='neg_mean_loss')
|
||||
grace_period=0,
|
||||
min_samples_required=1,
|
||||
time_attr='training_iteration',
|
||||
reward_attr='neg_mean_loss')
|
||||
t1 = Trial("PPO") # mean is 450, max 900, t_max=10
|
||||
t2 = Trial("PPO") # mean is 450, max 450, t_max=5
|
||||
for i in range(10):
|
||||
@@ -185,7 +184,6 @@ class _MockTrialRunner():
|
||||
|
||||
|
||||
class HyperbandSuite(unittest.TestCase):
|
||||
|
||||
def schedulerSetup(self, num_trials):
|
||||
"""Setup a scheduler and Runner with max Iter = 9
|
||||
|
||||
@@ -206,7 +204,10 @@ class HyperbandSuite(unittest.TestCase):
|
||||
"""Default statistics for HyperBand"""
|
||||
sched = HyperBandScheduler()
|
||||
res = {
|
||||
str(s): {"n": sched._get_n0(s), "r": sched._get_r0(s)}
|
||||
str(s): {
|
||||
"n": sched._get_n0(s),
|
||||
"r": sched._get_r0(s)
|
||||
}
|
||||
for s in range(sched._s_max_1)
|
||||
}
|
||||
res["max_trials"] = sum(v["n"] for v in res.values())
|
||||
@@ -298,8 +299,8 @@ class HyperbandSuite(unittest.TestCase):
|
||||
|
||||
# Provides results from 0 to 8 in order, keeping last one running
|
||||
for i, trl in enumerate(trials):
|
||||
action = sched.on_trial_result(
|
||||
mock_runner, trl, result(cur_units, i))
|
||||
action = sched.on_trial_result(mock_runner, trl,
|
||||
result(cur_units, i))
|
||||
if i < current_length - 1:
|
||||
self.assertEqual(action, TrialScheduler.PAUSE)
|
||||
mock_runner.process_action(trl, action)
|
||||
@@ -321,8 +322,8 @@ class HyperbandSuite(unittest.TestCase):
|
||||
# # Provides result in reverse order, killing the last one
|
||||
cur_units = stats[str(1)]["r"]
|
||||
for i, trl in reversed(list(enumerate(big_bracket.current_trials()))):
|
||||
action = sched.on_trial_result(
|
||||
mock_runner, trl, result(cur_units, i))
|
||||
action = sched.on_trial_result(mock_runner, trl,
|
||||
result(cur_units, i))
|
||||
mock_runner.process_action(trl, action)
|
||||
|
||||
self.assertEqual(action, TrialScheduler.STOP)
|
||||
@@ -338,8 +339,8 @@ class HyperbandSuite(unittest.TestCase):
|
||||
# # Provides result in reverse order, killing the last one
|
||||
cur_units = stats[str(0)]["r"]
|
||||
for i, trl in enumerate(big_bracket.current_trials()):
|
||||
action = sched.on_trial_result(
|
||||
mock_runner, trl, result(cur_units, i))
|
||||
action = sched.on_trial_result(mock_runner, trl,
|
||||
result(cur_units, i))
|
||||
mock_runner.process_action(trl, action)
|
||||
|
||||
self.assertEqual(action, TrialScheduler.STOP)
|
||||
@@ -354,14 +355,12 @@ class HyperbandSuite(unittest.TestCase):
|
||||
mock_runner._launch_trial(t)
|
||||
|
||||
sched.on_trial_error(mock_runner, t3)
|
||||
self.assertEqual(
|
||||
TrialScheduler.PAUSE,
|
||||
sched.on_trial_result(
|
||||
mock_runner, t1, result(stats[str(1)]["r"], 10)))
|
||||
self.assertEqual(
|
||||
TrialScheduler.CONTINUE,
|
||||
sched.on_trial_result(
|
||||
mock_runner, t2, result(stats[str(1)]["r"], 10)))
|
||||
self.assertEqual(TrialScheduler.PAUSE,
|
||||
sched.on_trial_result(mock_runner, t1,
|
||||
result(stats[str(1)]["r"], 10)))
|
||||
self.assertEqual(TrialScheduler.CONTINUE,
|
||||
sched.on_trial_result(mock_runner, t2,
|
||||
result(stats[str(1)]["r"], 10)))
|
||||
|
||||
def testTrialErrored2(self):
|
||||
"""Check successive halving happened even when last trial failed"""
|
||||
@@ -371,13 +370,14 @@ class HyperbandSuite(unittest.TestCase):
|
||||
trials = sched._state["bracket"].current_trials()
|
||||
for t in trials[:-1]:
|
||||
mock_runner._launch_trial(t)
|
||||
sched.on_trial_result(
|
||||
mock_runner, t, result(stats[str(1)]["r"], 10))
|
||||
sched.on_trial_result(mock_runner, t, result(
|
||||
stats[str(1)]["r"], 10))
|
||||
|
||||
mock_runner._launch_trial(trials[-1])
|
||||
sched.on_trial_error(mock_runner, trials[-1])
|
||||
self.assertEqual(len(sched._state["bracket"].current_trials()),
|
||||
self.downscale(stats[str(1)]["n"], sched))
|
||||
self.assertEqual(
|
||||
len(sched._state["bracket"].current_trials()),
|
||||
self.downscale(stats[str(1)]["n"], sched))
|
||||
|
||||
def testTrialEndedEarly(self):
|
||||
"""Check successive halving happened even when one trial failed"""
|
||||
@@ -390,14 +390,12 @@ class HyperbandSuite(unittest.TestCase):
|
||||
mock_runner._launch_trial(t)
|
||||
|
||||
sched.on_trial_complete(mock_runner, t3, result(1, 12))
|
||||
self.assertEqual(
|
||||
TrialScheduler.PAUSE,
|
||||
sched.on_trial_result(
|
||||
mock_runner, t1, result(stats[str(1)]["r"], 10)))
|
||||
self.assertEqual(
|
||||
TrialScheduler.CONTINUE,
|
||||
sched.on_trial_result(
|
||||
mock_runner, t2, result(stats[str(1)]["r"], 10)))
|
||||
self.assertEqual(TrialScheduler.PAUSE,
|
||||
sched.on_trial_result(mock_runner, t1,
|
||||
result(stats[str(1)]["r"], 10)))
|
||||
self.assertEqual(TrialScheduler.CONTINUE,
|
||||
sched.on_trial_result(mock_runner, t2,
|
||||
result(stats[str(1)]["r"], 10)))
|
||||
|
||||
def testTrialEndedEarly2(self):
|
||||
"""Check successive halving happened even when last trial failed"""
|
||||
@@ -407,13 +405,14 @@ class HyperbandSuite(unittest.TestCase):
|
||||
trials = sched._state["bracket"].current_trials()
|
||||
for t in trials[:-1]:
|
||||
mock_runner._launch_trial(t)
|
||||
sched.on_trial_result(
|
||||
mock_runner, t, result(stats[str(1)]["r"], 10))
|
||||
sched.on_trial_result(mock_runner, t, result(
|
||||
stats[str(1)]["r"], 10))
|
||||
|
||||
mock_runner._launch_trial(trials[-1])
|
||||
sched.on_trial_complete(mock_runner, trials[-1], result(100, 12))
|
||||
self.assertEqual(len(sched._state["bracket"].current_trials()),
|
||||
self.downscale(stats[str(1)]["n"], sched))
|
||||
self.assertEqual(
|
||||
len(sched._state["bracket"].current_trials()),
|
||||
self.downscale(stats[str(1)]["n"], sched))
|
||||
|
||||
def testAddAfterHalving(self):
|
||||
stats = self.default_statistics()
|
||||
@@ -426,8 +425,8 @@ class HyperbandSuite(unittest.TestCase):
|
||||
mock_runner._launch_trial(t)
|
||||
|
||||
for i, t in enumerate(bracket_trials):
|
||||
action = sched.on_trial_result(
|
||||
mock_runner, t, result(init_units, i))
|
||||
action = sched.on_trial_result(mock_runner, t, result(
|
||||
init_units, i))
|
||||
self.assertEqual(action, TrialScheduler.CONTINUE)
|
||||
t = Trial("__fake")
|
||||
sched.on_trial_add(None, t)
|
||||
@@ -435,13 +434,13 @@ class HyperbandSuite(unittest.TestCase):
|
||||
self.assertEqual(len(sched._state["bracket"].current_trials()), 2)
|
||||
|
||||
# Make sure that newly added trial gets fair computation (not just 1)
|
||||
self.assertEqual(
|
||||
TrialScheduler.CONTINUE,
|
||||
sched.on_trial_result(mock_runner, t, result(init_units, 12)))
|
||||
self.assertEqual(TrialScheduler.CONTINUE,
|
||||
sched.on_trial_result(mock_runner, t,
|
||||
result(init_units, 12)))
|
||||
new_units = init_units + int(init_units * sched._eta)
|
||||
self.assertEqual(
|
||||
TrialScheduler.PAUSE,
|
||||
sched.on_trial_result(mock_runner, t, result(new_units, 12)))
|
||||
self.assertEqual(TrialScheduler.PAUSE,
|
||||
sched.on_trial_result(mock_runner, t,
|
||||
result(new_units, 12)))
|
||||
|
||||
def testAlternateMetrics(self):
|
||||
"""Checking that alternate metrics will pass."""
|
||||
@@ -539,7 +538,6 @@ class _MockTrial(Trial):
|
||||
|
||||
|
||||
class PopulationBasedTestingSuite(unittest.TestCase):
|
||||
|
||||
def basicSetup(self, resample_prob=0.0, explore=None):
|
||||
pbt = PopulationBasedTraining(
|
||||
time_attr="training_iteration",
|
||||
@@ -554,9 +552,12 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
||||
runner = _MockTrialRunner(pbt)
|
||||
for i in range(5):
|
||||
trial = _MockTrial(
|
||||
i,
|
||||
{"id_factor": i, "float_factor": 2.0, "const_factor": 3,
|
||||
"int_factor": 10})
|
||||
i, {
|
||||
"id_factor": i,
|
||||
"float_factor": 2.0,
|
||||
"const_factor": 3,
|
||||
"int_factor": 10
|
||||
})
|
||||
runner.add_trial(trial)
|
||||
trial.status = Trial.RUNNING
|
||||
self.assertEqual(
|
||||
@@ -570,27 +571,23 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
||||
trials = runner.get_trials()
|
||||
|
||||
# no checkpoint: haven't hit next perturbation interval yet
|
||||
self.assertEqual(
|
||||
pbt.last_scores(trials), [0, 50, 100, 150, 200])
|
||||
self.assertEqual(pbt.last_scores(trials), [0, 50, 100, 150, 200])
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[0], result(15, 200)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
pbt.last_scores(trials), [0, 50, 100, 150, 200])
|
||||
self.assertEqual(pbt.last_scores(trials), [0, 50, 100, 150, 200])
|
||||
self.assertEqual(pbt._num_checkpoints, 0)
|
||||
|
||||
# checkpoint: both past interval and upper quantile
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[0], result(20, 200)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
pbt.last_scores(trials), [200, 50, 100, 150, 200])
|
||||
self.assertEqual(pbt.last_scores(trials), [200, 50, 100, 150, 200])
|
||||
self.assertEqual(pbt._num_checkpoints, 1)
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[1], result(30, 201)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
pbt.last_scores(trials), [200, 201, 100, 150, 200])
|
||||
self.assertEqual(pbt.last_scores(trials), [200, 201, 100, 150, 200])
|
||||
self.assertEqual(pbt._num_checkpoints, 2)
|
||||
|
||||
# not upper quantile any more
|
||||
@@ -608,8 +605,7 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[0], result(15, -100)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
pbt.last_scores(trials), [0, 50, 100, 150, 200])
|
||||
self.assertEqual(pbt.last_scores(trials), [0, 50, 100, 150, 200])
|
||||
self.assertTrue("@perturbed" not in trials[0].experiment_tag)
|
||||
self.assertEqual(pbt._num_perturbations, 0)
|
||||
|
||||
@@ -617,8 +613,7 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[0], result(20, -100)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
pbt.last_scores(trials), [-100, 50, 100, 150, 200])
|
||||
self.assertEqual(pbt.last_scores(trials), [-100, 50, 100, 150, 200])
|
||||
self.assertTrue("@perturbed" in trials[0].experiment_tag)
|
||||
self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"])
|
||||
self.assertEqual(pbt._num_perturbations, 1)
|
||||
@@ -627,8 +622,7 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[2], result(20, 40)),
|
||||
TrialScheduler.CONTINUE)
|
||||
self.assertEqual(
|
||||
pbt.last_scores(trials), [-100, 50, 40, 150, 200])
|
||||
self.assertEqual(pbt.last_scores(trials), [-100, 50, 40, 150, 200])
|
||||
self.assertEqual(pbt._num_perturbations, 2)
|
||||
self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"])
|
||||
self.assertTrue("@perturbed" in trials[2].experiment_tag)
|
||||
@@ -662,7 +656,6 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
||||
self.assertEqual(trials[0].config["const_factor"], 3)
|
||||
|
||||
def testPerturbationValues(self):
|
||||
|
||||
def assertProduces(fn, values):
|
||||
random.seed(0)
|
||||
seen = set()
|
||||
@@ -712,8 +705,7 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
||||
self.assertEqual(
|
||||
pbt.on_trial_result(runner, trials[1], result(20, 1000)),
|
||||
TrialScheduler.PAUSE)
|
||||
self.assertEqual(
|
||||
pbt.last_scores(trials), [0, 1000, 100, 150, 200])
|
||||
self.assertEqual(pbt.last_scores(trials), [0, 1000, 100, 150, 200])
|
||||
self.assertEqual(pbt.choose_trial_to_run(runner), trials[0])
|
||||
|
||||
def testSchedulesMostBehindTrialToRun(self):
|
||||
@@ -748,6 +740,7 @@ class PopulationBasedTestingSuite(unittest.TestCase):
|
||||
new_config["id_factor"] = 42
|
||||
new_config["float_factor"] = 43
|
||||
return new_config
|
||||
|
||||
pbt, runner = self.basicSetup(resample_prob=0.0, explore=explore)
|
||||
trials = runner.get_trials()
|
||||
self.assertEqual(
|
||||
@@ -774,8 +767,7 @@ class AsyncHyperBandSuite(unittest.TestCase):
|
||||
return t1, t2
|
||||
|
||||
def testAsyncHBOnComplete(self):
|
||||
scheduler = AsyncHyperBandScheduler(
|
||||
max_t=10, brackets=1)
|
||||
scheduler = AsyncHyperBandScheduler(max_t=10, brackets=1)
|
||||
t1, t2 = self.basicSetup(scheduler)
|
||||
t3 = Trial("PPO")
|
||||
scheduler.on_trial_add(None, t3)
|
||||
@@ -803,8 +795,7 @@ class AsyncHyperBandSuite(unittest.TestCase):
|
||||
TrialScheduler.STOP)
|
||||
|
||||
def testAsyncHBAllCompletes(self):
|
||||
scheduler = AsyncHyperBandScheduler(
|
||||
max_t=10, brackets=10)
|
||||
scheduler = AsyncHyperBandScheduler(max_t=10, brackets=10)
|
||||
trials = [Trial("PPO") for i in range(10)]
|
||||
for t in trials:
|
||||
scheduler.on_trial_add(None, t)
|
||||
@@ -834,8 +825,10 @@ class AsyncHyperBandSuite(unittest.TestCase):
|
||||
return TrainingResult(training_iteration=t, neg_mean_loss=rew)
|
||||
|
||||
scheduler = AsyncHyperBandScheduler(
|
||||
grace_period=1, time_attr='training_iteration',
|
||||
reward_attr='neg_mean_loss', brackets=1)
|
||||
grace_period=1,
|
||||
time_attr='training_iteration',
|
||||
reward_attr='neg_mean_loss',
|
||||
brackets=1)
|
||||
t1 = Trial("PPO") # mean is 450, max 900, t_max=10
|
||||
t2 = Trial("PPO") # mean is 450, max 450, t_max=5
|
||||
scheduler.on_trial_add(None, t1)
|
||||
|
||||
@@ -30,16 +30,15 @@ class TuneServerSuite(unittest.TestCase):
|
||||
def basicSetup(self):
|
||||
ray.init(num_cpus=4, num_gpus=1)
|
||||
port = get_valid_port()
|
||||
self.runner = TrialRunner(
|
||||
launch_web_server=True, server_port=port)
|
||||
self.runner = TrialRunner(launch_web_server=True, server_port=port)
|
||||
runner = self.runner
|
||||
kwargs = {
|
||||
"stopping_criterion": {"training_iteration": 3},
|
||||
"stopping_criterion": {
|
||||
"training_iteration": 3
|
||||
},
|
||||
"resources": Resources(cpu=1, gpu=1),
|
||||
}
|
||||
trials = [
|
||||
Trial("__fake", **kwargs),
|
||||
Trial("__fake", **kwargs)]
|
||||
trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)]
|
||||
for t in trials:
|
||||
runner.add_trial(t)
|
||||
client = TuneClient("localhost:{}".format(port))
|
||||
@@ -61,7 +60,9 @@ class TuneServerSuite(unittest.TestCase):
|
||||
runner.step()
|
||||
spec = {
|
||||
"run": "__fake",
|
||||
"stop": {"training_iteration": 3},
|
||||
"stop": {
|
||||
"training_iteration": 3
|
||||
},
|
||||
"trial_resources": dict(cpu=1, gpu=1),
|
||||
}
|
||||
client.add_trial("test", spec)
|
||||
|
||||
@@ -114,8 +114,8 @@ class Trainable(object):
|
||||
time_this_iter = time.time() - start
|
||||
|
||||
if result.timesteps_this_iter is None:
|
||||
raise TuneError(
|
||||
"Must specify timesteps_this_iter in result", result)
|
||||
raise TuneError("Must specify timesteps_this_iter in result",
|
||||
result)
|
||||
|
||||
self._time_total += time_this_iter
|
||||
self._timesteps_total += result.timesteps_this_iter
|
||||
@@ -159,10 +159,10 @@ class Trainable(object):
|
||||
"""
|
||||
|
||||
checkpoint_path = self._save(checkpoint_dir or self.logdir)
|
||||
pickle.dump(
|
||||
[self._experiment_id, self._iteration, self._timesteps_total,
|
||||
self._time_total],
|
||||
open(checkpoint_path + ".tune_metadata", "wb"))
|
||||
pickle.dump([
|
||||
self._experiment_id, self._iteration, self._timesteps_total,
|
||||
self._time_total
|
||||
], open(checkpoint_path + ".tune_metadata", "wb"))
|
||||
return checkpoint_path
|
||||
|
||||
def save_to_object(self):
|
||||
@@ -186,8 +186,10 @@ class Trainable(object):
|
||||
out = io.BytesIO()
|
||||
with gzip.GzipFile(fileobj=out, mode="wb") as f:
|
||||
compressed = pickle.dumps({
|
||||
"checkpoint_name": os.path.basename(checkpoint_prefix),
|
||||
"data": data,
|
||||
"checkpoint_name":
|
||||
os.path.basename(checkpoint_prefix),
|
||||
"data":
|
||||
data,
|
||||
})
|
||||
if len(compressed) > 10e6: # getting pretty large
|
||||
print("Checkpoint size is {} bytes".format(len(compressed)))
|
||||
|
||||
+37
-32
@@ -42,12 +42,12 @@ class Resources(
|
||||
__slots__ = ()
|
||||
|
||||
def __new__(cls, cpu, gpu, extra_cpu=0, extra_gpu=0):
|
||||
return super(Resources, cls).__new__(
|
||||
cls, cpu, gpu, extra_cpu, extra_gpu)
|
||||
return super(Resources, cls).__new__(cls, cpu, gpu, extra_cpu,
|
||||
extra_gpu)
|
||||
|
||||
def summary_string(self):
|
||||
return "{} CPUs, {} GPUs".format(
|
||||
self.cpu + self.extra_cpu, self.gpu + self.extra_gpu)
|
||||
return "{} CPUs, {} GPUs".format(self.cpu + self.extra_cpu,
|
||||
self.gpu + self.extra_gpu)
|
||||
|
||||
def cpu_total(self):
|
||||
return self.cpu + self.extra_cpu
|
||||
@@ -77,11 +77,17 @@ class Trial(object):
|
||||
TERMINATED = "TERMINATED"
|
||||
ERROR = "ERROR"
|
||||
|
||||
def __init__(
|
||||
self, trainable_name, config=None, local_dir=DEFAULT_RESULTS_DIR,
|
||||
experiment_tag="", resources=Resources(cpu=1, gpu=0),
|
||||
stopping_criterion=None, checkpoint_freq=0,
|
||||
restore_path=None, upload_dir=None, max_failures=0):
|
||||
def __init__(self,
|
||||
trainable_name,
|
||||
config=None,
|
||||
local_dir=DEFAULT_RESULTS_DIR,
|
||||
experiment_tag="",
|
||||
resources=Resources(cpu=1, gpu=0),
|
||||
stopping_criterion=None,
|
||||
checkpoint_freq=0,
|
||||
restore_path=None,
|
||||
upload_dir=None,
|
||||
max_failures=0):
|
||||
"""Initialize a new trial.
|
||||
|
||||
The args here take the same meaning as the command line flags defined
|
||||
@@ -166,19 +172,20 @@ class Trial(object):
|
||||
try:
|
||||
if error_msg and self.logdir:
|
||||
self.num_failures += 1
|
||||
error_file = os.path.join(
|
||||
self.logdir, "error_{}.txt".format(date_str()))
|
||||
error_file = os.path.join(self.logdir, "error_{}.txt".format(
|
||||
date_str()))
|
||||
with open(error_file, "w") as f:
|
||||
f.write(error_msg)
|
||||
self.error_file = error_file
|
||||
if self.runner:
|
||||
stop_tasks = []
|
||||
stop_tasks.append(self.runner.stop.remote())
|
||||
stop_tasks.append(self.runner.__ray_terminate__.remote(
|
||||
self.runner._ray_actor_id.id()))
|
||||
stop_tasks.append(
|
||||
self.runner.__ray_terminate__.remote(
|
||||
self.runner._ray_actor_id.id()))
|
||||
# TODO(ekl) seems like wait hangs when killing actors
|
||||
_, unfinished = ray.wait(
|
||||
stop_tasks, num_returns=2, timeout=250)
|
||||
stop_tasks, num_returns=2, timeout=250)
|
||||
except Exception:
|
||||
print("Error stopping runner:", traceback.format_exc())
|
||||
self.status = Trial.ERROR
|
||||
@@ -252,12 +259,12 @@ class Trial(object):
|
||||
return '{} pid={}'.format(hostname, pid)
|
||||
|
||||
pieces = [
|
||||
'{} [{}]'.format(
|
||||
self._status_string(),
|
||||
location_string(
|
||||
self.last_result.hostname, self.last_result.pid)),
|
||||
'{} s'.format(int(self.last_result.time_total_s)),
|
||||
'{} ts'.format(int(self.last_result.timesteps_total))]
|
||||
'{} [{}]'.format(self._status_string(),
|
||||
location_string(self.last_result.hostname,
|
||||
self.last_result.pid)),
|
||||
'{} s'.format(int(self.last_result.time_total_s)), '{} ts'.format(
|
||||
int(self.last_result.timesteps_total))
|
||||
]
|
||||
|
||||
if self.last_result.episode_reward_mean is not None:
|
||||
pieces.append('{} rew'.format(
|
||||
@@ -274,10 +281,8 @@ class Trial(object):
|
||||
return ', '.join(pieces)
|
||||
|
||||
def _status_string(self):
|
||||
return "{}{}".format(
|
||||
self.status,
|
||||
", {} failures: {}".format(self.num_failures, self.error_file)
|
||||
if self.error_file else "")
|
||||
return "{}{}".format(self.status, ", {} failures: {}".format(
|
||||
self.num_failures, self.error_file) if self.error_file else "")
|
||||
|
||||
def has_checkpoint(self):
|
||||
return self._checkpoint_path is not None or \
|
||||
@@ -335,9 +340,8 @@ class Trial(object):
|
||||
def update_last_result(self, result, terminate=False):
|
||||
if terminate:
|
||||
result = result._replace(done=True)
|
||||
if self.verbose and (
|
||||
terminate or
|
||||
time.time() - self.last_debug > DEBUG_PRINT_INTERVAL):
|
||||
if self.verbose and (terminate or time.time() - self.last_debug >
|
||||
DEBUG_PRINT_INTERVAL):
|
||||
print("TrainingResult for {}:".format(self))
|
||||
print(" {}".format(pretty_print(result).replace("\n", "\n ")))
|
||||
self.last_debug = time.time()
|
||||
@@ -358,8 +362,8 @@ class Trial(object):
|
||||
prefix="{}_{}".format(
|
||||
str(self)[:MAX_LEN_IDENTIFIER], date_str()),
|
||||
dir=self.local_dir)
|
||||
self.result_logger = UnifiedLogger(
|
||||
self.config, self.logdir, self.upload_dir)
|
||||
self.result_logger = UnifiedLogger(self.config, self.logdir,
|
||||
self.upload_dir)
|
||||
remote_logdir = self.logdir
|
||||
|
||||
def logger_creator(config):
|
||||
@@ -372,7 +376,8 @@ class Trial(object):
|
||||
# Logging for trials is handled centrally by TrialRunner, so
|
||||
# configure the remote runner to use a noop-logger.
|
||||
self.runner = cls.remote(
|
||||
config=self.config, registry=ray.tune.registry.get_registry(),
|
||||
config=self.config,
|
||||
registry=ray.tune.registry.get_registry(),
|
||||
logger_creator=logger_creator)
|
||||
|
||||
def set_verbose(self, verbose):
|
||||
@@ -387,8 +392,8 @@ class Trial(object):
|
||||
def __str__(self):
|
||||
"""Combines ``env`` with ``trainable_name`` and ``experiment_tag``."""
|
||||
if "env" in self.config:
|
||||
identifier = "{}_{}".format(
|
||||
self.trainable_name, self.config["env"])
|
||||
identifier = "{}_{}".format(self.trainable_name,
|
||||
self.config["env"])
|
||||
else:
|
||||
identifier = self.trainable_name
|
||||
if self.experiment_tag:
|
||||
|
||||
@@ -13,7 +13,6 @@ from ray.tune.web_server import TuneServer
|
||||
from ray.tune.trial import Trial, Resources
|
||||
from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler
|
||||
|
||||
|
||||
MAX_DEBUG_TRIALS = 20
|
||||
|
||||
|
||||
@@ -39,8 +38,11 @@ class TrialRunner(object):
|
||||
misleading benchmark results.
|
||||
"""
|
||||
|
||||
def __init__(self, scheduler=None, launch_web_server=False,
|
||||
server_port=TuneServer.DEFAULT_PORT, verbose=True):
|
||||
def __init__(self,
|
||||
scheduler=None,
|
||||
launch_web_server=False,
|
||||
server_port=TuneServer.DEFAULT_PORT,
|
||||
verbose=True):
|
||||
"""Initializes a new TrialRunner.
|
||||
|
||||
Args:
|
||||
@@ -73,9 +75,8 @@ class TrialRunner(object):
|
||||
"""Returns whether all trials have finished running."""
|
||||
|
||||
if self._total_time > self._global_time_limit:
|
||||
print(
|
||||
"Exceeded global time limit {} / {}".format(
|
||||
self._total_time, self._global_time_limit))
|
||||
print("Exceeded global time limit {} / {}".format(
|
||||
self._total_time, self._global_time_limit))
|
||||
return True
|
||||
|
||||
for t in self._trials:
|
||||
@@ -98,12 +99,12 @@ class TrialRunner(object):
|
||||
for trial in self._trials:
|
||||
if trial.status == Trial.PENDING:
|
||||
if not self.has_resources(trial.resources):
|
||||
raise TuneError((
|
||||
"Insufficient cluster resources to launch trial: "
|
||||
"trial requested {} but the cluster only has {} "
|
||||
"available.").format(
|
||||
trial.resources.summary_string(),
|
||||
self._avail_resources.summary_string()))
|
||||
raise TuneError(
|
||||
("Insufficient cluster resources to launch trial: "
|
||||
"trial requested {} but the cluster only has {} "
|
||||
"available.").format(
|
||||
trial.resources.summary_string(),
|
||||
self._avail_resources.summary_string()))
|
||||
elif trial.status == Trial.PAUSED:
|
||||
raise TuneError(
|
||||
"There are paused trials, but no more pending "
|
||||
@@ -165,24 +166,20 @@ class TrialRunner(object):
|
||||
for state, trials in sorted(states.items()):
|
||||
limit = limit_per_state[state]
|
||||
messages.append("{} trials:".format(state))
|
||||
for t in sorted(
|
||||
trials, key=lambda t: t.experiment_tag)[:limit]:
|
||||
for t in sorted(trials, key=lambda t: t.experiment_tag)[:limit]:
|
||||
messages.append(" - {}:\t{}".format(t, t.progress_string()))
|
||||
if len(trials) > limit:
|
||||
messages.append(" ... {} more not shown".format(
|
||||
len(trials) - limit))
|
||||
messages.append(
|
||||
" ... {} more not shown".format(len(trials) - limit))
|
||||
return "\n".join(messages) + "\n"
|
||||
|
||||
def _debug_messages(self):
|
||||
messages = ["== Status =="]
|
||||
messages.append(self._scheduler_alg.debug_string())
|
||||
if self._resources_initialized:
|
||||
messages.append(
|
||||
"Resources used: {}/{} CPUs, {}/{} GPUs".format(
|
||||
self._committed_resources.cpu,
|
||||
self._avail_resources.cpu,
|
||||
self._committed_resources.gpu,
|
||||
self._avail_resources.gpu))
|
||||
messages.append("Resources used: {}/{} CPUs, {}/{} GPUs".format(
|
||||
self._committed_resources.cpu, self._avail_resources.cpu,
|
||||
self._committed_resources.gpu, self._avail_resources.gpu))
|
||||
return messages
|
||||
|
||||
def has_resources(self, resources):
|
||||
@@ -190,9 +187,8 @@ class TrialRunner(object):
|
||||
|
||||
cpu_avail = self._avail_resources.cpu - self._committed_resources.cpu
|
||||
gpu_avail = self._avail_resources.gpu - self._committed_resources.gpu
|
||||
return (
|
||||
resources.cpu_total() <= cpu_avail and
|
||||
resources.gpu_total() <= gpu_avail)
|
||||
return (resources.cpu_total() <= cpu_avail
|
||||
and resources.gpu_total() <= gpu_avail)
|
||||
|
||||
def _get_next_trial(self):
|
||||
self._update_avail_resources()
|
||||
@@ -307,8 +303,9 @@ class TrialRunner(object):
|
||||
self._scheduler_alg.on_trial_remove(self, trial)
|
||||
elif trial.status is Trial.RUNNING:
|
||||
# NOTE: There should only be one...
|
||||
result_id = [rid for rid, t in self._running.items()
|
||||
if t is trial][0]
|
||||
result_id = [
|
||||
rid for rid, t in self._running.items() if t is trial
|
||||
][0]
|
||||
self._running.pop(result_id)
|
||||
try:
|
||||
result = ray.get(result_id)
|
||||
@@ -339,9 +336,8 @@ class TrialRunner(object):
|
||||
def _update_avail_resources(self):
|
||||
clients = ray.global_state.client_table()
|
||||
local_schedulers = [
|
||||
entry for client in clients.values() for entry in client
|
||||
if (entry['ClientType'] == 'local_scheduler' and not
|
||||
entry['Deleted'])
|
||||
entry for client in clients.values() for entry in client if
|
||||
(entry['ClientType'] == 'local_scheduler' and not entry['Deleted'])
|
||||
]
|
||||
num_cpus = sum(ls['CPU'] for ls in local_schedulers)
|
||||
num_gpus = sum(ls.get('GPU', 0) for ls in local_schedulers)
|
||||
|
||||
@@ -99,12 +99,12 @@ class FIFOScheduler(TrialScheduler):
|
||||
|
||||
def choose_trial_to_run(self, trial_runner):
|
||||
for trial in trial_runner.get_trials():
|
||||
if (trial.status == Trial.PENDING and
|
||||
trial_runner.has_resources(trial.resources)):
|
||||
if (trial.status == Trial.PENDING
|
||||
and trial_runner.has_resources(trial.resources)):
|
||||
return trial
|
||||
for trial in trial_runner.get_trials():
|
||||
if (trial.status == Trial.PAUSED and
|
||||
trial_runner.has_resources(trial.resources)):
|
||||
if (trial.status == Trial.PAUSED
|
||||
and trial_runner.has_resources(trial.resources)):
|
||||
return trial
|
||||
return None
|
||||
|
||||
|
||||
+16
-11
@@ -16,7 +16,6 @@ from ray.tune.trial_scheduler import FIFOScheduler
|
||||
from ray.tune.web_server import TuneServer
|
||||
from ray.tune.experiment import Experiment
|
||||
|
||||
|
||||
_SCHEDULERS = {
|
||||
"FIFO": FIFOScheduler,
|
||||
"MedianStopping": MedianStoppingRule,
|
||||
@@ -30,13 +29,15 @@ def _make_scheduler(args):
|
||||
if args.scheduler in _SCHEDULERS:
|
||||
return _SCHEDULERS[args.scheduler](**args.scheduler_config)
|
||||
else:
|
||||
raise TuneError(
|
||||
"Unknown scheduler: {}, should be one of {}".format(
|
||||
args.scheduler, _SCHEDULERS.keys()))
|
||||
raise TuneError("Unknown scheduler: {}, should be one of {}".format(
|
||||
args.scheduler, _SCHEDULERS.keys()))
|
||||
|
||||
|
||||
def run_experiments(experiments, scheduler=None, with_server=False,
|
||||
server_port=TuneServer.DEFAULT_PORT, verbose=True):
|
||||
def run_experiments(experiments,
|
||||
scheduler=None,
|
||||
with_server=False,
|
||||
server_port=TuneServer.DEFAULT_PORT,
|
||||
verbose=True):
|
||||
"""Tunes experiments.
|
||||
|
||||
Args:
|
||||
@@ -54,17 +55,21 @@ def run_experiments(experiments, scheduler=None, with_server=False,
|
||||
scheduler = FIFOScheduler()
|
||||
|
||||
runner = TrialRunner(
|
||||
scheduler, launch_web_server=with_server, server_port=server_port,
|
||||
scheduler,
|
||||
launch_web_server=with_server,
|
||||
server_port=server_port,
|
||||
verbose=verbose)
|
||||
exp_list = experiments
|
||||
if isinstance(experiments, Experiment):
|
||||
exp_list = [experiments]
|
||||
elif type(experiments) is dict:
|
||||
exp_list = [Experiment.from_json(name, spec)
|
||||
for name, spec in experiments.items()]
|
||||
exp_list = [
|
||||
Experiment.from_json(name, spec)
|
||||
for name, spec in experiments.items()
|
||||
]
|
||||
|
||||
if (type(exp_list) is list and
|
||||
all(isinstance(exp, Experiment) for exp in exp_list)):
|
||||
if (type(exp_list) is list
|
||||
and all(isinstance(exp, Experiment) for exp in exp_list)):
|
||||
for experiment in exp_list:
|
||||
scheduler.add_experiment(experiment, runner)
|
||||
else:
|
||||
|
||||
@@ -7,7 +7,6 @@ import base64
|
||||
import ray
|
||||
from ray.tune.registry import _to_pinnable, _from_pinnable
|
||||
|
||||
|
||||
_pinned_objects = []
|
||||
PINNED_OBJECT_PREFIX = "ray.tune.PinnedObject:"
|
||||
|
||||
@@ -15,14 +14,15 @@ PINNED_OBJECT_PREFIX = "ray.tune.PinnedObject:"
|
||||
def pin_in_object_store(obj):
|
||||
obj_id = ray.put(_to_pinnable(obj))
|
||||
_pinned_objects.append(ray.get(obj_id))
|
||||
return "{}{}".format(
|
||||
PINNED_OBJECT_PREFIX, base64.b64encode(obj_id.id()).decode("utf-8"))
|
||||
return "{}{}".format(PINNED_OBJECT_PREFIX,
|
||||
base64.b64encode(obj_id.id()).decode("utf-8"))
|
||||
|
||||
|
||||
def get_pinned_object(pinned_id):
|
||||
from ray.local_scheduler import ObjectID
|
||||
return _from_pinnable(ray.get(ObjectID(
|
||||
base64.b64decode(pinned_id[len(PINNED_OBJECT_PREFIX):]))))
|
||||
return _from_pinnable(
|
||||
ray.get(
|
||||
ObjectID(base64.b64decode(pinned_id[len(PINNED_OBJECT_PREFIX):]))))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -163,8 +163,8 @@ def _generate_variants(spec):
|
||||
for path, value in grid_vars:
|
||||
resolved_vars[path] = _get_value(spec, path)
|
||||
for k, v in resolved.items():
|
||||
if (k in resolved_vars and v != resolved_vars[k] and
|
||||
_is_resolved(resolved_vars[k])):
|
||||
if (k in resolved_vars and v != resolved_vars[k]
|
||||
and _is_resolved(resolved_vars[k])):
|
||||
raise ValueError(
|
||||
"The variable `{}` could not be unambiguously "
|
||||
"resolved to a single value. Consider simplifying "
|
||||
@@ -262,16 +262,16 @@ def _unresolved_values(spec):
|
||||
for k, v in spec.items():
|
||||
resolved, v = _try_resolve(v)
|
||||
if not resolved:
|
||||
found[(k,)] = v
|
||||
found[(k, )] = v
|
||||
elif isinstance(v, dict):
|
||||
# Recurse into a dict
|
||||
for (path, value) in _unresolved_values(v).items():
|
||||
found[(k,) + path] = value
|
||||
found[(k, ) + path] = value
|
||||
elif isinstance(v, list):
|
||||
# Recurse into a list
|
||||
for i, elem in enumerate(v):
|
||||
for (path, value) in _unresolved_values({i: elem}).items():
|
||||
found[(k,) + path] = value
|
||||
found[(k, ) + path] = value
|
||||
return found
|
||||
|
||||
|
||||
|
||||
@@ -61,8 +61,10 @@ def _resolve(directory, result_fname):
|
||||
|
||||
|
||||
def load_results_to_df(directory, result_name="result.json"):
|
||||
exp_directories = [dirpath for dirpath, dirs, files in os.walk(directory)
|
||||
for f in files if f == result_name]
|
||||
exp_directories = [
|
||||
dirpath for dirpath, dirs, files in os.walk(directory) for f in files
|
||||
if f == result_name
|
||||
]
|
||||
data = [_resolve(d, result_name) for d in exp_directories]
|
||||
data = [d for d in data if d]
|
||||
return pd.DataFrame(data)
|
||||
@@ -76,8 +78,9 @@ def generate_plotly_dim_dict(df, field):
|
||||
dim_dict["values"] = column
|
||||
elif is_string_dtype(column):
|
||||
texts = column.unique()
|
||||
dim_dict["values"] = [np.argwhere(texts == x).flatten()[0]
|
||||
for x in column]
|
||||
dim_dict["values"] = [
|
||||
np.argwhere(texts == x).flatten()[0] for x in column
|
||||
]
|
||||
dim_dict["tickvals"] = list(range(len(texts)))
|
||||
dim_dict["ticktext"] = texts
|
||||
else:
|
||||
|
||||
@@ -39,28 +39,30 @@ class TuneClient(object):
|
||||
|
||||
def get_all_trials(self):
|
||||
"""Returns a list of all trials (trial_id, config, status)."""
|
||||
return self._get_response(
|
||||
{"command": TuneClient.GET_LIST})
|
||||
return self._get_response({"command": TuneClient.GET_LIST})
|
||||
|
||||
def get_trial(self, trial_id):
|
||||
"""Returns the last result for queried trial."""
|
||||
return self._get_response(
|
||||
{"command": TuneClient.GET_TRIAL,
|
||||
"trial_id": trial_id})
|
||||
return self._get_response({
|
||||
"command": TuneClient.GET_TRIAL,
|
||||
"trial_id": trial_id
|
||||
})
|
||||
|
||||
def add_trial(self, name, trial_spec):
|
||||
"""Adds a trial of `name` with configurations."""
|
||||
# TODO(rliaw): have better way of specifying a new trial
|
||||
return self._get_response(
|
||||
{"command": TuneClient.ADD,
|
||||
"name": name,
|
||||
"spec": trial_spec})
|
||||
return self._get_response({
|
||||
"command": TuneClient.ADD,
|
||||
"name": name,
|
||||
"spec": trial_spec
|
||||
})
|
||||
|
||||
def stop_trial(self, trial_id):
|
||||
"""Requests to stop trial."""
|
||||
return self._get_response(
|
||||
{"command": TuneClient.STOP,
|
||||
"trial_id": trial_id})
|
||||
return self._get_response({
|
||||
"command": TuneClient.STOP,
|
||||
"trial_id": trial_id
|
||||
})
|
||||
|
||||
def _get_response(self, data):
|
||||
payload = json.dumps(data).encode()
|
||||
@@ -71,7 +73,6 @@ class TuneClient(object):
|
||||
|
||||
def RunnerHandler(runner):
|
||||
class Handler(SimpleHTTPRequestHandler):
|
||||
|
||||
def do_GET(self):
|
||||
content_len = int(self.headers.get('Content-Length'), 0)
|
||||
raw_body = self.rfile.read(content_len)
|
||||
@@ -82,8 +83,7 @@ def RunnerHandler(runner):
|
||||
else:
|
||||
self.send_response(400)
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps(
|
||||
response).encode())
|
||||
self.wfile.write(json.dumps(response).encode())
|
||||
|
||||
def trial_info(self, trial):
|
||||
if trial.last_result:
|
||||
@@ -112,8 +112,9 @@ def RunnerHandler(runner):
|
||||
response = {}
|
||||
try:
|
||||
if command == TuneClient.GET_LIST:
|
||||
response["trials"] = [self.trial_info(t)
|
||||
for t in runner.get_trials()]
|
||||
response["trials"] = [
|
||||
self.trial_info(t) for t in runner.get_trials()
|
||||
]
|
||||
elif command == TuneClient.GET_TRIAL:
|
||||
trial = get_trial()
|
||||
response["trial_info"] = self.trial_info(trial)
|
||||
@@ -147,8 +148,7 @@ class TuneServer(threading.Thread):
|
||||
self._port = port if port else self.DEFAULT_PORT
|
||||
address = ('localhost', self._port)
|
||||
print("Starting Tune Server...")
|
||||
self._server = HTTPServer(
|
||||
address, RunnerHandler(runner))
|
||||
self._server = HTTPServer(address, RunnerHandler(runner))
|
||||
self.start()
|
||||
|
||||
def run(self):
|
||||
|
||||
+11
-8
@@ -46,7 +46,10 @@ def format_error_message(exception_message, task_exception=False):
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def push_error_to_driver(redis_client, error_type, message, driver_id=None,
|
||||
def push_error_to_driver(redis_client,
|
||||
error_type,
|
||||
message,
|
||||
driver_id=None,
|
||||
data=None):
|
||||
"""Push an error message to the driver to be printed in the background.
|
||||
|
||||
@@ -64,9 +67,11 @@ def push_error_to_driver(redis_client, error_type, message, driver_id=None,
|
||||
driver_id = DRIVER_ID_LENGTH * b"\x00"
|
||||
error_key = ERROR_KEY_PREFIX + driver_id + b":" + _random_string()
|
||||
data = {} if data is None else data
|
||||
redis_client.hmset(error_key, {"type": error_type,
|
||||
"message": message,
|
||||
"data": data})
|
||||
redis_client.hmset(error_key, {
|
||||
"type": error_type,
|
||||
"message": message,
|
||||
"data": data
|
||||
})
|
||||
redis_client.rpush("ErrorKeys", error_key)
|
||||
|
||||
|
||||
@@ -134,10 +139,8 @@ def hex_to_binary(hex_identifier):
|
||||
return binascii.unhexlify(hex_identifier)
|
||||
|
||||
|
||||
FunctionProperties = collections.namedtuple("FunctionProperties",
|
||||
["num_return_vals",
|
||||
"resources",
|
||||
"max_calls"])
|
||||
FunctionProperties = collections.namedtuple(
|
||||
"FunctionProperties", ["num_return_vals", "resources", "max_calls"])
|
||||
"""FunctionProperties: A named tuple storing remote functions information."""
|
||||
|
||||
|
||||
|
||||
+398
-338
File diff suppressed because it is too large
Load Diff
@@ -8,34 +8,51 @@ import traceback
|
||||
import ray
|
||||
import ray.actor
|
||||
|
||||
parser = argparse.ArgumentParser(description=("Parse addresses for the worker "
|
||||
"to connect to."))
|
||||
parser.add_argument("--node-ip-address", required=True, type=str,
|
||||
help="the ip address of the worker's node")
|
||||
parser.add_argument("--redis-address", required=True, type=str,
|
||||
help="the address to use for Redis")
|
||||
parser.add_argument("--object-store-name", required=True, type=str,
|
||||
help="the object store's name")
|
||||
parser.add_argument("--object-store-manager-name", required=False, type=str,
|
||||
help="the object store manager's name")
|
||||
parser.add_argument("--local-scheduler-name", required=False, type=str,
|
||||
help="the local scheduler's name")
|
||||
parser.add_argument("--raylet-name", required=False, type=str,
|
||||
help="the raylet's name")
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description=("Parse addresses for the worker "
|
||||
"to connect to."))
|
||||
parser.add_argument(
|
||||
"--node-ip-address",
|
||||
required=True,
|
||||
type=str,
|
||||
help="the ip address of the worker's node")
|
||||
parser.add_argument(
|
||||
"--redis-address",
|
||||
required=True,
|
||||
type=str,
|
||||
help="the address to use for Redis")
|
||||
parser.add_argument(
|
||||
"--object-store-name",
|
||||
required=True,
|
||||
type=str,
|
||||
help="the object store's name")
|
||||
parser.add_argument(
|
||||
"--object-store-manager-name",
|
||||
required=False,
|
||||
type=str,
|
||||
help="the object store manager's name")
|
||||
parser.add_argument(
|
||||
"--local-scheduler-name",
|
||||
required=False,
|
||||
type=str,
|
||||
help="the local scheduler's name")
|
||||
parser.add_argument(
|
||||
"--raylet-name", required=False, type=str, help="the raylet's name")
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
info = {"node_ip_address": args.node_ip_address,
|
||||
"redis_address": args.redis_address,
|
||||
"store_socket_name": args.object_store_name,
|
||||
"manager_socket_name": args.object_store_manager_name,
|
||||
"local_scheduler_socket_name": args.local_scheduler_name,
|
||||
"raylet_socket_name": args.raylet_name}
|
||||
info = {
|
||||
"node_ip_address": args.node_ip_address,
|
||||
"redis_address": args.redis_address,
|
||||
"store_socket_name": args.object_store_name,
|
||||
"manager_socket_name": args.object_store_manager_name,
|
||||
"local_scheduler_socket_name": args.local_scheduler_name,
|
||||
"raylet_socket_name": args.raylet_name
|
||||
}
|
||||
|
||||
ray.worker.connect(info, mode=ray.WORKER_MODE,
|
||||
use_raylet=(args.raylet_name is not None))
|
||||
ray.worker.connect(
|
||||
info, mode=ray.WORKER_MODE, use_raylet=(args.raylet_name is not None))
|
||||
|
||||
error_explanation = """
|
||||
This error is unexpected and should not have happened. Somehow a worker
|
||||
@@ -54,8 +71,8 @@ if __name__ == "__main__":
|
||||
traceback_str = traceback.format_exc() + error_explanation
|
||||
# Create a Redis client.
|
||||
redis_client = ray.services.create_redis_client(args.redis_address)
|
||||
ray.utils.push_error_to_driver(redis_client, "worker_crash",
|
||||
traceback_str, driver_id=None)
|
||||
ray.utils.push_error_to_driver(
|
||||
redis_client, "worker_crash", traceback_str, driver_id=None)
|
||||
# TODO(rkn): Note that if the worker was in the middle of executing
|
||||
# a task, then any worker or driver that is blocking in a get call
|
||||
# and waiting for the output of that task will hang. We need to
|
||||
|
||||
+42
-41
@@ -18,13 +18,11 @@ import setuptools.command.build_ext as _build_ext
|
||||
ray_files = [
|
||||
"ray/core/src/common/thirdparty/redis/src/redis-server",
|
||||
"ray/core/src/common/redis_module/libray_redis_module.so",
|
||||
"ray/core/src/plasma/plasma_store",
|
||||
"ray/core/src/plasma/plasma_manager",
|
||||
"ray/core/src/plasma/plasma_store", "ray/core/src/plasma/plasma_manager",
|
||||
"ray/core/src/local_scheduler/local_scheduler",
|
||||
"ray/core/src/local_scheduler/liblocal_scheduler_library.so",
|
||||
"ray/core/src/global_scheduler/global_scheduler",
|
||||
"ray/core/src/ray/raylet/raylet_monitor",
|
||||
"ray/core/src/ray/raylet/raylet",
|
||||
"ray/core/src/ray/raylet/raylet_monitor", "ray/core/src/ray/raylet/raylet",
|
||||
"ray/WebUI.ipynb"
|
||||
]
|
||||
|
||||
@@ -35,14 +33,14 @@ ray_ui_files = [
|
||||
"ray/core/src/catapult_files/trace_viewer_full.html"
|
||||
]
|
||||
|
||||
ray_autoscaler_files = [
|
||||
"ray/autoscaler/aws/example-full.yaml"
|
||||
]
|
||||
ray_autoscaler_files = ["ray/autoscaler/aws/example-full.yaml"]
|
||||
|
||||
if "RAY_USE_NEW_GCS" in os.environ and os.environ["RAY_USE_NEW_GCS"] == "on":
|
||||
ray_files += ["ray/core/src/credis/build/src/libmember.so",
|
||||
"ray/core/src/credis/build/src/libmaster.so",
|
||||
"ray/core/src/credis/redis/src/redis-server"]
|
||||
ray_files += [
|
||||
"ray/core/src/credis/build/src/libmember.so",
|
||||
"ray/core/src/credis/build/src/libmaster.so",
|
||||
"ray/core/src/credis/redis/src/redis-server"
|
||||
]
|
||||
|
||||
# The UI files are mandatory if the INCLUDE_UI environment variable equals 1.
|
||||
# Otherwise, they are optional.
|
||||
@@ -54,9 +52,8 @@ else:
|
||||
optional_ray_files += ray_autoscaler_files
|
||||
|
||||
extras = {
|
||||
"rllib": [
|
||||
"tensorflow", "pyyaml", "gym[atari]", "opencv-python",
|
||||
"lz4", "scipy"]
|
||||
"rllib":
|
||||
["tensorflow", "pyyaml", "gym[atari]", "opencv-python", "lz4", "scipy"]
|
||||
}
|
||||
|
||||
|
||||
@@ -73,8 +70,9 @@ class build_ext(_build_ext.build_ext):
|
||||
pyarrow_files = [
|
||||
os.path.join("ray/pyarrow_files/pyarrow", filename)
|
||||
for filename in os.listdir("./ray/pyarrow_files/pyarrow")
|
||||
if not os.path.isdir(os.path.join("ray/pyarrow_files/pyarrow",
|
||||
filename))]
|
||||
if not os.path.isdir(
|
||||
os.path.join("ray/pyarrow_files/pyarrow", filename))
|
||||
]
|
||||
|
||||
files_to_include = ray_files + pyarrow_files
|
||||
|
||||
@@ -84,8 +82,8 @@ class build_ext(_build_ext.build_ext):
|
||||
generated_python_directory = "ray/core/generated"
|
||||
for filename in os.listdir(generated_python_directory):
|
||||
if filename[-3:] == ".py":
|
||||
self.move_file(os.path.join(generated_python_directory,
|
||||
filename))
|
||||
self.move_file(
|
||||
os.path.join(generated_python_directory, filename))
|
||||
|
||||
# Try to copy over the optional files.
|
||||
for filename in optional_ray_files:
|
||||
@@ -114,27 +112,30 @@ class BinaryDistribution(Distribution):
|
||||
return True
|
||||
|
||||
|
||||
setup(name="ray",
|
||||
# The version string is also in __init__.py. TODO(pcm): Fix this.
|
||||
version="0.4.0",
|
||||
packages=find_packages(),
|
||||
cmdclass={"build_ext": build_ext},
|
||||
# The BinaryDistribution argument triggers build_ext.
|
||||
distclass=BinaryDistribution,
|
||||
install_requires=["numpy",
|
||||
"funcsigs",
|
||||
"click",
|
||||
"colorama",
|
||||
"psutil",
|
||||
"pytest",
|
||||
"pyyaml",
|
||||
"redis",
|
||||
# The six module is required by pyarrow.
|
||||
"six >= 1.0.0",
|
||||
"flatbuffers"],
|
||||
setup_requires=["cython >= 0.23"],
|
||||
extras_require=extras,
|
||||
entry_points={"console_scripts": ["ray=ray.scripts.scripts:main"]},
|
||||
include_package_data=True,
|
||||
zip_safe=False,
|
||||
license="Apache 2.0")
|
||||
setup(
|
||||
name="ray",
|
||||
# The version string is also in __init__.py. TODO(pcm): Fix this.
|
||||
version="0.4.0",
|
||||
packages=find_packages(),
|
||||
cmdclass={"build_ext": build_ext},
|
||||
# The BinaryDistribution argument triggers build_ext.
|
||||
distclass=BinaryDistribution,
|
||||
install_requires=[
|
||||
"numpy",
|
||||
"funcsigs",
|
||||
"click",
|
||||
"colorama",
|
||||
"psutil",
|
||||
"pytest",
|
||||
"pyyaml",
|
||||
"redis",
|
||||
# The six module is required by pyarrow.
|
||||
"six >= 1.0.0",
|
||||
"flatbuffers"
|
||||
],
|
||||
setup_requires=["cython >= 0.23"],
|
||||
extras_require=extras,
|
||||
entry_points={"console_scripts": ["ray=ray.scripts.scripts:main"]},
|
||||
include_package_data=True,
|
||||
zip_safe=False,
|
||||
license="Apache 2.0")
|
||||
|
||||
Reference in New Issue
Block a user