Lint Python files with Yapf (#1872)

This commit is contained in:
Philipp Moritz
2018-04-11 10:11:35 -07:00
committed by Robert Nishihara
parent a3ddde398c
commit 74162d1492
97 changed files with 3927 additions and 3139 deletions
+21 -19
View File
@@ -12,8 +12,8 @@ if "pyarrow" in sys.modules:
# Add the directory containing pyarrow to the Python path so that we find the
# pyarrow version packaged with ray and not a pre-existing pyarrow.
pyarrow_path = os.path.join(os.path.abspath(os.path.dirname(__file__)),
"pyarrow_files")
pyarrow_path = os.path.join(
os.path.abspath(os.path.dirname(__file__)), "pyarrow_files")
sys.path.insert(0, pyarrow_path)
# See https://github.com/ray-project/ray/issues/131.
@@ -27,29 +27,29 @@ If you are using Anaconda, try fixing this problem by running:
try:
import pyarrow # noqa: F401
except ImportError as e:
if ((hasattr(e, "msg") and isinstance(e.msg, str) and
("libstdc++" in e.msg or "CXX" in e.msg))):
if ((hasattr(e, "msg") and isinstance(e.msg, str)
and ("libstdc++" in e.msg or "CXX" in e.msg))):
# This code path should be taken with Python 3.
e.msg += helpful_message
elif (hasattr(e, "message") and isinstance(e.message, str) and
("libstdc++" in e.message or "CXX" in e.message)):
elif (hasattr(e, "message") and isinstance(e.message, str)
and ("libstdc++" in e.message or "CXX" in e.message)):
# This code path should be taken with Python 2.
condition = (hasattr(e, "args") and isinstance(e.args, tuple) and
len(e.args) == 1 and isinstance(e.args[0], str))
condition = (hasattr(e, "args") and isinstance(e.args, tuple)
and len(e.args) == 1 and isinstance(e.args[0], str))
if condition:
e.args = (e.args[0] + helpful_message,)
e.args = (e.args[0] + helpful_message, )
else:
if not hasattr(e, "args"):
e.args = ()
elif not isinstance(e.args, tuple):
e.args = (e.args,)
e.args += (helpful_message,)
e.args = (e.args, )
e.args += (helpful_message, )
raise
from ray.local_scheduler import _config # noqa: E402
from ray.worker import (error_info, init, connect, disconnect,
get, put, wait, remote, log_event, log_span,
flush_log, get_gpu_ids, get_webui_url,
from ray.worker import (error_info, init, connect, disconnect, get, put, wait,
remote, log_event, log_span, flush_log, get_gpu_ids,
get_webui_url,
register_custom_serializer) # noqa: E402
from ray.worker import (SCRIPT_MODE, WORKER_MODE, PYTHON_MODE,
SILENT_MODE) # noqa: E402
@@ -63,11 +63,13 @@ from ray.actor import method # noqa: E402
# Fix this.
__version__ = "0.4.0"
__all__ = ["error_info", "init", "connect", "disconnect", "get", "put", "wait",
"remote", "log_event", "log_span", "flush_log", "actor", "method",
"get_gpu_ids", "get_webui_url", "register_custom_serializer",
"SCRIPT_MODE", "WORKER_MODE", "PYTHON_MODE", "SILENT_MODE",
"global_state", "_config", "__version__"]
__all__ = [
"error_info", "init", "connect", "disconnect", "get", "put", "wait",
"remote", "log_event", "log_span", "flush_log", "actor", "method",
"get_gpu_ids", "get_webui_url", "register_custom_serializer",
"SCRIPT_MODE", "WORKER_MODE", "PYTHON_MODE", "SILENT_MODE", "global_state",
"_config", "__version__"
]
import ctypes # noqa: E402
# Windows only
+114 -95
View File
@@ -121,16 +121,17 @@ def save_and_log_checkpoint(worker, actor):
try:
actor.__ray_checkpoint__()
except Exception:
traceback_str = ray.utils.format_error_message(
traceback.format_exc())
traceback_str = ray.utils.format_error_message(traceback.format_exc())
# Log the error message.
ray.utils.push_error_to_driver(
worker.redis_client,
"checkpoint",
traceback_str,
driver_id=worker.task_driver_id.id(),
data={"actor_class": actor.__class__.__name__,
"function_name": actor.__ray_checkpoint__.__name__})
data={
"actor_class": actor.__class__.__name__,
"function_name": actor.__ray_checkpoint__.__name__
})
def restore_and_log_checkpoint(worker, actor):
@@ -144,8 +145,7 @@ def restore_and_log_checkpoint(worker, actor):
try:
checkpoint_resumed = actor.__ray_checkpoint_restore__()
except Exception:
traceback_str = ray.utils.format_error_message(
traceback.format_exc())
traceback_str = ray.utils.format_error_message(traceback.format_exc())
# Log the error message.
ray.utils.push_error_to_driver(
worker.redis_client,
@@ -154,8 +154,8 @@ def restore_and_log_checkpoint(worker, actor):
driver_id=worker.task_driver_id.id(),
data={
"actor_class": actor.__class__.__name__,
"function_name":
actor.__ray_checkpoint_restore__.__name__})
"function_name": actor.__ray_checkpoint_restore__.__name__
})
return checkpoint_resumed
@@ -197,15 +197,15 @@ def make_actor_method_executor(worker, method_name, method, actor_imported):
return
# Determine whether we should checkpoint the actor.
checkpointing_on = (actor_imported and
worker.actor_checkpoint_interval > 0)
checkpointing_on = (actor_imported
and worker.actor_checkpoint_interval > 0)
# We should checkpoint the actor if user checkpointing is on, we've
# executed checkpoint_interval tasks since the last checkpoint, and the
# method we're about to execute is not a checkpoint.
save_checkpoint = (checkpointing_on and
(worker.actor_task_counter %
worker.actor_checkpoint_interval == 0 and
method_name != "__ray_checkpoint__"))
save_checkpoint = (
checkpointing_on and
(worker.actor_task_counter % worker.actor_checkpoint_interval == 0
and method_name != "__ray_checkpoint__"))
# Execute the assigned method and save a checkpoint if necessary.
try:
@@ -238,14 +238,14 @@ def fetch_and_register_actor(actor_class_key, resources, worker):
worker: The worker to use.
"""
actor_id_str = worker.actor_id
(driver_id, class_id, class_name,
module, pickled_class, checkpoint_interval,
actor_method_names,
(driver_id, class_id, class_name, module, pickled_class,
checkpoint_interval, actor_method_names,
actor_method_num_return_vals) = worker.redis_client.hmget(
actor_class_key, ["driver_id", "class_id", "class_name", "module",
"class", "checkpoint_interval",
"actor_method_names",
"actor_method_num_return_vals"])
actor_class_key, [
"driver_id", "class_id", "class_name", "module", "class",
"checkpoint_interval", "actor_method_names",
"actor_method_num_return_vals"
])
actor_name = class_name.decode("ascii")
module = module.decode("ascii")
@@ -259,12 +259,14 @@ def fetch_and_register_actor(actor_class_key, resources, worker):
# error messages and to prevent the driver from hanging).
class TemporaryActor(object):
pass
worker.actors[actor_id_str] = TemporaryActor()
worker.actor_checkpoint_interval = checkpoint_interval
def temporary_actor_method(*xs):
raise Exception("The actor with name {} failed to be imported, and so "
"cannot execute this method".format(actor_name))
# Register the actor method signatures.
register_actor_signatures(worker, driver_id, class_id, class_name,
actor_method_names, actor_method_num_return_vals)
@@ -272,10 +274,11 @@ def fetch_and_register_actor(actor_class_key, resources, worker):
for actor_method_name in actor_method_names:
function_id = compute_actor_method_function_id(class_name,
actor_method_name).id()
temporary_executor = make_actor_method_executor(worker,
actor_method_name,
temporary_actor_method,
actor_imported=False)
temporary_executor = make_actor_method_executor(
worker,
actor_method_name,
temporary_actor_method,
actor_imported=False)
worker.functions[driver_id][function_id] = (actor_method_name,
temporary_executor)
worker.num_task_executions[driver_id][function_id] = 0
@@ -288,9 +291,12 @@ def fetch_and_register_actor(actor_class_key, resources, worker):
# traceback and notify the scheduler of the failure.
traceback_str = ray.utils.format_error_message(traceback.format_exc())
# Log the error message.
push_error_to_driver(worker.redis_client, "register_actor_signatures",
traceback_str, driver_id,
data={"actor_id": actor_id_str})
push_error_to_driver(
worker.redis_client,
"register_actor_signatures",
traceback_str,
driver_id,
data={"actor_id": actor_id_str})
# TODO(rkn): In the future, it might make sense to have the worker exit
# here. However, currently that would lead to hanging if someone calls
# ray.get on a method invoked on the actor.
@@ -298,16 +304,17 @@ def fetch_and_register_actor(actor_class_key, resources, worker):
# TODO(pcm): Why is the below line necessary?
unpickled_class.__module__ = module
worker.actors[actor_id_str] = unpickled_class.__new__(unpickled_class)
actor_methods = inspect.getmembers(
unpickled_class, predicate=(lambda x: (inspect.isfunction(x) or
inspect.ismethod(x) or
is_cython(x))))
def pred(x):
return (inspect.isfunction(x) or inspect.ismethod(x)
or is_cython(x))
actor_methods = inspect.getmembers(unpickled_class, predicate=pred)
for actor_method_name, actor_method in actor_methods:
function_id = compute_actor_method_function_id(
class_name, actor_method_name).id()
executor = make_actor_method_executor(worker, actor_method_name,
actor_method,
actor_imported=True)
executor = make_actor_method_executor(
worker, actor_method_name, actor_method, actor_imported=True)
worker.functions[driver_id][function_id] = (actor_method_name,
executor)
# We do not set worker.function_properties[driver_id][function_id]
@@ -315,7 +322,10 @@ def fetch_and_register_actor(actor_class_key, resources, worker):
# for the actor.
def register_actor_signatures(worker, driver_id, class_id, class_name,
def register_actor_signatures(worker,
driver_id,
class_id,
class_name,
actor_method_names,
actor_method_num_return_vals,
actor_creation_resources=None,
@@ -346,18 +356,20 @@ def register_actor_signatures(worker, driver_id, class_id, class_name,
# The extra return value is an actor dummy object.
# In the cases where actor_method_cpus is None, that value should
# never be used.
FunctionProperties(num_return_vals=num_return_vals + 1,
resources={"CPU": actor_method_cpus},
max_calls=0))
FunctionProperties(
num_return_vals=num_return_vals + 1,
resources={"CPU": actor_method_cpus},
max_calls=0))
if actor_creation_resources is not None:
# Also register the actor creation task.
function_id = compute_actor_creation_function_id(class_id)
worker.function_properties[driver_id][function_id.id()] = (
# The extra return value is an actor dummy object.
FunctionProperties(num_return_vals=0 + 1,
resources=actor_creation_resources,
max_calls=0))
FunctionProperties(
num_return_vals=0 + 1,
resources=actor_creation_resources,
max_calls=0))
def publish_actor_class_to_key(key, actor_class_info, worker):
@@ -380,8 +392,8 @@ def publish_actor_class_to_key(key, actor_class_info, worker):
def export_actor_class(class_id, Class, actor_method_names,
actor_method_num_return_vals,
checkpoint_interval, worker):
actor_method_num_return_vals, checkpoint_interval,
worker):
key = b"ActorClass:" + class_id
actor_class_info = {
"class_name": Class.__name__,
@@ -389,8 +401,9 @@ def export_actor_class(class_id, Class, actor_method_names,
"class": pickle.dumps(Class),
"checkpoint_interval": checkpoint_interval,
"actor_method_names": json.dumps(list(actor_method_names)),
"actor_method_num_return_vals": json.dumps(
actor_method_num_return_vals)}
"actor_method_num_return_vals":
json.dumps(actor_method_num_return_vals)
}
if worker.mode is None:
# This means that 'ray.init()' has not been called yet and so we must
@@ -433,7 +446,11 @@ def export_actor(actor_id, class_id, class_name, actor_method_names,
driver_id = worker.task_driver_id.id()
register_actor_signatures(
worker, driver_id, class_id, class_name, actor_method_names,
worker,
driver_id,
class_id,
class_name,
actor_method_names,
actor_method_num_return_vals,
actor_creation_resources=actor_creation_resources,
actor_method_cpus=actor_method_cpus)
@@ -466,12 +483,14 @@ class ActorMethod(object):
def __call__(self, *args, **kwargs):
raise Exception("Actor methods cannot be called directly. Instead "
"of running 'object.{}()', try "
"'object.{}.remote()'."
.format(self._method_name, self._method_name))
"'object.{}.remote()'.".format(self._method_name,
self._method_name))
def remote(self, *args, **kwargs):
return self._actor._actor_method_call(
self._method_name, args=args, kwargs=kwargs,
self._method_name,
args=args,
kwargs=kwargs,
dependency=self._actor._ray_actor_cursor)
@@ -481,12 +500,13 @@ class ActorHandleWrapper(object):
This is essentially just a dictionary, but it is used so that the recipient
can tell that an argument is an ActorHandle.
"""
def __init__(self, actor_id, class_id, actor_handle_id, actor_cursor,
actor_counter, actor_method_names,
actor_method_num_return_vals, method_signatures,
checkpoint_interval, class_name,
actor_creation_dummy_object_id,
actor_creation_resources, actor_method_cpus):
actor_creation_dummy_object_id, actor_creation_resources,
actor_method_cpus):
# TODO(rkn): Some of these fields are probably not necessary. We should
# strip out the unnecessary fields to keep actor handles lightweight.
self.actor_id = actor_id
@@ -545,27 +565,20 @@ def unwrap_actor_handle(worker, wrapper):
The unwrapped ActorHandle instance.
"""
driver_id = worker.task_driver_id.id()
register_actor_signatures(worker, driver_id, wrapper.class_id,
wrapper.class_name, wrapper.actor_method_names,
wrapper.actor_method_num_return_vals,
wrapper.actor_creation_resources,
wrapper.actor_method_cpus)
register_actor_signatures(
worker, driver_id, wrapper.class_id, wrapper.class_name,
wrapper.actor_method_names, wrapper.actor_method_num_return_vals,
wrapper.actor_creation_resources, wrapper.actor_method_cpus)
actor_handle_class = make_actor_handle_class(wrapper.class_name)
actor_object = actor_handle_class.__new__(actor_handle_class)
actor_object._manual_init(
wrapper.actor_id,
wrapper.class_id,
wrapper.actor_handle_id,
wrapper.actor_cursor,
wrapper.actor_counter,
wrapper.actor_method_names,
wrapper.actor_method_num_return_vals,
wrapper.method_signatures,
wrapper.checkpoint_interval,
wrapper.actor_id, wrapper.class_id, wrapper.actor_handle_id,
wrapper.actor_cursor, wrapper.actor_counter,
wrapper.actor_method_names, wrapper.actor_method_num_return_vals,
wrapper.method_signatures, wrapper.checkpoint_interval,
wrapper.actor_creation_dummy_object_id,
wrapper.actor_creation_resources,
wrapper.actor_method_cpus)
wrapper.actor_creation_resources, wrapper.actor_method_cpus)
return actor_object
@@ -612,7 +625,10 @@ def make_actor_handle_class(class_name):
self._ray_actor_creation_resources = actor_creation_resources
self._ray_actor_method_cpus = actor_method_cpus
def _actor_method_call(self, method_name, args=None, kwargs=None,
def _actor_method_call(self,
method_name,
args=None,
kwargs=None,
dependency=None):
"""Method execution stub for an actor handle.
@@ -663,7 +679,9 @@ def make_actor_handle_class(class_name):
function_id = compute_actor_method_function_id(
self._ray_class_name, method_name)
object_ids = ray.worker.global_worker.submit_task(
function_id, args, actor_id=self._ray_actor_id,
function_id,
args,
actor_id=self._ray_actor_id,
actor_handle_id=self._ray_actor_handle_id,
actor_counter=self._ray_actor_counter,
is_actor_checkpoint_method=is_actor_checkpoint_method,
@@ -722,8 +740,8 @@ def make_actor_handle_class(class_name):
self._ray_actor_handle_id.id() == ray.worker.NIL_ACTOR_ID):
# TODO(rkn): Should we be passing in the actor cursor as a
# dependency here?
self._actor_method_call("__ray_terminate__",
args=[self._ray_actor_id.id()])
self._actor_method_call(
"__ray_terminate__", args=[self._ray_actor_id.id()])
return ActorHandle
@@ -735,7 +753,6 @@ def actor_handle_from_class(Class, class_id, actor_creation_resources,
exported = []
class ActorHandle(actor_handle_class):
@classmethod
def remote(cls, *args, **kwargs):
if ray.worker.global_worker.mode is None:
@@ -754,11 +771,13 @@ def actor_handle_from_class(Class, class_id, actor_creation_resources,
actor_cursor = None
# The number of actor method invocations that we've called so far.
actor_counter = 0
# Get the actor methods of the given class.
actor_methods = inspect.getmembers(
Class, predicate=(lambda x: (inspect.isfunction(x) or
inspect.ismethod(x) or
is_cython(x))))
def pred(x):
return (inspect.isfunction(x) or inspect.ismethod(x)
or is_cython(x))
actor_methods = inspect.getmembers(Class, predicate=pred)
# Extract the signatures of each of the methods. This will be used
# to catch some errors if the methods are called with inappropriate
# arguments.
@@ -773,8 +792,9 @@ def actor_handle_from_class(Class, class_id, actor_creation_resources,
method_signatures[k] = signature.extract_signature(
v, ignore_first=True)
actor_method_names = [method_name for method_name, _ in
actor_methods]
actor_method_names = [
method_name for method_name, _ in actor_methods
]
actor_method_num_return_vals = []
for _, method in actor_methods:
if hasattr(method, "__ray_num_return_vals__"):
@@ -796,30 +816,29 @@ def actor_handle_from_class(Class, class_id, actor_creation_resources,
checkpoint_interval,
ray.worker.global_worker)
exported.append(0)
actor_cursor = export_actor(actor_id, class_id, class_name,
actor_method_names,
actor_method_num_return_vals,
actor_creation_resources,
actor_method_cpus,
ray.worker.global_worker)
actor_cursor = export_actor(
actor_id, class_id, class_name, actor_method_names,
actor_method_num_return_vals, actor_creation_resources,
actor_method_cpus, ray.worker.global_worker)
# Increment the actor counter to account for the creation task.
actor_counter += 1
# Instantiate the actor handle.
actor_object = cls.__new__(cls)
actor_object._manual_init(actor_id, class_id, actor_handle_id,
actor_cursor, actor_counter,
actor_method_names,
actor_method_num_return_vals,
method_signatures, checkpoint_interval,
actor_cursor, actor_creation_resources,
actor_method_cpus)
actor_object._manual_init(
actor_id, class_id, actor_handle_id, actor_cursor,
actor_counter, actor_method_names,
actor_method_num_return_vals, method_signatures,
checkpoint_interval, actor_cursor, actor_creation_resources,
actor_method_cpus)
# Call __init__ as a remote function.
if "__init__" in actor_object._ray_actor_method_names:
actor_object._actor_method_call("__init__", args=args,
kwargs=kwargs,
dependency=actor_cursor)
actor_object._actor_method_call(
"__init__",
args=args,
kwargs=kwargs,
dependency=actor_cursor)
else:
print("WARNING: this object has no __init__ method.")
+80 -79
View File
@@ -51,25 +51,32 @@ CLUSTER_CONFIG_SCHEMA = {
"idle_timeout_minutes": (int, OPTIONAL),
# Cloud-provider specific configuration.
"provider": ({
"type": (str, REQUIRED), # e.g. aws
"region": (str, OPTIONAL), # e.g. us-east-1
"availability_zone": (str, OPTIONAL), # e.g. us-east-1a
"module": (str, OPTIONAL), # module, if using external node provider
}, REQUIRED),
"provider": (
{
"type": (str, REQUIRED), # e.g. aws
"region": (str, OPTIONAL), # e.g. us-east-1
"availability_zone": (str, OPTIONAL), # e.g. us-east-1a
"module": (str,
OPTIONAL), # module, if using external node provider
},
REQUIRED),
# How Ray will authenticate with newly launched nodes.
"auth": ({
"ssh_user": (str, REQUIRED), # e.g. ubuntu
"ssh_private_key": (str, OPTIONAL),
}, REQUIRED),
"auth": (
{
"ssh_user": (str, REQUIRED), # e.g. ubuntu
"ssh_private_key": (str, OPTIONAL),
},
REQUIRED),
# Docker configuration. If this is specified, all setup and start commands
# will be executed in the container.
"docker": ({
"image": (str, OPTIONAL), # e.g. tensorflow/tensorflow:1.5.0-py3
"container_name": (str, OPTIONAL), # e.g., ray_docker
}, OPTIONAL),
"docker": (
{
"image": (str, OPTIONAL), # e.g. tensorflow/tensorflow:1.5.0-py3
"container_name": (str, OPTIONAL), # e.g., ray_docker
},
OPTIONAL),
# Provider-specific config for the head node, e.g. instance type.
"head_node": (dict, OPTIONAL),
@@ -137,9 +144,9 @@ class LoadMetrics(object):
for unwanted_key in unwanted:
del mapping[unwanted_key]
if unwanted:
print(
"Removed {} stale ip mappings: {} not in {}".format(
len(unwanted), unwanted, active_ips))
print("Removed {} stale ip mappings: {} not in {}".format(
len(unwanted), unwanted, active_ips))
prune(self.last_used_time_by_ip)
prune(self.static_resources_by_ip)
prune(self.dynamic_resources_by_ip)
@@ -148,10 +155,8 @@ class LoadMetrics(object):
return self._info()["NumNodesUsed"]
def debug_string(self):
return " - {}".format(
"\n - ".join(
["{}: {}".format(k, v)
for k, v in sorted(self._info().items())]))
return " - {}".format("\n - ".join(
["{}: {}".format(k, v) for k, v in sorted(self._info().items())]))
def _info(self):
nodes_used = 0.0
@@ -176,14 +181,19 @@ class LoadMetrics(object):
nodes_used += max_frac
idle_times = [now - t for t in self.last_used_time_by_ip.values()]
return {
"ResourceUsage": ", ".join([
"ResourceUsage":
", ".join([
"{}/{} {}".format(
round(resources_used[rid], 2),
round(resources_total[rid], 2), rid)
for rid in sorted(resources_used)]),
"NumNodesConnected": len(self.static_resources_by_ip),
"NumNodesUsed": round(nodes_used, 2),
"NodeIdleSeconds": "Min={} Mean={} Max={}".format(
for rid in sorted(resources_used)
]),
"NumNodesConnected":
len(self.static_resources_by_ip),
"NumNodesUsed":
round(nodes_used, 2),
"NodeIdleSeconds":
"Min={} Mean={} Max={}".format(
int(np.min(idle_times)) if idle_times else -1,
int(np.mean(idle_times)) if idle_times else -1,
int(np.max(idle_times)) if idle_times else -1),
@@ -208,18 +218,20 @@ class StandardAutoscaler(object):
until the target cluster size is met).
"""
def __init__(
self, config_path, load_metrics,
max_concurrent_launches=AUTOSCALER_MAX_CONCURRENT_LAUNCHES,
max_failures=AUTOSCALER_MAX_NUM_FAILURES,
process_runner=subprocess, verbose_updates=False,
node_updater_cls=NodeUpdaterProcess,
update_interval_s=AUTOSCALER_UPDATE_INTERVAL_S):
def __init__(self,
config_path,
load_metrics,
max_concurrent_launches=AUTOSCALER_MAX_CONCURRENT_LAUNCHES,
max_failures=AUTOSCALER_MAX_NUM_FAILURES,
process_runner=subprocess,
verbose_updates=False,
node_updater_cls=NodeUpdaterProcess,
update_interval_s=AUTOSCALER_UPDATE_INTERVAL_S):
self.config_path = config_path
self.reload_config(errors_fatal=True)
self.load_metrics = load_metrics
self.provider = get_node_provider(
self.config["provider"], self.config["cluster_name"])
self.provider = get_node_provider(self.config["provider"],
self.config["cluster_name"])
self.max_failures = max_failures
self.max_concurrent_launches = max_concurrent_launches
@@ -245,9 +257,8 @@ class StandardAutoscaler(object):
self.reload_config(errors_fatal=False)
self._update()
except Exception as e:
print(
"StandardAutoscaler: Error during autoscaling: {}",
traceback.format_exc())
print("StandardAutoscaler: Error during autoscaling: {}",
traceback.format_exc())
self.num_failures += 1
if self.num_failures > self.max_failures:
print("*** StandardAutoscaler: Too many errors, abort. ***")
@@ -274,15 +285,13 @@ class StandardAutoscaler(object):
if node_ip in last_used and last_used[node_ip] < horizon and \
len(nodes) - num_terminated > self.config["min_workers"]:
num_terminated += 1
print(
"StandardAutoscaler: Terminating idle node: "
"{}".format(node_id))
print("StandardAutoscaler: Terminating idle node: "
"{}".format(node_id))
self.provider.terminate_node(node_id)
elif not self.launch_config_ok(node_id):
num_terminated += 1
print(
"StandardAutoscaler: Terminating outdated node: "
"{}".format(node_id))
print("StandardAutoscaler: Terminating outdated node: "
"{}".format(node_id))
self.provider.terminate_node(node_id)
if num_terminated > 0:
nodes = self.workers()
@@ -292,9 +301,8 @@ class StandardAutoscaler(object):
num_terminated = 0
while len(nodes) > self.config["max_workers"]:
num_terminated += 1
print(
"StandardAutoscaler: Terminating unneeded node: "
"{}".format(nodes[-1]))
print("StandardAutoscaler: Terminating unneeded node: "
"{}".format(nodes[-1]))
self.provider.terminate_node(nodes[-1])
nodes = nodes[:-1]
if num_terminated > 0:
@@ -339,13 +347,13 @@ class StandardAutoscaler(object):
with open(self.config_path) as f:
new_config = yaml.load(f.read())
validate_config(new_config)
new_launch_hash = hash_launch_conf(
new_config["worker_nodes"], new_config["auth"])
new_runtime_hash = hash_runtime_conf(
new_config["file_mounts"],
[new_config["setup_commands"],
new_config["worker_setup_commands"],
new_config["worker_start_ray_commands"]])
new_launch_hash = hash_launch_conf(new_config["worker_nodes"],
new_config["auth"])
new_runtime_hash = hash_runtime_conf(new_config["file_mounts"], [
new_config["setup_commands"],
new_config["worker_setup_commands"],
new_config["worker_start_ray_commands"]
])
self.config = new_config
self.launch_hash = new_launch_hash
self.runtime_hash = new_runtime_hash
@@ -353,17 +361,15 @@ class StandardAutoscaler(object):
if errors_fatal:
raise e
else:
print(
"StandardAutoscaler: Error parsing config: {}",
traceback.format_exc())
print("StandardAutoscaler: Error parsing config: {}",
traceback.format_exc())
def target_num_workers(self):
target_frac = self.config["target_utilization_fraction"]
cur_used = self.load_metrics.approx_workers_used()
ideal_num_workers = int(np.ceil(cur_used / float(target_frac)))
return min(
self.config["max_workers"],
max(self.config["min_workers"], ideal_num_workers))
return min(self.config["max_workers"],
max(self.config["min_workers"], ideal_num_workers))
def launch_config_ok(self, node_id):
launch_conf = self.provider.node_tags(node_id).get(
@@ -393,8 +399,7 @@ class StandardAutoscaler(object):
node_id,
self.config["provider"],
self.config["auth"],
self.config["cluster_name"],
{},
self.config["cluster_name"], {},
with_head_node_ip(self.config["worker_start_ray_commands"]),
self.runtime_hash,
redirect_output=not self.verbose_updates,
@@ -409,14 +414,12 @@ class StandardAutoscaler(object):
return
if self.config.get("no_restart", False) and \
self.num_successful_updates.get(node_id, 0) > 0:
init_commands = (
self.config["setup_commands"] +
self.config["worker_setup_commands"])
init_commands = (self.config["setup_commands"] +
self.config["worker_setup_commands"])
else:
init_commands = (
self.config["setup_commands"] +
self.config["worker_setup_commands"] +
self.config["worker_start_ray_commands"])
init_commands = (self.config["setup_commands"] +
self.config["worker_setup_commands"] +
self.config["worker_start_ray_commands"])
updater = self.node_updater_cls(
node_id,
self.config["provider"],
@@ -445,14 +448,12 @@ class StandardAutoscaler(object):
print("StandardAutoscaler: Launching {} new nodes".format(count))
num_before = len(self.workers())
self.provider.create_node(
self.config["worker_nodes"],
{
self.config["worker_nodes"], {
TAG_NAME: "ray-{}-worker".format(self.config["cluster_name"]),
TAG_RAY_NODE_TYPE: "Worker",
TAG_RAY_NODE_STATUS: "Uninitialized",
TAG_RAY_LAUNCH_CONFIG: self.launch_hash,
},
count)
}, count)
# TODO(ekl) be less conservative in this check
assert len(self.workers()) > num_before, \
"Num nodes failed to increase after creating a new node"
@@ -472,8 +473,8 @@ class StandardAutoscaler(object):
suffix += " ({} failed to update)".format(
len(self.num_failed_updates))
return "StandardAutoscaler [{}]: {}/{} target nodes{}\n{}".format(
datetime.now(), len(nodes), self.target_num_workers(),
suffix, self.load_metrics.debug_string())
datetime.now(), len(nodes), self.target_num_workers(), suffix,
self.load_metrics.debug_string())
def typename(v):
@@ -507,9 +508,8 @@ def check_extraneous(config, schema):
raise ValueError("Config {} is not a dictionary".format(config))
for k in config:
if k not in schema:
raise ValueError(
"Unexpected config key `{}` not in {}".format(
k, list(schema.keys())))
raise ValueError("Unexpected config key `{}` not in {}".format(
k, list(schema.keys())))
v, kreq = schema[k]
if v is None:
continue
@@ -517,7 +517,8 @@ def check_extraneous(config, schema):
if not isinstance(config[k], v):
raise ValueError(
"Config key `{}` has wrong type {}, expected {}".format(
k, type(config[k]).__name__, v.__name__))
k,
type(config[k]).__name__, v.__name__))
else:
check_extraneous(config[k], v)
+42 -25
View File
@@ -25,12 +25,10 @@ assert StrictVersion(boto3.__version__) >= StrictVersion("1.4.8"), \
def key_pair(i, region):
"""Returns the ith default (aws_key_pair_name, key_pair_path)."""
if i == 0:
return (
"{}_{}".format(RAY, region),
os.path.expanduser("~/.ssh/{}_{}.pem".format(RAY, region)))
return (
"{}_{}_{}".format(RAY, i, region),
os.path.expanduser("~/.ssh/{}_{}_{}.pem".format(RAY, i, region)))
return ("{}_{}".format(RAY, region),
os.path.expanduser("~/.ssh/{}_{}.pem".format(RAY, region)))
return ("{}_{}_{}".format(RAY, i, region),
os.path.expanduser("~/.ssh/{}_{}_{}.pem".format(RAY, i, region)))
# Suppress excessive connection dropped logs from boto
@@ -83,7 +81,9 @@ def _configure_iam_role(config):
"Statement": [
{
"Effect": "Allow",
"Principal": {"Service": "ec2.amazonaws.com"},
"Principal": {
"Service": "ec2.amazonaws.com"
},
"Action": "sts:AssumeRole",
},
],
@@ -97,8 +97,7 @@ def _configure_iam_role(config):
profile.add_role(RoleName=role.name)
time.sleep(15) # wait for propagation
print("Role not specified for head node, using {}".format(
profile.arn))
print("Role not specified for head node, using {}".format(profile.arn))
config["head_node"]["IamInstanceProfile"] = {"Arn": profile.arn}
return config
@@ -146,8 +145,10 @@ def _configure_key_pair(config):
def _configure_subnet(config):
ec2 = _resource("ec2", config)
subnets = sorted(
[s for s in ec2.subnets.all()
if s.state == "available" and s.map_public_ip_on_launch],
[
s for s in ec2.subnets.all()
if s.state == "available" and s.map_public_ip_on_launch
],
reverse=True, # sort from Z-A
key=lambda subnet: subnet.availability_zone)
if not subnets:
@@ -157,9 +158,9 @@ def _configure_subnet(config):
"and trying this again. Note that the subnet must map public IPs "
"on instance launch.")
if "availability_zone" in config["provider"]:
default_subnet = next((s for s in subnets
if s.availability_zone ==
config["provider"]["availability_zone"]),
default_subnet = next((
s for s in subnets
if s.availability_zone == config["provider"]["availability_zone"]),
None)
if not default_subnet:
raise Exception(
@@ -209,11 +210,21 @@ def _configure_security_group(config):
if not security_group.ip_permissions:
security_group.authorize_ingress(
IpPermissions=[
{"FromPort": -1, "ToPort": -1, "IpProtocol": "-1",
"UserIdGroupPairs": [{"GroupId": security_group.id}]},
{"FromPort": 22, "ToPort": 22, "IpProtocol": "TCP",
"IpRanges": [{"CidrIp": "0.0.0.0/0"}]}])
IpPermissions=[{
"FromPort": -1,
"ToPort": -1,
"IpProtocol": "-1",
"UserIdGroupPairs": [{
"GroupId": security_group.id
}]
}, {
"FromPort": 22,
"ToPort": 22,
"IpProtocol": "TCP",
"IpRanges": [{
"CidrIp": "0.0.0.0/0"
}]
}])
if "SecurityGroupIds" not in config["head_node"]:
print("SecurityGroupIds not specified for head node, using {}".format(
@@ -231,8 +242,10 @@ def _configure_security_group(config):
def _get_subnet_or_die(config, subnet_id):
ec2 = _resource("ec2", config)
subnet = list(
ec2.subnets.filter(Filters=[
{"Name": "subnet-id", "Values": [subnet_id]}]))
ec2.subnets.filter(Filters=[{
"Name": "subnet-id",
"Values": [subnet_id]
}]))
assert len(subnet) == 1, "Subnet not found"
subnet = subnet[0]
return subnet
@@ -241,8 +254,10 @@ def _get_subnet_or_die(config, subnet_id):
def _get_security_group(config, vpc_id, group_name):
ec2 = _resource("ec2", config)
existing_groups = list(
ec2.security_groups.filter(Filters=[
{"Name": "vpc-id", "Values": [vpc_id]}]))
ec2.security_groups.filter(Filters=[{
"Name": "vpc-id",
"Values": [vpc_id]
}]))
for sg in existing_groups:
if sg.group_name == group_name:
return sg
@@ -270,8 +285,10 @@ def _get_instance_profile(profile_name, config):
def _get_key(key_name, config):
ec2 = _resource("ec2", config)
for key in ec2.key_pairs.filter(
Filters=[{"Name": "key-name", "Values": [key_name]}]):
for key in ec2.key_pairs.filter(Filters=[{
"Name": "key-name",
"Values": [key_name]
}]):
if key.name == key_name:
return key
+15 -14
View File
@@ -84,7 +84,8 @@ class AWSNodeProvider(NodeProvider):
tag_pairs = []
for k, v in tags.items():
tag_pairs.append({
"Key": k, "Value": v,
"Key": k,
"Value": v,
})
node.create_tags(Tags=tag_pairs)
@@ -95,20 +96,20 @@ class AWSNodeProvider(NodeProvider):
"Value": self.cluster_name,
}]
for k, v in tags.items():
tag_pairs.append(
{
"Key": k,
"Value": v,
})
tag_pairs.append({
"Key": k,
"Value": v,
})
conf.update({
"MinCount": 1,
"MaxCount": count,
"TagSpecifications": conf.get("TagSpecifications", []) + [
{
"ResourceType": "instance",
"Tags": tag_pairs,
}
]
"MinCount":
1,
"MaxCount":
count,
"TagSpecifications":
conf.get("TagSpecifications", []) + [{
"ResourceType": "instance",
"Tags": tag_pairs,
}]
})
self.ec2.create_instances(**conf)
+22 -27
View File
@@ -23,9 +23,8 @@ from ray.autoscaler.tags import TAG_RAY_NODE_TYPE, TAG_RAY_LAUNCH_CONFIG, \
from ray.autoscaler.updater import NodeUpdaterProcess
def create_or_update_cluster(
config_file, override_min_workers, override_max_workers,
no_restart, yes):
def create_or_update_cluster(config_file, override_min_workers,
override_max_workers, no_restart, yes):
"""Create or updates an autoscaling Ray cluster from a config json."""
config = yaml.load(open(config_file).read())
@@ -39,8 +38,8 @@ def create_or_update_cluster(
importer = NODE_PROVIDERS.get(config["provider"]["type"])
if not importer:
raise NotImplementedError(
"Unsupported provider {}".format(config["provider"]))
raise NotImplementedError("Unsupported provider {}".format(
config["provider"]))
bootstrap_config, _ = importer()
config = bootstrap_config(config)
@@ -129,8 +128,10 @@ def get_or_create_head_node(config, no_restart, yes):
remote_config_file.write(json.dumps(remote_config))
remote_config_file.flush()
config["file_mounts"].update({
remote_key_path: config["auth"]["ssh_private_key"],
"~/ray_bootstrap_config.yaml": remote_config_file.name
remote_key_path:
config["auth"]["ssh_private_key"],
"~/ray_bootstrap_config.yaml":
remote_config_file.name
})
if no_restart:
@@ -160,30 +161,24 @@ def get_or_create_head_node(config, no_restart, yes):
print("Error: updating {} failed".format(
provider.external_ip(head_node)))
sys.exit(1)
print(
"Head node up-to-date, IP address is: {}".format(
provider.external_ip(head_node)))
print("Head node up-to-date, IP address is: {}".format(
provider.external_ip(head_node)))
monitor_str = "tail -f /tmp/raylogs/monitor-*"
for s in init_commands:
if ("ray start" in s and "docker exec" in s and
"--autoscaling-config" in s):
if ("ray start" in s and "docker exec" in s
and "--autoscaling-config" in s):
monitor_str = "docker exec {} /bin/sh -c {}".format(
config["docker"]["container_name"],
quote(monitor_str))
print(
"To monitor auto-scaling activity, you can run:\n\n"
" ssh -i {} {}@{} {}\n".format(
config["auth"]["ssh_private_key"],
config["auth"]["ssh_user"],
provider.external_ip(head_node),
quote(monitor_str)))
print(
"To login to the cluster, run:\n\n"
" ssh -i {} {}@{}\n".format(
config["auth"]["ssh_private_key"],
config["auth"]["ssh_user"],
provider.external_ip(head_node)))
config["docker"]["container_name"], quote(monitor_str))
print("To monitor auto-scaling activity, you can run:\n\n"
" ssh -i {} {}@{} {}\n".format(config["auth"]["ssh_private_key"],
config["auth"]["ssh_user"],
provider.external_ip(head_node),
quote(monitor_str)))
print("To login to the cluster, run:\n\n"
" ssh -i {} {}@{}\n".format(config["auth"]["ssh_private_key"],
config["auth"]["ssh_user"],
provider.external_ip(head_node)))
def get_head_node_ip(config_file):
+32 -28
View File
@@ -22,24 +22,21 @@ def dockerize_if_needed(config):
assert cname, "Must provide container name!"
docker_mounts = {dst: dst for dst in config["file_mounts"]}
config["setup_commands"] = (
docker_install_cmds() +
docker_start_cmds(
config["auth"]["ssh_user"], docker_image,
docker_mounts, cname) +
with_docker_exec(
config["setup_commands"], container_name=cname))
docker_install_cmds() + docker_start_cmds(
config["auth"]["ssh_user"], docker_image, docker_mounts, cname) +
with_docker_exec(config["setup_commands"], container_name=cname))
config["head_setup_commands"] = with_docker_exec(
config["head_setup_commands"], container_name=cname)
config["head_start_ray_commands"] = (
docker_autoscaler_setup(cname) +
with_docker_exec(
docker_autoscaler_setup(cname) + with_docker_exec(
config["head_start_ray_commands"], container_name=cname))
config["worker_setup_commands"] = with_docker_exec(
config["worker_setup_commands"], container_name=cname)
config["worker_start_ray_commands"] = with_docker_exec(
config["worker_start_ray_commands"], container_name=cname,
config["worker_start_ray_commands"],
container_name=cname,
env_vars=["RAY_HEAD_IP"])
return config
@@ -50,21 +47,24 @@ def with_docker_exec(cmds, container_name, env_vars=None):
if env_vars:
env_str = " ".join(
["-e {env}=${env}".format(env=env) for env in env_vars])
return ["docker exec {} {} /bin/sh -c {} ".format(
env_str, container_name, quote(cmd)) for cmd in cmds]
return [
"docker exec {} {} /bin/sh -c {} ".format(env_str, container_name,
quote(cmd)) for cmd in cmds
]
def docker_install_cmds():
return [aptwait_cmd() + " && sudo apt-get update",
aptwait_cmd() + " && sudo apt-get install -y docker.io"]
return [
aptwait_cmd() + " && sudo apt-get update",
aptwait_cmd() + " && sudo apt-get install -y docker.io"
]
def aptwait_cmd():
return (
"while sudo fuser"
" /var/{lib/{dpkg,apt/lists},cache/apt/archives}/lock"
" >/dev/null 2>&1; "
"do echo 'Waiting for release of dpkg/apt locks'; sleep 5; done")
return ("while sudo fuser"
" /var/{lib/{dpkg,apt/lists},cache/apt/archives}/lock"
" >/dev/null 2>&1; "
"do echo 'Waiting for release of dpkg/apt locks'; sleep 5; done")
def docker_start_cmds(user, image, mount, cname):
@@ -77,10 +77,12 @@ def docker_start_cmds(user, image, mount, cname):
# create flags
# ports for the redis, object manager, and tune client
port_flags = " ".join(["-p {port}:{port}".format(port=port)
for port in ["6379", "8076", "4321"]])
mount_flags = " ".join(["-v {src}:{dest}".format(src=k, dest=v)
for k, v in mount.items()])
port_flags = " ".join([
"-p {port}:{port}".format(port=port)
for port in ["6379", "8076", "4321"]
])
mount_flags = " ".join(
["-v {src}:{dest}".format(src=k, dest=v) for k, v in mount.items()])
# for click, used in ray cli
env_vars = {"LC_ALL": "C.UTF-8", "LANG": "C.UTF-8"}
@@ -88,9 +90,10 @@ def docker_start_cmds(user, image, mount, cname):
["-e {name}={val}".format(name=k, val=v) for k, v in env_vars.items()])
# docker run command
docker_run = ["docker", "run", "--rm", "--name {}".format(cname),
"-d", "-it", port_flags, mount_flags, env_flags,
"--net=host", image, "bash"]
docker_run = [
"docker", "run", "--rm", "--name {}".format(cname), "-d", "-it",
port_flags, mount_flags, env_flags, "--net=host", image, "bash"
]
cmds.append(" ".join(docker_run))
docker_update = []
docker_update.append("apt-get -y update")
@@ -107,7 +110,8 @@ def docker_autoscaler_setup(cname):
base_path = os.path.basename(path)
cmds.append("docker cp {path} {cname}:{dpath}".format(
path=path, dpath=base_path, cname=cname))
cmds.extend(with_docker_exec(
["cp {} {}".format("/" + base_path, path)],
container_name=cname))
cmds.extend(
with_docker_exec(
["cp {} {}".format("/" + base_path, path)],
container_name=cname))
return cmds
+8 -8
View File
@@ -15,14 +15,15 @@ def import_aws():
def load_aws_config():
import ray.autoscaler.aws as ray_aws
return os.path.join(os.path.dirname(
ray_aws.__file__), "example-full.yaml")
return os.path.join(os.path.dirname(ray_aws.__file__), "example-full.yaml")
def import_external():
"""Mock a normal provider importer."""
def return_it_back(config):
return config
return return_it_back, None
@@ -55,8 +56,7 @@ def load_class(path):
class_data = path.split(".")
if len(class_data) < 2:
raise ValueError(
"You need to pass a valid path like mymodule.provider_class"
)
"You need to pass a valid path like mymodule.provider_class")
module_path = ".".join(class_data[:-1])
class_str = class_data[-1]
module = importlib.import_module(module_path)
@@ -71,8 +71,8 @@ def get_node_provider(provider_config, cluster_name):
importer = NODE_PROVIDERS.get(provider_config["type"])
if importer is None:
raise NotImplementedError(
"Unsupported node provider: {}".format(provider_config["type"]))
raise NotImplementedError("Unsupported node provider: {}".format(
provider_config["type"]))
_, provider_cls = importer()
return provider_cls(provider_config, cluster_name)
@@ -82,8 +82,8 @@ def get_default_config(provider_config):
return {}
load_config = DEFAULT_CONFIGS.get(provider_config["type"])
if load_config is None:
raise NotImplementedError(
"Unsupported node provider: {}".format(provider_config["type"]))
raise NotImplementedError("Unsupported node provider: {}".format(
provider_config["type"]))
path_to_default = load_config()
with open(path_to_default) as f:
defaults = yaml.load(f)
-1
View File
@@ -1,7 +1,6 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""The Ray autoscaler uses tags to associate metadata with instances."""
# Tag for the name of the node
+43 -32
View File
@@ -26,10 +26,16 @@ def pretty_cmd(cmd_str):
class NodeUpdater(object):
"""A process for syncing files and running init commands on a node."""
def __init__(
self, node_id, provider_config, auth_config, cluster_name,
file_mounts, setup_cmds, runtime_hash, redirect_output=True,
process_runner=subprocess):
def __init__(self,
node_id,
provider_config,
auth_config,
cluster_name,
file_mounts,
setup_cmds,
runtime_hash,
redirect_output=True,
process_runner=subprocess):
self.daemon = True
self.process_runner = process_runner
self.provider = get_node_provider(provider_config, cluster_name)
@@ -66,13 +72,12 @@ class NodeUpdater(object):
"NodeUpdater: Error updating {}"
"See {} for remote logs.".format(error_str, self.output_name),
file=self.stdout)
self.provider.set_node_tags(
self.node_id, {TAG_RAY_NODE_STATUS: "UpdateFailed"})
self.provider.set_node_tags(self.node_id,
{TAG_RAY_NODE_STATUS: "UpdateFailed"})
if self.logfile is not None:
print(
"----- BEGIN REMOTE LOGS -----\n" +
open(self.logfile.name).read() +
"\n----- END REMOTE LOGS -----")
print("----- BEGIN REMOTE LOGS -----\n" + open(
self.logfile.name).read() + "\n----- END REMOTE LOGS -----"
)
raise e
self.provider.set_node_tags(
self.node_id, {
@@ -85,8 +90,8 @@ class NodeUpdater(object):
file=self.stdout)
def do_update(self):
self.provider.set_node_tags(
self.node_id, {TAG_RAY_NODE_STATUS: "WaitingForSSH"})
self.provider.set_node_tags(self.node_id,
{TAG_RAY_NODE_STATUS: "WaitingForSSH"})
deadline = time.time() + NODE_START_WAIT_S
# Wait for external IP
@@ -114,7 +119,8 @@ class NodeUpdater(object):
raise Exception("Node not running yet...")
self.ssh_cmd(
"uptime",
connect_timeout=5, redirect=open("/dev/null", "w"))
connect_timeout=5,
redirect=open("/dev/null", "w"))
ssh_ok = True
except Exception as e:
retry_str = str(e)
@@ -130,8 +136,8 @@ class NodeUpdater(object):
assert ssh_ok, "Unable to SSH to node"
# Rsync file mounts
self.provider.set_node_tags(
self.node_id, {TAG_RAY_NODE_STATUS: "SyncingFiles"})
self.provider.set_node_tags(self.node_id,
{TAG_RAY_NODE_STATUS: "SyncingFiles"})
for remote_path, local_path in self.file_mounts.items():
print(
"NodeUpdater: Syncing {} to {}...".format(
@@ -143,18 +149,20 @@ class NodeUpdater(object):
local_path += "/"
if not remote_path.endswith("/"):
remote_path += "/"
self.ssh_cmd(
"mkdir -p {}".format(os.path.dirname(remote_path)))
self.process_runner.check_call([
"rsync", "-e", "ssh -i {} ".format(self.ssh_private_key) +
"-o ConnectTimeout=120s -o StrictHostKeyChecking=no",
"--delete", "-avz", "{}".format(local_path),
"{}@{}:{}".format(self.ssh_user, self.ssh_ip, remote_path)
], stdout=self.stdout, stderr=self.stderr)
self.ssh_cmd("mkdir -p {}".format(os.path.dirname(remote_path)))
self.process_runner.check_call(
[
"rsync", "-e", "ssh -i {} ".format(self.ssh_private_key) +
"-o ConnectTimeout=120s -o StrictHostKeyChecking=no",
"--delete", "-avz", "{}".format(local_path),
"{}@{}:{}".format(self.ssh_user, self.ssh_ip, remote_path)
],
stdout=self.stdout,
stderr=self.stderr)
# Run init commands
self.provider.set_node_tags(
self.node_id, {TAG_RAY_NODE_STATUS: "SettingUp"})
self.provider.set_node_tags(self.node_id,
{TAG_RAY_NODE_STATUS: "SettingUp"})
for cmd in self.setup_cmds:
self.ssh_cmd(cmd, verbose=True)
@@ -165,13 +173,16 @@ class NodeUpdater(object):
pretty_cmd(cmd), self.ssh_ip),
file=self.stdout)
force_interactive = "set -i && source ~/.bashrc && "
self.process_runner.check_call([
"ssh", "-o", "ConnectTimeout={}s".format(connect_timeout),
"-o", "StrictHostKeyChecking=no",
"-i", self.ssh_private_key,
"{}@{}".format(self.ssh_user, self.ssh_ip),
"bash --login -c {}".format(pipes.quote(force_interactive + cmd))
], stdout=redirect or self.stdout, stderr=redirect or self.stderr)
self.process_runner.check_call(
[
"ssh", "-o", "ConnectTimeout={}s".format(connect_timeout),
"-o", "StrictHostKeyChecking=no",
"-i", self.ssh_private_key, "{}@{}".format(
self.ssh_user, self.ssh_ip), "bash --login -c {}".format(
pipes.quote(force_interactive + cmd))
],
stdout=redirect or self.stdout,
stderr=redirect or self.stderr)
class NodeUpdaterProcess(NodeUpdater, Process):
+27 -37
View File
@@ -25,7 +25,7 @@ OBJECT_CHANNEL_PREFIX = "OC:"
def integerToAsciiHex(num, numbytes):
retstr = b""
# Support 32 and 64 bit architecture.
assert(numbytes == 4 or numbytes == 8)
assert (numbytes == 4 or numbytes == 8)
for i in range(numbytes):
curbyte = num & 0xff
if sys.version_info >= (3, 0):
@@ -50,7 +50,6 @@ def get_next_message(pubsub_client, timeout_seconds=10):
class TestGlobalStateStore(unittest.TestCase):
def setUp(self):
redis_port, _ = ray.services.start_redis_instance()
self.redis = redis.StrictRedis(host="localhost", port=redis_port, db=0)
@@ -192,16 +191,16 @@ class TestGlobalStateStore(unittest.TestCase):
# notifications.
def check_object_notification(notification_message, object_id,
object_size, manager_ids):
notification_object = (SubscribeToNotificationsReply
.GetRootAsSubscribeToNotificationsReply(
notification_object = (SubscribeToNotificationsReply.
GetRootAsSubscribeToNotificationsReply(
notification_message, 0))
self.assertEqual(notification_object.ObjectId(), object_id)
self.assertEqual(notification_object.ObjectSize(), object_size)
self.assertEqual(notification_object.ManagerIdsLength(),
len(manager_ids))
for i in range(len(manager_ids)):
self.assertEqual(notification_object.ManagerIds(i),
manager_ids[i])
self.assertEqual(
notification_object.ManagerIds(i), manager_ids[i])
data_size = 0xf1f0
p = self.redis.pubsub()
@@ -215,10 +214,9 @@ class TestGlobalStateStore(unittest.TestCase):
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS",
"manager_id1", "object_id1")
# Verify that the notification is correct.
check_object_notification(get_next_message(p)["data"],
b"object_id1",
data_size,
[b"manager_id2"])
check_object_notification(
get_next_message(p)["data"], b"object_id1", data_size,
[b"manager_id2"])
# Request a notification for an object that isn't there. Then add the
# object and receive the data. Only the first call to
@@ -232,26 +230,22 @@ class TestGlobalStateStore(unittest.TestCase):
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3",
data_size, "hash1", "manager_id3")
# Verify that the notification is correct.
check_object_notification(get_next_message(p)["data"],
b"object_id3",
data_size,
[b"manager_id1"])
check_object_notification(
get_next_message(p)["data"], b"object_id3", data_size,
[b"manager_id1"])
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2",
data_size, "hash1", "manager_id3")
# Verify that the notification is correct.
check_object_notification(get_next_message(p)["data"],
b"object_id2",
data_size,
[b"manager_id3"])
check_object_notification(
get_next_message(p)["data"], b"object_id2", data_size,
[b"manager_id3"])
# Request notifications for object_id3 again.
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS",
"manager_id1", "object_id3")
# Verify that the notification is correct.
check_object_notification(get_next_message(p)["data"],
b"object_id3",
data_size,
[b"manager_id1", b"manager_id2",
b"manager_id3"])
check_object_notification(
get_next_message(p)["data"], b"object_id3", data_size,
[b"manager_id1", b"manager_id2", b"manager_id3"])
def testResultTableAddAndLookup(self):
def check_result_table_entry(message, task_id, is_put):
@@ -349,8 +343,7 @@ class TestGlobalStateStore(unittest.TestCase):
# update happens, and the response is still the same task.
task_args = [task_args[0]] + task_args
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
"task_id",
*task_args[:3])
"task_id", *task_args[:3])
check_task_reply(response, task_args[1:], updated=True)
# Check that the task entry is still the same.
get_response = self.redis.execute_command("RAY.TASK_TABLE_GET",
@@ -362,8 +355,7 @@ class TestGlobalStateStore(unittest.TestCase):
# task.
task_args[1] = TASK_STATUS_QUEUED
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
"task_id",
*task_args[:3])
"task_id", *task_args[:3])
check_task_reply(response, task_args[1:], updated=True)
# Check that the update happened.
get_response = self.redis.execute_command("RAY.TASK_TABLE_GET",
@@ -375,8 +367,7 @@ class TestGlobalStateStore(unittest.TestCase):
new_task_args = task_args[:]
new_task_args[1] = TASK_STATUS_WAITING
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
"task_id",
*new_task_args[:3])
"task_id", *new_task_args[:3])
check_task_reply(response, task_args[1:], updated=False)
# Check that the update did not happen.
get_response2 = self.redis.execute_command("RAY.TASK_TABLE_GET",
@@ -388,8 +379,7 @@ class TestGlobalStateStore(unittest.TestCase):
task_args = new_task_args
task_args[0] = TASK_STATUS_SCHEDULED | TASK_STATUS_QUEUED
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
"task_id",
*task_args[:3])
"task_id", *task_args[:3])
check_task_reply(response, task_args[1:], updated=True)
# If the test value is a bitmask that does not match the current value,
@@ -399,8 +389,7 @@ class TestGlobalStateStore(unittest.TestCase):
new_task_args[0] = TASK_STATUS_SCHEDULED
old_response = response
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
"task_id",
*new_task_args[:3])
"task_id", *new_task_args[:3])
check_task_reply(response, task_args[1:], updated=False)
# Check that the update did not happen.
get_response = self.redis.execute_command("RAY.TASK_TABLE_GET",
@@ -409,8 +398,10 @@ class TestGlobalStateStore(unittest.TestCase):
check_task_reply(get_response, task_args[1:])
def check_task_subscription(self, p, scheduling_state, local_scheduler_id):
task_args = [b"task_id", scheduling_state,
local_scheduler_id.encode("ascii"), b"", 0, b"task_spec"]
task_args = [
b"task_id", scheduling_state,
local_scheduler_id.encode("ascii"), b"", 0, b"task_spec"
]
self.redis.execute_command("RAY.TASK_TABLE_ADD", *task_args)
# Receive the data.
message = get_next_message(p)["data"]
@@ -418,8 +409,7 @@ class TestGlobalStateStore(unittest.TestCase):
notification_object = TaskReply.GetRootAsTaskReply(message, 0)
self.assertEqual(notification_object.TaskId(), task_args[0])
self.assertEqual(notification_object.State(), task_args[1])
self.assertEqual(notification_object.LocalSchedulerId(),
task_args[2])
self.assertEqual(notification_object.LocalSchedulerId(), task_args[2])
self.assertEqual(notification_object.ExecutionDependencies(),
task_args[3])
self.assertEqual(notification_object.TaskSpec(), task_args[-1])
+36 -46
View File
@@ -30,19 +30,23 @@ def random_task_id():
BASE_SIMPLE_OBJECTS = [
0, 1, 100000, 0.0, 0.5, 0.9, 100000.1, (), [], {}, "", 990 * "h", u"",
990 * u"h"]
990 * u"h"
]
if sys.version_info < (3, 0):
BASE_SIMPLE_OBJECTS += [long(0), long(1), long(100000), long(1 << 100)] # noqa: E501,F821
BASE_SIMPLE_OBJECTS += [
long(0), # noqa: E501,F821
long(1), # noqa: E501,F821
long(100000), # noqa: E501,F821
long(1 << 100) # noqa: E501,F821
]
LIST_SIMPLE_OBJECTS = [[obj] for obj in BASE_SIMPLE_OBJECTS]
TUPLE_SIMPLE_OBJECTS = [(obj,) for obj in BASE_SIMPLE_OBJECTS]
TUPLE_SIMPLE_OBJECTS = [(obj, ) for obj in BASE_SIMPLE_OBJECTS]
DICT_SIMPLE_OBJECTS = [{(): obj} for obj in BASE_SIMPLE_OBJECTS]
SIMPLE_OBJECTS = (BASE_SIMPLE_OBJECTS +
LIST_SIMPLE_OBJECTS +
TUPLE_SIMPLE_OBJECTS +
DICT_SIMPLE_OBJECTS)
SIMPLE_OBJECTS = (BASE_SIMPLE_OBJECTS + LIST_SIMPLE_OBJECTS +
TUPLE_SIMPLE_OBJECTS + DICT_SIMPLE_OBJECTS)
# Create some complex objects that cannot be serialized by value in tasks.
@@ -55,21 +59,20 @@ class Foo(object):
pass
BASE_COMPLEX_OBJECTS = [999 * "h", 999 * u"h", lst, Foo(),
10 * [10 * [10 * [1]]]]
BASE_COMPLEX_OBJECTS = [
999 * "h", 999 * u"h", lst,
Foo(), 10 * [10 * [10 * [1]]]
]
LIST_COMPLEX_OBJECTS = [[obj] for obj in BASE_COMPLEX_OBJECTS]
TUPLE_COMPLEX_OBJECTS = [(obj,) for obj in BASE_COMPLEX_OBJECTS]
TUPLE_COMPLEX_OBJECTS = [(obj, ) for obj in BASE_COMPLEX_OBJECTS]
DICT_COMPLEX_OBJECTS = [{(): obj} for obj in BASE_COMPLEX_OBJECTS]
COMPLEX_OBJECTS = (BASE_COMPLEX_OBJECTS +
LIST_COMPLEX_OBJECTS +
TUPLE_COMPLEX_OBJECTS +
DICT_COMPLEX_OBJECTS)
COMPLEX_OBJECTS = (BASE_COMPLEX_OBJECTS + LIST_COMPLEX_OBJECTS +
TUPLE_COMPLEX_OBJECTS + DICT_COMPLEX_OBJECTS)
class TestSerialization(unittest.TestCase):
def test_serialize_by_value(self):
for val in SIMPLE_OBJECTS:
@@ -79,7 +82,6 @@ class TestSerialization(unittest.TestCase):
class TestObjectID(unittest.TestCase):
def test_create_object_id(self):
random_object_id()
@@ -95,6 +97,7 @@ class TestObjectID(unittest.TestCase):
def h():
object_ids[0]
return 1
# Make sure that object IDs cannot be pickled (including functions that
# close over object IDs).
self.assertRaises(Exception, lambda: pickle.dumps(object_ids[0]))
@@ -113,10 +116,12 @@ class TestObjectID(unittest.TestCase):
self.assertNotEqual(x1, y1)
random_strings = [np.random.bytes(ID_SIZE) for _ in range(256)]
object_ids1 = [local_scheduler.ObjectID(random_strings[i])
for i in range(256)]
object_ids2 = [local_scheduler.ObjectID(random_strings[i])
for i in range(256)]
object_ids1 = [
local_scheduler.ObjectID(random_strings[i]) for i in range(256)
]
object_ids2 = [
local_scheduler.ObjectID(random_strings[i]) for i in range(256)
]
self.assertEqual(len(set(object_ids1)), 256)
self.assertEqual(len(set(object_ids1 + object_ids2)), 256)
self.assertEqual(set(object_ids1), set(object_ids2))
@@ -129,7 +134,6 @@ class TestObjectID(unittest.TestCase):
class TestTask(unittest.TestCase):
def check_task(self, task, function_id, num_return_vals, args):
self.assertEqual(function_id.id(), task.function_id().id())
retrieved_args = task.arguments()
@@ -148,31 +152,17 @@ class TestTask(unittest.TestCase):
parent_id = random_task_id()
function_id = random_function_id()
object_ids = [random_object_id() for _ in range(256)]
args_list = [
[],
1 * [1],
10 * [1],
100 * [1],
1000 * [1],
1 * ["a"],
10 * ["a"],
100 * ["a"],
1000 * ["a"],
[1, 1.3, 2, 1 << 100, "hi", u"hi", [1, 2]],
object_ids[:1],
object_ids[:2],
object_ids[:3],
object_ids[:4],
object_ids[:5],
object_ids[:10],
object_ids[:100],
object_ids[:256],
[1, object_ids[0]],
[object_ids[0], "a"],
[1, object_ids[0], "a"],
[object_ids[0], 1, object_ids[1], "a"],
object_ids[:3] + [1, "hi", 2.3] + object_ids[:5],
object_ids + 100 * ["a"] + object_ids]
args_list = [[], 1 * [1], 10 * [1], 100 * [1], 1000 * [1], 1 * ["a"],
10 * ["a"], 100 * ["a"], 1000 * ["a"], [
1, 1.3, 2, 1 << 100, "hi", u"hi", [1, 2]
], object_ids[:1], object_ids[:2], object_ids[:3],
object_ids[:4], object_ids[:5], object_ids[:10],
object_ids[:100], object_ids[:256], [1, object_ids[0]], [
object_ids[0], "a"
], [1, object_ids[0], "a"], [
object_ids[0], 1, object_ids[1], "a"
], object_ids[:3] + [1, "hi", 2.3] + object_ids[:5],
object_ids + 100 * ["a"] + object_ids]
for args in args_list:
for num_return_vals in [0, 1, 2, 3, 5, 10, 100]:
task = local_scheduler.Task(driver_id, function_id, args,
+4 -2
View File
@@ -5,5 +5,7 @@ from __future__ import print_function
from .tfutils import TensorFlowVariables
from .features import flush_redis_unsafe, flush_task_and_object_metadata_unsafe
__all__ = ["TensorFlowVariables", "flush_redis_unsafe",
"flush_task_and_object_metadata_unsafe"]
__all__ = [
"TensorFlowVariables", "flush_redis_unsafe",
"flush_task_and_object_metadata_unsafe"
]
@@ -8,6 +8,8 @@ from .core import (BLOCK_SIZE, DistArray, assemble, zeros, ones, copy, eye,
triu, tril, blockwise_dot, dot, transpose, add, subtract,
numpy_to_dist, subblocks)
__all__ = ["random", "linalg", "BLOCK_SIZE", "DistArray", "assemble", "zeros",
"ones", "copy", "eye", "triu", "tril", "blockwise_dot", "dot",
"transpose", "add", "subtract", "numpy_to_dist", "subblocks"]
__all__ = [
"random", "linalg", "BLOCK_SIZE", "DistArray", "assemble", "zeros", "ones",
"copy", "eye", "triu", "tril", "blockwise_dot", "dot", "transpose", "add",
"subtract", "numpy_to_dist", "subblocks"
]
@@ -13,8 +13,9 @@ class DistArray(object):
def __init__(self, shape, objectids=None):
self.shape = shape
self.ndim = len(shape)
self.num_blocks = [int(np.ceil(1.0 * a / BLOCK_SIZE))
for a in self.shape]
self.num_blocks = [
int(np.ceil(1.0 * a / BLOCK_SIZE)) for a in self.shape
]
if objectids is not None:
self.objectids = objectids
else:
@@ -56,7 +57,7 @@ class DistArray(object):
def assemble(self):
"""Assemble an array from a distributed array of object IDs."""
first_block = ray.get(self.objectids[(0,) * self.ndim])
first_block = ray.get(self.objectids[(0, ) * self.ndim])
dtype = first_block.dtype
result = np.zeros(self.shape, dtype=dtype)
for index in np.ndindex(*self.num_blocks):
@@ -85,8 +86,8 @@ def numpy_to_dist(a):
for index in np.ndindex(*result.num_blocks):
lower = DistArray.compute_block_lower(index, a.shape)
upper = DistArray.compute_block_upper(index, a.shape)
result.objectids[index] = ray.put(a[[slice(l, u) for (l, u)
in zip(lower, upper)]])
result.objectids[index] = ray.put(
a[[slice(l, u) for (l, u) in zip(lower, upper)]])
return result
@@ -126,12 +127,11 @@ def eye(dim1, dim2=-1, dtype_name="float"):
for (i, j) in np.ndindex(*result.num_blocks):
block_shape = DistArray.compute_block_shape([i, j], shape)
if i == j:
result.objectids[i, j] = ra.eye.remote(block_shape[0],
block_shape[1],
dtype_name=dtype_name)
result.objectids[i, j] = ra.eye.remote(
block_shape[0], block_shape[1], dtype_name=dtype_name)
else:
result.objectids[i, j] = ra.zeros.remote(block_shape,
dtype_name=dtype_name)
result.objectids[i, j] = ra.zeros.remote(
block_shape, dtype_name=dtype_name)
return result
@@ -190,8 +190,8 @@ def dot(a, b):
"b.ndim = {}.".format(b.ndim))
if a.shape[1] != b.shape[0]:
raise Exception("dot expects a.shape[1] to equal b.shape[0], but "
"a.shape = {} and b.shape = {}.".format(a.shape,
b.shape))
"a.shape = {} and b.shape = {}.".format(
a.shape, b.shape))
shape = [a.shape[0], b.shape[1]]
result = DistArray(shape)
for (i, j) in np.ndindex(*result.num_blocks):
@@ -227,8 +227,8 @@ def subblocks(a, *ranges):
"the {}th range is {}.".format(i, ranges[i]))
if ranges[i][0] < 0:
raise Exception("Values in the ranges passed to sub_blocks must "
"be at least 0, but the {}th range is {}."
.format(i, ranges[i]))
"be at least 0, but the {}th range is {}.".format(
i, ranges[i]))
if ranges[i][-1] >= a.num_blocks[i]:
raise Exception("Values in the ranges passed to sub_blocks must "
"be less than the relevant number of blocks, but "
@@ -240,8 +240,8 @@ def subblocks(a, *ranges):
for i in range(a.ndim)]
result = DistArray(shape)
for index in np.ndindex(*result.num_blocks):
result.objectids[index] = a.objectids[tuple([ranges[i][index[i]]
for i in range(a.ndim)])]
result.objectids[index] = a.objectids[tuple(
[ranges[i][index[i]] for i in range(a.ndim)])]
return result
@@ -249,8 +249,8 @@ def subblocks(a, *ranges):
def transpose(a):
if a.ndim != 2:
raise Exception("transpose expects its argument to be 2-dimensional, "
"but a.ndim = {}, a.shape = {}.".format(a.ndim,
a.shape))
"but a.ndim = {}, a.shape = {}.".format(
a.ndim, a.shape))
result = DistArray([a.shape[1], a.shape[0]])
for i in range(result.num_blocks[0]):
for j in range(result.num_blocks[1]):
@@ -263,8 +263,8 @@ def transpose(a):
def add(x1, x2):
if x1.shape != x2.shape:
raise Exception("add expects arguments `x1` and `x2` to have the same "
"shape, but x1.shape = {}, and x2.shape = {}."
.format(x1.shape, x2.shape))
"shape, but x1.shape = {}, and x2.shape = {}.".format(
x1.shape, x2.shape))
result = DistArray(x1.shape)
for index in np.ndindex(*result.num_blocks):
result.objectids[index] = ra.add.remote(x1.objectids[index],
@@ -76,9 +76,10 @@ def tsqr(a):
lower = [a.shape[1], 0]
upper = [2 * a.shape[1], core.BLOCK_SIZE]
ith_index //= 2
q_block_current = ra.dot.remote(
q_block_current, ra.subarray.remote(q_tree[ith_index, j],
lower, upper))
q_block_current = ra.dot.remote(q_block_current,
ra.subarray.remote(
q_tree[ith_index, j], lower,
upper))
q_result.objectids[i] = q_block_current
r = current_rs[0]
return q_result, ray.get(r)
@@ -196,8 +197,8 @@ def qr(a):
if a.shape[0] > a.shape[1]:
# in this case, R needs to be square
R_shape = ray.get(ra.shape.remote(R))
eye_temp = ra.eye.remote(R_shape[1], R_shape[0],
dtype_name=result_dtype)
eye_temp = ra.eye.remote(
R_shape[1], R_shape[0], dtype_name=result_dtype)
r_res.objectids[i, i] = ra.dot.remote(eye_temp, R)
else:
r_res.objectids[i, i] = R
@@ -220,10 +221,11 @@ def qr(a):
for i in range(len(Ts))[::-1]:
y_col_block = core.subblocks.remote(y_res, [], [i])
q = core.subtract.remote(
q, core.dot.remote(
y_col_block,
core.dot.remote(
Ts[i],
core.dot.remote(core.transpose.remote(y_col_block), q))))
q,
core.dot.remote(y_col_block,
core.dot.remote(
Ts[i],
core.dot.remote(
core.transpose.remote(y_col_block), q))))
return ray.get(q), r_res
@@ -8,6 +8,8 @@ from .core import (zeros, zeros_like, ones, eye, dot, vstack, hstack, subarray,
copy, tril, triu, diag, transpose, add, subtract, sum,
shape, sum_list)
__all__ = ["random", "linalg", "zeros", "zeros_like", "ones", "eye", "dot",
"vstack", "hstack", "subarray", "copy", "tril", "triu", "diag",
"transpose", "add", "subtract", "sum", "shape", "sum_list"]
__all__ = [
"random", "linalg", "zeros", "zeros_like", "ones", "eye", "dot", "vstack",
"hstack", "subarray", "copy", "tril", "triu", "diag", "transpose", "add",
"subtract", "sum", "shape", "sum_list"
]
@@ -5,10 +5,11 @@ from __future__ import print_function
import numpy as np
import ray
__all__ = ["matrix_power", "solve", "tensorsolve", "tensorinv", "inv",
"cholesky", "eigvals", "eigvalsh", "pinv", "slogdet", "det",
"svd", "eig", "eigh", "lstsq", "norm", "qr", "cond", "matrix_rank",
"multi_dot"]
__all__ = [
"matrix_power", "solve", "tensorsolve", "tensorinv", "inv", "cholesky",
"eigvals", "eigvalsh", "pinv", "slogdet", "det", "svd", "eig", "eigh",
"lstsq", "norm", "qr", "cond", "matrix_rank", "multi_dot"
]
@ray.remote
+2 -2
View File
@@ -69,14 +69,14 @@ def flush_task_and_object_metadata_unsafe():
for key in redis_client.scan_iter(match=OBJECT_INFO_PREFIX + b"*"):
num_object_keys_deleted += redis_client.delete(key)
print("Deleted {} object info keys from Redis.".format(
num_object_keys_deleted))
num_object_keys_deleted))
# Flush the object locations.
num_object_location_keys_deleted = 0
for key in redis_client.scan_iter(match=OBJECT_LOCATION_PREFIX + b"*"):
num_object_location_keys_deleted += redis_client.delete(key)
print("Deleted {} object location keys from Redis.".format(
num_object_location_keys_deleted))
num_object_location_keys_deleted))
# Loop over the shards and flush all of them.
for redis_client in ray.worker.global_state.redis_clients:
+279 -192
View File
@@ -59,6 +59,7 @@ class GlobalState(object):
Attributes:
redis_client: The redis client used to query the redis server.
"""
def __init__(self):
"""Create a GlobalState object."""
# The redis server storing metadata, such as function table, client
@@ -82,7 +83,9 @@ class GlobalState(object):
raise Exception("The ray.global_state API cannot be used before "
"ray.init has been called.")
def _initialize_global_state(self, redis_ip_address, redis_port,
def _initialize_global_state(self,
redis_ip_address,
redis_port,
timeout=20):
"""Initialize the GlobalState object by connecting to Redis.
@@ -97,8 +100,8 @@ class GlobalState(object):
timeout: The maximum amount of time (in seconds) that we should
wait for the keys in Redis to be populated.
"""
self.redis_client = redis.StrictRedis(host=redis_ip_address,
port=redis_port)
self.redis_client = redis.StrictRedis(
host=redis_ip_address, port=redis_port)
start_time = time.time()
@@ -118,8 +121,8 @@ class GlobalState(object):
"{}.".format(num_redis_shards))
# Attempt to get all of the Redis shards.
ip_address_ports = self.redis_client.lrange("RedisShards", start=0,
end=-1)
ip_address_ports = self.redis_client.lrange(
"RedisShards", start=0, end=-1)
if len(ip_address_ports) != num_redis_shards:
print("Waiting longer for RedisShards to be populated.")
time.sleep(1)
@@ -132,15 +135,15 @@ class GlobalState(object):
if time.time() - start_time >= timeout:
raise Exception("Timed out while attempting to initialize the "
"global state. num_redis_shards = {}, "
"ip_address_ports = {}"
.format(num_redis_shards, ip_address_ports))
"ip_address_ports = {}".format(
num_redis_shards, ip_address_ports))
# Get the rest of the information.
self.redis_clients = []
for ip_address_port in ip_address_ports:
shard_address, shard_port = ip_address_port.split(b":")
self.redis_clients.append(redis.StrictRedis(host=shard_address,
port=shard_port))
self.redis_clients.append(
redis.StrictRedis(host=shard_address, port=shard_port))
def _execute_command(self, key, *args):
"""Execute a Redis command on the appropriate Redis shard based on key.
@@ -152,8 +155,8 @@ class GlobalState(object):
Returns:
The value returned by the Redis command.
"""
client = self.redis_clients[key.redis_shard_hash() %
len(self.redis_clients)]
client = self.redis_clients[key.redis_shard_hash() % len(
self.redis_clients)]
return client.execute_command(*args)
def _keys(self, pattern):
@@ -189,8 +192,9 @@ class GlobalState(object):
"RAY.OBJECT_TABLE_LOOKUP",
object_id.id())
if object_locations is not None:
manager_ids = [binary_to_hex(manager_id)
for manager_id in object_locations]
manager_ids = [
binary_to_hex(manager_id) for manager_id in object_locations
]
else:
manager_ids = None
@@ -199,11 +203,13 @@ class GlobalState(object):
result_table_message = ResultTableReply.GetRootAsResultTableReply(
result_table_response, 0)
result = {"ManagerIDs": manager_ids,
"TaskID": binary_to_hex(result_table_message.TaskId()),
"IsPut": bool(result_table_message.IsPut()),
"DataSize": result_table_message.DataSize(),
"Hash": binary_to_hex(result_table_message.Hash())}
result = {
"ManagerIDs": manager_ids,
"TaskID": binary_to_hex(result_table_message.TaskId()),
"IsPut": bool(result_table_message.IsPut()),
"DataSize": result_table_message.DataSize(),
"Hash": binary_to_hex(result_table_message.Hash())
}
return result
@@ -227,9 +233,10 @@ class GlobalState(object):
object_info_keys = self._keys(OBJECT_INFO_PREFIX + "*")
object_location_keys = self._keys(OBJECT_LOCATION_PREFIX + "*")
object_ids_binary = set(
[key[len(OBJECT_INFO_PREFIX):] for key in object_info_keys] +
[key[len(OBJECT_LOCATION_PREFIX):]
for key in object_location_keys])
[key[len(OBJECT_INFO_PREFIX):] for key in object_info_keys] + [
key[len(OBJECT_LOCATION_PREFIX):]
for key in object_location_keys
])
results = {}
for object_id_binary in object_ids_binary:
results[binary_to_object_id(object_id_binary)] = (
@@ -254,26 +261,37 @@ class GlobalState(object):
if task_table_response is None:
raise Exception("There is no entry for task ID {} in the task "
"table.".format(binary_to_hex(task_id.id())))
task_table_message = TaskReply.GetRootAsTaskReply(task_table_response,
0)
task_table_message = TaskReply.GetRootAsTaskReply(
task_table_response, 0)
task_spec = task_table_message.TaskSpec()
task_spec = ray.local_scheduler.task_from_string(task_spec)
task_spec_info = {
"DriverID": binary_to_hex(task_spec.driver_id().id()),
"TaskID": binary_to_hex(task_spec.task_id().id()),
"ParentTaskID": binary_to_hex(task_spec.parent_task_id().id()),
"ParentCounter": task_spec.parent_counter(),
"ActorID": binary_to_hex(task_spec.actor_id().id()),
"DriverID":
binary_to_hex(task_spec.driver_id().id()),
"TaskID":
binary_to_hex(task_spec.task_id().id()),
"ParentTaskID":
binary_to_hex(task_spec.parent_task_id().id()),
"ParentCounter":
task_spec.parent_counter(),
"ActorID":
binary_to_hex(task_spec.actor_id().id()),
"ActorCreationID":
binary_to_hex(task_spec.actor_creation_id().id()),
binary_to_hex(task_spec.actor_creation_id().id()),
"ActorCreationDummyObjectID":
binary_to_hex(task_spec.actor_creation_dummy_object_id().id()),
"ActorCounter": task_spec.actor_counter(),
"FunctionID": binary_to_hex(task_spec.function_id().id()),
"Args": task_spec.arguments(),
"ReturnObjectIDs": task_spec.returns(),
"RequiredResources": task_spec.required_resources()}
binary_to_hex(task_spec.actor_creation_dummy_object_id().id()),
"ActorCounter":
task_spec.actor_counter(),
"FunctionID":
binary_to_hex(task_spec.function_id().id()),
"Args":
task_spec.arguments(),
"ReturnObjectIDs":
task_spec.returns(),
"RequiredResources":
task_spec.required_resources()
}
execution_dependencies_message = (
TaskExecutionDependencies.GetRootAsTaskExecutionDependencies(
@@ -282,21 +300,27 @@ class GlobalState(object):
ray.local_scheduler.ObjectID(
execution_dependencies_message.ExecutionDependencies(i))
for i in range(
execution_dependencies_message.ExecutionDependenciesLength())]
execution_dependencies_message.ExecutionDependenciesLength())
]
# TODO(rkn): The return fields ExecutionDependenciesString and
# ExecutionDependencies are redundant, so we should remove
# ExecutionDependencies. However, it is currently used in monitor.py.
return {"State": task_table_message.State(),
"LocalSchedulerID": binary_to_hex(
task_table_message.LocalSchedulerId()),
"ExecutionDependenciesString":
task_table_message.ExecutionDependencies(),
"ExecutionDependencies": execution_dependencies,
"SpillbackCount":
task_table_message.SpillbackCount(),
"TaskSpec": task_spec_info}
return {
"State":
task_table_message.State(),
"LocalSchedulerID":
binary_to_hex(task_table_message.LocalSchedulerId()),
"ExecutionDependenciesString":
task_table_message.ExecutionDependencies(),
"ExecutionDependencies":
execution_dependencies,
"SpillbackCount":
task_table_message.SpillbackCount(),
"TaskSpec":
task_spec_info
}
def task_table(self, task_id=None):
"""Fetch and parse the task table information for one or more task IDs.
@@ -337,7 +361,8 @@ class GlobalState(object):
function_info_parsed = {
"DriverID": binary_to_hex(info[b"driver_id"]),
"Module": decode(info[b"module"]),
"Name": decode(info[b"name"])}
"Name": decode(info[b"name"])
}
results[binary_to_hex(info[b"function_id"])] = function_info_parsed
return results
@@ -469,21 +494,17 @@ class GlobalState(object):
if start is None and end is None:
if fwd:
event_list = self.redis_client.zrange(
event_log_set,
**params)
event_log_set, **params)
else:
event_list = self.redis_client.zrevrange(
event_log_set,
**params)
event_log_set, **params)
else:
if fwd:
event_list = self.redis_client.zrangebyscore(
event_log_set,
**params)
event_log_set, **params)
else:
event_list = self.redis_client.zrevrangebyscore(
event_log_set,
**params)
event_log_set, **params)
for (event, score) in event_list:
event_dict = json.loads(event.decode())
@@ -503,11 +524,11 @@ class GlobalState(object):
task_info[task_id]["get_task_start"] = event[0]
if event[1] == "ray:get_task" and event[2] == 2:
task_info[task_id]["get_task_end"] = event[0]
if (event[1] == "ray:import_remote_function" and
event[2] == 1):
if (event[1] == "ray:import_remote_function"
and event[2] == 1):
task_info[task_id]["import_remote_start"] = event[0]
if (event[1] == "ray:import_remote_function" and
event[2] == 2):
if (event[1] == "ray:import_remote_function"
and event[2] == 2):
task_info[task_id]["import_remote_end"] = event[0]
if event[1] == "ray:acquire_lock" and event[2] == 1:
task_info[task_id]["acquire_lock_start"] = event[0]
@@ -547,7 +568,6 @@ class GlobalState(object):
breakdowns=True,
task_dep=True,
obj_dep=True):
"""Dump task profiling information to a file.
This information can be viewed as a timeline of profiling information
@@ -604,72 +624,103 @@ class GlobalState(object):
# modify it in place since we will use the original values later.
total_info = copy.copy(task_table[task_id]["TaskSpec"])
total_info["Args"] = [
oid.hex() if isinstance(oid, ray.local_scheduler.ObjectID)
else oid for oid in task_t_info["TaskSpec"]["Args"]]
oid.hex()
if isinstance(oid, ray.local_scheduler.ObjectID) else oid
for oid in task_t_info["TaskSpec"]["Args"]
]
total_info["ReturnObjectIDs"] = [
oid.hex() for oid
in task_t_info["TaskSpec"]["ReturnObjectIDs"]]
oid.hex() for oid in task_t_info["TaskSpec"]["ReturnObjectIDs"]
]
total_info["LocalSchedulerID"] = task_t_info["LocalSchedulerID"]
total_info["get_arguments"] = (info["get_arguments_end"] -
info["get_arguments_start"])
total_info["execute"] = (info["execute_end"] -
info["execute_start"])
total_info["store_outputs"] = (info["store_outputs_end"] -
info["store_outputs_start"])
total_info["get_arguments"] = (
info["get_arguments_end"] - info["get_arguments_start"])
total_info["execute"] = (
info["execute_end"] - info["execute_start"])
total_info["store_outputs"] = (
info["store_outputs_end"] - info["store_outputs_start"])
total_info["function_name"] = info["function_name"]
total_info["worker_id"] = info["worker_id"]
parent_info = task_info.get(
task_table[task_id]["TaskSpec"]["ParentTaskID"])
task_table[task_id]["TaskSpec"]["ParentTaskID"])
worker = workers[info["worker_id"]]
# The catapult trace format documentation can be found here:
# https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview # noqa: E501
if breakdowns:
if "get_arguments_end" in info:
get_args_trace = {
"cat": "get_arguments",
"pid": "Node " + worker["node_ip_address"],
"tid": info["worker_id"],
"id": task_id,
"ts": micros_rel(info["get_arguments_start"]),
"ph": "X",
"name": info["function_name"] + ":get_arguments",
"args": total_info,
"dur": micros(info["get_arguments_end"] -
info["get_arguments_start"]),
"cname": "rail_idle"
"cat":
"get_arguments",
"pid":
"Node " + worker["node_ip_address"],
"tid":
info["worker_id"],
"id":
task_id,
"ts":
micros_rel(info["get_arguments_start"]),
"ph":
"X",
"name":
info["function_name"] + ":get_arguments",
"args":
total_info,
"dur":
micros(info["get_arguments_end"] -
info["get_arguments_start"]),
"cname":
"rail_idle"
}
full_trace.append(get_args_trace)
if "store_outputs_end" in info:
outputs_trace = {
"cat": "store_outputs",
"pid": "Node " + worker["node_ip_address"],
"tid": info["worker_id"],
"id": task_id,
"ts": micros_rel(info["store_outputs_start"]),
"ph": "X",
"name": info["function_name"] + ":store_outputs",
"args": total_info,
"dur": micros(info["store_outputs_end"] -
info["store_outputs_start"]),
"cname": "thread_state_runnable"
"cat":
"store_outputs",
"pid":
"Node " + worker["node_ip_address"],
"tid":
info["worker_id"],
"id":
task_id,
"ts":
micros_rel(info["store_outputs_start"]),
"ph":
"X",
"name":
info["function_name"] + ":store_outputs",
"args":
total_info,
"dur":
micros(info["store_outputs_end"] -
info["store_outputs_start"]),
"cname":
"thread_state_runnable"
}
full_trace.append(outputs_trace)
if "execute_end" in info:
execute_trace = {
"cat": "execute",
"pid": "Node " + worker["node_ip_address"],
"tid": info["worker_id"],
"id": task_id,
"ts": micros_rel(info["execute_start"]),
"ph": "X",
"name": info["function_name"] + ":execute",
"args": total_info,
"dur": micros(info["execute_end"] -
info["execute_start"]),
"cname": "rail_animation"
"cat":
"execute",
"pid":
"Node " + worker["node_ip_address"],
"tid":
info["worker_id"],
"id":
task_id,
"ts":
micros_rel(info["execute_start"]),
"ph":
"X",
"name":
info["function_name"] + ":execute",
"args":
total_info,
"dur":
micros(info["execute_end"] - info["execute_start"]),
"cname":
"rail_animation"
}
full_trace.append(execute_trace)
@@ -680,15 +731,20 @@ class GlobalState(object):
parent_profile = task_info.get(
task_table[task_id]["TaskSpec"]["ParentTaskID"])
parent = {
"cat": "submit_task",
"pid": "Node " + parent_worker["node_ip_address"],
"tid": parent_info["worker_id"],
"ts": micros_rel(
parent_profile and
parent_profile["get_arguments_start"] or
start_time),
"ph": "s",
"name": "SubmitTask",
"cat":
"submit_task",
"pid":
"Node " + parent_worker["node_ip_address"],
"tid":
parent_info["worker_id"],
"ts":
micros_rel(parent_profile
and parent_profile["get_arguments_start"]
or start_time),
"ph":
"s",
"name":
"SubmitTask",
"args": {},
"id": (parent_info["worker_id"] +
str(micros(min(parent_times))))
@@ -696,32 +752,50 @@ class GlobalState(object):
full_trace.append(parent)
task_trace = {
"cat": "submit_task",
"pid": "Node " + worker["node_ip_address"],
"tid": info["worker_id"],
"ts": micros_rel(info["get_arguments_start"]),
"ph": "f",
"name": "SubmitTask",
"cat":
"submit_task",
"pid":
"Node " + worker["node_ip_address"],
"tid":
info["worker_id"],
"ts":
micros_rel(info["get_arguments_start"]),
"ph":
"f",
"name":
"SubmitTask",
"args": {},
"id": (info["worker_id"] +
str(micros(min(parent_times)))),
"bp": "e",
"cname": "olive"
"id":
(info["worker_id"] + str(micros(min(parent_times)))),
"bp":
"e",
"cname":
"olive"
}
full_trace.append(task_trace)
task = {
"cat": "task",
"pid": "Node " + worker["node_ip_address"],
"tid": info["worker_id"],
"id": task_id,
"ts": micros_rel(info["get_arguments_start"]),
"ph": "X",
"name": info["function_name"],
"args": total_info,
"dur": micros(info["store_outputs_end"] -
info["get_arguments_start"]),
"cname": "thread_state_runnable"
"cat":
"task",
"pid":
"Node " + worker["node_ip_address"],
"tid":
info["worker_id"],
"id":
task_id,
"ts":
micros_rel(info["get_arguments_start"]),
"ph":
"X",
"name":
info["function_name"],
"args":
total_info,
"dur":
micros(info["store_outputs_end"] -
info["get_arguments_start"]),
"cname":
"thread_state_runnable"
}
full_trace.append(task)
@@ -732,15 +806,20 @@ class GlobalState(object):
parent_profile = task_info.get(
task_table[task_id]["TaskSpec"]["ParentTaskID"])
parent = {
"cat": "submit_task",
"pid": "Node " + parent_worker["node_ip_address"],
"tid": parent_info["worker_id"],
"ts": micros_rel(
parent_profile and
parent_profile["get_arguments_start"] or
start_time),
"ph": "s",
"name": "SubmitTask",
"cat":
"submit_task",
"pid":
"Node " + parent_worker["node_ip_address"],
"tid":
parent_info["worker_id"],
"ts":
micros_rel(parent_profile
and parent_profile["get_arguments_start"]
or start_time),
"ph":
"s",
"name":
"SubmitTask",
"args": {},
"id": (parent_info["worker_id"] +
str(micros(min(parent_times))))
@@ -748,16 +827,23 @@ class GlobalState(object):
full_trace.append(parent)
task_trace = {
"cat": "submit_task",
"pid": "Node " + worker["node_ip_address"],
"tid": info["worker_id"],
"ts": micros_rel(info["get_arguments_start"]),
"ph": "f",
"name": "SubmitTask",
"cat":
"submit_task",
"pid":
"Node " + worker["node_ip_address"],
"tid":
info["worker_id"],
"ts":
micros_rel(info["get_arguments_start"]),
"ph":
"f",
"name":
"SubmitTask",
"args": {},
"id": (info["worker_id"] +
str(micros(min(parent_times)))),
"bp": "e"
"id":
(info["worker_id"] + str(micros(min(parent_times)))),
"bp":
"e"
}
full_trace.append(task_trace)
@@ -775,8 +861,8 @@ class GlobalState(object):
seen_obj[arg] += 1
owner_task = self._object_table(arg)["TaskID"]
if owner_task in task_info:
owner_worker = (workers[
task_info[owner_task]["worker_id"]])
owner_worker = (workers[task_info[owner_task][
"worker_id"]])
# Adding/subtracting 2 to the time associated
# with the beginning/ending of the flow event
# is necessary to make the flow events show up
@@ -790,27 +876,35 @@ class GlobalState(object):
# duration event that it's associated with, and
# the flow event therefore always gets drawn.
owner = {
"cat": "obj_dependency",
"cat":
"obj_dependency",
"pid": ("Node " +
owner_worker["node_ip_address"]),
"tid": task_info[owner_task]["worker_id"],
"ts": micros_rel(task_info[
owner_task]["store_outputs_end"]) - 2,
"ph": "s",
"name": "ObjectDependency",
"tid":
task_info[owner_task]["worker_id"],
"ts":
micros_rel(task_info[owner_task]
["store_outputs_end"]) - 2,
"ph":
"s",
"name":
"ObjectDependency",
"args": {},
"bp": "e",
"cname": "cq_build_attempt_failed",
"id": "obj" + str(arg) + str(seen_obj[arg])
"bp":
"e",
"cname":
"cq_build_attempt_failed",
"id":
"obj" + str(arg) + str(seen_obj[arg])
}
full_trace.append(owner)
dependent = {
"cat": "obj_dependency",
"pid": "Node " + worker["node_ip_address"],
"pid": "Node " + worker["node_ip_address"],
"tid": info["worker_id"],
"ts": micros_rel(
info["get_arguments_start"]) + 2,
"ts":
micros_rel(info["get_arguments_start"]) + 2,
"ph": "f",
"name": "ObjectDependency",
"args": {},
@@ -852,14 +946,10 @@ class GlobalState(object):
"""
keys = [
"acquire_lock_start",
"acquire_lock_end",
"get_arguments_start",
"get_arguments_end",
"execute_start",
"execute_end",
"store_outputs_start",
"store_outputs_end"]
"acquire_lock_start", "acquire_lock_end", "get_arguments_start",
"get_arguments_end", "execute_start", "execute_end",
"store_outputs_start", "store_outputs_end"
]
latest_timestamp = 0
for key in keys:
@@ -877,8 +967,8 @@ class GlobalState(object):
local_schedulers = []
for ip_address, client_list in clients.items():
for client in client_list:
if (client["ClientType"] == "local_scheduler" and
not client["Deleted"]):
if (client["ClientType"] == "local_scheduler"
and not client["Deleted"]):
local_schedulers.append(client)
return local_schedulers
@@ -893,8 +983,7 @@ class GlobalState(object):
workers_data[worker_id] = {
"local_scheduler_socket":
(worker_info[b"local_scheduler_socket"]
.decode("ascii")),
(worker_info[b"local_scheduler_socket"].decode("ascii")),
"node_ip_address": (worker_info[b"node_ip_address"]
.decode("ascii")),
"plasma_manager_socket": (worker_info[b"plasma_manager_socket"]
@@ -921,9 +1010,10 @@ class GlobalState(object):
"class_id": binary_to_hex(info[b"class_id"]),
"driver_id": binary_to_hex(info[b"driver_id"]),
"local_scheduler_id":
binary_to_hex(info[b"local_scheduler_id"]),
binary_to_hex(info[b"local_scheduler_id"]),
"num_gpus": int(info[b"num_gpus"]),
"removed": decode(info[b"removed"]) == "True"}
"removed": decode(info[b"removed"]) == "True"
}
return actor_info
def _job_length(self):
@@ -932,21 +1022,16 @@ class GlobalState(object):
overall_largest = 0
num_tasks = 0
for event_log_set in event_log_sets:
fwd_range = self.redis_client.zrange(event_log_set,
start=0,
end=0,
withscores=True)
fwd_range = self.redis_client.zrange(
event_log_set, start=0, end=0, withscores=True)
overall_smallest = min(overall_smallest, fwd_range[0][1])
rev_range = self.redis_client.zrevrange(event_log_set,
start=0,
end=0,
withscores=True)
rev_range = self.redis_client.zrevrange(
event_log_set, start=0, end=0, withscores=True)
overall_largest = max(overall_largest, rev_range[0][1])
num_tasks += self.redis_client.zcount(event_log_set,
min=0,
max=time.time())
num_tasks += self.redis_client.zcount(
event_log_set, min=0, max=time.time())
if num_tasks is 0:
return 0, 0, 0
return overall_smallest, overall_largest, num_tasks
@@ -966,8 +1051,10 @@ class GlobalState(object):
for local_scheduler in local_schedulers:
for key, value in local_scheduler.items():
if key not in ["ClientType", "Deleted", "DBClientID",
"AuxAddress", "LocalSchedulerSocketName"]:
if key not in [
"ClientType", "Deleted", "DBClientID", "AuxAddress",
"LocalSchedulerSocketName"
]:
resources[key] += value
return dict(resources)
+37 -22
View File
@@ -27,6 +27,7 @@ class TensorFlowVariables(object):
placeholders (Dict[str, tf.placeholders]): Placeholders for weights.
assignment_nodes (Dict[str, tf.Tensor]): Nodes that assign weights.
"""
def __init__(self, loss, sess=None, input_variables=None):
"""Creates TensorFlowVariables containing extracted variables.
@@ -74,8 +75,10 @@ class TensorFlowVariables(object):
if "Variable" in tf_obj.node_def.op:
variable_names.append(tf_obj.node_def.name)
self.variables = OrderedDict()
variable_list = [v for v in tf.global_variables()
if v.op.node_def.name in variable_names]
variable_list = [
v for v in tf.global_variables()
if v.op.node_def.name in variable_names
]
if input_variables is not None:
variable_list += input_variables
for v in variable_list:
@@ -86,9 +89,10 @@ class TensorFlowVariables(object):
# Create new placeholders to put in custom weights.
for k, var in self.variables.items():
self.placeholders[k] = tf.placeholder(var.value().dtype,
var.get_shape().as_list(),
name="Placeholder_" + k)
self.placeholders[k] = tf.placeholder(
var.value().dtype,
var.get_shape().as_list(),
name="Placeholder_" + k)
self.assignment_nodes[k] = var.assign(self.placeholders[k])
def set_session(self, sess):
@@ -105,8 +109,9 @@ class TensorFlowVariables(object):
Returns:
The length of all flattened variables concatenated.
"""
return sum([np.prod(v.get_shape().as_list())
for v in self.variables.values()])
return sum([
np.prod(v.get_shape().as_list()) for v in self.variables.values()
])
def _check_sess(self):
"""Checks if the session is set, and if not throw an error message."""
@@ -122,8 +127,10 @@ class TensorFlowVariables(object):
1D Array containing the flattened weights.
"""
self._check_sess()
return np.concatenate([v.eval(session=self.sess).flatten()
for v in self.variables.values()])
return np.concatenate([
v.eval(session=self.sess).flatten()
for v in self.variables.values()
])
def set_flat(self, new_weights):
"""Sets the weights to new_weights, converting from a flat array.
@@ -138,10 +145,12 @@ class TensorFlowVariables(object):
self._check_sess()
shapes = [v.get_shape().as_list() for v in self.variables.values()]
arrays = unflatten(new_weights, shapes)
placeholders = [self.placeholders[k] for k, v
in self.variables.items()]
self.sess.run(list(self.assignment_nodes.values()),
feed_dict=dict(zip(placeholders, arrays)))
placeholders = [
self.placeholders[k] for k, v in self.variables.items()
]
self.sess.run(
list(self.assignment_nodes.values()),
feed_dict=dict(zip(placeholders, arrays)))
def get_weights(self):
"""Returns a dictionary containing the weights of the network.
@@ -150,8 +159,10 @@ class TensorFlowVariables(object):
Dictionary mapping variable names to their weights.
"""
self._check_sess()
return {k: v.eval(session=self.sess) for k, v
in self.variables.items()}
return {
k: v.eval(session=self.sess)
for k, v in self.variables.items()
}
def set_weights(self, new_weights):
"""Sets the weights to new_weights.
@@ -165,15 +176,19 @@ class TensorFlowVariables(object):
weights.
"""
self._check_sess()
assign_list = [self.assignment_nodes[name]
for name in new_weights.keys()
if name in self.assignment_nodes]
assign_list = [
self.assignment_nodes[name] for name in new_weights.keys()
if name in self.assignment_nodes
]
assert assign_list, ("No variables in the input matched those in the "
"network. Possible cause: Two networks were "
"defined in the same TensorFlow graph. To fix "
"this, place each network definition in its own "
"tf.Graph.")
self.sess.run(assign_list,
feed_dict={self.placeholders[name]: value
for (name, value) in new_weights.items()
if name in self.placeholders})
self.sess.run(
assign_list,
feed_dict={
self.placeholders[name]: value
for (name, value) in new_weights.items()
if name in self.placeholders
})
+200 -215
View File
@@ -29,9 +29,9 @@ class _EventRecursionContextManager(object):
total_time_value = "% total time"
total_tasks_value = "% total tasks"
# Function that returns instances of sliders and handles associated events.
def get_sliders(update):
# Start_box value indicates the desired start point of queried window.
start_box = widgets.FloatText(
@@ -60,18 +60,14 @@ def get_sliders(update):
# Indicates the number of tasks that the user wants to be returned. Is
# disabled when the breakdown_opt value is set to total_time_value.
num_tasks_box = widgets.IntText(
description="Num Tasks:",
disabled=False
)
num_tasks_box = widgets.IntText(description="Num Tasks:", disabled=False)
# Dropdown bar that lets the user choose between modifying % of total
# time or total number of tasks.
breakdown_opt = widgets.Dropdown(
options=[total_time_value, total_tasks_value],
value=total_tasks_value,
description="Selection Options:"
)
description="Selection Options:")
# Display box for layout.
total_time_box = widgets.VBox([start_box, end_box])
@@ -105,9 +101,9 @@ def get_sliders(update):
if event == INIT_EVENT:
if breakdown_opt.value == total_tasks_value:
num_tasks_box.value = -min(10000, num_tasks)
range_slider.value = (int(100 -
(100. * -num_tasks_box.value) /
num_tasks), 100)
range_slider.value = (int(
100 - (100. * -num_tasks_box.value) / num_tasks),
100)
else:
low, high = map(lambda x: x / 100., range_slider.value)
start_box.value = round(diff * low, 2)
@@ -120,8 +116,8 @@ def get_sliders(update):
elif start_box.value < 0:
start_box.value = 0
low, high = range_slider.value
range_slider.value = (int((start_box.value * 100.) /
diff), high)
range_slider.value = (int((start_box.value * 100.) / diff),
high)
# Event was triggered by a change in the end_box value.
elif event["owner"] == end_box:
@@ -130,8 +126,8 @@ def get_sliders(update):
elif end_box.value > diff:
end_box.value = diff
low, high = range_slider.value
range_slider.value = (low, int((end_box.value * 100.) /
diff))
range_slider.value = (low,
int((end_box.value * 100.) / diff))
# Event was triggered by a change in the breakdown options
# toggle.
@@ -145,9 +141,9 @@ def get_sliders(update):
# Make CSS display go back to the default settings.
num_tasks_box.layout.display = None
num_tasks_box.value = min(10000, num_tasks)
range_slider.value = (int(100 -
(100. * num_tasks_box.value) /
num_tasks), 100)
range_slider.value = (int(
100 - (100. * num_tasks_box.value) / num_tasks),
100)
else:
start_box.disabled = False
end_box.disabled = False
@@ -156,10 +152,9 @@ def get_sliders(update):
# Make CSS display go back to the default settings.
total_time_box.layout.display = None
num_tasks_box.layout.display = 'none'
range_slider.value = (int((start_box.value * 100.) /
diff),
int((end_box.value * 100.) /
diff))
range_slider.value = (
int((start_box.value * 100.) / diff),
int((end_box.value * 100.) / diff))
# Event was triggered by a change in the range_slider
# value.
@@ -170,8 +165,8 @@ def get_sliders(update):
new_low, new_high = event["new"]
if old_low != new_low:
range_slider.value = (new_low, 100)
num_tasks_box.value = (-(100. - new_low) /
100. * num_tasks)
num_tasks_box.value = (
-(100. - new_low) / 100. * num_tasks)
else:
range_slider.value = (0, new_high)
num_tasks_box.value = new_high / 100. * num_tasks
@@ -183,14 +178,12 @@ def get_sliders(update):
# value.
elif event["owner"] == num_tasks_box:
if num_tasks_box.value > 0:
range_slider.value = (0, int(100 *
float(num_tasks_box.value) /
num_tasks))
range_slider.value = (
0, int(
100 * float(num_tasks_box.value) / num_tasks))
elif num_tasks_box.value < 0:
range_slider.value = (100 +
int(100 *
float(num_tasks_box.value) /
num_tasks), 100)
range_slider.value = (100 + int(
100 * float(num_tasks_box.value) / num_tasks), 100)
if not update:
return
@@ -205,23 +198,20 @@ def get_sliders(update):
# box values.
# (Querying based on the % total amount of time.)
if breakdown_opt.value == total_time_value:
tasks = _truncated_task_profiles(start=(smallest +
diff * low),
end=(smallest +
diff * high))
tasks = _truncated_task_profiles(
start=(smallest + diff * low),
end=(smallest + diff * high))
# (Querying based on % of total number of tasks that were
# run.)
elif breakdown_opt.value == total_tasks_value:
if range_slider.value[0] == 0:
tasks = _truncated_task_profiles(num_tasks=(int(
num_tasks * high)),
fwd=True)
tasks = _truncated_task_profiles(
num_tasks=(int(num_tasks * high)), fwd=True)
else:
tasks = _truncated_task_profiles(num_tasks=(int(
num_tasks *
(high - low))),
fwd=False)
tasks = _truncated_task_profiles(
num_tasks=(int(num_tasks * (high - low))),
fwd=False)
update(smallest, largest, num_tasks, tasks)
@@ -237,8 +227,8 @@ def get_sliders(update):
update_wrapper(INIT_EVENT)
# Display sliders and search boxes
display(breakdown_opt, widgets.HBox([range_slider, total_time_box,
num_tasks_box]))
display(breakdown_opt,
widgets.HBox([range_slider, total_time_box, num_tasks_box]))
# Return the sliders and text boxes
return start_box, end_box, range_slider, breakdown_opt
@@ -249,8 +239,7 @@ def object_search_bar():
value="",
placeholder="Object ID",
description="Search for an object:",
disabled=False
)
disabled=False)
display(object_search)
def handle_submit(sender):
@@ -265,8 +254,7 @@ def task_search_bar():
value="",
placeholder="Task ID",
description="Search for a task:",
disabled=False
)
disabled=False)
display(task_search)
def handle_submit(sender):
@@ -284,14 +272,12 @@ MAX_TASKS_TO_VISUALIZE = 10000
def _truncated_task_profiles(start=None, end=None, num_tasks=None, fwd=True):
if num_tasks is None:
num_tasks = MAX_TASKS_TO_VISUALIZE
print(
"Warning: at most {} tasks will be fetched within this "
"time range.".format(MAX_TASKS_TO_VISUALIZE))
print("Warning: at most {} tasks will be fetched within this "
"time range.".format(MAX_TASKS_TO_VISUALIZE))
elif num_tasks > MAX_TASKS_TO_VISUALIZE:
print(
"Warning: too many tasks to visualize, "
"fetching only the first {} of {}.".format(
MAX_TASKS_TO_VISUALIZE, num_tasks))
print("Warning: too many tasks to visualize, "
"fetching only the first {} of {}.".format(
MAX_TASKS_TO_VISUALIZE, num_tasks))
num_tasks = MAX_TASKS_TO_VISUALIZE
return ray.global_state.task_profiles(num_tasks, start, end, fwd)
@@ -299,9 +285,8 @@ def _truncated_task_profiles(start=None, end=None, num_tasks=None, fwd=True):
# Helper function that guarantees unique and writeable temp files.
# Prevents clashes in task trace files when multiple notebooks are running.
def _get_temp_file_path(**kwargs):
temp_file = tempfile.NamedTemporaryFile(delete=False,
dir=os.getcwd(),
**kwargs)
temp_file = tempfile.NamedTemporaryFile(
delete=False, dir=os.getcwd(), **kwargs)
temp_file_path = temp_file.name
temp_file.close()
return os.path.relpath(temp_file_path)
@@ -319,22 +304,16 @@ def task_timeline():
disabled=False,
)
obj_dep = widgets.Checkbox(
value=True,
disabled=False,
layout=widgets.Layout(width='20px')
)
value=True, disabled=False, layout=widgets.Layout(width='20px'))
task_dep = widgets.Checkbox(
value=True,
disabled=False,
layout=widgets.Layout(width='20px')
)
value=True, disabled=False, layout=widgets.Layout(width='20px'))
# Labels to bypass width limitation for descriptions.
label_tasks = widgets.Label(value='Task submissions',
layout=widgets.Layout(width='110px'))
label_objects = widgets.Label(value='Object dependencies',
layout=widgets.Layout(width='130px'))
label_options = widgets.Label(value='View options:',
layout=widgets.Layout(width='100px'))
label_tasks = widgets.Label(
value='Task submissions', layout=widgets.Layout(width='110px'))
label_objects = widgets.Label(
value='Object dependencies', layout=widgets.Layout(width='130px'))
label_options = widgets.Label(
value='View options:', layout=widgets.Layout(width='100px'))
start_box, end_box, range_slider, time_opt = get_sliders(False)
display(widgets.HBox([task_dep, label_tasks, obj_dep, label_objects]))
display(widgets.HBox([label_options, breakdown_opt]))
@@ -344,8 +323,9 @@ def task_timeline():
# current working directory if it is not present.
if not os.path.exists("trace_viewer_full.html"):
shutil.copy(
os.path.join(os.path.dirname(os.path.abspath(__file__)),
"../core/src/catapult_files/trace_viewer_full.html"),
os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"../core/src/catapult_files/trace_viewer_full.html"),
"trace_viewer_full.html")
def handle_submit(sender):
@@ -357,8 +337,8 @@ def task_timeline():
elif breakdown_opt.value == breakdown_task:
breakdown = True
else:
raise ValueError(
"Unexpected breakdown value '{}'".format(breakdown_opt.value))
raise ValueError("Unexpected breakdown value '{}'".format(
breakdown_opt.value))
low, high = map(lambda x: x / 100., range_slider.value)
@@ -366,30 +346,28 @@ def task_timeline():
diff = largest - smallest
if time_opt.value == total_time_value:
tasks = _truncated_task_profiles(start=smallest + diff * low,
end=smallest + diff * high)
tasks = _truncated_task_profiles(
start=smallest + diff * low, end=smallest + diff * high)
elif time_opt.value == total_tasks_value:
if range_slider.value[0] == 0:
tasks = _truncated_task_profiles(num_tasks=int(
num_tasks * high),
fwd=True)
tasks = _truncated_task_profiles(
num_tasks=int(num_tasks * high), fwd=True)
else:
tasks = _truncated_task_profiles(num_tasks=int(
num_tasks * (high - low)),
fwd=False)
tasks = _truncated_task_profiles(
num_tasks=int(num_tasks * (high - low)), fwd=False)
else:
raise ValueError("Unexpected time value '{}'".format(
time_opt.value))
time_opt.value))
# Write trace to a JSON file
print("Collected profiles for {} tasks.".format(len(tasks)))
print(
"Dumping task profile data to {}, "
"this might take a while...".format(json_tmp))
ray.global_state.dump_catapult_trace(json_tmp,
tasks,
breakdowns=breakdown,
obj_dep=obj_dep.value,
task_dep=task_dep.value)
print("Dumping task profile data to {}, "
"this might take a while...".format(json_tmp))
ray.global_state.dump_catapult_trace(
json_tmp,
tasks,
breakdowns=breakdown,
obj_dep=obj_dep.value,
task_dep=task_dep.value)
print("Opening html file in browser...")
trace_viewer_path = os.path.join(
@@ -415,9 +393,8 @@ def task_timeline():
# Display the task trace within the Jupyter notebook
clear_output(wait=True)
print(
"To view fullscreen, open chrome://tracing in Google Chrome "
"and load `{}`".format(json_tmp))
print("To view fullscreen, open chrome://tracing in Google Chrome "
"and load `{}`".format(json_tmp))
display(IFrame(html_file_path, 900, 800))
path_input.on_click(handle_submit)
@@ -432,36 +409,41 @@ def task_completion_time_distribution():
output_notebook(resources=CDN)
# Create the Bokeh plot
p = figure(title="Task Completion Time Distribution",
tools=["save", "hover", "wheel_zoom", "box_zoom", "pan"],
background_fill_color="#FFFFFF",
x_range=(0, 1),
y_range=(0, 1))
p = figure(
title="Task Completion Time Distribution",
tools=["save", "hover", "wheel_zoom", "box_zoom", "pan"],
background_fill_color="#FFFFFF",
x_range=(0, 1),
y_range=(0, 1))
# Create the data source that the plot pulls from
source = ColumnDataSource(data={
"top": [],
"left": [],
"right": []
})
source = ColumnDataSource(data={"top": [], "left": [], "right": []})
# Plot the histogram rectangles
p.quad(top="top", bottom=0, left="left", right="right", source=source,
fill_color="#B3B3B3", line_color="#033649")
p.quad(
top="top",
bottom=0,
left="left",
right="right",
source=source,
fill_color="#B3B3B3",
line_color="#033649")
# Label the plot axes
p.xaxis.axis_label = "Duration in seconds"
p.yaxis.axis_label = "Number of tasks"
handle = show(gridplot(p, ncols=1,
plot_width=500,
plot_height=500,
toolbar_location="below"), notebook_handle=True)
handle = show(
gridplot(
p,
ncols=1,
plot_width=500,
plot_height=500,
toolbar_location="below"),
notebook_handle=True)
# Function to update the plot
def task_completion_time_update(abs_earliest,
abs_latest,
abs_num_tasks,
def task_completion_time_update(abs_earliest, abs_latest, abs_num_tasks,
tasks):
if len(tasks) == 0:
return
@@ -469,8 +451,8 @@ def task_completion_time_distribution():
# Create the distribution to plot
distr = []
for task_id, data in tasks.items():
distr.append(data["store_outputs_end"] -
data["get_arguments_start"])
distr.append(
data["store_outputs_end"] - data["get_arguments_start"])
# Create a histogram from the distribution
top, bin_edges = np.histogram(distr, bins="auto")
@@ -480,8 +462,8 @@ def task_completion_time_distribution():
source.data = {"top": top, "left": left, "right": right}
# Set the x and y ranges
x_range = (min(left) if len(left) else 0,
max(right) if len(right) else 1)
x_range = (min(left) if len(left) else 0, max(right)
if len(right) else 1)
y_range = (0, max(top) + 1 if len(top) else 1)
x_range = helpers._get_range(x_range)
@@ -517,8 +499,7 @@ def compute_utilizations(abs_earliest,
latest_time = 0
for task_id, data in tasks.items():
latest_time = max((latest_time, data["store_outputs_end"]))
earliest_time = min((earliest_time,
data["get_arguments_start"]))
earliest_time = min((earliest_time, data["get_arguments_start"]))
# Add some epsilon to latest_time to ensure that the end time of the
# last task falls __within__ a bucket, and not on the edge
@@ -533,37 +514,37 @@ def compute_utilizations(abs_earliest,
task_start_time = data["get_arguments_start"]
task_end_time = data["store_outputs_end"]
start_bucket = int((task_start_time - earliest_time) /
bucket_time_length)
end_bucket = int((task_end_time - earliest_time) /
bucket_time_length)
start_bucket = int(
(task_start_time - earliest_time) / bucket_time_length)
end_bucket = int((task_end_time - earliest_time) / bucket_time_length)
# Walk over each time bucket that this task intersects, adding the
# amount of time that the task intersects within each bucket
for bucket_idx in range(start_bucket, end_bucket + 1):
bucket_start_time = ((earliest_time + bucket_idx) *
bucket_time_length)
bucket_end_time = ((earliest_time + (bucket_idx + 1)) *
bucket_time_length)
bucket_start_time = ((
earliest_time + bucket_idx) * bucket_time_length)
bucket_end_time = ((earliest_time +
(bucket_idx + 1)) * bucket_time_length)
task_start_time_within_bucket = max(task_start_time,
bucket_start_time)
task_end_time_within_bucket = min(task_end_time,
bucket_end_time)
task_cpu_time_within_bucket = (task_end_time_within_bucket -
task_start_time_within_bucket)
task_end_time_within_bucket = min(task_end_time, bucket_end_time)
task_cpu_time_within_bucket = (
task_end_time_within_bucket - task_start_time_within_bucket)
if bucket_idx > -1 and bucket_idx < num_buckets:
cpu_time[bucket_idx] += task_cpu_time_within_bucket
# Cpu_utilization is the average cpu utilization of the bucket, which
# is just cpu_time divided by bucket_time_length.
cpu_utilization = list(map(lambda x: x / float(bucket_time_length),
cpu_time))
cpu_utilization = list(
map(lambda x: x / float(bucket_time_length), cpu_time))
# Generate histogram bucket edges. Subtract out abs_earliest to get
# relative time.
all_edges = [earliest_time - abs_earliest + i * bucket_time_length
for i in range(num_buckets + 1)]
all_edges = [
earliest_time - abs_earliest + i * bucket_time_length
for i in range(num_buckets + 1)
]
# Left edges are all but the rightmost edge, right edges are all but
# the leftmost edge.
left_edges = all_edges[:-1]
@@ -591,54 +572,53 @@ def cpu_usage():
# Update the plot based on the sliders
def plot_utilization():
# Create the Bokeh plot
time_series_fig = figure(title="CPU Utilization",
tools=["save", "hover", "wheel_zoom",
"box_zoom", "pan"],
background_fill_color="#FFFFFF",
x_range=[0, 1],
y_range=[0, 1])
time_series_fig = figure(
title="CPU Utilization",
tools=["save", "hover", "wheel_zoom", "box_zoom", "pan"],
background_fill_color="#FFFFFF",
x_range=[0, 1],
y_range=[0, 1])
# Create the data source that the plot will pull from
time_series_source = ColumnDataSource(data=dict(
left=[],
right=[],
top=[]
))
time_series_source = ColumnDataSource(
data=dict(left=[], right=[], top=[]))
# Plot the rectangles representing the distribution
time_series_fig.quad(left="left",
right="right",
top="top",
bottom=0,
source=time_series_source,
fill_color="#B3B3B3",
line_color="#033649")
time_series_fig.quad(
left="left",
right="right",
top="top",
bottom=0,
source=time_series_source,
fill_color="#B3B3B3",
line_color="#033649")
# Label the plot axes
time_series_fig.xaxis.axis_label = "Time in seconds"
time_series_fig.yaxis.axis_label = "Number of CPUs used"
handle = show(gridplot(time_series_fig,
ncols=1,
plot_width=500,
plot_height=500,
toolbar_location="below"), notebook_handle=True)
handle = show(
gridplot(
time_series_fig,
ncols=1,
plot_width=500,
plot_height=500,
toolbar_location="below"),
notebook_handle=True)
def update_plot(abs_earliest, abs_latest, abs_num_tasks, tasks):
num_buckets = 100
left, right, top = compute_utilizations(abs_earliest,
abs_latest,
abs_num_tasks,
tasks,
num_buckets)
left, right, top = compute_utilizations(
abs_earliest, abs_latest, abs_num_tasks, tasks, num_buckets)
time_series_source.data = {"left": left,
"right": right,
"top": top}
time_series_source.data = {
"left": left,
"right": right,
"top": top
}
x_range = (max(0, min(left))
if len(left) else 0,
max(right) if len(right) else 1)
x_range = (max(0, min(left)) if len(left) else 0, max(right)
if len(right) else 1)
y_range = (0, max(top) + 1 if len(top) else 1)
# Define the axis ranges
@@ -654,6 +634,7 @@ def cpu_usage():
push_notebook(handle=handle)
get_sliders(update_plot)
plot_utilization()
@@ -672,33 +653,32 @@ def cluster_usage():
output_notebook(resources=CDN)
# Initial values
source = ColumnDataSource(data={"node_ip_address": ['127.0.0.1'],
"time": ['0.5'],
"num_tasks": ['1'],
"length": [1]})
source = ColumnDataSource(
data={
"node_ip_address": ['127.0.0.1'],
"time": ['0.5'],
"num_tasks": ['1'],
"length": [1]
})
# Define the color schema
colors = ["#75968f",
"#a5bab7",
"#c9d9d3",
"#e2e2e2",
"#dfccce",
"#ddb7b1",
"#cc7878",
"#933b41",
"#550b1d"]
colors = [
"#75968f", "#a5bab7", "#c9d9d3", "#e2e2e2", "#dfccce", "#ddb7b1",
"#cc7878", "#933b41", "#550b1d"
]
mapper = LinearColorMapper(palette=colors, low=0, high=2)
TOOLS = "hover, save, xpan, box_zoom, reset, xwheel_zoom"
# Create the plot
p = figure(title="Cluster Usage",
y_range=list(set(source.data['node_ip_address'])),
x_axis_location="above",
plot_width=900,
plot_height=500,
tools=TOOLS,
toolbar_location='below')
p = figure(
title="Cluster Usage",
y_range=list(set(source.data['node_ip_address'])),
x_axis_location="above",
plot_width=900,
plot_height=500,
tools=TOOLS,
toolbar_location='below')
# Format the plot axes
p.grid.grid_line_color = None
@@ -709,26 +689,33 @@ def cluster_usage():
p.xaxis.major_label_orientation = np.pi / 3
# Plot rectangles
p.rect(x="time", y="node_ip_address", width="length", height=1,
source=source,
fill_color={"field": "num_tasks", "transform": mapper},
line_color=None)
p.rect(
x="time",
y="node_ip_address",
width="length",
height=1,
source=source,
fill_color={
"field": "num_tasks",
"transform": mapper
},
line_color=None)
# Add legend to the side of the plot
color_bar = ColorBar(color_mapper=mapper,
major_label_text_font_size="8pt",
ticker=BasicTicker(desired_num_ticks=len(colors)),
label_standoff=6,
border_line_color=None,
location=(0, 0))
color_bar = ColorBar(
color_mapper=mapper,
major_label_text_font_size="8pt",
ticker=BasicTicker(desired_num_ticks=len(colors)),
label_standoff=6,
border_line_color=None,
location=(0, 0))
p.add_layout(color_bar, "right")
# Define hover tool
p.select_one(HoverTool).tooltips = [
("Node IP Address", "@node_ip_address"),
("Number of tasks running", "@num_tasks"),
("Time", "@time")
]
p.select_one(HoverTool).tooltips = [("Node IP Address",
"@node_ip_address"),
("Number of tasks running",
"@num_tasks"), ("Time", "@time")]
# Define the axis labels
p.xaxis.axis_label = "Time in seconds"
@@ -764,12 +751,8 @@ def cluster_usage():
num_tasks = []
for node_ip, task_dict in node_to_tasks.items():
left, right, top = compute_utilizations(earliest,
latest,
abs_num_tasks,
task_dict,
100,
True)
left, right, top = compute_utilizations(
earliest, latest, abs_num_tasks, task_dict, 100, True)
for (l, r, t) in zip(left, right, top):
nodes.append(node_ip)
times.append((l + r) / 2)
@@ -783,10 +766,12 @@ def cluster_usage():
mapper.high = max(max(num_tasks), 1)
# Update plot with new data based on slider and text box values
source.data = {"node_ip_address": nodes,
"time": times,
"num_tasks": num_tasks,
"length": lengths}
source.data = {
"node_ip_address": nodes,
"time": times,
"num_tasks": num_tasks,
"length": lengths
}
push_notebook(handle=handle)
@@ -7,9 +7,12 @@ import subprocess
import time
def start_global_scheduler(redis_address, node_ip_address,
use_valgrind=False, use_profiler=False,
stdout_file=None, stderr_file=None):
def start_global_scheduler(redis_address,
node_ip_address,
use_valgrind=False,
use_profiler=False,
stdout_file=None,
stderr_file=None):
"""Start a global scheduler process.
Args:
@@ -33,21 +36,24 @@ def start_global_scheduler(redis_address, node_ip_address,
global_scheduler_executable = os.path.join(
os.path.abspath(os.path.dirname(__file__)),
"../core/src/global_scheduler/global_scheduler")
command = [global_scheduler_executable,
"-r", redis_address,
"-h", node_ip_address]
command = [
global_scheduler_executable, "-r", redis_address, "-h", node_ip_address
]
if use_valgrind:
pid = subprocess.Popen(["valgrind",
"--track-origins=yes",
"--leak-check=full",
"--show-leak-kinds=all",
"--leak-check-heuristics=stdstring",
"--error-exitcode=1"] + command,
stdout=stdout_file, stderr=stderr_file)
pid = subprocess.Popen(
[
"valgrind", "--track-origins=yes", "--leak-check=full",
"--show-leak-kinds=all", "--leak-check-heuristics=stdstring",
"--error-exitcode=1"
] + command,
stdout=stdout_file,
stderr=stderr_file)
time.sleep(1.0)
elif use_profiler:
pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command,
stdout=stdout_file, stderr=stderr_file)
pid = subprocess.Popen(
["valgrind", "--tool=callgrind"] + command,
stdout=stdout_file,
stderr=stderr_file)
time.sleep(1.0)
else:
pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
+39 -33
View File
@@ -56,7 +56,6 @@ def new_port():
class TestGlobalScheduler(unittest.TestCase):
def setUp(self):
# Start one Redis server and N pairs of (plasma, local_scheduler)
self.node_ip_address = "127.0.0.1"
@@ -164,17 +163,19 @@ class TestGlobalScheduler(unittest.TestCase):
return db_client_id
def test_task_default_resources(self):
task1 = local_scheduler.Task(random_driver_id(), random_function_id(),
[random_object_id()], 0, random_task_id(),
0)
task1 = local_scheduler.Task(
random_driver_id(), random_function_id(), [random_object_id()], 0,
random_task_id(), 0)
self.assertEqual(task1.required_resources(), {"CPU": 1})
task2 = local_scheduler.Task(random_driver_id(), random_function_id(),
[random_object_id()], 0, random_task_id(),
0, local_scheduler.ObjectID(NIL_ACTOR_ID),
local_scheduler.ObjectID(NIL_OBJECT_ID),
local_scheduler.ObjectID(NIL_ACTOR_ID),
local_scheduler.ObjectID(NIL_ACTOR_ID),
0, 0, [], {"CPU": 1, "GPU": 2})
task2 = local_scheduler.Task(
random_driver_id(), random_function_id(), [random_object_id()], 0,
random_task_id(), 0, local_scheduler.ObjectID(NIL_ACTOR_ID),
local_scheduler.ObjectID(NIL_OBJECT_ID),
local_scheduler.ObjectID(NIL_ACTOR_ID),
local_scheduler.ObjectID(NIL_ACTOR_ID), 0, 0, [], {
"CPU": 1,
"GPU": 2
})
self.assertEqual(task2.required_resources(), {"CPU": 1, "GPU": 2})
def test_redis_only_single_task(self):
@@ -189,7 +190,7 @@ class TestGlobalScheduler(unittest.TestCase):
len(self.state.client_table()[self.node_ip_address]),
2 * NUM_CLUSTER_NODES + 1)
db_client_id = self.get_plasma_manager_id()
assert(db_client_id is not None)
assert (db_client_id is not None)
@unittest.skipIf(
os.environ.get('RAY_USE_NEW_GCS', False),
@@ -227,9 +228,10 @@ class TestGlobalScheduler(unittest.TestCase):
if len(task_entries) == 1:
task_id, task = task_entries.popitem()
task_status = task["State"]
self.assertTrue(task_status in [state.TASK_STATUS_WAITING,
state.TASK_STATUS_SCHEDULED,
state.TASK_STATUS_QUEUED])
self.assertTrue(task_status in [
state.TASK_STATUS_WAITING, state.TASK_STATUS_SCHEDULED,
state.TASK_STATUS_QUEUED
])
if task_status == state.TASK_STATUS_QUEUED:
break
else:
@@ -258,17 +260,14 @@ class TestGlobalScheduler(unittest.TestCase):
data_size = np.random.randint(1 << 12)
metadata_size = np.random.randint(1 << 9)
plasma_client = self.plasma_clients[0]
object_dep, memory_buffer, metadata = create_object(plasma_client,
data_size,
metadata_size,
seal=True)
object_dep, memory_buffer, metadata = create_object(
plasma_client, data_size, metadata_size, seal=True)
if timesync:
# Give 10ms for object info handler to fire (long enough to
# yield CPU).
time.sleep(0.010)
task = local_scheduler.Task(
random_driver_id(),
random_function_id(),
random_driver_id(), random_function_id(),
[local_scheduler.ObjectID(object_dep.binary())],
num_return_vals[0], random_task_id(), 0)
self.local_scheduler_clients[0].submit(task)
@@ -281,12 +280,18 @@ class TestGlobalScheduler(unittest.TestCase):
self.assertLessEqual(len(task_entries), num_tasks)
# First, check if all tasks made it to Redis.
if len(task_entries) == num_tasks:
task_statuses = [task_entry["State"] for task_entry in
task_entries.values()]
self.assertTrue(all([status in [state.TASK_STATUS_WAITING,
state.TASK_STATUS_SCHEDULED,
state.TASK_STATUS_QUEUED]
for status in task_statuses]))
task_statuses = [
task_entry["State"]
for task_entry in task_entries.values()
]
self.assertTrue(
all([
status in [
state.TASK_STATUS_WAITING,
state.TASK_STATUS_SCHEDULED,
state.TASK_STATUS_QUEUED
] for status in task_statuses
]))
num_tasks_done = task_statuses.count(state.TASK_STATUS_QUEUED)
num_tasks_scheduled = task_statuses.count(
state.TASK_STATUS_SCHEDULED)
@@ -294,12 +299,13 @@ class TestGlobalScheduler(unittest.TestCase):
state.TASK_STATUS_WAITING)
print("tasks in Redis = {}, tasks waiting = {}, "
"tasks scheduled = {}, "
"tasks queued = {}, retries left = {}"
.format(len(task_entries), num_tasks_waiting,
num_tasks_scheduled, num_tasks_done,
num_retries))
if all([status == state.TASK_STATUS_QUEUED for status in
task_statuses]):
"tasks queued = {}, retries left = {}".format(
len(task_entries), num_tasks_waiting,
num_tasks_scheduled, num_tasks_done, num_retries))
if all([
status == state.TASK_STATUS_QUEUED
for status in task_statuses
]):
# We're done, so pass.
break
num_retries -= 1
+5 -3
View File
@@ -7,6 +7,8 @@ from ray.core.src.local_scheduler.liblocal_scheduler_library import (
task_to_string, _config, common_error)
from .local_scheduler_services import start_local_scheduler
__all__ = ["Task", "LocalSchedulerClient", "ObjectID", "check_simple_value",
"task_from_string", "task_to_string", "start_local_scheduler",
"_config", "common_error"]
__all__ = [
"Task", "LocalSchedulerClient", "ObjectID", "check_simple_value",
"task_from_string", "task_to_string", "start_local_scheduler", "_config",
"common_error"
]
@@ -68,15 +68,15 @@ def start_local_scheduler(plasma_store_name,
"provided.")
if use_valgrind and use_profiler:
raise Exception("Cannot use valgrind and profiler at the same time.")
local_scheduler_executable = os.path.join(os.path.dirname(
os.path.abspath(__file__)),
local_scheduler_executable = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"../core/src/local_scheduler/local_scheduler")
local_scheduler_name = "/tmp/scheduler{}".format(random_name())
command = [local_scheduler_executable,
"-s", local_scheduler_name,
"-p", plasma_store_name,
"-h", node_ip_address,
"-n", str(num_workers)]
command = [
local_scheduler_executable, "-s", local_scheduler_name, "-p",
plasma_store_name, "-h", node_ip_address, "-n",
str(num_workers)
]
if plasma_manager_name is not None:
command += ["-m", plasma_manager_name]
if worker_path is not None:
@@ -88,14 +88,11 @@ def start_local_scheduler(plasma_store_name,
"--object-store-name={} "
"--object-store-manager-name={} "
"--local-scheduler-name={} "
"--redis-address={}"
.format(sys.executable,
worker_path,
node_ip_address,
plasma_store_name,
plasma_manager_name,
local_scheduler_name,
redis_address))
"--redis-address={}".format(
sys.executable, worker_path,
node_ip_address, plasma_store_name,
plasma_manager_name, local_scheduler_name,
redis_address))
command += ["-w", start_worker_command]
if redis_address is not None:
command += ["-r", redis_address]
@@ -104,27 +101,31 @@ def start_local_scheduler(plasma_store_name,
if static_resources is not None:
resource_argument = ""
for resource_name, resource_quantity in static_resources.items():
assert (isinstance(resource_quantity, int) or
isinstance(resource_quantity, float))
resource_argument = ",".join(
[resource_name + "," + str(resource_quantity)
for resource_name, resource_quantity in static_resources.items()])
assert (isinstance(resource_quantity, int)
or isinstance(resource_quantity, float))
resource_argument = ",".join([
resource_name + "," + str(resource_quantity)
for resource_name, resource_quantity in static_resources.items()
])
else:
resource_argument = "CPU,{}".format(psutil.cpu_count())
command += ["-c", resource_argument]
if use_valgrind:
pid = subprocess.Popen(["valgrind",
"--track-origins=yes",
"--leak-check=full",
"--show-leak-kinds=all",
"--leak-check-heuristics=stdstring",
"--error-exitcode=1"] + command,
stdout=stdout_file, stderr=stderr_file)
pid = subprocess.Popen(
[
"valgrind", "--track-origins=yes", "--leak-check=full",
"--show-leak-kinds=all", "--leak-check-heuristics=stdstring",
"--error-exitcode=1"
] + command,
stdout=stdout_file,
stderr=stderr_file)
time.sleep(1.0)
elif use_profiler:
pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command,
stdout=stdout_file, stderr=stderr_file)
pid = subprocess.Popen(
["valgrind", "--tool=callgrind"] + command,
stdout=stdout_file,
stderr=stderr_file)
time.sleep(1.0)
else:
pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
+13 -29
View File
@@ -37,7 +37,6 @@ def random_function_id():
class TestLocalSchedulerClient(unittest.TestCase):
def setUp(self):
# Start Plasma store.
plasma_store_name, self.p1 = plasma.start_plasma_store()
@@ -74,34 +73,17 @@ class TestLocalSchedulerClient(unittest.TestCase):
self.plasma_client.create(pa.plasma.ObjectID(object_id.id()), 0)
self.plasma_client.seal(pa.plasma.ObjectID(object_id.id()))
# Define some arguments to use for the tasks.
args_list = [
[],
[{}],
[()],
1 * [1],
10 * [1],
100 * [1],
1000 * [1],
1 * ["a"],
10 * ["a"],
100 * ["a"],
1000 * ["a"],
[1, 1.3, 1 << 100, "hi", u"hi", [1, 2]],
object_ids[:1],
object_ids[:2],
object_ids[:3],
object_ids[:4],
object_ids[:5],
object_ids[:10],
object_ids[:100],
object_ids[:256],
[1, object_ids[0]],
[object_ids[0], "a"],
[1, object_ids[0], "a"],
[object_ids[0], 1, object_ids[1], "a"],
object_ids[:3] + [1, "hi", 2.3] + object_ids[:5],
object_ids + 100 * ["a"] + object_ids
]
args_list = [[], [{}], [()], 1 * [1], 10 * [1], 100 * [1], 1000 * [1],
1 * ["a"], 10 * ["a"], 100 * ["a"], 1000 * ["a"], [
1, 1.3, 1 << 100, "hi", u"hi", [1, 2]
], object_ids[:1], object_ids[:2], object_ids[:3],
object_ids[:4], object_ids[:5], object_ids[:10],
object_ids[:100], object_ids[:256], [1, object_ids[0]], [
object_ids[0], "a"
], [1, object_ids[0], "a"], [
object_ids[0], 1, object_ids[1], "a"
], object_ids[:3] + [1, "hi", 2.3] + object_ids[:5],
object_ids + 100 * ["a"] + object_ids]
for args in args_list:
for num_return_vals in [0, 1, 2, 3, 5, 10, 100]:
@@ -146,6 +128,7 @@ class TestLocalSchedulerClient(unittest.TestCase):
# Launch a thread to get the task.
def get_task():
self.local_scheduler_client.get_task()
t = threading.Thread(target=get_task)
t.start()
# Sleep to give the thread time to call get_task.
@@ -170,6 +153,7 @@ class TestLocalSchedulerClient(unittest.TestCase):
# Launch a thread to get the task.
def get_task():
self.local_scheduler_client.get_task()
t = threading.Thread(target=get_task)
t.start()
+21 -14
View File
@@ -26,11 +26,12 @@ class LogMonitor(object):
log_file_handles: A dictionary mapping the name of a log file to a file
handle for that file.
"""
def __init__(self, redis_ip_address, redis_port, node_ip_address):
"""Initialize the log monitor object."""
self.node_ip_address = node_ip_address
self.redis_client = redis.StrictRedis(host=redis_ip_address,
port=redis_port)
self.redis_client = redis.StrictRedis(
host=redis_ip_address, port=redis_port)
self.log_files = {}
self.log_file_handles = {}
self.files_to_ignore = set()
@@ -38,9 +39,8 @@ class LogMonitor(object):
def update_log_filenames(self):
"""Get the most up-to-date list of log files to monitor from Redis."""
num_current_log_files = len(self.log_files)
new_log_filenames = self.redis_client.lrange(
"LOG_FILENAMES:{}".format(self.node_ip_address),
num_current_log_files, -1)
new_log_filenames = self.redis_client.lrange("LOG_FILENAMES:{}".format(
self.node_ip_address), num_current_log_files, -1)
for log_filename in new_log_filenames:
print("Beginning to track file {}".format(log_filename))
assert log_filename not in self.log_files
@@ -78,8 +78,8 @@ class LogMonitor(object):
# Try to open this file for the first time.
else:
try:
self.log_file_handles[log_filename] = open(log_filename,
"r")
self.log_file_handles[log_filename] = open(
log_filename, "r")
except IOError as e:
if e.errno == os.errno.EMFILE:
print("Warning: Ignoring {} because there are too "
@@ -106,13 +106,20 @@ class LogMonitor(object):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=("Parse Redis server for the "
"log monitor to connect "
"to."))
parser.add_argument("--redis-address", required=True, type=str,
help="The address to use for Redis.")
parser.add_argument("--node-ip-address", required=True, type=str,
help="The IP address of the node this process is on.")
parser = argparse.ArgumentParser(
description=("Parse Redis server for the "
"log monitor to connect "
"to."))
parser.add_argument(
"--redis-address",
required=True,
type=str,
help="The address to use for Redis.")
parser.add_argument(
"--node-ip-address",
required=True,
type=str,
help="The IP address of the node this process is on.")
args = parser.parse_args()
redis_ip_address = get_ip_address(args.redis_address)
+9 -10
View File
@@ -100,8 +100,8 @@ class Monitor(object):
self.local_scheduler_id_to_ip_map = dict()
self.load_metrics = LoadMetrics()
if autoscaling_config:
self.autoscaler = StandardAutoscaler(
autoscaling_config, self.load_metrics)
self.autoscaler = StandardAutoscaler(autoscaling_config,
self.load_metrics)
else:
self.autoscaler = None
@@ -160,11 +160,9 @@ class Monitor(object):
# task as lost.
key = binary_to_object_id(hex_to_binary(task_id))
ok = self.state._execute_command(
key, "RAY.TASK_TABLE_UPDATE",
hex_to_binary(task_id),
key, "RAY.TASK_TABLE_UPDATE", hex_to_binary(task_id),
ray.experimental.state.TASK_STATUS_LOST, NIL_ID,
task["ExecutionDependenciesString"],
task["SpillbackCount"])
task["ExecutionDependenciesString"], task["SpillbackCount"])
if ok != b"OK":
log.warn("Failed to update lost task for dead scheduler.")
num_tasks_updated += 1
@@ -428,8 +426,8 @@ class Monitor(object):
"""
message = DriverTableMessage.GetRootAsDriverTableMessage(data, 0)
driver_id = message.DriverId()
log.info(
"Driver {} has been removed.".format(binary_to_hex(driver_id)))
log.info("Driver {} has been removed.".format(
binary_to_hex(driver_id)))
self._clean_up_entries_for_driver(driver_id)
@@ -571,8 +569,9 @@ class Monitor(object):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=("Parse Redis server for the "
"monitor to connect to."))
parser = argparse.ArgumentParser(
description=("Parse Redis server for the "
"monitor to connect to."))
parser.add_argument(
"--redis-address",
required=True,
+3 -2
View File
@@ -5,5 +5,6 @@ from __future__ import print_function
from ray.plasma.plasma import (start_plasma_store, start_plasma_manager,
DEFAULT_PLASMA_STORE_MEMORY)
__all__ = ["start_plasma_store", "start_plasma_manager",
"DEFAULT_PLASMA_STORE_MEMORY"]
__all__ = [
"start_plasma_store", "start_plasma_manager", "DEFAULT_PLASMA_STORE_MEMORY"
]
+65 -46
View File
@@ -8,13 +8,13 @@ import subprocess
import sys
import time
__all__ = [
"start_plasma_store", "start_plasma_manager", "DEFAULT_PLASMA_STORE_MEMORY"
]
__all__ = ["start_plasma_store", "start_plasma_manager",
"DEFAULT_PLASMA_STORE_MEMORY"]
PLASMA_WAIT_TIMEOUT = 2**30
PLASMA_WAIT_TIMEOUT = 2 ** 30
DEFAULT_PLASMA_STORE_MEMORY = 10 ** 9
DEFAULT_PLASMA_STORE_MEMORY = 10**9
def random_name():
@@ -22,9 +22,12 @@ def random_name():
def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY,
use_valgrind=False, use_profiler=False,
stdout_file=None, stderr_file=None,
plasma_directory=None, huge_pages=False):
use_valgrind=False,
use_profiler=False,
stdout_file=None,
stderr_file=None,
plasma_directory=None,
huge_pages=False):
"""Start a plasma store process.
Args:
@@ -48,8 +51,8 @@ def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY,
if use_valgrind and use_profiler:
raise Exception("Cannot use valgrind and profiler at the same time.")
if huge_pages and not (sys.platform == "linux" or
sys.platform == "linux2"):
if huge_pages and not (sys.platform == "linux"
or sys.platform == "linux2"):
raise Exception("The huge_pages argument is only supported on "
"Linux.")
@@ -57,29 +60,33 @@ def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY,
raise Exception("If huge_pages is True, then the "
"plasma_directory argument must be provided.")
plasma_store_executable = os.path.join(os.path.abspath(
os.path.dirname(__file__)),
plasma_store_executable = os.path.join(
os.path.abspath(os.path.dirname(__file__)),
"../core/src/plasma/plasma_store")
plasma_store_name = "/tmp/plasma_store{}".format(random_name())
command = [plasma_store_executable,
"-s", plasma_store_name,
"-m", str(plasma_store_memory)]
command = [
plasma_store_executable, "-s", plasma_store_name, "-m",
str(plasma_store_memory)
]
if plasma_directory is not None:
command += ["-d", plasma_directory]
if huge_pages:
command += ["-h"]
if use_valgrind:
pid = subprocess.Popen(["valgrind",
"--track-origins=yes",
"--leak-check=full",
"--show-leak-kinds=all",
"--leak-check-heuristics=stdstring",
"--error-exitcode=1"] + command,
stdout=stdout_file, stderr=stderr_file)
pid = subprocess.Popen(
[
"valgrind", "--track-origins=yes", "--leak-check=full",
"--show-leak-kinds=all", "--leak-check-heuristics=stdstring",
"--error-exitcode=1"
] + command,
stdout=stdout_file,
stderr=stderr_file)
time.sleep(1.0)
elif use_profiler:
pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command,
stdout=stdout_file, stderr=stderr_file)
pid = subprocess.Popen(
["valgrind", "--tool=callgrind"] + command,
stdout=stdout_file,
stderr=stderr_file)
time.sleep(1.0)
else:
pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
@@ -91,10 +98,14 @@ def new_port():
return random.randint(10000, 65535)
def start_plasma_manager(store_name, redis_address,
node_ip_address="127.0.0.1", plasma_manager_port=None,
num_retries=20, use_valgrind=False,
run_profiler=False, stdout_file=None,
def start_plasma_manager(store_name,
redis_address,
node_ip_address="127.0.0.1",
plasma_manager_port=None,
num_retries=20,
use_valgrind=False,
run_profiler=False,
stdout_file=None,
stderr_file=None):
"""Start a plasma manager and return the ports it listens on.
@@ -133,27 +144,35 @@ def start_plasma_manager(store_name, redis_address,
while counter < num_retries:
if counter > 0:
print("Plasma manager failed to start, retrying now.")
command = [plasma_manager_executable,
"-s", store_name,
"-m", plasma_manager_name,
"-h", node_ip_address,
"-p", str(plasma_manager_port),
"-r", redis_address,
]
command = [
plasma_manager_executable,
"-s",
store_name,
"-m",
plasma_manager_name,
"-h",
node_ip_address,
"-p",
str(plasma_manager_port),
"-r",
redis_address,
]
if use_valgrind:
process = subprocess.Popen(["valgrind",
"--track-origins=yes",
"--leak-check=full",
"--show-leak-kinds=all",
"--error-exitcode=1"] + command,
stdout=stdout_file, stderr=stderr_file)
process = subprocess.Popen(
[
"valgrind", "--track-origins=yes", "--leak-check=full",
"--show-leak-kinds=all", "--error-exitcode=1"
] + command,
stdout=stdout_file,
stderr=stderr_file)
elif run_profiler:
process = subprocess.Popen((["valgrind", "--tool=callgrind"] +
command),
stdout=stdout_file, stderr=stderr_file)
process = subprocess.Popen(
(["valgrind", "--tool=callgrind"] + command),
stdout=stdout_file,
stderr=stderr_file)
else:
process = subprocess.Popen(command, stdout=stdout_file,
stderr=stderr_file)
process = subprocess.Popen(
command, stdout=stdout_file, stderr=stderr_file)
# This sleep is critical. If the plasma_manager fails to start because
# the port is already in use, then we need it to fail within 0.1
# seconds.
+143 -107
View File
@@ -16,8 +16,8 @@ import unittest
# The ray import must come before the pyarrow import because ray modifies the
# python path so that the right version of pyarrow is found.
import ray
from ray.plasma.utils import (random_object_id,
create_object_with_id, create_object)
from ray.plasma.utils import (random_object_id, create_object_with_id,
create_object)
from ray import services
import pyarrow as pa
import pyarrow.plasma as plasma
@@ -30,8 +30,12 @@ def random_name():
return str(random.randint(0, 99999999))
def assert_get_object_equal(unit_test, client1, client2, object_id,
memory_buffer=None, metadata=None):
def assert_get_object_equal(unit_test,
client1,
client2,
object_id,
memory_buffer=None,
metadata=None):
client1_buff = client1.get_buffers([object_id])[0]
client2_buff = client2.get_buffers([object_id])[0]
client1_metadata = client1.get_metadata([object_id])[0]
@@ -39,27 +43,33 @@ def assert_get_object_equal(unit_test, client1, client2, object_id,
unit_test.assertEqual(len(client1_buff), len(client2_buff))
unit_test.assertEqual(len(client1_metadata), len(client2_metadata))
# Check that the buffers from the two clients are the same.
assert_equal(np.frombuffer(client1_buff, dtype="uint8"),
np.frombuffer(client2_buff, dtype="uint8"))
assert_equal(
np.frombuffer(client1_buff, dtype="uint8"),
np.frombuffer(client2_buff, dtype="uint8"))
# Check that the metadata buffers from the two clients are the same.
assert_equal(np.frombuffer(client1_metadata, dtype="uint8"),
np.frombuffer(client2_metadata, dtype="uint8"))
assert_equal(
np.frombuffer(client1_metadata, dtype="uint8"),
np.frombuffer(client2_metadata, dtype="uint8"))
# If a reference buffer was provided, check that it is the same as well.
if memory_buffer is not None:
assert_equal(np.frombuffer(memory_buffer, dtype="uint8"),
np.frombuffer(client1_buff, dtype="uint8"))
assert_equal(
np.frombuffer(memory_buffer, dtype="uint8"),
np.frombuffer(client1_buff, dtype="uint8"))
# If reference metadata was provided, check that it is the same as well.
if metadata is not None:
assert_equal(np.frombuffer(metadata, dtype="uint8"),
np.frombuffer(client1_metadata, dtype="uint8"))
assert_equal(
np.frombuffer(metadata, dtype="uint8"),
np.frombuffer(client1_metadata, dtype="uint8"))
DEFAULT_PLASMA_STORE_MEMORY = 10 ** 9
DEFAULT_PLASMA_STORE_MEMORY = 10**9
def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY,
use_valgrind=False, use_profiler=False,
stdout_file=None, stderr_file=None):
use_valgrind=False,
use_profiler=False,
stdout_file=None,
stderr_file=None):
"""Start a plasma store process.
Args:
use_valgrind (bool): True if the plasma store should be started inside
@@ -78,21 +88,25 @@ def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY,
raise Exception("Cannot use valgrind and profiler at the same time.")
plasma_store_executable = os.path.join(pa.__path__[0], "plasma_store")
plasma_store_name = "/tmp/plasma_store{}".format(random_name())
command = [plasma_store_executable,
"-s", plasma_store_name,
"-m", str(plasma_store_memory)]
command = [
plasma_store_executable, "-s", plasma_store_name, "-m",
str(plasma_store_memory)
]
if use_valgrind:
pid = subprocess.Popen(["valgrind",
"--track-origins=yes",
"--leak-check=full",
"--show-leak-kinds=all",
"--leak-check-heuristics=stdstring",
"--error-exitcode=1"] + command,
stdout=stdout_file, stderr=stderr_file)
pid = subprocess.Popen(
[
"valgrind", "--track-origins=yes", "--leak-check=full",
"--show-leak-kinds=all", "--leak-check-heuristics=stdstring",
"--error-exitcode=1"
] + command,
stdout=stdout_file,
stderr=stderr_file)
time.sleep(1.0)
elif use_profiler:
pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command,
stdout=stdout_file, stderr=stderr_file)
pid = subprocess.Popen(
["valgrind", "--tool=callgrind"] + command,
stdout=stdout_file,
stderr=stderr_file)
time.sleep(1.0)
else:
pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
@@ -104,13 +118,10 @@ def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY,
class TestPlasmaManager(unittest.TestCase):
def setUp(self):
# Start two PlasmaStores.
store_name1, self.p2 = start_plasma_store(
use_valgrind=USE_VALGRIND)
store_name2, self.p3 = start_plasma_store(
use_valgrind=USE_VALGRIND)
store_name1, self.p2 = start_plasma_store(use_valgrind=USE_VALGRIND)
store_name2, self.p3 = start_plasma_store(use_valgrind=USE_VALGRIND)
# Start a Redis server.
redis_address, _ = services.start_redis("127.0.0.1")
# Start two PlasmaManagers.
@@ -152,9 +163,8 @@ class TestPlasmaManager(unittest.TestCase):
def test_fetch(self):
for _ in range(10):
# Create an object.
object_id1, memory_buffer1, metadata1 = create_object(self.client1,
2000,
2000)
object_id1, memory_buffer1, metadata1 = create_object(
self.client1, 2000, 2000)
self.client1.fetch([object_id1])
self.assertEqual(self.client1.contains(object_id1), True)
self.assertEqual(self.client2.contains(object_id1), False)
@@ -164,18 +174,20 @@ class TestPlasmaManager(unittest.TestCase):
while not self.client2.contains(object_id1):
self.client2.fetch([object_id1])
# Compare the two buffers.
assert_get_object_equal(self, self.client1, self.client2,
object_id1,
memory_buffer=memory_buffer1,
metadata=metadata1)
assert_get_object_equal(
self,
self.client1,
self.client2,
object_id1,
memory_buffer=memory_buffer1,
metadata=metadata1)
# Test that we can call fetch on object IDs that don't exist yet.
object_id2 = random_object_id()
self.client1.fetch([object_id2])
self.assertEqual(self.client1.contains(object_id2), False)
memory_buffer2, metadata2 = create_object_with_id(self.client2,
object_id2,
2000, 2000)
memory_buffer2, metadata2 = create_object_with_id(
self.client2, object_id2, 2000, 2000)
# # Check that the object has been fetched.
# self.assertEqual(self.client1.contains(object_id2), True)
# Compare the two buffers.
@@ -190,68 +202,88 @@ class TestPlasmaManager(unittest.TestCase):
for _ in range(10):
self.client1.fetch([object_id3])
self.client2.fetch([object_id3])
memory_buffer3, metadata3 = create_object_with_id(self.client1,
object_id3,
2000, 2000)
memory_buffer3, metadata3 = create_object_with_id(
self.client1, object_id3, 2000, 2000)
for _ in range(10):
self.client1.fetch([object_id3])
self.client2.fetch([object_id3])
# TODO(rkn): Right now we must wait for the object table to be updated.
while not self.client2.contains(object_id3):
self.client2.fetch([object_id3])
assert_get_object_equal(self, self.client1, self.client2, object_id3,
memory_buffer=memory_buffer3,
metadata=metadata3)
assert_get_object_equal(
self,
self.client1,
self.client2,
object_id3,
memory_buffer=memory_buffer3,
metadata=metadata3)
def test_fetch_multiple(self):
for _ in range(20):
# Create two objects and a third fake one that doesn't exist.
object_id1, memory_buffer1, metadata1 = create_object(self.client1,
2000,
2000)
object_id1, memory_buffer1, metadata1 = create_object(
self.client1, 2000, 2000)
missing_object_id = random_object_id()
object_id2, memory_buffer2, metadata2 = create_object(self.client1,
2000,
2000)
object_id2, memory_buffer2, metadata2 = create_object(
self.client1, 2000, 2000)
object_ids = [object_id1, missing_object_id, object_id2]
# Fetch the objects from the other plasma store. The second object
# ID should timeout since it does not exist.
# TODO(rkn): Right now we must wait for the object table to be
# updated.
while ((not self.client2.contains(object_id1)) or
(not self.client2.contains(object_id2))):
while ((not self.client2.contains(object_id1))
or (not self.client2.contains(object_id2))):
self.client2.fetch(object_ids)
# Compare the buffers of the objects that do exist.
assert_get_object_equal(self, self.client1, self.client2,
object_id1, memory_buffer=memory_buffer1,
metadata=metadata1)
assert_get_object_equal(self, self.client1, self.client2,
object_id2, memory_buffer=memory_buffer2,
metadata=metadata2)
assert_get_object_equal(
self,
self.client1,
self.client2,
object_id1,
memory_buffer=memory_buffer1,
metadata=metadata1)
assert_get_object_equal(
self,
self.client1,
self.client2,
object_id2,
memory_buffer=memory_buffer2,
metadata=metadata2)
# Fetch in the other direction. The fake object still does not
# exist.
self.client1.fetch(object_ids)
assert_get_object_equal(self, self.client2, self.client1,
object_id1, memory_buffer=memory_buffer1,
metadata=metadata1)
assert_get_object_equal(self, self.client2, self.client1,
object_id2, memory_buffer=memory_buffer2,
metadata=metadata2)
assert_get_object_equal(
self,
self.client2,
self.client1,
object_id1,
memory_buffer=memory_buffer1,
metadata=metadata1)
assert_get_object_equal(
self,
self.client2,
self.client1,
object_id2,
memory_buffer=memory_buffer2,
metadata=metadata2)
# Check that we can call fetch with duplicated object IDs.
object_id3 = random_object_id()
self.client1.fetch([object_id3, object_id3])
object_id4, memory_buffer4, metadata4 = create_object(self.client1,
2000,
2000)
object_id4, memory_buffer4, metadata4 = create_object(
self.client1, 2000, 2000)
time.sleep(0.1)
# TODO(rkn): Right now we must wait for the object table to be updated.
while not self.client2.contains(object_id4):
self.client2.fetch([object_id3, object_id3, object_id4,
object_id4])
assert_get_object_equal(self, self.client2, self.client1, object_id4,
memory_buffer=memory_buffer4,
metadata=metadata4)
self.client2.fetch(
[object_id3, object_id3, object_id4, object_id4])
assert_get_object_equal(
self,
self.client2,
self.client1,
object_id4,
memory_buffer=memory_buffer4,
metadata=metadata4)
def test_wait(self):
# Test timeout.
@@ -263,8 +295,8 @@ class TestPlasmaManager(unittest.TestCase):
obj_id1 = random_object_id()
self.client1.create(obj_id1, 1000)
self.client1.seal(obj_id1)
ready, waiting = self.client1.wait([obj_id1], timeout=100,
num_returns=1)
ready, waiting = self.client1.wait(
[obj_id1], timeout=100, num_returns=1)
self.assertEqual(set(ready), set([obj_id1]))
self.assertEqual(waiting, [])
@@ -273,8 +305,8 @@ class TestPlasmaManager(unittest.TestCase):
obj_id2 = random_object_id()
self.client1.create(obj_id2, 1000)
# Don't seal.
ready, waiting = self.client1.wait([obj_id2, obj_id1], timeout=100,
num_returns=1)
ready, waiting = self.client1.wait(
[obj_id2, obj_id1], timeout=100, num_returns=1)
self.assertEqual(set(ready), set([obj_id1]))
self.assertEqual(set(waiting), set([obj_id2]))
@@ -287,8 +319,8 @@ class TestPlasmaManager(unittest.TestCase):
t = threading.Timer(0.1, finish)
t.start()
ready, waiting = self.client1.wait([obj_id3, obj_id2, obj_id1],
timeout=1000, num_returns=2)
ready, waiting = self.client1.wait(
[obj_id3, obj_id2, obj_id1], timeout=1000, num_returns=2)
self.assertEqual(set(ready), set([obj_id1, obj_id3]))
self.assertEqual(set(waiting), set([obj_id2]))
@@ -319,26 +351,26 @@ class TestPlasmaManager(unittest.TestCase):
waiting = object_ids
retrieved = []
for i in range(1, n + 1):
ready, waiting = self.client1.wait(waiting, timeout=1000,
num_returns=i)
ready, waiting = self.client1.wait(
waiting, timeout=1000, num_returns=i)
self.assertEqual(len(ready), i)
retrieved += ready
self.assertEqual(set(retrieved), set(object_ids))
ready, waiting = self.client1.wait(object_ids, timeout=1000,
num_returns=len(object_ids))
ready, waiting = self.client1.wait(
object_ids, timeout=1000, num_returns=len(object_ids))
self.assertEqual(set(ready), set(object_ids))
self.assertEqual(waiting, [])
# Try waiting for all of the object IDs on the second client.
waiting = object_ids
retrieved = []
for i in range(1, n + 1):
ready, waiting = self.client2.wait(waiting, timeout=1000,
num_returns=i)
ready, waiting = self.client2.wait(
waiting, timeout=1000, num_returns=i)
self.assertEqual(len(ready), i)
retrieved += ready
self.assertEqual(set(retrieved), set(object_ids))
ready, waiting = self.client2.wait(object_ids, timeout=1000,
num_returns=len(object_ids))
ready, waiting = self.client2.wait(
object_ids, timeout=1000, num_returns=len(object_ids))
self.assertEqual(set(ready), set(object_ids))
self.assertEqual(waiting, [])
@@ -363,9 +395,8 @@ class TestPlasmaManager(unittest.TestCase):
num_attempts = 100
for _ in range(100):
# Create an object.
object_id1, memory_buffer1, metadata1 = create_object(self.client1,
2000,
2000)
object_id1, memory_buffer1, metadata1 = create_object(
self.client1, 2000, 2000)
# Transfer the buffer to the the other Plasma store. There is a
# race condition on the create and transfer of the object, so keep
# trying until the object appears on the second Plasma store.
@@ -379,9 +410,13 @@ class TestPlasmaManager(unittest.TestCase):
del buff
# Compare the two buffers.
assert_get_object_equal(self, self.client1, self.client2,
object_id1, memory_buffer=memory_buffer1,
metadata=metadata1)
assert_get_object_equal(
self,
self.client1,
self.client2,
object_id1,
memory_buffer=memory_buffer1,
metadata=metadata1)
# # Transfer the buffer again.
# self.client1.transfer("127.0.0.1", self.port2, object_id1)
# # Compare the two buffers.
@@ -391,8 +426,8 @@ class TestPlasmaManager(unittest.TestCase):
# metadata=metadata1)
# Create an object.
object_id2, memory_buffer2, metadata2 = create_object(self.client2,
20000, 20000)
object_id2, memory_buffer2, metadata2 = create_object(
self.client2, 20000, 20000)
# Transfer the buffer to the the other Plasma store. There is a
# race condition on the create and transfer of the object, so keep
# trying until the object appears on the second Plasma store.
@@ -406,9 +441,13 @@ class TestPlasmaManager(unittest.TestCase):
del buff
# Compare the two buffers.
assert_get_object_equal(self, self.client1, self.client2,
object_id2, memory_buffer=memory_buffer2,
metadata=metadata2)
assert_get_object_equal(
self,
self.client1,
self.client2,
object_id2,
memory_buffer=memory_buffer2,
metadata=metadata2)
def test_illegal_functionality(self):
# Create an object id string.
@@ -437,7 +476,6 @@ class TestPlasmaManager(unittest.TestCase):
class TestPlasmaManagerRecovery(unittest.TestCase):
def setUp(self):
# Start a Plasma store.
self.store_name, self.p2 = start_plasma_store(
@@ -446,9 +484,7 @@ class TestPlasmaManagerRecovery(unittest.TestCase):
self.redis_address, _ = services.start_redis("127.0.0.1")
# Start a PlasmaManagers.
manager_name, self.p3, self.port1 = ray.plasma.start_plasma_manager(
self.store_name,
self.redis_address,
use_valgrind=USE_VALGRIND)
self.store_name, self.redis_address, use_valgrind=USE_VALGRIND)
# Connect a PlasmaClient.
self.client = plasma.connect(self.store_name, manager_name, 64)
@@ -501,8 +537,8 @@ class TestPlasmaManagerRecovery(unittest.TestCase):
client2 = plasma.connect(self.store_name, manager_name, 64)
ready, waiting = [], object_ids
while True:
ready, waiting = client2.wait(object_ids, num_returns=num_objects,
timeout=0)
ready, waiting = client2.wait(
object_ids, num_returns=num_objects, timeout=0)
if len(ready) == len(object_ids):
break
+8 -6
View File
@@ -18,8 +18,8 @@ def generate_metadata(length):
metadata_buffer[0] = random.randint(0, 255)
metadata_buffer[-1] = random.randint(0, 255)
for _ in range(100):
metadata_buffer[random.randint(0, length - 1)] = (
random.randint(0, 255))
metadata_buffer[random.randint(0, length - 1)] = (random.randint(
0, 255))
return metadata_buffer
@@ -32,7 +32,10 @@ def write_to_data_buffer(buff, length):
array[random.randint(0, length - 1)] = random.randint(0, 255)
def create_object_with_id(client, object_id, data_size, metadata_size,
def create_object_with_id(client,
object_id,
data_size,
metadata_size,
seal=True):
metadata = generate_metadata(metadata_size)
memory_buffer = client.create(object_id, data_size, metadata)
@@ -44,7 +47,6 @@ def create_object_with_id(client, object_id, data_size, metadata_size,
def create_object(client, data_size, metadata_size, seal=True):
object_id = random_object_id()
memory_buffer, metadata = create_object_with_id(client, object_id,
data_size, metadata_size,
seal=seal)
memory_buffer, metadata = create_object_with_id(
client, object_id, data_size, metadata_size, seal=seal)
return object_id, memory_buffer, metadata
-2
View File
@@ -1,10 +1,8 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""Ray constants used in the Python code."""
# Abort autoscaling if more than this number of errors are encountered. This
# is a safety feature to prevent e.g. runaway node launches.
AUTOSCALER_MAX_NUM_FAILURES = 5
+172 -81
View File
@@ -7,8 +7,8 @@ import json
import subprocess
import ray.services as services
from ray.autoscaler.commands import (
create_or_update_cluster, teardown_cluster, get_head_node_ip)
from ray.autoscaler.commands import (create_or_update_cluster,
teardown_cluster, get_head_node_ip)
def check_no_existing_redis_clients(node_ip_address, redis_client):
@@ -41,58 +41,116 @@ def cli():
@click.command()
@click.option("--node-ip-address", required=False, type=str,
help="the IP address of this node")
@click.option("--redis-address", required=False, type=str,
help="the address to use for connecting to Redis")
@click.option("--redis-port", required=False, type=str,
help="the port to use for starting Redis")
@click.option("--num-redis-shards", required=False, type=int,
help=("the number of additional Redis shards to use in "
"addition to the primary Redis shard"))
@click.option("--redis-max-clients", required=False, type=int,
help=("If provided, attempt to configure Redis with this "
"maximum number of clients."))
@click.option("--redis-shard-ports", required=False, type=str,
help="the port to use for the Redis shards other than the "
"primary Redis shard")
@click.option("--object-manager-port", required=False, type=int,
help="the port to use for starting the object manager")
@click.option("--object-store-memory", required=False, type=int,
help="the maximum amount of memory (in bytes) to allow the "
"object store to use")
@click.option("--num-workers", required=False, type=int,
help=("The initial number of workers to start on this node, "
"note that the local scheduler may start additional "
"workers. If you wish to control the total number of "
"concurent tasks, then use --resources instead and "
"specify the CPU field."))
@click.option("--num-cpus", required=False, type=int,
help="the number of CPUs on this node")
@click.option("--num-gpus", required=False, type=int,
help="the number of GPUs on this node")
@click.option("--resources", required=False, default="{}", type=str,
help="a JSON serialized dictionary mapping resource name to "
"resource quantity")
@click.option("--head", is_flag=True, default=False,
help="provide this argument for the head node")
@click.option("--no-ui", is_flag=True, default=False,
help="provide this argument if the UI should not be started")
@click.option("--block", is_flag=True, default=False,
help="provide this argument to block forever in this command")
@click.option("--plasma-directory", required=False, type=str,
help="object store directory for memory mapped files")
@click.option("--huge-pages", is_flag=True, default=False,
help="enable support for huge pages in the object store")
@click.option("--autoscaling-config", required=False, type=str,
help="the file that contains the autoscaling config")
@click.option("--use-raylet", is_flag=True, default=False,
help="use the raylet code path, this is not supported yet")
@click.option(
"--node-ip-address",
required=False,
type=str,
help="the IP address of this node")
@click.option(
"--redis-address",
required=False,
type=str,
help="the address to use for connecting to Redis")
@click.option(
"--redis-port",
required=False,
type=str,
help="the port to use for starting Redis")
@click.option(
"--num-redis-shards",
required=False,
type=int,
help=("the number of additional Redis shards to use in "
"addition to the primary Redis shard"))
@click.option(
"--redis-max-clients",
required=False,
type=int,
help=("If provided, attempt to configure Redis with this "
"maximum number of clients."))
@click.option(
"--redis-shard-ports",
required=False,
type=str,
help="the port to use for the Redis shards other than the "
"primary Redis shard")
@click.option(
"--object-manager-port",
required=False,
type=int,
help="the port to use for starting the object manager")
@click.option(
"--object-store-memory",
required=False,
type=int,
help="the maximum amount of memory (in bytes) to allow the "
"object store to use")
@click.option(
"--num-workers",
required=False,
type=int,
help=("The initial number of workers to start on this node, "
"note that the local scheduler may start additional "
"workers. If you wish to control the total number of "
"concurent tasks, then use --resources instead and "
"specify the CPU field."))
@click.option(
"--num-cpus",
required=False,
type=int,
help="the number of CPUs on this node")
@click.option(
"--num-gpus",
required=False,
type=int,
help="the number of GPUs on this node")
@click.option(
"--resources",
required=False,
default="{}",
type=str,
help="a JSON serialized dictionary mapping resource name to "
"resource quantity")
@click.option(
"--head",
is_flag=True,
default=False,
help="provide this argument for the head node")
@click.option(
"--no-ui",
is_flag=True,
default=False,
help="provide this argument if the UI should not be started")
@click.option(
"--block",
is_flag=True,
default=False,
help="provide this argument to block forever in this command")
@click.option(
"--plasma-directory",
required=False,
type=str,
help="object store directory for memory mapped files")
@click.option(
"--huge-pages",
is_flag=True,
default=False,
help="enable support for huge pages in the object store")
@click.option(
"--autoscaling-config",
required=False,
type=str,
help="the file that contains the autoscaling config")
@click.option(
"--use-raylet",
is_flag=True,
default=False,
help="use the raylet code path, this is not supported yet")
def start(node_ip_address, redis_address, redis_port, num_redis_shards,
redis_max_clients, redis_shard_ports, object_manager_port,
object_store_memory, num_workers, num_cpus, num_gpus, resources,
head, no_ui, block, plasma_directory, huge_pages,
autoscaling_config, use_raylet):
head, no_ui, block, plasma_directory, huge_pages, autoscaling_config,
use_raylet):
# Convert hostnames to numerical IP address.
if node_ip_address is not None:
node_ip_address = services.address_to_ip(node_ip_address)
@@ -245,33 +303,54 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards,
@click.command()
def stop():
subprocess.call(["killall global_scheduler plasma_store plasma_manager "
"local_scheduler raylet"], shell=True)
subprocess.call(
[
"killall global_scheduler plasma_store plasma_manager "
"local_scheduler raylet"
],
shell=True)
# Find the PID of the monitor process and kill it.
subprocess.call(["kill $(ps aux | grep monitor.py | grep -v grep | "
"awk '{ print $2 }') 2> /dev/null"], shell=True)
subprocess.call(
[
"kill $(ps aux | grep monitor.py | grep -v grep | "
"awk '{ print $2 }') 2> /dev/null"
],
shell=True)
# Find the PID of the Redis process and kill it.
subprocess.call(["kill $(ps aux | grep redis-server | grep -v grep | "
"awk '{ print $2 }') 2> /dev/null"], shell=True)
subprocess.call(
[
"kill $(ps aux | grep redis-server | grep -v grep | "
"awk '{ print $2 }') 2> /dev/null"
],
shell=True)
# Find the PIDs of the worker processes and kill them.
subprocess.call(["kill -9 $(ps aux | grep default_worker.py | "
"grep -v grep | awk '{ print $2 }') 2> /dev/null"],
shell=True)
subprocess.call(
[
"kill -9 $(ps aux | grep default_worker.py | "
"grep -v grep | awk '{ print $2 }') 2> /dev/null"
],
shell=True)
# Find the PID of the Ray log monitor process and kill it.
subprocess.call(["kill $(ps aux | grep log_monitor.py | grep -v grep | "
"awk '{ print $2 }') 2> /dev/null"], shell=True)
subprocess.call(
[
"kill $(ps aux | grep log_monitor.py | grep -v grep | "
"awk '{ print $2 }') 2> /dev/null"
],
shell=True)
# Find the PID of the jupyter process and kill it.
try:
from notebook.notebookapp import list_running_servers
pids = [str(server["pid"]) for server in list_running_servers()
if "/tmp/raylogs" in server["notebook_dir"]]
subprocess.call(["kill {} 2> /dev/null".format(
" ".join(pids))], shell=True)
pids = [
str(server["pid"]) for server in list_running_servers()
if "/tmp/raylogs" in server["notebook_dir"]
]
subprocess.call(
["kill {} 2> /dev/null".format(" ".join(pids))], shell=True)
except ImportError:
pass
@@ -279,29 +358,41 @@ def stop():
@click.command()
@click.argument("cluster_config_file", required=True, type=str)
@click.option(
"--no-restart", is_flag=True, default=False, help=(
"Whether to skip restarting Ray services during the update. "
"This avoids interrupting running jobs."))
"--no-restart",
is_flag=True,
default=False,
help=("Whether to skip restarting Ray services during the update. "
"This avoids interrupting running jobs."))
@click.option(
"--min-workers", required=False, type=int, help=(
"Override the configured min worker node count for the cluster."))
"--min-workers",
required=False,
type=int,
help=("Override the configured min worker node count for the cluster."))
@click.option(
"--max-workers", required=False, type=int, help=(
"Override the configured max worker node count for the cluster."))
"--max-workers",
required=False,
type=int,
help=("Override the configured max worker node count for the cluster."))
@click.option(
"--yes", "-y", is_flag=True, default=False, help=(
"Don't ask for confirmation."))
def create_or_update(
cluster_config_file, min_workers, max_workers, no_restart, yes):
create_or_update_cluster(
cluster_config_file, min_workers, max_workers, no_restart, yes)
"--yes",
"-y",
is_flag=True,
default=False,
help=("Don't ask for confirmation."))
def create_or_update(cluster_config_file, min_workers, max_workers, no_restart,
yes):
create_or_update_cluster(cluster_config_file, min_workers, max_workers,
no_restart, yes)
@click.command()
@click.argument("cluster_config_file", required=True, type=str)
@click.option(
"--yes", "-y", is_flag=True, default=False, help=(
"Don't ask for confirmation."))
"--yes",
"-y",
is_flag=True,
default=False,
help=("Don't ask for confirmation."))
def teardown(cluster_config_file, yes):
teardown_cluster(cluster_config_file, yes)
+237 -200
View File
@@ -41,16 +41,12 @@ PROCESS_TYPE_WEB_UI = "web_ui"
# important because it determines the order in which these processes will be
# terminated when Ray exits, and certain orders will cause errors to be logged
# to the screen.
all_processes = OrderedDict([(PROCESS_TYPE_MONITOR, []),
(PROCESS_TYPE_LOG_MONITOR, []),
(PROCESS_TYPE_WORKER, []),
(PROCESS_TYPE_RAYLET, []),
(PROCESS_TYPE_LOCAL_SCHEDULER, []),
(PROCESS_TYPE_PLASMA_MANAGER, []),
(PROCESS_TYPE_PLASMA_STORE, []),
(PROCESS_TYPE_GLOBAL_SCHEDULER, []),
(PROCESS_TYPE_REDIS_SERVER, []),
(PROCESS_TYPE_WEB_UI, [])],)
all_processes = OrderedDict(
[(PROCESS_TYPE_MONITOR, []), (PROCESS_TYPE_LOG_MONITOR, []),
(PROCESS_TYPE_WORKER, []), (PROCESS_TYPE_RAYLET, []),
(PROCESS_TYPE_LOCAL_SCHEDULER, []), (PROCESS_TYPE_PLASMA_MANAGER, []),
(PROCESS_TYPE_PLASMA_STORE, []), (PROCESS_TYPE_GLOBAL_SCHEDULER, []),
(PROCESS_TYPE_REDIS_SERVER, []), (PROCESS_TYPE_WEB_UI, [])], )
# True if processes are run in the valgrind profiler.
RUN_RAYLET_PROFILER = False
@@ -82,17 +78,15 @@ RAYLET_MONITOR_EXECUTABLE = os.path.join(
os.path.abspath(os.path.dirname(__file__)),
"core/src/ray/raylet/raylet_monitor")
RAYLET_EXECUTABLE = os.path.join(
os.path.abspath(os.path.dirname(__file__)),
"core/src/ray/raylet/raylet")
os.path.abspath(os.path.dirname(__file__)), "core/src/ray/raylet/raylet")
# ObjectStoreAddress tuples contain all information necessary to connect to an
# object store. The fields are:
# - name: The socket name for the object store
# - manager_name: The socket name for the object store manager
# - manager_port: The Internet port that the object store manager listens on
ObjectStoreAddress = namedtuple("ObjectStoreAddress", ["name",
"manager_name",
"manager_port"])
ObjectStoreAddress = namedtuple("ObjectStoreAddress",
["name", "manager_name", "manager_port"])
def address(ip_address, port):
@@ -133,8 +127,10 @@ def kill_process(p):
if p.poll() is not None:
# The process has already terminated.
return True
if any([RUN_RAYLET_PROFILER, RUN_LOCAL_SCHEDULER_PROFILER,
RUN_PLASMA_MANAGER_PROFILER, RUN_PLASMA_STORE_PROFILER]):
if any([
RUN_RAYLET_PROFILER, RUN_LOCAL_SCHEDULER_PROFILER,
RUN_PLASMA_MANAGER_PROFILER, RUN_PLASMA_STORE_PROFILER
]):
# Give process signal to write profiler data.
os.kill(p.pid, signal.SIGINT)
# Wait for profiling data to be written.
@@ -260,8 +256,8 @@ def record_log_files_in_redis(redis_address, node_ip_address, log_files):
for log_file in log_files:
if log_file is not None:
redis_ip_address, redis_port = redis_address.split(":")
redis_client = redis.StrictRedis(host=redis_ip_address,
port=redis_port)
redis_client = redis.StrictRedis(
host=redis_ip_address, port=redis_port)
# The name of the key storing the list of log filenames for this IP
# address.
log_file_list_key = "LOG_FILENAMES:{}".format(node_ip_address)
@@ -304,8 +300,8 @@ def wait_for_redis_to_start(redis_ip_address, redis_port, num_retries=5):
while counter < num_retries:
try:
# Run some random command and see if it worked.
print("Waiting for redis server at {}:{} to respond..."
.format(redis_ip_address, redis_port))
print("Waiting for redis server at {}:{} to respond...".format(
redis_ip_address, redis_port))
redis_client.client_list()
except redis.ConnectionError as e:
# Wait a little bit.
@@ -427,17 +423,19 @@ def start_credis(node_ip_address,
"""
components = ["credis_master", "credis_head", "credis_tail"]
modules = [CREDIS_MASTER_MODULE, CREDIS_MEMBER_MODULE,
CREDIS_MEMBER_MODULE]
modules = [
CREDIS_MASTER_MODULE, CREDIS_MEMBER_MODULE, CREDIS_MEMBER_MODULE
]
ports = []
for i, component in enumerate(components):
stdout_file, stderr_file = new_log_files(
component, redirect_output)
stdout_file, stderr_file = new_log_files(component, redirect_output)
new_port, _ = start_redis_instance(
node_ip_address=node_ip_address, port=port,
stdout_file=stdout_file, stderr_file=stderr_file,
node_ip_address=node_ip_address,
port=port,
stdout_file=stdout_file,
stderr_file=stderr_file,
cleanup=cleanup,
module=modules[i],
executable=CREDIS_EXECUTABLE)
@@ -456,8 +454,7 @@ def start_credis(node_ip_address,
# Register credis master in redis
redis_ip_address, redis_port = redis_address.split(":")
redis_client = redis.StrictRedis(host=redis_ip_address,
port=redis_port)
redis_client = redis.StrictRedis(host=redis_ip_address, port=redis_port)
redis_client.set("credis_address", credis_address)
return credis_address
@@ -509,9 +506,11 @@ def start_redis(node_ip_address,
"number of Redis shards.")
assigned_port, _ = start_redis_instance(
node_ip_address=node_ip_address, port=port,
node_ip_address=node_ip_address,
port=port,
redis_max_clients=redis_max_clients,
stdout_file=redis_stdout_file, stderr_file=redis_stderr_file,
stdout_file=redis_stdout_file,
stderr_file=redis_stderr_file,
cleanup=cleanup)
if port is not None:
assert assigned_port == port
@@ -540,7 +539,8 @@ def start_redis(node_ip_address,
node_ip_address=node_ip_address,
port=redis_shard_ports[i],
redis_max_clients=redis_max_clients,
stdout_file=redis_stdout_file, stderr_file=redis_stderr_file,
stdout_file=redis_stdout_file,
stderr_file=redis_stderr_file,
cleanup=cleanup)
if redis_shard_ports[i] is not None:
assert redis_shard_port == redis_shard_ports[i]
@@ -601,11 +601,13 @@ def start_redis_instance(node_ip_address="127.0.0.1",
while counter < num_retries:
if counter > 0:
print("Redis failed to start, retrying now.")
p = subprocess.Popen([executable,
"--port", str(port),
"--loglevel", "warning",
"--loadmodule", module],
stdout=stdout_file, stderr=stderr_file)
p = subprocess.Popen(
[
executable, "--port",
str(port), "--loglevel", "warning", "--loadmodule", module
],
stdout=stdout_file,
stderr=stderr_file)
time.sleep(0.1)
# Check if Redis successfully started (or at least if it the executable
# did not exit within 0.1 seconds).
@@ -652,8 +654,8 @@ def start_redis_instance(node_ip_address="127.0.0.1",
# Increase the hard and soft limits for the redis client pubsub buffer to
# 128MB. This is a hack to make it less likely for pubsub messages to be
# dropped and for pubsub connections to therefore be killed.
cur_config = (redis_client.config_get("client-output-buffer-limit")
["client-output-buffer-limit"])
cur_config = (redis_client.config_get("client-output-buffer-limit")[
"client-output-buffer-limit"])
cur_config_list = cur_config.split()
assert len(cur_config_list) == 12
cur_config_list[8:] = ["pubsub", "134217728", "134217728", "60"]
@@ -662,13 +664,17 @@ def start_redis_instance(node_ip_address="127.0.0.1",
# Put a time stamp in Redis to indicate when it was started.
redis_client.set("redis_start_time", time.time())
# Record the log files in Redis.
record_log_files_in_redis(address(node_ip_address, port), node_ip_address,
[stdout_file, stderr_file])
record_log_files_in_redis(
address(node_ip_address, port), node_ip_address,
[stdout_file, stderr_file])
return port, p
def start_log_monitor(redis_address, node_ip_address, stdout_file=None,
stderr_file=None, cleanup=cleanup):
def start_log_monitor(redis_address,
node_ip_address,
stdout_file=None,
stderr_file=None,
cleanup=cleanup):
"""Start a log monitor process.
Args:
@@ -684,20 +690,25 @@ def start_log_monitor(redis_address, node_ip_address, stdout_file=None,
Python process that imported services exits.
"""
log_monitor_filepath = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"log_monitor.py")
p = subprocess.Popen([sys.executable, "-u", log_monitor_filepath,
"--redis-address", redis_address,
"--node-ip-address", node_ip_address],
stdout=stdout_file, stderr=stderr_file)
os.path.dirname(os.path.abspath(__file__)), "log_monitor.py")
p = subprocess.Popen(
[
sys.executable, "-u", log_monitor_filepath, "--redis-address",
redis_address, "--node-ip-address", node_ip_address
],
stdout=stdout_file,
stderr=stderr_file)
if cleanup:
all_processes[PROCESS_TYPE_LOG_MONITOR].append(p)
record_log_files_in_redis(redis_address, node_ip_address,
[stdout_file, stderr_file])
def start_global_scheduler(redis_address, node_ip_address,
stdout_file=None, stderr_file=None, cleanup=True):
def start_global_scheduler(redis_address,
node_ip_address,
stdout_file=None,
stderr_file=None,
cleanup=True):
"""Start a global scheduler process.
Args:
@@ -712,10 +723,11 @@ def start_global_scheduler(redis_address, node_ip_address,
then this process will be killed by services.cleanup() when the
Python process that imported services exits.
"""
p = global_scheduler.start_global_scheduler(redis_address,
node_ip_address,
stdout_file=stdout_file,
stderr_file=stderr_file)
p = global_scheduler.start_global_scheduler(
redis_address,
node_ip_address,
stdout_file=stdout_file,
stderr_file=stderr_file)
if cleanup:
all_processes[PROCESS_TYPE_GLOBAL_SCHEDULER].append(p)
record_log_files_in_redis(redis_address, node_ip_address,
@@ -737,8 +749,7 @@ def start_ui(redis_address, stdout_file=None, stderr_file=None, cleanup=True):
"""
new_env = os.environ.copy()
notebook_filepath = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"WebUI.ipynb")
os.path.dirname(os.path.abspath(__file__)), "WebUI.ipynb")
# We copy the notebook file so that the original doesn't get modified by
# the user.
random_ui_id = random.randint(0, 100000)
@@ -759,19 +770,23 @@ def start_ui(redis_address, stdout_file=None, stderr_file=None, cleanup=True):
# We generate the token used for authentication ourselves to avoid
# querying the jupyter server.
token = binascii.hexlify(os.urandom(24)).decode("ascii")
command = ["jupyter", "notebook", "--no-browser",
"--port={}".format(port),
"--NotebookApp.iopub_data_rate_limit=10000000000",
"--NotebookApp.open_browser=False",
"--NotebookApp.token={}".format(token)]
command = [
"jupyter", "notebook", "--no-browser", "--port={}".format(port),
"--NotebookApp.iopub_data_rate_limit=10000000000",
"--NotebookApp.open_browser=False",
"--NotebookApp.token={}".format(token)
]
# If the user is root, add the --allow-root flag.
if os.geteuid() == 0:
command.append("--allow-root")
try:
ui_process = subprocess.Popen(command, env=new_env,
cwd=new_notebook_directory,
stdout=stdout_file, stderr=stderr_file)
ui_process = subprocess.Popen(
command,
env=new_env,
cwd=new_notebook_directory,
stdout=stdout_file,
stderr=stderr_file)
except Exception:
print("Failed to start the UI, you may need to run "
"'pip install jupyter'.")
@@ -836,8 +851,8 @@ def start_local_scheduler(redis_address,
# Check that the number of GPUs that the local scheduler wants doesn't
# excede the amount allowed by CUDA_VISIBLE_DEVICES.
if ("GPU" in resources and gpu_ids is not None and
resources["GPU"] > len(gpu_ids)):
if ("GPU" in resources and gpu_ids is not None
and resources["GPU"] > len(gpu_ids)):
raise Exception("Attempting to start local scheduler with {} GPUs, "
"but CUDA_VISIBLE_DEVICES contains {}.".format(
resources["GPU"], gpu_ids))
@@ -906,21 +921,14 @@ def start_raylet(redis_address,
"--node-ip-address={} "
"--object-store-name={} "
"--raylet-name={} "
"--redis-address={}"
.format(sys.executable,
worker_path,
node_ip_address,
plasma_store_name,
raylet_name,
redis_address))
"--redis-address={}".format(
sys.executable, worker_path, node_ip_address,
plasma_store_name, raylet_name, redis_address))
command = [RAYLET_EXECUTABLE,
raylet_name,
plasma_store_name,
node_ip_address,
gcs_ip_address,
gcs_port,
start_worker_command]
command = [
RAYLET_EXECUTABLE, raylet_name, plasma_store_name, node_ip_address,
gcs_ip_address, gcs_port, start_worker_command
]
pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
if cleanup:
@@ -931,12 +939,18 @@ def start_raylet(redis_address,
return raylet_name
def start_objstore(node_ip_address, redis_address,
object_manager_port=None, store_stdout_file=None,
store_stderr_file=None, manager_stdout_file=None,
manager_stderr_file=None, objstore_memory=None,
cleanup=True, plasma_directory=None,
huge_pages=False, use_raylet=False):
def start_objstore(node_ip_address,
redis_address,
object_manager_port=None,
store_stdout_file=None,
store_stderr_file=None,
manager_stdout_file=None,
manager_stderr_file=None,
objstore_memory=None,
cleanup=True,
plasma_directory=None,
huge_pages=False,
use_raylet=False):
"""This method starts an object store process.
Args:
@@ -1013,24 +1027,24 @@ def start_objstore(node_ip_address, redis_address,
if object_manager_port is not None:
(plasma_manager_name, p2,
plasma_manager_port) = ray.plasma.start_plasma_manager(
plasma_store_name,
redis_address,
plasma_manager_port=object_manager_port,
node_ip_address=node_ip_address,
num_retries=1,
run_profiler=RUN_PLASMA_MANAGER_PROFILER,
stdout_file=manager_stdout_file,
stderr_file=manager_stderr_file)
plasma_store_name,
redis_address,
plasma_manager_port=object_manager_port,
node_ip_address=node_ip_address,
num_retries=1,
run_profiler=RUN_PLASMA_MANAGER_PROFILER,
stdout_file=manager_stdout_file,
stderr_file=manager_stderr_file)
assert plasma_manager_port == object_manager_port
else:
(plasma_manager_name, p2,
plasma_manager_port) = ray.plasma.start_plasma_manager(
plasma_store_name,
redis_address,
node_ip_address=node_ip_address,
run_profiler=RUN_PLASMA_MANAGER_PROFILER,
stdout_file=manager_stdout_file,
stderr_file=manager_stderr_file)
plasma_store_name,
redis_address,
node_ip_address=node_ip_address,
run_profiler=RUN_PLASMA_MANAGER_PROFILER,
stdout_file=manager_stdout_file,
stderr_file=manager_stderr_file)
else:
plasma_manager_port = None
plasma_manager_name = None
@@ -1049,9 +1063,15 @@ def start_objstore(node_ip_address, redis_address,
plasma_manager_port)
def start_worker(node_ip_address, object_store_name, object_store_manager_name,
local_scheduler_name, redis_address, worker_path,
stdout_file=None, stderr_file=None, cleanup=True):
def start_worker(node_ip_address,
object_store_name,
object_store_manager_name,
local_scheduler_name,
redis_address,
worker_path,
stdout_file=None,
stderr_file=None,
cleanup=True):
"""This method starts a worker process.
Args:
@@ -1072,14 +1092,14 @@ def start_worker(node_ip_address, object_store_name, object_store_manager_name,
Python process that imported services exits. This is True by
default.
"""
command = [sys.executable,
"-u",
worker_path,
"--node-ip-address=" + node_ip_address,
"--object-store-name=" + object_store_name,
"--object-store-manager-name=" + object_store_manager_name,
"--local-scheduler-name=" + local_scheduler_name,
"--redis-address=" + str(redis_address)]
command = [
sys.executable, "-u", worker_path,
"--node-ip-address=" + node_ip_address,
"--object-store-name=" + object_store_name,
"--object-store-manager-name=" + object_store_manager_name,
"--local-scheduler-name=" + local_scheduler_name,
"--redis-address=" + str(redis_address)
]
p = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
if cleanup:
all_processes[PROCESS_TYPE_WORKER].append(p)
@@ -1087,8 +1107,12 @@ def start_worker(node_ip_address, object_store_name, object_store_manager_name,
[stdout_file, stderr_file])
def start_monitor(redis_address, node_ip_address, stdout_file=None,
stderr_file=None, cleanup=True, autoscaling_config=None):
def start_monitor(redis_address,
node_ip_address,
stdout_file=None,
stderr_file=None,
cleanup=True,
autoscaling_config=None):
"""Run a process to monitor the other processes.
Args:
@@ -1105,12 +1129,12 @@ def start_monitor(redis_address, node_ip_address, stdout_file=None,
default.
autoscaling_config: path to autoscaling config file.
"""
monitor_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
"monitor.py")
command = [sys.executable,
"-u",
monitor_path,
"--redis-address=" + str(redis_address)]
monitor_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "monitor.py")
command = [
sys.executable, "-u", monitor_path,
"--redis-address=" + str(redis_address)
]
if autoscaling_config:
command.append("--autoscaling-config=" + str(autoscaling_config))
p = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
@@ -1120,8 +1144,10 @@ def start_monitor(redis_address, node_ip_address, stdout_file=None,
[stdout_file, stderr_file])
def start_raylet_monitor(redis_address, stdout_file=None,
stderr_file=None, cleanup=True):
def start_raylet_monitor(redis_address,
stdout_file=None,
stderr_file=None,
cleanup=True):
"""Run a process to monitor the other processes.
Args:
@@ -1136,9 +1162,7 @@ def start_raylet_monitor(redis_address, stdout_file=None,
default.
"""
gcs_ip_address, gcs_port = redis_address.split(":")
command = [RAYLET_MONITOR_EXECUTABLE,
gcs_ip_address,
gcs_port]
command = [RAYLET_MONITOR_EXECUTABLE, gcs_ip_address, gcs_port]
p = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
if cleanup:
all_processes[PROCESS_TYPE_MONITOR].append(p)
@@ -1238,16 +1262,17 @@ def start_ray_processes(address_info=None,
workers_per_local_scheduler = []
for resource_dict in resources:
cpus = resource_dict.get("CPU")
workers_per_local_scheduler.append(cpus if cpus is not None
else psutil.cpu_count())
workers_per_local_scheduler.append(cpus if cpus is not None else
psutil.cpu_count())
if address_info is None:
address_info = {}
address_info["node_ip_address"] = node_ip_address
if worker_path is None:
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
"workers/default_worker.py")
worker_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"workers/default_worker.py")
# Start Redis if there isn't already an instance running. TODO(rkn): We are
# suppressing the output of Redis because on Linux it prints a bunch of
@@ -1257,7 +1282,8 @@ def start_ray_processes(address_info=None,
redis_shards = address_info.get("redis_shards", [])
if redis_address is None:
redis_address, redis_shards = start_redis(
node_ip_address, port=redis_port,
node_ip_address,
port=redis_port,
redis_shard_ports=redis_shard_ports,
num_redis_shards=num_redis_shards,
redis_max_clients=redis_max_clients,
@@ -1274,23 +1300,25 @@ def start_ray_processes(address_info=None,
# Start monitoring the processes.
monitor_stdout_file, monitor_stderr_file = new_log_files(
"monitor", redirect_output)
start_monitor(redis_address,
node_ip_address,
stdout_file=monitor_stdout_file,
stderr_file=monitor_stderr_file,
cleanup=cleanup,
autoscaling_config=autoscaling_config)
start_monitor(
redis_address,
node_ip_address,
stdout_file=monitor_stdout_file,
stderr_file=monitor_stderr_file,
cleanup=cleanup,
autoscaling_config=autoscaling_config)
if use_raylet:
start_raylet_monitor(redis_address,
stdout_file=monitor_stdout_file,
stderr_file=monitor_stderr_file,
cleanup=cleanup)
start_raylet_monitor(
redis_address,
stdout_file=monitor_stdout_file,
stderr_file=monitor_stderr_file,
cleanup=cleanup)
if redis_shards == []:
# Get redis shards from primary redis instance.
redis_ip_address, redis_port = redis_address.split(":")
redis_client = redis.StrictRedis(host=redis_ip_address,
port=redis_port)
redis_client = redis.StrictRedis(
host=redis_ip_address, port=redis_port)
redis_shards = redis_client.lrange("RedisShards", start=0, end=-1)
redis_shards = [shard.decode("ascii") for shard in redis_shards]
address_info["redis_shards"] = redis_shards
@@ -1299,21 +1327,23 @@ def start_ray_processes(address_info=None,
if include_log_monitor:
log_monitor_stdout_file, log_monitor_stderr_file = new_log_files(
"log_monitor", redirect_output=True)
start_log_monitor(redis_address,
node_ip_address,
stdout_file=log_monitor_stdout_file,
stderr_file=log_monitor_stderr_file,
cleanup=cleanup)
start_log_monitor(
redis_address,
node_ip_address,
stdout_file=log_monitor_stdout_file,
stderr_file=log_monitor_stderr_file,
cleanup=cleanup)
# Start the global scheduler, if necessary.
if include_global_scheduler and not use_raylet:
global_scheduler_stdout_file, global_scheduler_stderr_file = (
new_log_files("global_scheduler", redirect_output))
start_global_scheduler(redis_address,
node_ip_address,
stdout_file=global_scheduler_stdout_file,
stderr_file=global_scheduler_stderr_file,
cleanup=cleanup)
start_global_scheduler(
redis_address,
node_ip_address,
stdout_file=global_scheduler_stdout_file,
stderr_file=global_scheduler_stderr_file,
cleanup=cleanup)
# Initialize with existing services.
if "object_store_addresses" not in address_info:
@@ -1324,9 +1354,8 @@ def start_ray_processes(address_info=None,
local_scheduler_socket_names = address_info["local_scheduler_socket_names"]
# Get the ports to use for the object managers if any are provided.
object_manager_ports = (address_info["object_manager_ports"]
if "object_manager_ports" in address_info
else None)
object_manager_ports = (address_info["object_manager_ports"] if
"object_manager_ports" in address_info else None)
if not isinstance(object_manager_ports, list):
object_manager_ports = num_local_schedulers * [object_manager_ports]
assert len(object_manager_ports) == num_local_schedulers
@@ -1347,7 +1376,8 @@ def start_ray_processes(address_info=None,
manager_stdout_file=plasma_manager_stdout_file,
manager_stderr_file=plasma_manager_stderr_file,
objstore_memory=object_store_memory,
cleanup=cleanup, plasma_directory=plasma_directory,
cleanup=cleanup,
plasma_directory=plasma_directory,
huge_pages=huge_pages,
use_raylet=use_raylet)
object_store_addresses.append(object_store_address)
@@ -1355,8 +1385,8 @@ def start_ray_processes(address_info=None,
# Start any local schedulers that do not yet exist.
if not use_raylet:
for i in range(len(local_scheduler_socket_names),
num_local_schedulers):
for i in range(
len(local_scheduler_socket_names), num_local_schedulers):
# Connect the local scheduler to the object store at the same
# index.
object_store_address = object_store_addresses[i]
@@ -1374,8 +1404,9 @@ def start_ray_processes(address_info=None,
# redirect the worker output, then we cannot redirect the local
# scheduler output.
local_scheduler_stdout_file, local_scheduler_stderr_file = (
new_log_files("local_scheduler_{}".format(i),
redirect_output=redirect_worker_output))
new_log_files(
"local_scheduler_{}".format(i),
redirect_output=redirect_worker_output))
local_scheduler_name = start_local_scheduler(
redis_address,
node_ip_address,
@@ -1398,17 +1429,18 @@ def start_ray_processes(address_info=None,
else:
# Start the raylet. TODO(rkn): Modify this to allow starting
# multiple raylets on the same machine.
raylet_stdout_file, raylet_stderr_file = (
new_log_files("raylet_{}".format(i),
redirect_output=redirect_output))
address_info["raylet_socket_names"] = [start_raylet(
redis_address,
node_ip_address,
object_store_addresses[i].name,
worker_path,
stdout_file=None,
stderr_file=None,
cleanup=cleanup)]
raylet_stdout_file, raylet_stderr_file = (new_log_files(
"raylet_{}".format(i), redirect_output=redirect_output))
address_info["raylet_socket_names"] = [
start_raylet(
redis_address,
node_ip_address,
object_store_addresses[i].name,
worker_path,
stdout_file=None,
stderr_file=None,
cleanup=cleanup)
]
if not use_raylet:
# Start any workers that the local scheduler has not already started.
@@ -1419,28 +1451,30 @@ def start_ray_processes(address_info=None,
for j in range(num_local_scheduler_workers):
worker_stdout_file, worker_stderr_file = new_log_files(
"worker_{}_{}".format(i, j), redirect_output)
start_worker(node_ip_address,
object_store_address.name,
object_store_address.manager_name,
local_scheduler_name,
redis_address,
worker_path,
stdout_file=worker_stdout_file,
stderr_file=worker_stderr_file,
cleanup=cleanup)
start_worker(
node_ip_address,
object_store_address.name,
object_store_address.manager_name,
local_scheduler_name,
redis_address,
worker_path,
stdout_file=worker_stdout_file,
stderr_file=worker_stderr_file,
cleanup=cleanup)
workers_per_local_scheduler[i] -= 1
# Make sure that we've started all the workers.
assert(sum(workers_per_local_scheduler) == 0)
assert (sum(workers_per_local_scheduler) == 0)
# Try to start the web UI.
if include_webui:
ui_stdout_file, ui_stderr_file = new_log_files(
"webui", redirect_output=True)
address_info["webui_url"] = start_ui(redis_address,
stdout_file=ui_stdout_file,
stderr_file=ui_stderr_file,
cleanup=cleanup)
address_info["webui_url"] = start_ui(
redis_address,
stdout_file=ui_stdout_file,
stderr_file=ui_stderr_file,
cleanup=cleanup)
else:
address_info["webui_url"] = ""
# Return the addresses of the relevant processes.
@@ -1500,21 +1534,24 @@ def start_ray_node(node_ip_address,
A dictionary of the address information for the processes that were
started.
"""
address_info = {"redis_address": redis_address,
"object_manager_ports": object_manager_ports}
return start_ray_processes(address_info=address_info,
node_ip_address=node_ip_address,
num_workers=num_workers,
num_local_schedulers=num_local_schedulers,
object_store_memory=object_store_memory,
worker_path=worker_path,
include_log_monitor=True,
cleanup=cleanup,
redirect_worker_output=redirect_worker_output,
redirect_output=redirect_output,
resources=resources,
plasma_directory=plasma_directory,
huge_pages=huge_pages)
address_info = {
"redis_address": redis_address,
"object_manager_ports": object_manager_ports
}
return start_ray_processes(
address_info=address_info,
node_ip_address=node_ip_address,
num_workers=num_workers,
num_local_schedulers=num_local_schedulers,
object_store_memory=object_store_memory,
worker_path=worker_path,
include_log_monitor=True,
cleanup=cleanup,
redirect_worker_output=redirect_worker_output,
redirect_output=redirect_output,
resources=resources,
plasma_directory=plasma_directory,
huge_pages=huge_pages)
def start_ray_head(address_info=None,
+19 -17
View File
@@ -7,11 +7,10 @@ import funcsigs
from ray.utils import is_cython
FunctionSignature = namedtuple("FunctionSignature", ["arg_names",
"arg_defaults",
"arg_is_positionals",
"keyword_names",
"function_name"])
FunctionSignature = namedtuple("FunctionSignature", [
"arg_names", "arg_defaults", "arg_is_positionals", "keyword_names",
"function_name"
])
"""This class is used to represent a function signature.
Attributes:
@@ -49,13 +48,16 @@ def get_signature_params(func):
# The first condition for Cython functions, the latter for Cython instance
# methods
if is_cython(func):
attrs = ["__code__", "__annotations__",
"__defaults__", "__kwdefaults__"]
attrs = [
"__code__", "__annotations__", "__defaults__", "__kwdefaults__"
]
if all([hasattr(func, attr) for attr in attrs]):
original_func = func
def func(): return
def func():
return
for attr in attrs:
setattr(func, attr, getattr(original_func, attr))
else:
@@ -130,8 +132,8 @@ def extract_signature(func, ignore_first=False):
if ignore_first:
if len(sig_params) == 0:
raise Exception("Methods must take a 'self' argument, but the "
"method '{}' does not have one."
.format(func.__name__))
"method '{}' does not have one.".format(
func.__name__))
sig_params = sig_params[1:]
# Extract the names of the keyword arguments.
@@ -183,8 +185,8 @@ def extend_args(function_signature, args, kwargs):
for keyword_name in kwargs:
if keyword_name not in keyword_names:
raise Exception("The name '{}' is not a valid keyword argument "
"for the function '{}'."
.format(keyword_name, function_name))
"for the function '{}'.".format(
keyword_name, function_name))
# Fill in the remaining arguments.
zipped_info = list(zip(arg_names, arg_defaults,
@@ -201,12 +203,12 @@ def extend_args(function_signature, args, kwargs):
# can be omitted.
if not is_positional:
raise Exception("No value was provided for the argument "
"'{}' for the function '{}'."
.format(keyword_name, function_name))
"'{}' for the function '{}'.".format(
keyword_name, function_name))
too_many_arguments = (len(args) > len(arg_names) and
(len(arg_is_positionals) == 0 or
not arg_is_positionals[-1]))
too_many_arguments = (len(args) > len(arg_names)
and (len(arg_is_positionals) == 0
or not arg_is_positionals[-1]))
if too_many_arguments:
raise Exception("Too many arguments were passed to the function '{}'"
.format(function_name))
+9
View File
@@ -13,6 +13,7 @@ import numpy as np
def handle_int(a, b):
return a + 1, b + 1
# Test timing
@@ -25,6 +26,7 @@ def empty_function():
def trivial_function():
return 1
# Test keyword arguments
@@ -42,6 +44,7 @@ def keyword_fct2(a="hello", b="world"):
def keyword_fct3(a, b, c="hello", d="world"):
return "{} {} {} {}".format(a, b, c, d)
# Test variable numbers of arguments
@@ -56,17 +59,21 @@ def varargs_fct2(a, *b):
try:
@ray.remote
def kwargs_throw_exception(**c):
return ()
kwargs_exception_thrown = False
except Exception:
kwargs_exception_thrown = True
try:
@ray.remote
def varargs_and_kwargs_throw_exception(a, b="hi", *c):
return "{} {} {}".format(a, b, c)
varargs_and_kwargs_exception_thrown = False
except Exception:
varargs_and_kwargs_exception_thrown = True
@@ -88,6 +95,7 @@ def throw_exception_fct2():
def throw_exception_fct3(x):
raise Exception("Test function 3 intentionally failed.")
# test Python mode
@@ -101,6 +109,7 @@ def python_mode_g(x):
x[0] = 1
return x
# test no return values
+2 -2
View File
@@ -48,8 +48,8 @@ def _wait_for_nodes_to_join(num_nodes, timeout=20):
if num_ready_nodes > num_nodes:
# Too many nodes have joined. Something must be wrong.
raise Exception("{} nodes have joined the cluster, but we were "
"expecting {} nodes.".format(num_ready_nodes,
num_nodes))
"expecting {} nodes.".format(
num_ready_nodes, num_nodes))
time.sleep(0.1)
# If we get here then we timed out.
+2 -9
View File
@@ -9,14 +9,7 @@ from ray.tune.result import TrainingResult
from ray.tune.trainable import Trainable
from ray.tune.variant_generator import grid_search
__all__ = [
"Trainable",
"TrainingResult",
"TuneError",
"grid_search",
"register_env",
"register_trainable",
"run_experiments",
"Experiment"
"Trainable", "TrainingResult", "TuneError", "grid_search", "register_env",
"register_trainable", "run_experiments", "Experiment"
]
+22 -20
View File
@@ -35,10 +35,13 @@ class AsyncHyperBandScheduler(FIFOScheduler):
halving rate, specified by the reduction factor.
"""
def __init__(
self, time_attr='training_iteration',
reward_attr='episode_reward_mean', max_t=100,
grace_period=10, reduction_factor=3, brackets=3):
def __init__(self,
time_attr='training_iteration',
reward_attr='episode_reward_mean',
max_t=100,
grace_period=10,
reduction_factor=3,
brackets=3):
assert max_t > 0, "Max (time_attr) not valid!"
assert max_t >= grace_period, "grace_period must be <= max_t!"
assert grace_period > 0, "grace_period must be positive!"
@@ -51,8 +54,10 @@ class AsyncHyperBandScheduler(FIFOScheduler):
self._trial_info = {} # Stores Trial -> Bracket
# Tracks state for new trial add
self._brackets = [_Bracket(
grace_period, max_t, reduction_factor, s) for s in range(brackets)]
self._brackets = [
_Bracket(grace_period, max_t, reduction_factor, s)
for s in range(brackets)
]
self._counter = 0 # for
self._num_stopped = 0
self._reward_attr = reward_attr
@@ -60,7 +65,7 @@ class AsyncHyperBandScheduler(FIFOScheduler):
def on_trial_add(self, trial_runner, trial):
sizes = np.array([len(b._rungs) for b in self._brackets])
probs = np.e ** (sizes - sizes.max())
probs = np.e**(sizes - sizes.max())
normalized = probs / probs.sum()
idx = np.random.choice(len(self._brackets), p=normalized)
self._trial_info[trial.trial_id] = self._brackets[idx]
@@ -71,28 +76,23 @@ class AsyncHyperBandScheduler(FIFOScheduler):
action = TrialScheduler.STOP
else:
bracket = self._trial_info[trial.trial_id]
action = bracket.on_result(
trial,
getattr(result, self._time_attr),
getattr(result, self._reward_attr))
action = bracket.on_result(trial, getattr(result, self._time_attr),
getattr(result, self._reward_attr))
if action == TrialScheduler.STOP:
self._num_stopped += 1
return action
def on_trial_complete(self, trial_runner, trial, result):
bracket = self._trial_info[trial.trial_id]
bracket.on_result(
trial,
getattr(result, self._time_attr),
getattr(result, self._reward_attr))
bracket.on_result(trial, getattr(result, self._time_attr),
getattr(result, self._reward_attr))
del self._trial_info[trial.trial_id]
def on_trial_remove(self, trial_runner, trial):
del self._trial_info[trial.trial_id]
def debug_string(self):
out = "Using AsyncHyperBand: num_stopped={}".format(
self._num_stopped)
out = "Using AsyncHyperBand: num_stopped={}".format(self._num_stopped)
out += "\n" + "\n".join([b.debug_str() for b in self._brackets])
return out
@@ -111,6 +111,7 @@ class _Bracket():
>>> b.on_result(trial3, 1, 1) # STOP
>>> b.cutoff(b._rungs[0][1]) == 2.0
"""
def __init__(self, min_t, max_t, reduction_factor, s):
self.rf = reduction_factor
MAX_RUNGS = int(np.log(max_t / min_t) / np.log(self.rf) - s + 1)
@@ -140,9 +141,10 @@ class _Bracket():
return action
def debug_str(self):
iters = " | ".join(
["Iter {:.3f}: {}".format(milestone, self.cutoff(recorded))
for milestone, recorded in self._rungs])
iters = " | ".join([
"Iter {:.3f}: {}".format(milestone, self.cutoff(recorded))
for milestone, recorded in self._rungs
])
return "Bracket: " + iters
+46 -21
View File
@@ -2,7 +2,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import json
@@ -24,8 +23,8 @@ def json_to_resources(data):
"Unknown resource type {}, must be one of {}".format(
k, Resources._fields))
return Resources(
data.get("cpu", 1), data.get("gpu", 0),
data.get("extra_cpu", 0), data.get("extra_gpu", 0))
data.get("cpu", 1), data.get("gpu", 0), data.get("extra_cpu", 0),
data.get("extra_gpu", 0))
def resources_to_json(resources):
@@ -50,59 +49,85 @@ def make_parser(**kwargs):
# Note: keep this in sync with rllib/train.py
parser.add_argument(
"--run", default=None, type=str,
"--run",
default=None,
type=str,
help="The algorithm or model to train. This may refer to the name "
"of a built-on algorithm (e.g. RLLib's DQN or PPO), or a "
"user-defined trainable function or class registered in the "
"tune registry.")
parser.add_argument(
"--stop", default="{}", type=json.loads,
"--stop",
default="{}",
type=json.loads,
help="The stopping criteria, specified in JSON. The keys may be any "
"field in TrainingResult, e.g. "
"'{\"time_total_s\": 600, \"timesteps_total\": 100000}' to stop "
"after 600 seconds or 100k timesteps, whichever is reached first.")
parser.add_argument(
"--config", default="{}", type=json.loads,
"--config",
default="{}",
type=json.loads,
help="Algorithm-specific configuration (e.g. env, hyperparams), "
"specified in JSON.")
parser.add_argument(
"--resources", help="Deprecated, use --trial-resources.",
type=lambda v: _tune_error(
"The `resources` argument is no longer supported. "
"Use `trial_resources` or --trial-resources instead."))
"--resources",
help="Deprecated, use --trial-resources.",
type=lambda v: _tune_error("The `resources` argument is no longer "
"supported. Use `trial_resources` or "
"--trial-resources instead."))
parser.add_argument(
"--trial-resources", default='{"cpu": 1}', type=json_to_resources,
"--trial-resources",
default='{"cpu": 1}',
type=json_to_resources,
help="Machine resources to allocate per trial, e.g. "
"'{\"cpu\": 64, \"gpu\": 8}'. Note that GPUs will not be assigned "
"unless you specify them here.")
parser.add_argument(
"--repeat", default=1, type=int,
"--repeat",
default=1,
type=int,
help="Number of times to repeat each trial.")
parser.add_argument(
"--local-dir", default=DEFAULT_RESULTS_DIR, type=str,
"--local-dir",
default=DEFAULT_RESULTS_DIR,
type=str,
help="Local dir to save training results to. Defaults to '{}'.".format(
DEFAULT_RESULTS_DIR))
parser.add_argument(
"--upload-dir", default="", type=str,
"--upload-dir",
default="",
type=str,
help="Optional URI to sync training results to (e.g. s3://bucket).")
parser.add_argument(
"--checkpoint-freq", default=0, type=int,
"--checkpoint-freq",
default=0,
type=int,
help="How many training iterations between checkpoints. "
"A value of 0 (default) disables checkpointing.")
parser.add_argument(
"--max-failures", default=3, type=int,
"--max-failures",
default=3,
type=int,
help="Try to recover a trial from its last checkpoint at least this "
"many times. Only applies if checkpointing is enabled.")
parser.add_argument(
"--scheduler", default="FIFO", type=str,
"--scheduler",
default="FIFO",
type=str,
help="FIFO (default), MedianStopping, AsyncHyperBand,"
"HyperBand, or HyperOpt.")
"HyperBand, or HyperOpt.")
parser.add_argument(
"--scheduler-config", default="{}", type=json.loads,
"--scheduler-config",
default="{}",
type=json.loads,
help="Config options to pass to the scheduler.")
# Note: this currently only makes sense when running a single trial
parser.add_argument("--restore", default=None, type=str,
help="If specified, restore from this checkpoint.")
parser.add_argument(
"--restore",
default=None,
type=str,
help="If specified, restore from this checkpoint.")
return parser
@@ -60,18 +60,27 @@ if __name__ == "__main__":
# `episode_reward_mean` as the
# objective and `timesteps_total` as the time unit.
ahb = AsyncHyperBandScheduler(
time_attr="timesteps_total", reward_attr="episode_reward_mean",
grace_period=5, max_t=100)
time_attr="timesteps_total",
reward_attr="episode_reward_mean",
grace_period=5,
max_t=100)
run_experiments({
"asynchyperband_test": {
"run": "my_class",
"stop": {"training_iteration": 1 if args.smoke_test else 99999},
"repeat": 20,
"trial_resources": {"cpu": 1, "gpu": 0},
"config": {
"width": lambda spec: 10 + int(90 * random.random()),
"height": lambda spec: int(100 * random.random()),
},
}
}, scheduler=ahb)
run_experiments(
{
"asynchyperband_test": {
"run": "my_class",
"stop": {
"training_iteration": 1 if args.smoke_test else 99999
},
"repeat": 20,
"trial_resources": {
"cpu": 1,
"gpu": 0
},
"config": {
"width": lambda spec: 10 + int(90 * random.random()),
"height": lambda spec: int(100 * random.random()),
},
}
},
scheduler=ahb)
@@ -59,7 +59,8 @@ if __name__ == "__main__":
# Hyperband early stopping, configured with `episode_reward_mean` as the
# objective and `timesteps_total` as the time unit.
hyperband = HyperBandScheduler(
time_attr="timesteps_total", reward_attr="episode_reward_mean",
time_attr="timesteps_total",
reward_attr="episode_reward_mean",
max_t=100)
exp = Experiment(
+11 -5
View File
@@ -12,8 +12,8 @@ def easy_objective(config, reporter):
time.sleep(0.2)
reporter(
timesteps_total=1,
episode_reward_mean=-((config["height"]-14) ** 2
+ abs(config["width"]-3)))
episode_reward_mean=-(
(config["height"] - 14)**2 + abs(config["width"] - 3)))
time.sleep(0.2)
@@ -34,12 +34,18 @@ if __name__ == '__main__':
'height': hp.uniform('height', -100, 100),
}
config = {"my_exp": {
config = {
"my_exp": {
"run": "exp",
"repeat": 5 if args.smoke_test else 1000,
"stop": {"training_iteration": 1},
"stop": {
"training_iteration": 1
},
"config": {
"space": space}}}
"space": space
}
}
}
hpo_sched = HyperOptScheduler()
run_experiments(config, verbose=False, scheduler=hpo_sched)
+27 -15
View File
@@ -42,8 +42,11 @@ class MyTrainableClass(Trainable):
def _save(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "checkpoint")
with open(path, "w") as f:
f.write(json.dumps(
{"timestep": self.timestep, "value": self.current_value}))
f.write(
json.dumps({
"timestep": self.timestep,
"value": self.current_value
}))
return path
def _restore(self, checkpoint_path):
@@ -63,7 +66,8 @@ if __name__ == "__main__":
ray.init()
pbt = PopulationBasedTraining(
time_attr="training_iteration", reward_attr="episode_reward_mean",
time_attr="training_iteration",
reward_attr="episode_reward_mean",
perturbation_interval=10,
hyperparam_mutations={
# Allow for scaling-based perturbations, with a uniform backing
@@ -74,15 +78,23 @@ if __name__ == "__main__":
})
# Try to find the best factor 1 and factor 2
run_experiments({
"pbt_test": {
"run": "my_class",
"stop": {"training_iteration": 2 if args.smoke_test else 99999},
"repeat": 10,
"trial_resources": {"cpu": 1, "gpu": 0},
"config": {
"factor_1": 4.0,
"factor_2": 1.0,
},
}
}, scheduler=pbt, verbose=False)
run_experiments(
{
"pbt_test": {
"run": "my_class",
"stop": {
"training_iteration": 2 if args.smoke_test else 99999
},
"repeat": 10,
"trial_resources": {
"cpu": 1,
"gpu": 0
},
"config": {
"factor_1": 4.0,
"factor_2": 1.0,
},
}
},
scheduler=pbt,
verbose=False)
+36 -22
View File
@@ -1,5 +1,4 @@
#!/usr/bin/env python
"""Example of using PBT with RLlib.
Note that this requires a cluster with at least 8 GPUs in order for all trials
@@ -30,7 +29,8 @@ if __name__ == "__main__":
return config
pbt = PopulationBasedTraining(
time_attr="time_total_s", reward_attr="episode_reward_mean",
time_attr="time_total_s",
reward_attr="episode_reward_mean",
perturbation_interval=120,
resample_probability=0.25,
# Specifies the mutations of these hyperparams
@@ -45,26 +45,40 @@ if __name__ == "__main__":
custom_explore_fn=explore)
ray.init()
run_experiments({
"pbt_humanoid_test": {
"run": "PPO",
"env": "Humanoid-v1",
"repeat": 8,
"trial_resources": {"cpu": 4, "gpu": 1},
"config": {
"kl_coeff": 1.0,
"num_workers": 8,
"devices": ["/gpu:0"],
"model": {"free_log_std": True},
# These params are tuned from a fixed starting value.
"lambda": 0.95,
"clip_param": 0.2,
"sgd_stepsize": 1e-4,
# These params start off randomly drawn from a set.
"num_sgd_iter": lambda spec: random.choice([10, 20, 30]),
"sgd_batchsize": lambda spec: random.choice([128, 512, 2048]),
"timesteps_per_batch":
run_experiments(
{
"pbt_humanoid_test": {
"run": "PPO",
"env": "Humanoid-v1",
"repeat": 8,
"trial_resources": {
"cpu": 4,
"gpu": 1
},
"config": {
"kl_coeff":
1.0,
"num_workers":
8,
"devices": ["/gpu:0"],
"model": {
"free_log_std": True
},
# These params are tuned from a fixed starting value.
"lambda":
0.95,
"clip_param":
0.2,
"sgd_stepsize":
1e-4,
# These params start off randomly drawn from a set.
"num_sgd_iter":
lambda spec: random.choice([10, 20, 30]),
"sgd_batchsize":
lambda spec: random.choice([128, 512, 2048]),
"timesteps_per_batch":
lambda spec: random.choice([10000, 20000, 40000])
},
},
},
}, scheduler=pbt)
scheduler=pbt)
@@ -29,12 +29,10 @@ from ray.tune import Trainable
from ray.tune import TrainingResult
from ray.tune.pbt import PopulationBasedTraining
num_classes = 10
class Cifar10Model(Trainable):
def _read_data(self):
# The data, split between train and test sets:
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
@@ -54,27 +52,51 @@ class Cifar10Model(Trainable):
x = Input(shape=(32, 32, 3))
y = x
y = Convolution2D(
filters=64, kernel_size=3, strides=1, padding="same",
activation="relu", kernel_initializer="he_normal")(y)
filters=64,
kernel_size=3,
strides=1,
padding="same",
activation="relu",
kernel_initializer="he_normal")(y)
y = Convolution2D(
filters=64, kernel_size=3, strides=1, padding="same",
activation="relu", kernel_initializer="he_normal")(y)
filters=64,
kernel_size=3,
strides=1,
padding="same",
activation="relu",
kernel_initializer="he_normal")(y)
y = MaxPooling2D(pool_size=2, strides=2, padding="same")(y)
y = Convolution2D(
filters=128, kernel_size=3, strides=1, padding="same",
activation="relu", kernel_initializer="he_normal")(y)
filters=128,
kernel_size=3,
strides=1,
padding="same",
activation="relu",
kernel_initializer="he_normal")(y)
y = Convolution2D(
filters=128, kernel_size=3, strides=1, padding="same",
activation="relu", kernel_initializer="he_normal")(y)
filters=128,
kernel_size=3,
strides=1,
padding="same",
activation="relu",
kernel_initializer="he_normal")(y)
y = MaxPooling2D(pool_size=2, strides=2, padding="same")(y)
y = Convolution2D(
filters=256, kernel_size=3, strides=1, padding="same",
activation="relu", kernel_initializer="he_normal")(y)
filters=256,
kernel_size=3,
strides=1,
padding="same",
activation="relu",
kernel_initializer="he_normal")(y)
y = Convolution2D(
filters=256, kernel_size=3, strides=1, padding="same",
activation="relu", kernel_initializer="he_normal")(y)
filters=256,
kernel_size=3,
strides=1,
padding="same",
activation="relu",
kernel_initializer="he_normal")(y)
y = MaxPooling2D(pool_size=2, strides=2, padding="same")(y)
y = Flatten()(y)
@@ -91,9 +113,10 @@ class Cifar10Model(Trainable):
model = self._build_model(x_train.shape[1:])
opt = tf.keras.optimizers.Adadelta()
model.compile(loss="categorical_crossentropy",
optimizer=opt,
metrics=["accuracy"])
model.compile(
loss="categorical_crossentropy",
optimizer=opt,
metrics=["accuracy"])
self.model = model
def _train(self):
@@ -134,8 +157,7 @@ class Cifar10Model(Trainable):
# loss, accuracy
_, accuracy = self.model.evaluate(x_test, y_test, verbose=0)
return TrainingResult(timesteps_this_iter=10,
mean_accuracy=accuracy)
return TrainingResult(timesteps_this_iter=10, mean_accuracy=accuracy)
def _save(self, checkpoint_dir):
file_path = checkpoint_dir + "/model"
@@ -154,15 +176,17 @@ class Cifar10Model(Trainable):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--smoke-test",
action="store_true",
help="Finish quickly for testing")
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing")
args, _ = parser.parse_known_args()
register_trainable("train_cifar10", Cifar10Model)
train_spec = {
"run": "train_cifar10",
"trial_resources": {"cpu": 1, "gpu": 1},
"trial_resources": {
"cpu": 1,
"gpu": 1
},
"stop": {
"mean_accuracy": 0.80,
"timesteps_total": 300,
@@ -170,7 +194,7 @@ if __name__ == "__main__":
"config": {
"epochs": 1,
"batch_size": 64,
"lr": grid_search([10 ** -4, 10 ** -5]),
"lr": grid_search([10**-4, 10**-5]),
"decay": lambda spec: spec.config.lr / 100.0,
"dropout": grid_search([0.25, 0.5]),
},
@@ -178,17 +202,17 @@ if __name__ == "__main__":
}
if args.smoke_test:
train_spec["config"]["lr"] = 10 ** -4
train_spec["config"]["lr"] = 10**-4
train_spec["config"]["dropout"] = 0.5
ray.init()
pbt = PopulationBasedTraining(
time_attr="timesteps_total", reward_attr="mean_accuracy",
time_attr="timesteps_total",
reward_attr="mean_accuracy",
perturbation_interval=10,
hyperparam_mutations={
"dropout": lambda _: np.random.uniform(0, 1),
})
run_experiments({"pbt_cifar10": train_spec},
scheduler=pbt)
run_experiments({"pbt_cifar10": train_spec}, scheduler=pbt)
@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A deep MNIST classifier using convolutional layers.
See extensive documentation at
@@ -42,7 +41,7 @@ import tensorflow as tf
FLAGS = None
status_reporter = None # used to report training status back to Ray
activation_fn = None # e.g. tf.nn.relu
activation_fn = None # e.g. tf.nn.relu
def deepnn(x):
@@ -90,7 +89,7 @@ def deepnn(x):
W_fc1 = weight_variable([7 * 7 * 64, 1024])
b_fc1 = bias_variable([1024])
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
h_fc1 = activation_fn(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
# Dropout - controls the complexity of the model, prevents co-adaptation of
@@ -173,7 +172,10 @@ def main(_):
batch = mnist.train.next_batch(50)
if i % 10 == 0:
train_accuracy = accuracy.eval(feed_dict={
x: batch[0], y_: batch[1], keep_prob: 1.0})
x: batch[0],
y_: batch[1],
keep_prob: 1.0
})
# !!! Report status to ray.tune !!!
if status_reporter:
@@ -181,11 +183,17 @@ def main(_):
timesteps_total=i, mean_accuracy=train_accuracy)
print('step %d, training accuracy %g' % (i, train_accuracy))
train_step.run(
feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
train_step.run(feed_dict={
x: batch[0],
y_: batch[1],
keep_prob: 0.5
})
print('test accuracy %g' % accuracy.eval(feed_dict={
x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))
x: mnist.test.images,
y_: mnist.test.labels,
keep_prob: 1.0
}))
# !!! Entrypoint for ray.tune !!!
@@ -195,7 +203,9 @@ def train(config={'activation': 'relu'}, reporter=None):
activation_fn = getattr(tf.nn, config['activation'])
parser = argparse.ArgumentParser()
parser.add_argument(
'--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data',
'--data_dir',
type=str,
default='/tmp/tensorflow/mnist/input_data',
help='Directory for storing input data')
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
@@ -213,8 +223,8 @@ if __name__ == '__main__':
'run': 'train_mnist',
'repeat': 10,
'stop': {
'mean_accuracy': 0.99,
'timesteps_total': 600,
'mean_accuracy': 0.99,
'timesteps_total': 600,
},
'config': {
'activation': grid_search(['relu', 'elu', 'tanh']),
@@ -228,8 +238,12 @@ if __name__ == '__main__':
ray.init()
from ray.tune.async_hyperband import AsyncHyperBandScheduler
run_experiments({'tune_mnist_test': mnist_spec},
scheduler=AsyncHyperBandScheduler(
time_attr="timesteps_total",
reward_attr="mean_accuracy",
max_t=600,))
run_experiments(
{
'tune_mnist_test': mnist_spec
},
scheduler=AsyncHyperBandScheduler(
time_attr="timesteps_total",
reward_attr="mean_accuracy",
max_t=600,
))
+20 -10
View File
@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A deep MNIST classifier using convolutional layers.
See extensive documentation at
@@ -42,7 +41,7 @@ import tensorflow as tf
FLAGS = None
status_reporter = None # used to report training status back to Ray
activation_fn = None # e.g. tf.nn.relu
activation_fn = None # e.g. tf.nn.relu
def deepnn(x):
@@ -90,7 +89,7 @@ def deepnn(x):
W_fc1 = weight_variable([7 * 7 * 64, 1024])
b_fc1 = bias_variable([1024])
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
h_fc1 = activation_fn(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
# Dropout - controls the complexity of the model, prevents co-adaptation of
@@ -173,7 +172,10 @@ def main(_):
batch = mnist.train.next_batch(50)
if i % 10 == 0:
train_accuracy = accuracy.eval(feed_dict={
x: batch[0], y_: batch[1], keep_prob: 1.0})
x: batch[0],
y_: batch[1],
keep_prob: 1.0
})
# !!! Report status to ray.tune !!!
if status_reporter:
@@ -181,11 +183,17 @@ def main(_):
timesteps_total=i, mean_accuracy=train_accuracy)
print('step %d, training accuracy %g' % (i, train_accuracy))
train_step.run(
feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
train_step.run(feed_dict={
x: batch[0],
y_: batch[1],
keep_prob: 0.5
})
print('test accuracy %g' % accuracy.eval(feed_dict={
x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))
x: mnist.test.images,
y_: mnist.test.labels,
keep_prob: 1.0
}))
# !!! Entrypoint for ray.tune !!!
@@ -195,7 +203,9 @@ def train(config={'activation': 'relu'}, reporter=None):
activation_fn = getattr(tf.nn, config['activation'])
parser = argparse.ArgumentParser()
parser.add_argument(
'--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data',
'--data_dir',
type=str,
default='/tmp/tensorflow/mnist/input_data',
help='Directory for storing input data')
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
@@ -212,8 +222,8 @@ if __name__ == '__main__':
mnist_spec = {
'run': 'train_mnist',
'stop': {
'mean_accuracy': 0.99,
'time_total_s': 600,
'mean_accuracy': 0.99,
'time_total_s': 600,
},
'config': {
'activation': grid_search(['relu', 'elu', 'tanh']),
@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A deep MNIST classifier using convolutional layers.
See extensive documentation at
https://www.tensorflow.org/get_started/mnist/pros
@@ -39,7 +38,7 @@ from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import numpy as np
activation_fn = None # e.g. tf.nn.relu
activation_fn = None # e.g. tf.nn.relu
def setupCNN(x):
@@ -85,7 +84,7 @@ def setupCNN(x):
W_fc1 = weight_variable([7 * 7 * 64, 1024])
b_fc1 = bias_variable([1024])
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
h_fc1 = activation_fn(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
# Dropout - controls the complexity of the model, prevents co-adaptation of
@@ -182,14 +181,18 @@ class TrainMNIST(Trainable):
self.sess.run(
self.train_step,
feed_dict={
self.x: batch[0], self.y_: batch[1], self.keep_prob: 0.5
self.x: batch[0],
self.y_: batch[1],
self.keep_prob: 0.5
})
batch = self.mnist.train.next_batch(50)
train_accuracy = self.sess.run(
self.accuracy,
feed_dict={
self.x: batch[0], self.y_: batch[1], self.keep_prob: 1.0
self.x: batch[0],
self.y_: batch[1],
self.keep_prob: 1.0
})
self.iterations += 1
@@ -215,11 +218,11 @@ if __name__ == '__main__':
mnist_spec = {
'run': 'my_class',
'stop': {
'mean_accuracy': 0.99,
'time_total_s': 600,
'mean_accuracy': 0.99,
'time_total_s': 600,
},
'config': {
'learning_rate': lambda spec: 10 ** np.random.uniform(-5, -3),
'learning_rate': lambda spec: 10**np.random.uniform(-5, -3),
'activation': grid_search(['relu', 'elu', 'tanh']),
},
"repeat": 10,
@@ -231,8 +234,6 @@ if __name__ == '__main__':
ray.init()
hyperband = HyperBandScheduler(
time_attr="timesteps_total", reward_attr="mean_accuracy",
max_t=100)
time_attr="timesteps_total", reward_attr="mean_accuracy", max_t=100)
run_experiments(
{'mnist_hyperband_test': mnist_spec}, scheduler=hyperband)
run_experiments({'mnist_hyperband_test': mnist_spec}, scheduler=hyperband)
+16 -4
View File
@@ -35,14 +35,26 @@ class Experiment(object):
checkpoint at least this many times. Only applies if
checkpointing is enabled. Defaults to 3.
"""
def __init__(self, name, run, stop=None, config=None,
trial_resources=None, repeat=1, local_dir=None,
upload_dir="", checkpoint_freq=0, max_failures=3):
def __init__(self,
name,
run,
stop=None,
config=None,
trial_resources=None,
repeat=1,
local_dir=None,
upload_dir="",
checkpoint_freq=0,
max_failures=3):
spec = {
"run": run,
"stop": stop or {},
"config": config or {},
"trial_resources": trial_resources or {"cpu": 1, "gpu": 0},
"trial_resources": trial_resources or {
"cpu": 1,
"gpu": 0
},
"repeat": repeat,
"local_dir": local_dir or DEFAULT_RESULTS_DIR,
"upload_dir": upload_dir,
+4 -5
View File
@@ -91,8 +91,8 @@ class FunctionRunner(Trainable):
for k in self._default_config:
if k in scrubbed_config:
del scrubbed_config[k]
self._runner = _RunnerThread(
entrypoint, scrubbed_config, self._status_reporter)
self._runner = _RunnerThread(entrypoint, scrubbed_config,
self._status_reporter)
self._start_time = time.time()
self._last_reported_timestep = 0
self._runner.start()
@@ -104,9 +104,8 @@ class FunctionRunner(Trainable):
def _train(self):
time.sleep(
self.config.get(
"script_min_iter_time_s",
self._default_config["script_min_iter_time_s"]))
self.config.get("script_min_iter_time_s",
self._default_config["script_min_iter_time_s"]))
result = self._status_reporter._get_and_clear_status()
while result is None:
time.sleep(1)
+11 -9
View File
@@ -102,9 +102,8 @@ class HyperOptScheduler(FIFOScheduler):
self._hpopt_trials.refresh()
# Get new suggestion from
new_trials = self.algo(
new_ids, self.domain, self._hpopt_trials,
self.rstate.randint(2 ** 31 - 1))
new_trials = self.algo(new_ids, self.domain, self._hpopt_trials,
self.rstate.randint(2**31 - 1))
self._hpopt_trials.insert_trial_docs(new_trials)
self._hpopt_trials.refresh()
new_trial = new_trials[0]
@@ -112,8 +111,11 @@ class HyperOptScheduler(FIFOScheduler):
suggested_config = hpo.base.spec_from_misc(new_trial["misc"])
new_cfg.update(suggested_config)
kv_str = "_".join(["{}={}".format(k, str(v)[:5])
for k, v in sorted(suggested_config.items())])
kv_str = "_".join([
"{}={}".format(k,
str(v)[:5])
for k, v in sorted(suggested_config.items())
])
experiment_tag = "{}_{}".format(new_trial_id, kv_str)
# Keep this consistent with tune.variant_generator
@@ -166,8 +168,7 @@ class HyperOptScheduler(FIFOScheduler):
del self._tune_to_hp[trial]
def _to_hyperopt_result(self, result):
return {"loss": -getattr(result, self._reward_attr),
"status": "ok"}
return {"loss": -getattr(result, self._reward_attr), "status": "ok"}
def _get_hyperopt_trial(self, tid):
return [t for t in self._hpopt_trials.trials if t["tid"] == tid][0]
@@ -183,8 +184,9 @@ class HyperOptScheduler(FIFOScheduler):
experiments and trials left to run. If self._max_concurrent is None,
scheduler will add new trial if there is none that are pending.
"""
pending = [t for t in trial_runner.get_trials()
if t.status == Trial.PENDING]
pending = [
t for t in trial_runner.get_trials() if t.status == Trial.PENDING
]
if self._num_trials_left <= 0:
return
if self._max_concurrent is None:
+24 -22
View File
@@ -66,9 +66,10 @@ class HyperBandScheduler(FIFOScheduler):
mentioned in the original HyperBand paper.
"""
def __init__(
self, time_attr='training_iteration',
reward_attr='episode_reward_mean', max_t=81):
def __init__(self,
time_attr='training_iteration',
reward_attr='episode_reward_mean',
max_t=81):
assert max_t > 0, "Max (time_attr) not valid!"
FIFOScheduler.__init__(self)
self._eta = 3
@@ -78,13 +79,12 @@ class HyperBandScheduler(FIFOScheduler):
self._get_n0 = lambda s: int(
np.ceil(self._s_max_1/(s+1) * self._eta**s))
# bracket initial iterations
self._get_r0 = lambda s: int((max_t*self._eta**(-s)))
self._get_r0 = lambda s: int((max_t * self._eta**(-s)))
self._hyperbands = [[]] # list of hyperband iterations
self._trial_info = {} # Stores Trial -> Bracket, Band Iteration
# Tracks state for new trial add
self._state = {"bracket": None,
"band_idx": 0}
self._state = {"bracket": None, "band_idx": 0}
self._num_stopped = 0
self._reward_attr = reward_attr
self._time_attr = time_attr
@@ -116,9 +116,9 @@ class HyperBandScheduler(FIFOScheduler):
cur_bracket = None
else:
retry = False
cur_bracket = Bracket(
self._time_attr, self._get_n0(s), self._get_r0(s),
self._max_t_attr, self._eta, s)
cur_bracket = Bracket(self._time_attr, self._get_n0(s),
self._get_r0(s), self._max_t_attr,
self._eta, s)
cur_band.append(cur_bracket)
self._state["bracket"] = cur_bracket
@@ -217,11 +217,11 @@ class HyperBandScheduler(FIFOScheduler):
"""
for hyperband in self._hyperbands:
for bracket in sorted(hyperband,
key=lambda b: b.completion_percentage()):
for bracket in sorted(
hyperband, key=lambda b: b.completion_percentage()):
for trial in bracket.current_trials():
if (trial.status == Trial.PENDING and
trial_runner.has_resources(trial.resources)):
if (trial.status == Trial.PENDING
and trial_runner.has_resources(trial.resources)):
return trial
return None
@@ -258,6 +258,7 @@ class Bracket():
Also keeps track of progress to ensure good scheduling.
"""
def __init__(self, time_attr, max_trials, init_t_attr, max_t_attr, eta, s):
self._live_trials = {} # maps trial -> current result
self._all_trials = []
@@ -287,8 +288,9 @@ class Bracket():
"""Checks if all iterations have completed.
TODO(rliaw): also check that `t.iterations == self._r`"""
return all(self._get_result_time(result) >= self._cumul_r
for result in self._live_trials.values())
return all(
self._get_result_time(result) >= self._cumul_r
for result in self._live_trials.values())
def finished(self):
return self._halves == 0 and self.cur_iter_done()
@@ -379,7 +381,7 @@ class Bracket():
def _calculate_total_work(self, n, r, s):
work = 0
cumulative_r = r
for i in range(s+1):
for i in range(s + 1):
work += int(n) * int(r)
n /= self._eta
n = int(np.ceil(n))
@@ -389,11 +391,11 @@ class Bracket():
def __repr__(self):
status = ", ".join([
"Max Size (n)={}".format(self._n),
"Milestone (r)={}".format(self._cumul_r),
"completed={:.1%}".format(self.completion_percentage())
])
"Max Size (n)={}".format(self._n), "Milestone (r)={}".format(
self._cumul_r), "completed={:.1%}".format(
self.completion_percentage())
])
counts = collections.Counter([t.status for t in self._all_trials])
trial_statuses = ", ".join(sorted(
["{}: {}".format(k, v) for k, v in counts.items()]))
trial_statuses = ", ".join(
sorted(["{}: {}".format(k, v) for k, v in counts.items()]))
return "Bracket({}): {{{}}} ".format(status, trial_statuses)
+9 -13
View File
@@ -13,7 +13,6 @@ from ray.tune.cluster_info import get_ssh_key, get_ssh_user
from ray.tune.error import TuneError
from ray.tune.result import DEFAULT_RESULTS_DIR
# Map from (logdir, remote_dir) -> syncer
_syncers = {}
@@ -69,9 +68,8 @@ class _LogSyncer(object):
def sync_now(self, force=False):
self.last_sync_time = time.time()
if not self.worker_ip:
print(
"Worker ip unknown, skipping log sync for {}".format(
self.local_dir))
print("Worker ip unknown, skipping log sync for {}".format(
self.local_dir))
return
if self.worker_ip == self.local_ip:
@@ -80,23 +78,21 @@ class _LogSyncer(object):
ssh_key = get_ssh_key()
ssh_user = get_ssh_user()
if ssh_key is None or ssh_user is None:
print(
"Error: log sync requires cluster to be setup with "
"`ray create_or_update`.")
print("Error: log sync requires cluster to be setup with "
"`ray create_or_update`.")
return
if not distutils.spawn.find_executable("rsync"):
print("Error: log sync requires rsync to be installed.")
return
worker_to_local_sync_cmd = (
("""rsync -avz -e "ssh -i '{}' -o ConnectTimeout=120s """
"""-o StrictHostKeyChecking=no" '{}@{}:{}/' '{}/'""").format(
worker_to_local_sync_cmd = ((
"""rsync -avz -e "ssh -i '{}' -o ConnectTimeout=120s """
"""-o StrictHostKeyChecking=no" '{}@{}:{}/' '{}/'""").format(
ssh_key, ssh_user, self.worker_ip,
pipes.quote(self.local_dir), pipes.quote(self.local_dir)))
if self.remote_dir:
local_to_remote_sync_cmd = (
"aws s3 sync '{}' '{}'".format(
pipes.quote(self.local_dir), pipes.quote(self.remote_dir)))
local_to_remote_sync_cmd = ("aws s3 sync '{}' '{}'".format(
pipes.quote(self.local_dir), pipes.quote(self.remote_dir)))
else:
local_to_remote_sync_cmd = None
+8 -8
View File
@@ -110,9 +110,9 @@ def to_tf_values(result, path):
for attr, value in result.items():
if value is not None:
if type(value) in [int, float]:
values.append(tf.Summary.Value(
tag="/".join(path + [attr]),
simple_value=value))
values.append(
tf.Summary.Value(
tag="/".join(path + [attr]), simple_value=value))
elif type(value) is dict:
values.extend(to_tf_values(value, path + [attr]))
return values
@@ -125,8 +125,8 @@ class _TFLogger(Logger):
def on_result(self, result):
tmp = result._asdict()
for k in [
"config", "pid", "timestamp", "time_total_s",
"timesteps_total"]:
"config", "pid", "timestamp", "time_total_s", "timesteps_total"
]:
del tmp[k] # not useful to tf log these
values = to_tf_values(tmp, ["ray", "tune"])
train_stats = tf.Summary(value=values)
@@ -165,9 +165,9 @@ class _CustomEncoder(json.JSONEncoder):
return repr(o) if not np.isnan(o) else nan_str
_iterencode = json.encoder._make_iterencode(
None, self.default, _encoder, self.indent, floatstr,
self.key_separator, self.item_separator, self.sort_keys,
self.skipkeys, _one_shot)
None, self.default, _encoder, self.indent, floatstr,
self.key_separator, self.item_separator, self.sort_keys,
self.skipkeys, _one_shot)
return _iterencode(o, 0)
def default(self, value):
+11 -7
View File
@@ -32,10 +32,13 @@ class MedianStoppingRule(FIFOScheduler):
time a trial reports. Defaults to True.
"""
def __init__(
self, time_attr="time_total_s", reward_attr="episode_reward_mean",
grace_period=60.0, min_samples_required=3,
hard_stop=True, verbose=True):
def __init__(self,
time_attr="time_total_s",
reward_attr="episode_reward_mean",
grace_period=60.0,
min_samples_required=3,
hard_stop=True,
verbose=True):
FIFOScheduler.__init__(self)
self._stopped_trials = set()
self._completed_trials = set()
@@ -103,9 +106,10 @@ class MedianStoppingRule(FIFOScheduler):
results = self._results[trial]
# TODO(ekl) we could do interpolation to be more precise, but for now
# assume len(results) is large and the time diffs are roughly equal
return np.mean(
[getattr(r, self._reward_attr)
for r in results if getattr(r, self._time_attr) <= t_max])
return np.mean([
getattr(r, self._reward_attr) for r in results
if getattr(r, self._time_attr) <= t_max
])
def _best_result(self, trial):
results = self._results[trial]
+26 -26
View File
@@ -11,7 +11,6 @@ from ray.tune.trial import Trial
from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler
from ray.tune.variant_generator import _format_vars
# Parameters are transferred from the top PBT_QUANTILE fraction of trials to
# the bottom PBT_QUANTILE fraction.
PBT_QUANTILE = 0.25
@@ -27,9 +26,8 @@ class PBTTrialState(object):
self.last_perturbation_time = 0
def __repr__(self):
return str((
self.last_score, self.last_checkpoint,
self.last_perturbation_time))
return str((self.last_score, self.last_checkpoint,
self.last_perturbation_time))
def explore(config, mutations, resample_probability, custom_explore_fn):
@@ -51,12 +49,13 @@ def explore(config, mutations, resample_probability, custom_explore_fn):
config[key] not in distribution:
new_config[key] = random.choice(distribution)
elif random.random() > 0.5:
new_config[key] = distribution[
max(0, distribution.index(config[key]) - 1)]
new_config[key] = distribution[max(
0,
distribution.index(config[key]) - 1)]
else:
new_config[key] = distribution[
min(len(distribution) - 1,
distribution.index(config[key]) + 1)]
new_config[key] = distribution[min(
len(distribution) - 1,
distribution.index(config[key]) + 1)]
else:
if random.random() < resample_probability:
new_config[key] = distribution()
@@ -70,8 +69,8 @@ def explore(config, mutations, resample_probability, custom_explore_fn):
new_config = custom_explore_fn(new_config)
assert new_config is not None, \
"Custom explore fn failed to return new config"
print(
"[explore] perturbed config from {} -> {}".format(config, new_config))
print("[explore] perturbed config from {} -> {}".format(
config, new_config))
return new_config
@@ -148,10 +147,13 @@ class PopulationBasedTraining(FIFOScheduler):
>>> run_experiments({...}, scheduler=pbt)
"""
def __init__(
self, time_attr="time_total_s", reward_attr="episode_reward_mean",
perturbation_interval=60.0, hyperparam_mutations={},
resample_probability=0.25, custom_explore_fn=None):
def __init__(self,
time_attr="time_total_s",
reward_attr="episode_reward_mean",
perturbation_interval=60.0,
hyperparam_mutations={},
resample_probability=0.25,
custom_explore_fn=None):
if not hyperparam_mutations and not custom_explore_fn:
raise TuneError(
"You must specify at least one of `hyperparam_mutations` or "
@@ -209,14 +211,13 @@ class PopulationBasedTraining(FIFOScheduler):
if not new_state.last_checkpoint:
print("[pbt] warn: no checkpoint for trial, skip exploit", trial)
return
new_config = explore(
trial_to_clone.config, self._hyperparam_mutations,
self._resample_probability, self._custom_explore_fn)
print(
"[exploit] transferring weights from trial "
"{} (score {}) -> {} (score {})".format(
trial_to_clone, new_state.last_score, trial,
trial_state.last_score))
new_config = explore(trial_to_clone.config, self._hyperparam_mutations,
self._resample_probability,
self._custom_explore_fn)
print("[exploit] transferring weights from trial "
"{} (score {}) -> {} (score {})".format(
trial_to_clone, new_state.last_score, trial,
trial_state.last_score))
# TODO(ekl) restarting the trial is expensive. We should implement a
# lighter way reset() method that can alter the trial config.
trial.stop(stop_logger=False)
@@ -242,9 +243,8 @@ class PopulationBasedTraining(FIFOScheduler):
if len(trials) <= 1:
return [], []
else:
return (
trials[:int(math.ceil(len(trials)*PBT_QUANTILE))],
trials[int(math.floor(-len(trials)*PBT_QUANTILE)):])
return (trials[:int(math.ceil(len(trials) * PBT_QUANTILE))],
trials[int(math.floor(-len(trials) * PBT_QUANTILE)):])
def choose_trial_to_run(self, trial_runner):
"""Ensures all trials get fair share of time (as defined by time_attr).
+5 -5
View File
@@ -14,7 +14,8 @@ ENV_CREATOR = "env_creator"
RLLIB_MODEL = "rllib_model"
RLLIB_PREPROCESSOR = "rllib_preprocessor"
KNOWN_CATEGORIES = [
TRAINABLE_CLASS, ENV_CREATOR, RLLIB_MODEL, RLLIB_PREPROCESSOR]
TRAINABLE_CLASS, ENV_CREATOR, RLLIB_MODEL, RLLIB_PREPROCESSOR
]
def register_trainable(name, trainable):
@@ -32,8 +33,8 @@ def register_trainable(name, trainable):
if isinstance(trainable, FunctionType):
trainable = wrap_function(trainable)
if not issubclass(trainable, Trainable):
raise TypeError(
"Second argument must be convertable to Trainable", trainable)
raise TypeError("Second argument must be convertable to Trainable",
trainable)
_default_registry.register(TRAINABLE_CLASS, name, trainable)
@@ -46,8 +47,7 @@ def register_env(name, env_creator):
"""
if not isinstance(env_creator, FunctionType):
raise TypeError(
"Second argument must be a function.", env_creator)
raise TypeError("Second argument must be a function.", env_creator)
_default_registry.register(ENV_CREATOR, name, env_creator)
+54 -51
View File
@@ -4,8 +4,6 @@ from __future__ import print_function
from collections import namedtuple
import os
"""
When using ray.tune with custom training scripts, you must periodically report
training status back to Ray by calling reporter(result).
@@ -18,73 +16,78 @@ In RLlib, the supplied algorithms fill in TrainingResult for you.
# Where ray.tune writes result files by default
DEFAULT_RESULTS_DIR = os.path.expanduser("~/ray_results")
TrainingResult = namedtuple(
"TrainingResult",
[
# (Required) Accumulated timesteps for this entire experiment.
"timesteps_total",
TrainingResult = namedtuple("TrainingResult", [
# (Required) Accumulated timesteps for this entire experiment.
"timesteps_total",
# (Optional) If training is terminated.
"done",
# (Optional) If training is terminated.
"done",
# (Optional) Custom metadata to report for this iteration.
"info",
# (Optional) Custom metadata to report for this iteration.
"info",
# (Optional) The mean episode reward if applicable.
"episode_reward_mean",
# (Optional) The mean episode reward if applicable.
"episode_reward_mean",
# (Optional) The mean episode length if applicable.
"episode_len_mean",
# (Optional) The mean episode length if applicable.
"episode_len_mean",
# (Optional) The number of episodes total.
"episodes_total",
# (Optional) The number of episodes total.
"episodes_total",
# (Optional) The current training accuracy if applicable.
"mean_accuracy",
# (Optional) The current training accuracy if applicable.
"mean_accuracy",
# (Optional) The current validation accuracy if applicable.
"mean_validation_accuracy",
# (Optional) The current validation accuracy if applicable.
"mean_validation_accuracy",
# (Optional) The current training loss if applicable.
"mean_loss",
# (Optional) The current training loss if applicable.
"mean_loss",
# (Auto-filled) The negated current training loss.
"neg_mean_loss",
# (Auto-filled) The negated current training loss.
"neg_mean_loss",
# (Auto-filled) Unique string identifier for this experiment.
# This id is preserved across checkpoint / restore calls.
"experiment_id",
# (Auto-filled) Unique string identifier for this experiment. This id is
# preserved across checkpoint / restore calls.
"experiment_id",
# (Auto-filled) The index of this training iteration,
# e.g. call to train().
"training_iteration",
# (Auto-filled) The index of this training iteration, e.g. call to train().
"training_iteration",
# (Auto-filled) Number of timesteps in the simulator
# in this iteration.
"timesteps_this_iter",
# (Auto-filled) Number of timesteps in the simulator in this iteration.
"timesteps_this_iter",
# (Auto-filled) Time in seconds this iteration took to run. This may
# be overriden in order to override the system-computed
# time difference.
"time_this_iter_s",
# (Auto-filled) Time in seconds this iteration took to run. This may be
# overriden in order to override the system-computed time difference.
"time_this_iter_s",
# (Auto-filled) Accumulated time in seconds for this entire experiment.
"time_total_s",
# (Auto-filled) Accumulated time in seconds for this entire experiment.
"time_total_s",
# (Auto-filled) The pid of the training process.
"pid",
# (Auto-filled) The pid of the training process.
"pid",
# (Auto-filled) A formatted date of when the result was processed.
"date",
# (Auto-filled) A formatted date of when the result was processed.
"date",
# (Auto-filled) A UNIX timestamp of when the result was processed.
"timestamp",
# (Auto-filled) A UNIX timestamp of when the result was processed.
"timestamp",
# (Auto-filled) The hostname of the machine hosting the
# training process.
"hostname",
# (Auto-filled) The hostname of the machine hosting the training process.
"hostname",
# (Auto-filled) The node ip of the machine hosting the
# training process.
"node_ip",
# (Auto-filled) The node ip of the machine hosting the training process.
"node_ip",
# (Auto=filled) The current hyperparameter configuration.
"config",
])
# (Auto=filled) The current hyperparameter configuration.
"config",
])
TrainingResult.__new__.__defaults__ = (None,) * len(TrainingResult._fields)
TrainingResult.__new__.__defaults__ = (None, ) * len(TrainingResult._fields)
+3 -1
View File
@@ -20,7 +20,9 @@ if __name__ == "__main__":
run_experiments({
"test": {
"run": "my_class",
"stop": {"training_iteration": 1}
"stop": {
"training_iteration": 1
}
}
})
assert 'ray.rllib' not in sys.modules, "RLlib should not be imported"
+185 -122
View File
@@ -60,163 +60,209 @@ class TrainableFunctionApiTest(unittest.TestCase):
def testRewriteEnv(self):
def train(config, reporter):
reporter(timesteps_total=1)
register_trainable("f1", train)
[trial] = run_experiments({"foo": {
"run": "f1",
"env": "CartPole-v0",
}})
[trial] = run_experiments({
"foo": {
"run": "f1",
"env": "CartPole-v0",
}
})
self.assertEqual(trial.config["env"], "CartPole-v0")
def testConfigPurity(self):
def train(config, reporter):
assert config == {"a": "b"}, config
reporter(timesteps_total=1)
register_trainable("f1", train)
run_experiments({"foo": {
"run": "f1",
"config": {"a": "b"},
}})
run_experiments({
"foo": {
"run": "f1",
"config": {
"a": "b"
},
}
})
def testLogdir(self):
def train(config, reporter):
assert "/tmp/logdir/foo" in os.getcwd(), os.getcwd()
reporter(timesteps_total=1)
register_trainable("f1", train)
run_experiments({"foo": {
"run": "f1",
"local_dir": "/tmp/logdir",
"config": {"a": "b"},
}})
run_experiments({
"foo": {
"run": "f1",
"local_dir": "/tmp/logdir",
"config": {
"a": "b"
},
}
})
def testLongFilename(self):
def train(config, reporter):
assert "/tmp/logdir/foo" in os.getcwd(), os.getcwd()
reporter(timesteps_total=1)
register_trainable("f1", train)
run_experiments({"foo": {
"run": "f1",
"local_dir": "/tmp/logdir",
"config": {
"a" * 50: lambda spec: 5.0 / 7,
"b" * 50: lambda spec: "long" * 40},
}})
run_experiments({
"foo": {
"run": "f1",
"local_dir": "/tmp/logdir",
"config": {
"a" * 50: lambda spec: 5.0 / 7,
"b" * 50: lambda spec: "long" * 40
},
}
})
def testBadParams(self):
def f():
run_experiments({"foo": {}})
self.assertRaises(TuneError, f)
def testBadParams2(self):
def f():
run_experiments({"foo": {
"run": "asdf",
"bah": "this param is not allowed",
}})
run_experiments({
"foo": {
"run": "asdf",
"bah": "this param is not allowed",
}
})
self.assertRaises(TuneError, f)
def testBadParams3(self):
def f():
run_experiments({"foo": {
"run": grid_search("invalid grid search"),
}})
run_experiments({
"foo": {
"run": grid_search("invalid grid search"),
}
})
self.assertRaises(TuneError, f)
def testBadParams4(self):
def f():
run_experiments({"foo": {
"run": "asdf",
}})
run_experiments({
"foo": {
"run": "asdf",
}
})
self.assertRaises(TuneError, f)
def testBadParams5(self):
def f():
run_experiments({"foo": {
"run": "PPO",
"stop": {"asdf": 1}
}})
run_experiments({"foo": {"run": "PPO", "stop": {"asdf": 1}}})
self.assertRaises(TuneError, f)
def testBadParams6(self):
def f():
run_experiments({"foo": {
"run": "PPO",
"trial_resources": {"asdf": 1}
}})
run_experiments({
"foo": {
"run": "PPO",
"trial_resources": {
"asdf": 1
}
}
})
self.assertRaises(TuneError, f)
def testBadReturn(self):
def train(config, reporter):
reporter()
register_trainable("f1", train)
def f():
run_experiments({"foo": {
"run": "f1",
"config": {
"script_min_iter_time_s": 0,
},
}})
run_experiments({
"foo": {
"run": "f1",
"config": {
"script_min_iter_time_s": 0,
},
}
})
self.assertRaises(TuneError, f)
def testEarlyReturn(self):
def train(config, reporter):
reporter(timesteps_total=100, done=True)
time.sleep(99999)
register_trainable("f1", train)
[trial] = run_experiments({"foo": {
"run": "f1",
"config": {
"script_min_iter_time_s": 0,
},
}})
[trial] = run_experiments({
"foo": {
"run": "f1",
"config": {
"script_min_iter_time_s": 0,
},
}
})
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result.timesteps_total, 100)
def testAbruptReturn(self):
def train(config, reporter):
reporter(timesteps_total=100)
register_trainable("f1", train)
[trial] = run_experiments({"foo": {
"run": "f1",
"config": {
"script_min_iter_time_s": 0,
},
}})
[trial] = run_experiments({
"foo": {
"run": "f1",
"config": {
"script_min_iter_time_s": 0,
},
}
})
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result.timesteps_total, 100)
def testErrorReturn(self):
def train(config, reporter):
raise Exception("uh oh")
register_trainable("f1", train)
def f():
run_experiments({"foo": {
"run": "f1",
"config": {
"script_min_iter_time_s": 0,
},
}})
run_experiments({
"foo": {
"run": "f1",
"config": {
"script_min_iter_time_s": 0,
},
}
})
self.assertRaises(TuneError, f)
def testSuccess(self):
def train(config, reporter):
for i in range(100):
reporter(timesteps_total=i)
register_trainable("f1", train)
[trial] = run_experiments({"foo": {
"run": "f1",
"config": {
"script_min_iter_time_s": 0,
},
}})
[trial] = run_experiments({
"foo": {
"run": "f1",
"config": {
"script_min_iter_time_s": 0,
},
}
})
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result.timesteps_total, 99)
class RunExperimentTest(unittest.TestCase):
def setUp(self):
ray.init()
@@ -228,6 +274,7 @@ class RunExperimentTest(unittest.TestCase):
def train(config, reporter):
for i in range(100):
reporter(timesteps_total=i)
register_trainable("f1", train)
trials = run_experiments({
"foo": {
@@ -251,13 +298,14 @@ class RunExperimentTest(unittest.TestCase):
def train(config, reporter):
for i in range(100):
reporter(timesteps_total=i)
register_trainable("f1", train)
exp1 = Experiment(**{
"name": "foo",
"run": "f1",
"config": {
"script_min_iter_time_s": 0
}
"name": "foo",
"run": "f1",
"config": {
"script_min_iter_time_s": 0
}
})
[trial] = run_experiments(exp1)
self.assertEqual(trial.status, Trial.TERMINATED)
@@ -267,20 +315,21 @@ class RunExperimentTest(unittest.TestCase):
def train(config, reporter):
for i in range(100):
reporter(timesteps_total=i)
register_trainable("f1", train)
exp1 = Experiment(**{
"name": "foo",
"run": "f1",
"config": {
"script_min_iter_time_s": 0
}
"name": "foo",
"run": "f1",
"config": {
"script_min_iter_time_s": 0
}
})
exp2 = Experiment(**{
"name": "bar",
"run": "f1",
"config": {
"script_min_iter_time_s": 0
}
"name": "bar",
"run": "f1",
"config": {
"script_min_iter_time_s": 0
}
})
trials = run_experiments([exp1, exp2])
for trial in trials:
@@ -306,9 +355,8 @@ class VariantGeneratorTest(unittest.TestCase):
self.assertEqual(trials[0].trainable_name, "PPO")
self.assertEqual(trials[0].experiment_tag, "0")
self.assertEqual(trials[0].max_failures, 5)
self.assertEqual(
trials[0].local_dir,
os.path.join(DEFAULT_RESULTS_DIR, "tune-pong"))
self.assertEqual(trials[0].local_dir,
os.path.join(DEFAULT_RESULTS_DIR, "tune-pong"))
self.assertEqual(trials[1].experiment_tag, "1")
def testEval(self):
@@ -392,11 +440,13 @@ class VariantGeneratorTest(unittest.TestCase):
trials = generate_trials({
"run": "PPO",
"config": {
"x": grid_search([
"x":
grid_search([
lambda spec: spec.config.y * 100,
lambda spec: spec.config.y * 200
]),
"y": lambda spec: 1,
"y":
lambda spec: 1,
},
})
trials = list(trials)
@@ -406,12 +456,13 @@ class VariantGeneratorTest(unittest.TestCase):
def testRecursiveDep(self):
try:
list(generate_trials({
"run": "PPO",
"config": {
"foo": lambda spec: spec.config.foo,
},
}))
list(
generate_trials({
"run": "PPO",
"config": {
"foo": lambda spec: spec.config.foo,
},
}))
except RecursiveDependencyError as e:
assert "`foo` recursively depends on" in str(e), e
else:
@@ -442,12 +493,15 @@ class TrialRunnerTest(unittest.TestCase):
register_trainable("f1", train)
experiments = {"foo": {
"run": "f1",
"config": {
"a" * 50: lambda spec: 5.0 / 7,
"b" * 50: lambda spec: "long" * 40},
}}
experiments = {
"foo": {
"run": "f1",
"config": {
"a" * 50: lambda spec: 5.0 / 7,
"b" * 50: lambda spec: "long" * 40
},
}
}
for name, spec in experiments.items():
for trial in generate_trials(spec, name):
@@ -468,12 +522,12 @@ class TrialRunnerTest(unittest.TestCase):
ray.init(num_cpus=4, num_gpus=2)
runner = TrialRunner()
kwargs = {
"stopping_criterion": {"training_iteration": 1},
"stopping_criterion": {
"training_iteration": 1
},
"resources": Resources(cpu=1, gpu=0, extra_cpu=3, extra_gpu=1),
}
trials = [
Trial("__fake", **kwargs),
Trial("__fake", **kwargs)]
trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)]
for t in trials:
runner.add_trial(t)
@@ -489,12 +543,12 @@ class TrialRunnerTest(unittest.TestCase):
ray.init(num_cpus=4, num_gpus=1)
runner = TrialRunner()
kwargs = {
"stopping_criterion": {"training_iteration": 1},
"stopping_criterion": {
"training_iteration": 1
},
"resources": Resources(cpu=1, gpu=1),
}
trials = [
Trial("__fake", **kwargs),
Trial("__fake", **kwargs)]
trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)]
for t in trials:
runner.add_trial(t)
@@ -518,12 +572,12 @@ class TrialRunnerTest(unittest.TestCase):
ray.init(num_cpus=4, num_gpus=2)
runner = TrialRunner()
kwargs = {
"stopping_criterion": {"training_iteration": 5},
"stopping_criterion": {
"training_iteration": 5
},
"resources": Resources(cpu=1, gpu=1),
}
trials = [
Trial("__fake", **kwargs),
Trial("__fake", **kwargs)]
trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)]
for t in trials:
runner.add_trial(t)
@@ -547,13 +601,13 @@ class TrialRunnerTest(unittest.TestCase):
ray.init(num_cpus=4, num_gpus=2)
runner = TrialRunner()
kwargs = {
"stopping_criterion": {"training_iteration": 1},
"stopping_criterion": {
"training_iteration": 1
},
"resources": Resources(cpu=1, gpu=1),
}
_default_registry.register(TRAINABLE_CLASS, "asdf", None)
trials = [
Trial("asdf", **kwargs),
Trial("__fake", **kwargs)]
trials = [Trial("asdf", **kwargs), Trial("__fake", **kwargs)]
for t in trials:
runner.add_trial(t)
@@ -644,7 +698,9 @@ class TrialRunnerTest(unittest.TestCase):
ray.init(num_cpus=1, num_gpus=1)
runner = TrialRunner()
kwargs = {
"stopping_criterion": {"training_iteration": 1},
"stopping_criterion": {
"training_iteration": 1
},
"resources": Resources(cpu=1, gpu=1),
}
runner.add_trial(Trial("__fake", **kwargs))
@@ -675,7 +731,9 @@ class TrialRunnerTest(unittest.TestCase):
ray.init(num_cpus=1, num_gpus=1)
runner = TrialRunner()
kwargs = {
"stopping_criterion": {"training_iteration": 2},
"stopping_criterion": {
"training_iteration": 2
},
"resources": Resources(cpu=1, gpu=1),
}
runner.add_trial(Trial("__fake", **kwargs))
@@ -692,7 +750,9 @@ class TrialRunnerTest(unittest.TestCase):
ray.init(num_cpus=1, num_gpus=1)
runner = TrialRunner()
kwargs = {
"stopping_criterion": {"training_iteration": 2},
"stopping_criterion": {
"training_iteration": 2
},
"resources": Resources(cpu=1, gpu=1),
}
runner.add_trial(Trial("__fake", **kwargs))
@@ -721,14 +781,17 @@ class TrialRunnerTest(unittest.TestCase):
ray.init(num_cpus=4, num_gpus=2)
runner = TrialRunner()
kwargs = {
"stopping_criterion": {"training_iteration": 5},
"stopping_criterion": {
"training_iteration": 5
},
"resources": Resources(cpu=1, gpu=1),
}
trials = [
Trial("__fake", **kwargs),
Trial("__fake", **kwargs),
Trial("__fake", **kwargs),
Trial("__fake", **kwargs)]
Trial("__fake", **kwargs)
]
for t in trials:
runner.add_trial(t)
runner.step()
+69 -76
View File
@@ -19,9 +19,8 @@ _register_all()
def result(t, rew):
return TrainingResult(time_total_s=t,
episode_reward_mean=rew,
training_iteration=int(t))
return TrainingResult(
time_total_s=t, episode_reward_mean=rew, training_iteration=int(t))
class EarlyStoppingSuite(unittest.TestCase):
@@ -76,8 +75,7 @@ class EarlyStoppingSuite(unittest.TestCase):
rule.on_trial_result(None, t3, result(2, 10)),
TrialScheduler.CONTINUE)
self.assertEqual(
rule.on_trial_result(None, t3, result(3, 10)),
TrialScheduler.STOP)
rule.on_trial_result(None, t3, result(3, 10)), TrialScheduler.STOP)
def testMedianStoppingMinSamples(self):
rule = MedianStoppingRule(grace_period=0, min_samples_required=2)
@@ -89,8 +87,7 @@ class EarlyStoppingSuite(unittest.TestCase):
TrialScheduler.CONTINUE)
rule.on_trial_complete(None, t2, result(10, 1000))
self.assertEqual(
rule.on_trial_result(None, t3, result(3, 10)),
TrialScheduler.STOP)
rule.on_trial_result(None, t3, result(3, 10)), TrialScheduler.STOP)
def testMedianStoppingUsesMedian(self):
rule = MedianStoppingRule(grace_period=0, min_samples_required=1)
@@ -124,8 +121,10 @@ class EarlyStoppingSuite(unittest.TestCase):
return TrainingResult(training_iteration=t, neg_mean_loss=rew)
rule = MedianStoppingRule(
grace_period=0, min_samples_required=1,
time_attr='training_iteration', reward_attr='neg_mean_loss')
grace_period=0,
min_samples_required=1,
time_attr='training_iteration',
reward_attr='neg_mean_loss')
t1 = Trial("PPO") # mean is 450, max 900, t_max=10
t2 = Trial("PPO") # mean is 450, max 450, t_max=5
for i in range(10):
@@ -185,7 +184,6 @@ class _MockTrialRunner():
class HyperbandSuite(unittest.TestCase):
def schedulerSetup(self, num_trials):
"""Setup a scheduler and Runner with max Iter = 9
@@ -206,7 +204,10 @@ class HyperbandSuite(unittest.TestCase):
"""Default statistics for HyperBand"""
sched = HyperBandScheduler()
res = {
str(s): {"n": sched._get_n0(s), "r": sched._get_r0(s)}
str(s): {
"n": sched._get_n0(s),
"r": sched._get_r0(s)
}
for s in range(sched._s_max_1)
}
res["max_trials"] = sum(v["n"] for v in res.values())
@@ -298,8 +299,8 @@ class HyperbandSuite(unittest.TestCase):
# Provides results from 0 to 8 in order, keeping last one running
for i, trl in enumerate(trials):
action = sched.on_trial_result(
mock_runner, trl, result(cur_units, i))
action = sched.on_trial_result(mock_runner, trl,
result(cur_units, i))
if i < current_length - 1:
self.assertEqual(action, TrialScheduler.PAUSE)
mock_runner.process_action(trl, action)
@@ -321,8 +322,8 @@ class HyperbandSuite(unittest.TestCase):
# # Provides result in reverse order, killing the last one
cur_units = stats[str(1)]["r"]
for i, trl in reversed(list(enumerate(big_bracket.current_trials()))):
action = sched.on_trial_result(
mock_runner, trl, result(cur_units, i))
action = sched.on_trial_result(mock_runner, trl,
result(cur_units, i))
mock_runner.process_action(trl, action)
self.assertEqual(action, TrialScheduler.STOP)
@@ -338,8 +339,8 @@ class HyperbandSuite(unittest.TestCase):
# # Provides result in reverse order, killing the last one
cur_units = stats[str(0)]["r"]
for i, trl in enumerate(big_bracket.current_trials()):
action = sched.on_trial_result(
mock_runner, trl, result(cur_units, i))
action = sched.on_trial_result(mock_runner, trl,
result(cur_units, i))
mock_runner.process_action(trl, action)
self.assertEqual(action, TrialScheduler.STOP)
@@ -354,14 +355,12 @@ class HyperbandSuite(unittest.TestCase):
mock_runner._launch_trial(t)
sched.on_trial_error(mock_runner, t3)
self.assertEqual(
TrialScheduler.PAUSE,
sched.on_trial_result(
mock_runner, t1, result(stats[str(1)]["r"], 10)))
self.assertEqual(
TrialScheduler.CONTINUE,
sched.on_trial_result(
mock_runner, t2, result(stats[str(1)]["r"], 10)))
self.assertEqual(TrialScheduler.PAUSE,
sched.on_trial_result(mock_runner, t1,
result(stats[str(1)]["r"], 10)))
self.assertEqual(TrialScheduler.CONTINUE,
sched.on_trial_result(mock_runner, t2,
result(stats[str(1)]["r"], 10)))
def testTrialErrored2(self):
"""Check successive halving happened even when last trial failed"""
@@ -371,13 +370,14 @@ class HyperbandSuite(unittest.TestCase):
trials = sched._state["bracket"].current_trials()
for t in trials[:-1]:
mock_runner._launch_trial(t)
sched.on_trial_result(
mock_runner, t, result(stats[str(1)]["r"], 10))
sched.on_trial_result(mock_runner, t, result(
stats[str(1)]["r"], 10))
mock_runner._launch_trial(trials[-1])
sched.on_trial_error(mock_runner, trials[-1])
self.assertEqual(len(sched._state["bracket"].current_trials()),
self.downscale(stats[str(1)]["n"], sched))
self.assertEqual(
len(sched._state["bracket"].current_trials()),
self.downscale(stats[str(1)]["n"], sched))
def testTrialEndedEarly(self):
"""Check successive halving happened even when one trial failed"""
@@ -390,14 +390,12 @@ class HyperbandSuite(unittest.TestCase):
mock_runner._launch_trial(t)
sched.on_trial_complete(mock_runner, t3, result(1, 12))
self.assertEqual(
TrialScheduler.PAUSE,
sched.on_trial_result(
mock_runner, t1, result(stats[str(1)]["r"], 10)))
self.assertEqual(
TrialScheduler.CONTINUE,
sched.on_trial_result(
mock_runner, t2, result(stats[str(1)]["r"], 10)))
self.assertEqual(TrialScheduler.PAUSE,
sched.on_trial_result(mock_runner, t1,
result(stats[str(1)]["r"], 10)))
self.assertEqual(TrialScheduler.CONTINUE,
sched.on_trial_result(mock_runner, t2,
result(stats[str(1)]["r"], 10)))
def testTrialEndedEarly2(self):
"""Check successive halving happened even when last trial failed"""
@@ -407,13 +405,14 @@ class HyperbandSuite(unittest.TestCase):
trials = sched._state["bracket"].current_trials()
for t in trials[:-1]:
mock_runner._launch_trial(t)
sched.on_trial_result(
mock_runner, t, result(stats[str(1)]["r"], 10))
sched.on_trial_result(mock_runner, t, result(
stats[str(1)]["r"], 10))
mock_runner._launch_trial(trials[-1])
sched.on_trial_complete(mock_runner, trials[-1], result(100, 12))
self.assertEqual(len(sched._state["bracket"].current_trials()),
self.downscale(stats[str(1)]["n"], sched))
self.assertEqual(
len(sched._state["bracket"].current_trials()),
self.downscale(stats[str(1)]["n"], sched))
def testAddAfterHalving(self):
stats = self.default_statistics()
@@ -426,8 +425,8 @@ class HyperbandSuite(unittest.TestCase):
mock_runner._launch_trial(t)
for i, t in enumerate(bracket_trials):
action = sched.on_trial_result(
mock_runner, t, result(init_units, i))
action = sched.on_trial_result(mock_runner, t, result(
init_units, i))
self.assertEqual(action, TrialScheduler.CONTINUE)
t = Trial("__fake")
sched.on_trial_add(None, t)
@@ -435,13 +434,13 @@ class HyperbandSuite(unittest.TestCase):
self.assertEqual(len(sched._state["bracket"].current_trials()), 2)
# Make sure that newly added trial gets fair computation (not just 1)
self.assertEqual(
TrialScheduler.CONTINUE,
sched.on_trial_result(mock_runner, t, result(init_units, 12)))
self.assertEqual(TrialScheduler.CONTINUE,
sched.on_trial_result(mock_runner, t,
result(init_units, 12)))
new_units = init_units + int(init_units * sched._eta)
self.assertEqual(
TrialScheduler.PAUSE,
sched.on_trial_result(mock_runner, t, result(new_units, 12)))
self.assertEqual(TrialScheduler.PAUSE,
sched.on_trial_result(mock_runner, t,
result(new_units, 12)))
def testAlternateMetrics(self):
"""Checking that alternate metrics will pass."""
@@ -539,7 +538,6 @@ class _MockTrial(Trial):
class PopulationBasedTestingSuite(unittest.TestCase):
def basicSetup(self, resample_prob=0.0, explore=None):
pbt = PopulationBasedTraining(
time_attr="training_iteration",
@@ -554,9 +552,12 @@ class PopulationBasedTestingSuite(unittest.TestCase):
runner = _MockTrialRunner(pbt)
for i in range(5):
trial = _MockTrial(
i,
{"id_factor": i, "float_factor": 2.0, "const_factor": 3,
"int_factor": 10})
i, {
"id_factor": i,
"float_factor": 2.0,
"const_factor": 3,
"int_factor": 10
})
runner.add_trial(trial)
trial.status = Trial.RUNNING
self.assertEqual(
@@ -570,27 +571,23 @@ class PopulationBasedTestingSuite(unittest.TestCase):
trials = runner.get_trials()
# no checkpoint: haven't hit next perturbation interval yet
self.assertEqual(
pbt.last_scores(trials), [0, 50, 100, 150, 200])
self.assertEqual(pbt.last_scores(trials), [0, 50, 100, 150, 200])
self.assertEqual(
pbt.on_trial_result(runner, trials[0], result(15, 200)),
TrialScheduler.CONTINUE)
self.assertEqual(
pbt.last_scores(trials), [0, 50, 100, 150, 200])
self.assertEqual(pbt.last_scores(trials), [0, 50, 100, 150, 200])
self.assertEqual(pbt._num_checkpoints, 0)
# checkpoint: both past interval and upper quantile
self.assertEqual(
pbt.on_trial_result(runner, trials[0], result(20, 200)),
TrialScheduler.CONTINUE)
self.assertEqual(
pbt.last_scores(trials), [200, 50, 100, 150, 200])
self.assertEqual(pbt.last_scores(trials), [200, 50, 100, 150, 200])
self.assertEqual(pbt._num_checkpoints, 1)
self.assertEqual(
pbt.on_trial_result(runner, trials[1], result(30, 201)),
TrialScheduler.CONTINUE)
self.assertEqual(
pbt.last_scores(trials), [200, 201, 100, 150, 200])
self.assertEqual(pbt.last_scores(trials), [200, 201, 100, 150, 200])
self.assertEqual(pbt._num_checkpoints, 2)
# not upper quantile any more
@@ -608,8 +605,7 @@ class PopulationBasedTestingSuite(unittest.TestCase):
self.assertEqual(
pbt.on_trial_result(runner, trials[0], result(15, -100)),
TrialScheduler.CONTINUE)
self.assertEqual(
pbt.last_scores(trials), [0, 50, 100, 150, 200])
self.assertEqual(pbt.last_scores(trials), [0, 50, 100, 150, 200])
self.assertTrue("@perturbed" not in trials[0].experiment_tag)
self.assertEqual(pbt._num_perturbations, 0)
@@ -617,8 +613,7 @@ class PopulationBasedTestingSuite(unittest.TestCase):
self.assertEqual(
pbt.on_trial_result(runner, trials[0], result(20, -100)),
TrialScheduler.CONTINUE)
self.assertEqual(
pbt.last_scores(trials), [-100, 50, 100, 150, 200])
self.assertEqual(pbt.last_scores(trials), [-100, 50, 100, 150, 200])
self.assertTrue("@perturbed" in trials[0].experiment_tag)
self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"])
self.assertEqual(pbt._num_perturbations, 1)
@@ -627,8 +622,7 @@ class PopulationBasedTestingSuite(unittest.TestCase):
self.assertEqual(
pbt.on_trial_result(runner, trials[2], result(20, 40)),
TrialScheduler.CONTINUE)
self.assertEqual(
pbt.last_scores(trials), [-100, 50, 40, 150, 200])
self.assertEqual(pbt.last_scores(trials), [-100, 50, 40, 150, 200])
self.assertEqual(pbt._num_perturbations, 2)
self.assertIn(trials[0].restored_checkpoint, ["trial_3", "trial_4"])
self.assertTrue("@perturbed" in trials[2].experiment_tag)
@@ -662,7 +656,6 @@ class PopulationBasedTestingSuite(unittest.TestCase):
self.assertEqual(trials[0].config["const_factor"], 3)
def testPerturbationValues(self):
def assertProduces(fn, values):
random.seed(0)
seen = set()
@@ -712,8 +705,7 @@ class PopulationBasedTestingSuite(unittest.TestCase):
self.assertEqual(
pbt.on_trial_result(runner, trials[1], result(20, 1000)),
TrialScheduler.PAUSE)
self.assertEqual(
pbt.last_scores(trials), [0, 1000, 100, 150, 200])
self.assertEqual(pbt.last_scores(trials), [0, 1000, 100, 150, 200])
self.assertEqual(pbt.choose_trial_to_run(runner), trials[0])
def testSchedulesMostBehindTrialToRun(self):
@@ -748,6 +740,7 @@ class PopulationBasedTestingSuite(unittest.TestCase):
new_config["id_factor"] = 42
new_config["float_factor"] = 43
return new_config
pbt, runner = self.basicSetup(resample_prob=0.0, explore=explore)
trials = runner.get_trials()
self.assertEqual(
@@ -774,8 +767,7 @@ class AsyncHyperBandSuite(unittest.TestCase):
return t1, t2
def testAsyncHBOnComplete(self):
scheduler = AsyncHyperBandScheduler(
max_t=10, brackets=1)
scheduler = AsyncHyperBandScheduler(max_t=10, brackets=1)
t1, t2 = self.basicSetup(scheduler)
t3 = Trial("PPO")
scheduler.on_trial_add(None, t3)
@@ -803,8 +795,7 @@ class AsyncHyperBandSuite(unittest.TestCase):
TrialScheduler.STOP)
def testAsyncHBAllCompletes(self):
scheduler = AsyncHyperBandScheduler(
max_t=10, brackets=10)
scheduler = AsyncHyperBandScheduler(max_t=10, brackets=10)
trials = [Trial("PPO") for i in range(10)]
for t in trials:
scheduler.on_trial_add(None, t)
@@ -834,8 +825,10 @@ class AsyncHyperBandSuite(unittest.TestCase):
return TrainingResult(training_iteration=t, neg_mean_loss=rew)
scheduler = AsyncHyperBandScheduler(
grace_period=1, time_attr='training_iteration',
reward_attr='neg_mean_loss', brackets=1)
grace_period=1,
time_attr='training_iteration',
reward_attr='neg_mean_loss',
brackets=1)
t1 = Trial("PPO") # mean is 450, max 900, t_max=10
t2 = Trial("PPO") # mean is 450, max 450, t_max=5
scheduler.on_trial_add(None, t1)
+8 -7
View File
@@ -30,16 +30,15 @@ class TuneServerSuite(unittest.TestCase):
def basicSetup(self):
ray.init(num_cpus=4, num_gpus=1)
port = get_valid_port()
self.runner = TrialRunner(
launch_web_server=True, server_port=port)
self.runner = TrialRunner(launch_web_server=True, server_port=port)
runner = self.runner
kwargs = {
"stopping_criterion": {"training_iteration": 3},
"stopping_criterion": {
"training_iteration": 3
},
"resources": Resources(cpu=1, gpu=1),
}
trials = [
Trial("__fake", **kwargs),
Trial("__fake", **kwargs)]
trials = [Trial("__fake", **kwargs), Trial("__fake", **kwargs)]
for t in trials:
runner.add_trial(t)
client = TuneClient("localhost:{}".format(port))
@@ -61,7 +60,9 @@ class TuneServerSuite(unittest.TestCase):
runner.step()
spec = {
"run": "__fake",
"stop": {"training_iteration": 3},
"stop": {
"training_iteration": 3
},
"trial_resources": dict(cpu=1, gpu=1),
}
client.add_trial("test", spec)
+10 -8
View File
@@ -114,8 +114,8 @@ class Trainable(object):
time_this_iter = time.time() - start
if result.timesteps_this_iter is None:
raise TuneError(
"Must specify timesteps_this_iter in result", result)
raise TuneError("Must specify timesteps_this_iter in result",
result)
self._time_total += time_this_iter
self._timesteps_total += result.timesteps_this_iter
@@ -159,10 +159,10 @@ class Trainable(object):
"""
checkpoint_path = self._save(checkpoint_dir or self.logdir)
pickle.dump(
[self._experiment_id, self._iteration, self._timesteps_total,
self._time_total],
open(checkpoint_path + ".tune_metadata", "wb"))
pickle.dump([
self._experiment_id, self._iteration, self._timesteps_total,
self._time_total
], open(checkpoint_path + ".tune_metadata", "wb"))
return checkpoint_path
def save_to_object(self):
@@ -186,8 +186,10 @@ class Trainable(object):
out = io.BytesIO()
with gzip.GzipFile(fileobj=out, mode="wb") as f:
compressed = pickle.dumps({
"checkpoint_name": os.path.basename(checkpoint_prefix),
"data": data,
"checkpoint_name":
os.path.basename(checkpoint_prefix),
"data":
data,
})
if len(compressed) > 10e6: # getting pretty large
print("Checkpoint size is {} bytes".format(len(compressed)))
+37 -32
View File
@@ -42,12 +42,12 @@ class Resources(
__slots__ = ()
def __new__(cls, cpu, gpu, extra_cpu=0, extra_gpu=0):
return super(Resources, cls).__new__(
cls, cpu, gpu, extra_cpu, extra_gpu)
return super(Resources, cls).__new__(cls, cpu, gpu, extra_cpu,
extra_gpu)
def summary_string(self):
return "{} CPUs, {} GPUs".format(
self.cpu + self.extra_cpu, self.gpu + self.extra_gpu)
return "{} CPUs, {} GPUs".format(self.cpu + self.extra_cpu,
self.gpu + self.extra_gpu)
def cpu_total(self):
return self.cpu + self.extra_cpu
@@ -77,11 +77,17 @@ class Trial(object):
TERMINATED = "TERMINATED"
ERROR = "ERROR"
def __init__(
self, trainable_name, config=None, local_dir=DEFAULT_RESULTS_DIR,
experiment_tag="", resources=Resources(cpu=1, gpu=0),
stopping_criterion=None, checkpoint_freq=0,
restore_path=None, upload_dir=None, max_failures=0):
def __init__(self,
trainable_name,
config=None,
local_dir=DEFAULT_RESULTS_DIR,
experiment_tag="",
resources=Resources(cpu=1, gpu=0),
stopping_criterion=None,
checkpoint_freq=0,
restore_path=None,
upload_dir=None,
max_failures=0):
"""Initialize a new trial.
The args here take the same meaning as the command line flags defined
@@ -166,19 +172,20 @@ class Trial(object):
try:
if error_msg and self.logdir:
self.num_failures += 1
error_file = os.path.join(
self.logdir, "error_{}.txt".format(date_str()))
error_file = os.path.join(self.logdir, "error_{}.txt".format(
date_str()))
with open(error_file, "w") as f:
f.write(error_msg)
self.error_file = error_file
if self.runner:
stop_tasks = []
stop_tasks.append(self.runner.stop.remote())
stop_tasks.append(self.runner.__ray_terminate__.remote(
self.runner._ray_actor_id.id()))
stop_tasks.append(
self.runner.__ray_terminate__.remote(
self.runner._ray_actor_id.id()))
# TODO(ekl) seems like wait hangs when killing actors
_, unfinished = ray.wait(
stop_tasks, num_returns=2, timeout=250)
stop_tasks, num_returns=2, timeout=250)
except Exception:
print("Error stopping runner:", traceback.format_exc())
self.status = Trial.ERROR
@@ -252,12 +259,12 @@ class Trial(object):
return '{} pid={}'.format(hostname, pid)
pieces = [
'{} [{}]'.format(
self._status_string(),
location_string(
self.last_result.hostname, self.last_result.pid)),
'{} s'.format(int(self.last_result.time_total_s)),
'{} ts'.format(int(self.last_result.timesteps_total))]
'{} [{}]'.format(self._status_string(),
location_string(self.last_result.hostname,
self.last_result.pid)),
'{} s'.format(int(self.last_result.time_total_s)), '{} ts'.format(
int(self.last_result.timesteps_total))
]
if self.last_result.episode_reward_mean is not None:
pieces.append('{} rew'.format(
@@ -274,10 +281,8 @@ class Trial(object):
return ', '.join(pieces)
def _status_string(self):
return "{}{}".format(
self.status,
", {} failures: {}".format(self.num_failures, self.error_file)
if self.error_file else "")
return "{}{}".format(self.status, ", {} failures: {}".format(
self.num_failures, self.error_file) if self.error_file else "")
def has_checkpoint(self):
return self._checkpoint_path is not None or \
@@ -335,9 +340,8 @@ class Trial(object):
def update_last_result(self, result, terminate=False):
if terminate:
result = result._replace(done=True)
if self.verbose and (
terminate or
time.time() - self.last_debug > DEBUG_PRINT_INTERVAL):
if self.verbose and (terminate or time.time() - self.last_debug >
DEBUG_PRINT_INTERVAL):
print("TrainingResult for {}:".format(self))
print(" {}".format(pretty_print(result).replace("\n", "\n ")))
self.last_debug = time.time()
@@ -358,8 +362,8 @@ class Trial(object):
prefix="{}_{}".format(
str(self)[:MAX_LEN_IDENTIFIER], date_str()),
dir=self.local_dir)
self.result_logger = UnifiedLogger(
self.config, self.logdir, self.upload_dir)
self.result_logger = UnifiedLogger(self.config, self.logdir,
self.upload_dir)
remote_logdir = self.logdir
def logger_creator(config):
@@ -372,7 +376,8 @@ class Trial(object):
# Logging for trials is handled centrally by TrialRunner, so
# configure the remote runner to use a noop-logger.
self.runner = cls.remote(
config=self.config, registry=ray.tune.registry.get_registry(),
config=self.config,
registry=ray.tune.registry.get_registry(),
logger_creator=logger_creator)
def set_verbose(self, verbose):
@@ -387,8 +392,8 @@ class Trial(object):
def __str__(self):
"""Combines ``env`` with ``trainable_name`` and ``experiment_tag``."""
if "env" in self.config:
identifier = "{}_{}".format(
self.trainable_name, self.config["env"])
identifier = "{}_{}".format(self.trainable_name,
self.config["env"])
else:
identifier = self.trainable_name
if self.experiment_tag:
+26 -30
View File
@@ -13,7 +13,6 @@ from ray.tune.web_server import TuneServer
from ray.tune.trial import Trial, Resources
from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler
MAX_DEBUG_TRIALS = 20
@@ -39,8 +38,11 @@ class TrialRunner(object):
misleading benchmark results.
"""
def __init__(self, scheduler=None, launch_web_server=False,
server_port=TuneServer.DEFAULT_PORT, verbose=True):
def __init__(self,
scheduler=None,
launch_web_server=False,
server_port=TuneServer.DEFAULT_PORT,
verbose=True):
"""Initializes a new TrialRunner.
Args:
@@ -73,9 +75,8 @@ class TrialRunner(object):
"""Returns whether all trials have finished running."""
if self._total_time > self._global_time_limit:
print(
"Exceeded global time limit {} / {}".format(
self._total_time, self._global_time_limit))
print("Exceeded global time limit {} / {}".format(
self._total_time, self._global_time_limit))
return True
for t in self._trials:
@@ -98,12 +99,12 @@ class TrialRunner(object):
for trial in self._trials:
if trial.status == Trial.PENDING:
if not self.has_resources(trial.resources):
raise TuneError((
"Insufficient cluster resources to launch trial: "
"trial requested {} but the cluster only has {} "
"available.").format(
trial.resources.summary_string(),
self._avail_resources.summary_string()))
raise TuneError(
("Insufficient cluster resources to launch trial: "
"trial requested {} but the cluster only has {} "
"available.").format(
trial.resources.summary_string(),
self._avail_resources.summary_string()))
elif trial.status == Trial.PAUSED:
raise TuneError(
"There are paused trials, but no more pending "
@@ -165,24 +166,20 @@ class TrialRunner(object):
for state, trials in sorted(states.items()):
limit = limit_per_state[state]
messages.append("{} trials:".format(state))
for t in sorted(
trials, key=lambda t: t.experiment_tag)[:limit]:
for t in sorted(trials, key=lambda t: t.experiment_tag)[:limit]:
messages.append(" - {}:\t{}".format(t, t.progress_string()))
if len(trials) > limit:
messages.append(" ... {} more not shown".format(
len(trials) - limit))
messages.append(
" ... {} more not shown".format(len(trials) - limit))
return "\n".join(messages) + "\n"
def _debug_messages(self):
messages = ["== Status =="]
messages.append(self._scheduler_alg.debug_string())
if self._resources_initialized:
messages.append(
"Resources used: {}/{} CPUs, {}/{} GPUs".format(
self._committed_resources.cpu,
self._avail_resources.cpu,
self._committed_resources.gpu,
self._avail_resources.gpu))
messages.append("Resources used: {}/{} CPUs, {}/{} GPUs".format(
self._committed_resources.cpu, self._avail_resources.cpu,
self._committed_resources.gpu, self._avail_resources.gpu))
return messages
def has_resources(self, resources):
@@ -190,9 +187,8 @@ class TrialRunner(object):
cpu_avail = self._avail_resources.cpu - self._committed_resources.cpu
gpu_avail = self._avail_resources.gpu - self._committed_resources.gpu
return (
resources.cpu_total() <= cpu_avail and
resources.gpu_total() <= gpu_avail)
return (resources.cpu_total() <= cpu_avail
and resources.gpu_total() <= gpu_avail)
def _get_next_trial(self):
self._update_avail_resources()
@@ -307,8 +303,9 @@ class TrialRunner(object):
self._scheduler_alg.on_trial_remove(self, trial)
elif trial.status is Trial.RUNNING:
# NOTE: There should only be one...
result_id = [rid for rid, t in self._running.items()
if t is trial][0]
result_id = [
rid for rid, t in self._running.items() if t is trial
][0]
self._running.pop(result_id)
try:
result = ray.get(result_id)
@@ -339,9 +336,8 @@ class TrialRunner(object):
def _update_avail_resources(self):
clients = ray.global_state.client_table()
local_schedulers = [
entry for client in clients.values() for entry in client
if (entry['ClientType'] == 'local_scheduler' and not
entry['Deleted'])
entry for client in clients.values() for entry in client if
(entry['ClientType'] == 'local_scheduler' and not entry['Deleted'])
]
num_cpus = sum(ls['CPU'] for ls in local_schedulers)
num_gpus = sum(ls.get('GPU', 0) for ls in local_schedulers)
+4 -4
View File
@@ -99,12 +99,12 @@ class FIFOScheduler(TrialScheduler):
def choose_trial_to_run(self, trial_runner):
for trial in trial_runner.get_trials():
if (trial.status == Trial.PENDING and
trial_runner.has_resources(trial.resources)):
if (trial.status == Trial.PENDING
and trial_runner.has_resources(trial.resources)):
return trial
for trial in trial_runner.get_trials():
if (trial.status == Trial.PAUSED and
trial_runner.has_resources(trial.resources)):
if (trial.status == Trial.PAUSED
and trial_runner.has_resources(trial.resources)):
return trial
return None
+16 -11
View File
@@ -16,7 +16,6 @@ from ray.tune.trial_scheduler import FIFOScheduler
from ray.tune.web_server import TuneServer
from ray.tune.experiment import Experiment
_SCHEDULERS = {
"FIFO": FIFOScheduler,
"MedianStopping": MedianStoppingRule,
@@ -30,13 +29,15 @@ def _make_scheduler(args):
if args.scheduler in _SCHEDULERS:
return _SCHEDULERS[args.scheduler](**args.scheduler_config)
else:
raise TuneError(
"Unknown scheduler: {}, should be one of {}".format(
args.scheduler, _SCHEDULERS.keys()))
raise TuneError("Unknown scheduler: {}, should be one of {}".format(
args.scheduler, _SCHEDULERS.keys()))
def run_experiments(experiments, scheduler=None, with_server=False,
server_port=TuneServer.DEFAULT_PORT, verbose=True):
def run_experiments(experiments,
scheduler=None,
with_server=False,
server_port=TuneServer.DEFAULT_PORT,
verbose=True):
"""Tunes experiments.
Args:
@@ -54,17 +55,21 @@ def run_experiments(experiments, scheduler=None, with_server=False,
scheduler = FIFOScheduler()
runner = TrialRunner(
scheduler, launch_web_server=with_server, server_port=server_port,
scheduler,
launch_web_server=with_server,
server_port=server_port,
verbose=verbose)
exp_list = experiments
if isinstance(experiments, Experiment):
exp_list = [experiments]
elif type(experiments) is dict:
exp_list = [Experiment.from_json(name, spec)
for name, spec in experiments.items()]
exp_list = [
Experiment.from_json(name, spec)
for name, spec in experiments.items()
]
if (type(exp_list) is list and
all(isinstance(exp, Experiment) for exp in exp_list)):
if (type(exp_list) is list
and all(isinstance(exp, Experiment) for exp in exp_list)):
for experiment in exp_list:
scheduler.add_experiment(experiment, runner)
else:
+5 -5
View File
@@ -7,7 +7,6 @@ import base64
import ray
from ray.tune.registry import _to_pinnable, _from_pinnable
_pinned_objects = []
PINNED_OBJECT_PREFIX = "ray.tune.PinnedObject:"
@@ -15,14 +14,15 @@ PINNED_OBJECT_PREFIX = "ray.tune.PinnedObject:"
def pin_in_object_store(obj):
obj_id = ray.put(_to_pinnable(obj))
_pinned_objects.append(ray.get(obj_id))
return "{}{}".format(
PINNED_OBJECT_PREFIX, base64.b64encode(obj_id.id()).decode("utf-8"))
return "{}{}".format(PINNED_OBJECT_PREFIX,
base64.b64encode(obj_id.id()).decode("utf-8"))
def get_pinned_object(pinned_id):
from ray.local_scheduler import ObjectID
return _from_pinnable(ray.get(ObjectID(
base64.b64decode(pinned_id[len(PINNED_OBJECT_PREFIX):]))))
return _from_pinnable(
ray.get(
ObjectID(base64.b64decode(pinned_id[len(PINNED_OBJECT_PREFIX):]))))
if __name__ == '__main__':
+5 -5
View File
@@ -163,8 +163,8 @@ def _generate_variants(spec):
for path, value in grid_vars:
resolved_vars[path] = _get_value(spec, path)
for k, v in resolved.items():
if (k in resolved_vars and v != resolved_vars[k] and
_is_resolved(resolved_vars[k])):
if (k in resolved_vars and v != resolved_vars[k]
and _is_resolved(resolved_vars[k])):
raise ValueError(
"The variable `{}` could not be unambiguously "
"resolved to a single value. Consider simplifying "
@@ -262,16 +262,16 @@ def _unresolved_values(spec):
for k, v in spec.items():
resolved, v = _try_resolve(v)
if not resolved:
found[(k,)] = v
found[(k, )] = v
elif isinstance(v, dict):
# Recurse into a dict
for (path, value) in _unresolved_values(v).items():
found[(k,) + path] = value
found[(k, ) + path] = value
elif isinstance(v, list):
# Recurse into a list
for i, elem in enumerate(v):
for (path, value) in _unresolved_values({i: elem}).items():
found[(k,) + path] = value
found[(k, ) + path] = value
return found
+7 -4
View File
@@ -61,8 +61,10 @@ def _resolve(directory, result_fname):
def load_results_to_df(directory, result_name="result.json"):
exp_directories = [dirpath for dirpath, dirs, files in os.walk(directory)
for f in files if f == result_name]
exp_directories = [
dirpath for dirpath, dirs, files in os.walk(directory) for f in files
if f == result_name
]
data = [_resolve(d, result_name) for d in exp_directories]
data = [d for d in data if d]
return pd.DataFrame(data)
@@ -76,8 +78,9 @@ def generate_plotly_dim_dict(df, field):
dim_dict["values"] = column
elif is_string_dtype(column):
texts = column.unique()
dim_dict["values"] = [np.argwhere(texts == x).flatten()[0]
for x in column]
dim_dict["values"] = [
np.argwhere(texts == x).flatten()[0] for x in column
]
dim_dict["tickvals"] = list(range(len(texts)))
dim_dict["ticktext"] = texts
else:
+19 -19
View File
@@ -39,28 +39,30 @@ class TuneClient(object):
def get_all_trials(self):
"""Returns a list of all trials (trial_id, config, status)."""
return self._get_response(
{"command": TuneClient.GET_LIST})
return self._get_response({"command": TuneClient.GET_LIST})
def get_trial(self, trial_id):
"""Returns the last result for queried trial."""
return self._get_response(
{"command": TuneClient.GET_TRIAL,
"trial_id": trial_id})
return self._get_response({
"command": TuneClient.GET_TRIAL,
"trial_id": trial_id
})
def add_trial(self, name, trial_spec):
"""Adds a trial of `name` with configurations."""
# TODO(rliaw): have better way of specifying a new trial
return self._get_response(
{"command": TuneClient.ADD,
"name": name,
"spec": trial_spec})
return self._get_response({
"command": TuneClient.ADD,
"name": name,
"spec": trial_spec
})
def stop_trial(self, trial_id):
"""Requests to stop trial."""
return self._get_response(
{"command": TuneClient.STOP,
"trial_id": trial_id})
return self._get_response({
"command": TuneClient.STOP,
"trial_id": trial_id
})
def _get_response(self, data):
payload = json.dumps(data).encode()
@@ -71,7 +73,6 @@ class TuneClient(object):
def RunnerHandler(runner):
class Handler(SimpleHTTPRequestHandler):
def do_GET(self):
content_len = int(self.headers.get('Content-Length'), 0)
raw_body = self.rfile.read(content_len)
@@ -82,8 +83,7 @@ def RunnerHandler(runner):
else:
self.send_response(400)
self.end_headers()
self.wfile.write(json.dumps(
response).encode())
self.wfile.write(json.dumps(response).encode())
def trial_info(self, trial):
if trial.last_result:
@@ -112,8 +112,9 @@ def RunnerHandler(runner):
response = {}
try:
if command == TuneClient.GET_LIST:
response["trials"] = [self.trial_info(t)
for t in runner.get_trials()]
response["trials"] = [
self.trial_info(t) for t in runner.get_trials()
]
elif command == TuneClient.GET_TRIAL:
trial = get_trial()
response["trial_info"] = self.trial_info(trial)
@@ -147,8 +148,7 @@ class TuneServer(threading.Thread):
self._port = port if port else self.DEFAULT_PORT
address = ('localhost', self._port)
print("Starting Tune Server...")
self._server = HTTPServer(
address, RunnerHandler(runner))
self._server = HTTPServer(address, RunnerHandler(runner))
self.start()
def run(self):
+11 -8
View File
@@ -46,7 +46,10 @@ def format_error_message(exception_message, task_exception=False):
return "\n".join(lines)
def push_error_to_driver(redis_client, error_type, message, driver_id=None,
def push_error_to_driver(redis_client,
error_type,
message,
driver_id=None,
data=None):
"""Push an error message to the driver to be printed in the background.
@@ -64,9 +67,11 @@ def push_error_to_driver(redis_client, error_type, message, driver_id=None,
driver_id = DRIVER_ID_LENGTH * b"\x00"
error_key = ERROR_KEY_PREFIX + driver_id + b":" + _random_string()
data = {} if data is None else data
redis_client.hmset(error_key, {"type": error_type,
"message": message,
"data": data})
redis_client.hmset(error_key, {
"type": error_type,
"message": message,
"data": data
})
redis_client.rpush("ErrorKeys", error_key)
@@ -134,10 +139,8 @@ def hex_to_binary(hex_identifier):
return binascii.unhexlify(hex_identifier)
FunctionProperties = collections.namedtuple("FunctionProperties",
["num_return_vals",
"resources",
"max_calls"])
FunctionProperties = collections.namedtuple(
"FunctionProperties", ["num_return_vals", "resources", "max_calls"])
"""FunctionProperties: A named tuple storing remote functions information."""
+398 -338
View File
File diff suppressed because it is too large Load Diff
+42 -25
View File
@@ -8,34 +8,51 @@ import traceback
import ray
import ray.actor
parser = argparse.ArgumentParser(description=("Parse addresses for the worker "
"to connect to."))
parser.add_argument("--node-ip-address", required=True, type=str,
help="the ip address of the worker's node")
parser.add_argument("--redis-address", required=True, type=str,
help="the address to use for Redis")
parser.add_argument("--object-store-name", required=True, type=str,
help="the object store's name")
parser.add_argument("--object-store-manager-name", required=False, type=str,
help="the object store manager's name")
parser.add_argument("--local-scheduler-name", required=False, type=str,
help="the local scheduler's name")
parser.add_argument("--raylet-name", required=False, type=str,
help="the raylet's name")
parser = argparse.ArgumentParser(
description=("Parse addresses for the worker "
"to connect to."))
parser.add_argument(
"--node-ip-address",
required=True,
type=str,
help="the ip address of the worker's node")
parser.add_argument(
"--redis-address",
required=True,
type=str,
help="the address to use for Redis")
parser.add_argument(
"--object-store-name",
required=True,
type=str,
help="the object store's name")
parser.add_argument(
"--object-store-manager-name",
required=False,
type=str,
help="the object store manager's name")
parser.add_argument(
"--local-scheduler-name",
required=False,
type=str,
help="the local scheduler's name")
parser.add_argument(
"--raylet-name", required=False, type=str, help="the raylet's name")
if __name__ == "__main__":
args = parser.parse_args()
info = {"node_ip_address": args.node_ip_address,
"redis_address": args.redis_address,
"store_socket_name": args.object_store_name,
"manager_socket_name": args.object_store_manager_name,
"local_scheduler_socket_name": args.local_scheduler_name,
"raylet_socket_name": args.raylet_name}
info = {
"node_ip_address": args.node_ip_address,
"redis_address": args.redis_address,
"store_socket_name": args.object_store_name,
"manager_socket_name": args.object_store_manager_name,
"local_scheduler_socket_name": args.local_scheduler_name,
"raylet_socket_name": args.raylet_name
}
ray.worker.connect(info, mode=ray.WORKER_MODE,
use_raylet=(args.raylet_name is not None))
ray.worker.connect(
info, mode=ray.WORKER_MODE, use_raylet=(args.raylet_name is not None))
error_explanation = """
This error is unexpected and should not have happened. Somehow a worker
@@ -54,8 +71,8 @@ if __name__ == "__main__":
traceback_str = traceback.format_exc() + error_explanation
# Create a Redis client.
redis_client = ray.services.create_redis_client(args.redis_address)
ray.utils.push_error_to_driver(redis_client, "worker_crash",
traceback_str, driver_id=None)
ray.utils.push_error_to_driver(
redis_client, "worker_crash", traceback_str, driver_id=None)
# TODO(rkn): Note that if the worker was in the middle of executing
# a task, then any worker or driver that is blocking in a get call
# and waiting for the output of that task will hang. We need to
+42 -41
View File
@@ -18,13 +18,11 @@ import setuptools.command.build_ext as _build_ext
ray_files = [
"ray/core/src/common/thirdparty/redis/src/redis-server",
"ray/core/src/common/redis_module/libray_redis_module.so",
"ray/core/src/plasma/plasma_store",
"ray/core/src/plasma/plasma_manager",
"ray/core/src/plasma/plasma_store", "ray/core/src/plasma/plasma_manager",
"ray/core/src/local_scheduler/local_scheduler",
"ray/core/src/local_scheduler/liblocal_scheduler_library.so",
"ray/core/src/global_scheduler/global_scheduler",
"ray/core/src/ray/raylet/raylet_monitor",
"ray/core/src/ray/raylet/raylet",
"ray/core/src/ray/raylet/raylet_monitor", "ray/core/src/ray/raylet/raylet",
"ray/WebUI.ipynb"
]
@@ -35,14 +33,14 @@ ray_ui_files = [
"ray/core/src/catapult_files/trace_viewer_full.html"
]
ray_autoscaler_files = [
"ray/autoscaler/aws/example-full.yaml"
]
ray_autoscaler_files = ["ray/autoscaler/aws/example-full.yaml"]
if "RAY_USE_NEW_GCS" in os.environ and os.environ["RAY_USE_NEW_GCS"] == "on":
ray_files += ["ray/core/src/credis/build/src/libmember.so",
"ray/core/src/credis/build/src/libmaster.so",
"ray/core/src/credis/redis/src/redis-server"]
ray_files += [
"ray/core/src/credis/build/src/libmember.so",
"ray/core/src/credis/build/src/libmaster.so",
"ray/core/src/credis/redis/src/redis-server"
]
# The UI files are mandatory if the INCLUDE_UI environment variable equals 1.
# Otherwise, they are optional.
@@ -54,9 +52,8 @@ else:
optional_ray_files += ray_autoscaler_files
extras = {
"rllib": [
"tensorflow", "pyyaml", "gym[atari]", "opencv-python",
"lz4", "scipy"]
"rllib":
["tensorflow", "pyyaml", "gym[atari]", "opencv-python", "lz4", "scipy"]
}
@@ -73,8 +70,9 @@ class build_ext(_build_ext.build_ext):
pyarrow_files = [
os.path.join("ray/pyarrow_files/pyarrow", filename)
for filename in os.listdir("./ray/pyarrow_files/pyarrow")
if not os.path.isdir(os.path.join("ray/pyarrow_files/pyarrow",
filename))]
if not os.path.isdir(
os.path.join("ray/pyarrow_files/pyarrow", filename))
]
files_to_include = ray_files + pyarrow_files
@@ -84,8 +82,8 @@ class build_ext(_build_ext.build_ext):
generated_python_directory = "ray/core/generated"
for filename in os.listdir(generated_python_directory):
if filename[-3:] == ".py":
self.move_file(os.path.join(generated_python_directory,
filename))
self.move_file(
os.path.join(generated_python_directory, filename))
# Try to copy over the optional files.
for filename in optional_ray_files:
@@ -114,27 +112,30 @@ class BinaryDistribution(Distribution):
return True
setup(name="ray",
# The version string is also in __init__.py. TODO(pcm): Fix this.
version="0.4.0",
packages=find_packages(),
cmdclass={"build_ext": build_ext},
# The BinaryDistribution argument triggers build_ext.
distclass=BinaryDistribution,
install_requires=["numpy",
"funcsigs",
"click",
"colorama",
"psutil",
"pytest",
"pyyaml",
"redis",
# The six module is required by pyarrow.
"six >= 1.0.0",
"flatbuffers"],
setup_requires=["cython >= 0.23"],
extras_require=extras,
entry_points={"console_scripts": ["ray=ray.scripts.scripts:main"]},
include_package_data=True,
zip_safe=False,
license="Apache 2.0")
setup(
name="ray",
# The version string is also in __init__.py. TODO(pcm): Fix this.
version="0.4.0",
packages=find_packages(),
cmdclass={"build_ext": build_ext},
# The BinaryDistribution argument triggers build_ext.
distclass=BinaryDistribution,
install_requires=[
"numpy",
"funcsigs",
"click",
"colorama",
"psutil",
"pytest",
"pyyaml",
"redis",
# The six module is required by pyarrow.
"six >= 1.0.0",
"flatbuffers"
],
setup_requires=["cython >= 0.23"],
extras_require=extras,
entry_points={"console_scripts": ["ray=ray.scripts.scripts:main"]},
include_package_data=True,
zip_safe=False,
license="Apache 2.0")