[xray] Implement timeline and profiling API. (#2306)

* Add profile table and store profiling information there.

* Code for dumping timeline.

* Improve color scheme.

* Push timeline events on driver only for raylet.

* Improvements to profiling and timeline visualization

* Some linting

* Small fix.

* Linting

* Propagate node IP address through profiling events.

* Fix test.

* object_id.hex() should return byte string in python 2.

* Include gcs.fbs in node_manager.fbs.

* Remove flatbuffer definition duplication.

* Decode to unicode in Python 3 and bytes in Python 2.

* Minor

* Submit profile events in a batch. Revert some CMake changes.

* Fix

* Workaround test failure.

* Fix linting

* Linting

* Don't return anything from chrome_tracing_dump when filename is provided.

* Remove some redundancy from profile table.

* Linting

* Move TODOs out of docstring.

* Minor
This commit is contained in:
Robert Nishihara
2018-07-04 23:23:48 -07:00
committed by Philipp Moritz
parent 8e687cbc98
commit b90e551b41
27 changed files with 777 additions and 147 deletions
+2 -2
View File
@@ -48,7 +48,7 @@ except ImportError as e:
from ray.local_scheduler import ObjectID, _config # noqa: E402
from ray.worker import (error_info, init, connect, disconnect, get, put, wait,
remote, log_event, log_span, flush_log, get_gpu_ids,
remote, profile, flush_profile_data, get_gpu_ids,
get_resource_ids, get_webui_url,
register_custom_serializer) # noqa: E402
from ray.worker import (SCRIPT_MODE, WORKER_MODE, PYTHON_MODE,
@@ -65,7 +65,7 @@ __version__ = "0.4.0"
__all__ = [
"error_info", "init", "connect", "disconnect", "get", "put", "wait",
"remote", "log_event", "log_span", "flush_log", "actor", "method",
"remote", "profile", "flush_profile_data", "actor", "method",
"get_gpu_ids", "get_resource_ids", "get_webui_url",
"register_custom_serializer", "SCRIPT_MODE", "WORKER_MODE", "PYTHON_MODE",
"SILENT_MODE", "global_state", "ObjectID", "_config", "__version__"
+4 -3
View File
@@ -14,6 +14,7 @@ import ray.ray_constants as ray_constants
import ray.signature as signature
import ray.worker
from ray.utils import (
decode,
_random_string,
check_oversized_pickle,
is_cython,
@@ -292,10 +293,10 @@ def fetch_and_register_actor(actor_class_key, worker):
"checkpoint_interval", "actor_method_names"
])
class_name = class_name.decode("ascii")
module = module.decode("ascii")
class_name = decode(class_name)
module = decode(module)
checkpoint_interval = int(checkpoint_interval)
actor_method_names = json.loads(actor_method_names.decode("ascii"))
actor_method_names = json.loads(decode(actor_method_names))
# Create a temporary actor with some temporary methods so that if the actor
# fails to be unpickled, the temporary actor can be used (just to produce
+202 -43
View File
@@ -357,8 +357,6 @@ class GlobalState(object):
task_table_message = ray.gcs_utils.Task.GetRootAsTask(
gcs_entries.Entries(i), 0)
task_table_message = ray.gcs_utils.Task.GetRootAsTask(
gcs_entries.Entries(0), 0)
execution_spec = task_table_message.TaskExecutionSpec()
task_spec = task_table_message.TaskSpecification()
task_spec = ray.local_scheduler.task_from_string(task_spec)
@@ -487,11 +485,10 @@ class GlobalState(object):
decode(value))
elif client_info[b"client_type"] == b"local_scheduler":
# The remaining fields are resource types.
client_info_parsed[field.decode("ascii")] = float(
client_info_parsed[decode(field)] = float(
decode(value))
else:
client_info_parsed[field.decode("ascii")] = decode(
value)
client_info_parsed[decode(field)] = decode(value)
node_info[node_ip_address].append(client_info_parsed)
@@ -513,21 +510,19 @@ class GlobalState(object):
gcs_entry.Entries(i), 0))
resources = {
client.ResourcesTotalLabel(i).decode("ascii"):
decode(client.ResourcesTotalLabel(i)):
client.ResourcesTotalCapacity(i)
for i in range(client.ResourcesTotalLabelLength())
}
node_info.append({
"ClientID": ray.utils.binary_to_hex(client.ClientId()),
"IsInsertion": client.IsInsertion(),
"NodeManagerAddress": client.NodeManagerAddress().decode(
"ascii"),
"NodeManagerAddress": decode(client.NodeManagerAddress()),
"NodeManagerPort": client.NodeManagerPort(),
"ObjectManagerPort": client.ObjectManagerPort(),
"ObjectStoreSocketName": client.ObjectStoreSocketName()
.decode("ascii"),
"RayletSocketName": client.RayletSocketName().decode(
"ascii"),
"ObjectStoreSocketName": decode(
client.ObjectStoreSocketName()),
"RayletSocketName": decode(client.RayletSocketName()),
"Resources": resources
})
return node_info
@@ -543,14 +538,14 @@ class GlobalState(object):
ip_filename_file = {}
for filename in relevant_files:
filename = filename.decode("ascii")
filename = decode(filename)
filename_components = filename.split(":")
ip_addr = filename_components[1]
file = self.redis_client.lrange(filename, 0, -1)
file_str = []
for x in file:
y = x.decode("ascii")
y = decode(x)
file_str.append(y)
if ip_addr not in ip_filename_file:
@@ -630,7 +625,7 @@ class GlobalState(object):
event_log_set, **params)
for (event, score) in event_list:
event_dict = json.loads(event.decode())
event_dict = json.loads(decode(event))
task_id = ""
for event in event_dict:
if "task_id" in event[3]:
@@ -643,31 +638,29 @@ class GlobalState(object):
heap_size += 1
for event in event_dict:
if event[1] == "ray:get_task" and event[2] == 1:
if event[1] == "get_task" and event[2] == 1:
task_info[task_id]["get_task_start"] = event[0]
if event[1] == "ray:get_task" and event[2] == 2:
if event[1] == "get_task" and event[2] == 2:
task_info[task_id]["get_task_end"] = event[0]
if (event[1] == "ray:import_remote_function"
if (event[1] == "register_remote_function"
and event[2] == 1):
task_info[task_id]["import_remote_start"] = event[0]
if (event[1] == "ray:import_remote_function"
if (event[1] == "register_remote_function"
and event[2] == 2):
task_info[task_id]["import_remote_end"] = event[0]
if event[1] == "ray:acquire_lock" and event[2] == 1:
task_info[task_id]["acquire_lock_start"] = event[0]
if event[1] == "ray:acquire_lock" and event[2] == 2:
task_info[task_id]["acquire_lock_end"] = event[0]
if event[1] == "ray:task:get_arguments" and event[2] == 1:
if (event[1] == "task:deserialize_arguments"
and event[2] == 1):
task_info[task_id]["get_arguments_start"] = event[0]
if event[1] == "ray:task:get_arguments" and event[2] == 2:
if (event[1] == "task:deserialize_arguments"
and event[2] == 2):
task_info[task_id]["get_arguments_end"] = event[0]
if event[1] == "ray:task:execute" and event[2] == 1:
if event[1] == "task:execute" and event[2] == 1:
task_info[task_id]["execute_start"] = event[0]
if event[1] == "ray:task:execute" and event[2] == 2:
if event[1] == "task:execute" and event[2] == 2:
task_info[task_id]["execute_end"] = event[0]
if event[1] == "ray:task:store_outputs" and event[2] == 1:
if event[1] == "task:store_outputs" and event[2] == 1:
task_info[task_id]["store_outputs_start"] = event[0]
if event[1] == "ray:task:store_outputs" and event[2] == 2:
if event[1] == "task:store_outputs" and event[2] == 2:
task_info[task_id]["store_outputs_end"] = event[0]
if "worker_id" in event[3]:
task_info[task_id]["worker_id"] = event[3]["worker_id"]
@@ -685,6 +678,173 @@ class GlobalState(object):
return task_info
def _profile_table(self, component_id):
"""Get the profile events for a given component.
Args:
component_id: An identifier for a component.
Returns:
A list of the profile events for the specified process.
"""
# TODO(rkn): This method should support limiting the number of log
# events and should also support returning a window of events.
message = self._execute_command(component_id, "RAY.TABLE_LOOKUP",
ray.gcs_utils.TablePrefix.PROFILE, "",
component_id.id())
if message is None:
return []
gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
message, 0)
profile_events = []
for i in range(gcs_entries.EntriesLength()):
profile_table_message = (
ray.gcs_utils.ProfileTableData.GetRootAsProfileTableData(
gcs_entries.Entries(i), 0))
component_type = decode(profile_table_message.ComponentType())
component_id = binary_to_hex(profile_table_message.ComponentId())
node_ip_address = decode(profile_table_message.NodeIpAddress())
for j in range(profile_table_message.ProfileEventsLength()):
profile_event_message = profile_table_message.ProfileEvents(j)
profile_event = {
"event_type": decode(profile_event_message.EventType()),
"component_id": component_id,
"node_ip_address": node_ip_address,
"component_type": component_type,
"start_time": profile_event_message.StartTime(),
"end_time": profile_event_message.EndTime(),
"extra_data": json.loads(
decode(profile_event_message.ExtraData())),
}
profile_events.append(profile_event)
return profile_events
def profile_table(self):
if not self.use_raylet:
raise Exception("This method is only supported in the raylet "
"code path.")
profile_table_keys = self._keys(
ray.gcs_utils.TablePrefix_PROFILE_string + "*")
component_identifiers_binary = [
key[len(ray.gcs_utils.TablePrefix_PROFILE_string):]
for key in profile_table_keys
]
return {
binary_to_hex(component_id): self._profile_table(
binary_to_object_id(component_id))
for component_id in component_identifiers_binary
}
def chrome_tracing_dump(self,
include_task_data=False,
filename=None,
open_browser=False):
"""Return a list of profiling events that can viewed as a timeline.
To view this information as a timeline, simply dump it as a json file
using json.dumps, and then load go to chrome://tracing in the Chrome
web browser and load the dumped file. Make sure to enable "Flow events"
in the "View Options" menu.
Args:
include_task_data: If true, we will include more task metadata such
as the task specifications in the json.
filename: If a filename is provided, the timeline is dumped to that
file.
open_browser: If true, we will attempt to automatically open the
timeline visualization in Chrome.
Returns:
If filename is not provided, this returns a list of profiling
events. Each profile event is a dictionary.
"""
# TODO(rkn): Support including the task specification data in the
# timeline.
# TODO(rkn): This should support viewing just a window of time or a
# limited number of events.
if include_task_data:
raise NotImplementedError("This flag has not been implented yet.")
if open_browser:
raise NotImplementedError("This flag has not been implented yet.")
profile_table = self.profile_table()
all_events = []
# Colors are specified at
# https://github.com/catapult-project/catapult/blob/master/tracing/tracing/base/color_scheme.html. # noqa: E501
default_color_mapping = defaultdict(
lambda: "generic_work", {
"get_task": "cq_build_abandoned",
"task": "rail_response",
"task:deserialize_arguments": "rail_load",
"task:execute": "rail_animation",
"task:store_outputs": "rail_idle",
"wait_for_function": "detailed_memory_dump",
"ray.get": "good",
"ray.put": "terrible",
"ray.wait": "vsync_highlight_color",
"submit_task": "background_memory_dump",
"fetch_and_run_function": "detailed_memory_dump",
"register_remote_function": "detailed_memory_dump",
})
def seconds_to_microseconds(time_in_seconds):
time_in_microseconds = 10**6 * time_in_seconds
return time_in_microseconds
for component_id_hex, component_events in profile_table.items():
for event in component_events:
new_event = {
# The category of the event.
"cat": event["event_type"],
# The string displayed on the event.
"name": event["event_type"],
# The identifier for the group of rows that the event
# appears in.
"pid": event["node_ip_address"],
# The identifier for the row that the event appears in.
"tid": event["component_type"] + ":" +
event["component_id"],
# The start time in microseconds.
"ts": seconds_to_microseconds(event["start_time"]),
# The duration in microseconds.
"dur": seconds_to_microseconds(event["end_time"] -
event["start_time"]),
# What is this?
"ph": "X",
# This is the name of the color to display the box in.
"cname": default_color_mapping[event["event_type"]],
# The extra user-defined data.
"args": event["extra_data"],
}
# Modify the json with the additional user-defined extra data.
# This can be used to add fields or override existing fields.
if "cname" in event["extra_data"]:
new_event["cname"] = event["extra_data"]["cname"]
if "name" in event["extra_data"]:
new_event["name"] = event["extra_data"]["name"]
all_events.append(new_event)
if filename is not None:
with open(filename, "w") as outfile:
json.dump(all_events, outfile)
else:
return all_events
def dump_catapult_trace(self,
path,
task_info,
@@ -1047,21 +1207,20 @@ class GlobalState(object):
worker_id = binary_to_hex(worker_key[len("Workers:"):])
workers_data[worker_id] = {
"local_scheduler_socket": (
worker_info[b"local_scheduler_socket"].decode("ascii")),
"node_ip_address": (worker_info[b"node_ip_address"]
.decode("ascii")),
"plasma_manager_socket": (worker_info[b"plasma_manager_socket"]
.decode("ascii")),
"plasma_store_socket": (worker_info[b"plasma_store_socket"]
.decode("ascii"))
"local_scheduler_socket": (decode(
worker_info[b"local_scheduler_socket"])),
"node_ip_address": decode(worker_info[b"node_ip_address"]),
"plasma_manager_socket": decode(
worker_info[b"plasma_manager_socket"]),
"plasma_store_socket": decode(
worker_info[b"plasma_store_socket"])
}
if b"stderr_file" in worker_info:
workers_data[worker_id]["stderr_file"] = (
worker_info[b"stderr_file"].decode("ascii"))
workers_data[worker_id]["stderr_file"] = decode(
worker_info[b"stderr_file"])
if b"stdout_file" in worker_info:
workers_data[worker_id]["stdout_file"] = (
worker_info[b"stdout_file"].decode("ascii"))
workers_data[worker_id]["stdout_file"] = decode(
worker_info[b"stdout_file"])
return workers_data
def actors(self):
@@ -1155,8 +1314,8 @@ class GlobalState(object):
error_data = ray.gcs_utils.ErrorTableData.GetRootAsErrorTableData(
gcs_entries.Entries(i), 0)
error_message = {
"type": error_data.Type().decode("ascii"),
"message": error_data.ErrorMessage().decode("ascii"),
"type": decode(error_data.Type()),
"message": decode(error_data.ErrorMessage()),
"timestamp": error_data.Timestamp(),
}
error_messages.append(error_message)
+5 -3
View File
@@ -22,6 +22,7 @@ import ray.core.generated.ErrorTableData
from ray.core.generated.GcsTableEntry import GcsTableEntry
from ray.core.generated.ClientTableData import ClientTableData
from ray.core.generated.ErrorTableData import ErrorTableData
from ray.core.generated.ProfileTableData import ProfileTableData
from ray.core.generated.HeartbeatTableData import HeartbeatTableData
from ray.core.generated.ObjectTableData import ObjectTableData
from ray.core.generated.ray.protocol.Task import Task
@@ -33,9 +34,9 @@ __all__ = [
"SubscribeToNotificationsReply", "ResultTableReply",
"TaskExecutionDependencies", "TaskReply", "DriverTableMessage",
"LocalSchedulerInfoMessage", "SubscribeToDBClientTableReply", "TaskInfo",
"GcsTableEntry", "ClientTableData", "ErrorTableData", "HeartbeatTableData",
"ObjectTableData", "Task", "TablePrefix", "TablePubsub",
"construct_error_message"
"GcsTableEntry", "ClientTableData", "ErrorTableData", "ProfileTableData",
"HeartbeatTableData", "ObjectTableData", "Task", "TablePrefix",
"TablePubsub", "construct_error_message"
]
# These prefixes must be kept up-to-date with the definitions in
@@ -53,6 +54,7 @@ FUNCTION_PREFIX = "RemoteFunction:"
TablePrefix_RAYLET_TASK_string = "RAYLET_TASK"
TablePrefix_OBJECT_string = "OBJECT"
TablePrefix_ERROR_INFO_string = "ERROR_INFO"
TablePrefix_PROFILE_string = "PROFILE"
def construct_error_message(error_type, message, timestamp):
+2 -1
View File
@@ -10,6 +10,7 @@ import time
from ray.services import get_ip_address
from ray.services import get_port
from ray.services import logger
import ray.utils
class LogMonitor(object):
@@ -70,7 +71,7 @@ class LogMonitor(object):
if len(new_lines) > 0:
self.log_files[log_filename] += new_lines
redis_key = "LOGFILE:{}:{}".format(
self.node_ip_address, log_filename.decode("ascii"))
self.node_ip_address, ray.utils.decode(log_filename))
self.redis_client.rpush(redis_key, *new_lines)
# Pass if we already failed to open the log file.
+2 -1
View File
@@ -10,6 +10,7 @@ import subprocess
import ray.services as services
from ray.autoscaler.commands import (create_or_update_cluster,
teardown_cluster, get_head_node_ip)
import ray.utils
def check_no_existing_redis_clients(node_ip_address, redis_client):
@@ -31,7 +32,7 @@ def check_no_existing_redis_clients(node_ip_address, redis_client):
if deleted:
continue
if info[b"node_ip_address"].decode("ascii") == node_ip_address:
if ray.utils.decode(info[b"node_ip_address"]) == node_ip_address:
raise Exception("This Redis instance is already connected to "
"clients with this IP address.")
+3 -3
View File
@@ -386,7 +386,7 @@ def check_version_info(redis_client):
if redis_reply is None:
return
true_version_info = tuple(json.loads(redis_reply.decode("ascii")))
true_version_info = tuple(json.loads(ray.utils.decode(redis_reply)))
version_info = _compute_version_info()
if version_info != true_version_info:
node_ip_address = ray.services.get_node_ip_address()
@@ -776,7 +776,7 @@ def start_ui(redis_address, stdout_file=None, stderr_file=None, cleanup=True):
new_env["REDIS_ADDRESS"] = redis_address
# We generate the token used for authentication ourselves to avoid
# querying the jupyter server.
token = binascii.hexlify(os.urandom(24)).decode("ascii")
token = ray.utils.decode(binascii.hexlify(os.urandom(24)))
command = [
"jupyter", "notebook", "--no-browser", "--port={}".format(port),
"--NotebookApp.iopub_data_rate_limit=10000000000",
@@ -1373,7 +1373,7 @@ def start_ray_processes(address_info=None,
redis_client = redis.StrictRedis(
host=redis_ip_address, port=redis_port)
redis_shards = redis_client.lrange("RedisShards", start=0, end=-1)
redis_shards = [shard.decode("ascii") for shard in redis_shards]
redis_shards = [ray.utils.decode(shard) for shard in redis_shards]
address_info["redis_shards"] = redis_shards
# Start the log monitor, if necessary.
+2
View File
@@ -170,6 +170,8 @@ def random_string():
def decode(byte_str):
"""Make this unicode in Python 3, otherwise leave it as bytes."""
if not isinstance(byte_str, bytes):
raise ValueError("The argument must be a bytes object.")
if sys.version_info >= (3, 0):
return byte_str.decode("ascii")
else:
+207 -64
View File
@@ -562,7 +562,7 @@ class Worker(object):
Returns:
The return object IDs for this task.
"""
with log_span("ray:submit_task", worker=self):
with profile("submit_task", worker=self):
check_main_thread()
if actor_id is None:
assert actor_handle_id is None
@@ -867,7 +867,7 @@ class Worker(object):
# Get task arguments from the object store.
try:
with log_span("ray:task:get_arguments", worker=self):
with profile("task:deserialize_arguments", worker=self):
arguments = self._get_arguments_for_execution(
function_name, args)
except (RayGetError, RayGetArgumentError) as e:
@@ -882,7 +882,7 @@ class Worker(object):
# Execute the task.
try:
with log_span("ray:task:execute", worker=self):
with profile("task:execute", worker=self):
if task.actor_id().id() == NIL_ACTOR_ID:
outputs = function_executor(*arguments)
else:
@@ -901,7 +901,7 @@ class Worker(object):
# Store the outputs in the local object store.
try:
with log_span("ray:task:store_outputs", worker=self):
with profile("task:store_outputs", worker=self):
# If this is an actor task, then the last object ID returned by
# the task is a dummy output, not returned by the function
# itself. Decrement to get the correct number of return values.
@@ -976,7 +976,7 @@ class Worker(object):
# Wait until the function to be executed has actually been registered
# on this worker. We will push warnings to the user if we spend too
# long in this loop.
with log_span("ray:wait_for_function", worker=self):
with profile("wait_for_function", worker=self):
self._wait_for_function(function_id, driver_id)
# Execute the task.
@@ -984,22 +984,26 @@ class Worker(object):
# warning to the user if we are waiting too long to acquire the lock
# because that may indicate that the system is hanging, and it'd be
# good to know where the system is hanging.
log(event_type="ray:acquire_lock", kind=LOG_SPAN_START, worker=self)
with self.lock:
log(event_type="ray:acquire_lock", kind=LOG_SPAN_END, worker=self)
function_name = (self.function_execution_info[driver_id][
function_id.id()]).function_name
contents = {
"function_name": function_name,
"task_id": task.task_id().hex(),
"worker_id": binary_to_hex(self.worker_id)
}
with log_span("ray:task", contents=contents, worker=self):
if not self.use_raylet:
extra_data = {
"function_name": function_name,
"task_id": task.task_id().hex(),
"worker_id": binary_to_hex(self.worker_id)
}
else:
extra_data = {
"name": function_name,
"task_id": task.task_id().hex()
}
with profile("task", extra_data=extra_data, worker=self):
self._process_task(task)
# Push all of the log events to the global state store.
flush_log()
flush_profile_data()
# Increase the task execution counter.
self.num_task_executions[driver_id][function_id.id()] += 1
@@ -1017,7 +1021,7 @@ class Worker(object):
Returns:
A task from the local scheduler.
"""
with log_span("ray:get_task", worker=self):
with profile("get_task", worker=self):
task = self.local_scheduler_client.get_task()
# Automatically restrict the GPUs available to this task.
@@ -1103,7 +1107,7 @@ def _webui_url_helper(client):
The URL of the web UI as a string.
"""
result = client.hmget("webui", "url")[0]
return result.decode("ascii") if result is not None else result
return ray.utils.decode(result) if result is not None else result
def get_webui_url():
@@ -1194,9 +1198,9 @@ def error_info(worker=global_worker):
if error_applies_to_driver(error_key, worker=worker):
error_contents = worker.redis_client.hgetall(error_key)
error_contents = {
"type": error_contents[b"type"].decode("ascii"),
"message": error_contents[b"message"].decode("ascii"),
"data": error_contents[b"data"].decode("ascii")
"type": ray.utils.decode(error_contents[b"type"]),
"message": ray.utils.decode(error_contents[b"message"]),
"data": ray.utils.decode(error_contents[b"data"])
}
errors.append(error_contents)
@@ -1296,13 +1300,14 @@ def get_address_info_from_redis_helper(redis_address,
assert b"ray_client_id" in info
assert b"node_ip_address" in info
assert b"client_type" in info
client_node_ip_address = info[b"node_ip_address"].decode("ascii")
client_node_ip_address = ray.utils.decode(info[b"node_ip_address"])
if (client_node_ip_address == node_ip_address or
(client_node_ip_address == "127.0.0.1"
and redis_ip_address == ray.services.get_node_ip_address())):
if info[b"client_type"].decode("ascii") == "plasma_manager":
if ray.utils.decode(info[b"client_type"]) == "plasma_manager":
plasma_managers.append(info)
elif info[b"client_type"].decode("ascii") == "local_scheduler":
elif (ray.utils.decode(
info[b"client_type"]) == "local_scheduler"):
local_schedulers.append(info)
# Make sure that we got at least one plasma manager and local
# scheduler.
@@ -1311,16 +1316,16 @@ def get_address_info_from_redis_helper(redis_address,
# Build the address information.
object_store_addresses = []
for manager in plasma_managers:
address = manager[b"manager_address"].decode("ascii")
address = ray.utils.decode(manager[b"manager_address"])
port = services.get_port(address)
object_store_addresses.append(
services.ObjectStoreAddress(
name=manager[b"store_socket_name"].decode("ascii"),
manager_name=manager[b"manager_socket_name"].decode(
"ascii"),
name=ray.utils.decode(manager[b"store_socket_name"]),
manager_name=ray.utils.decode(
manager[b"manager_socket_name"]),
manager_port=port))
scheduler_names = [
scheduler[b"local_scheduler_socket_name"].decode("ascii")
ray.utils.decode(scheduler[b"local_scheduler_socket_name"])
for scheduler in local_schedulers
]
client_info = {
@@ -1343,8 +1348,8 @@ def get_address_info_from_redis_helper(redis_address,
for client_message in clients:
client = ray.gcs_utils.ClientTableData.GetRootAsClientTableData(
client_message, 0)
client_node_ip_address = client.NodeManagerAddress().decode(
"ascii")
client_node_ip_address = ray.utils.decode(
client.NodeManagerAddress())
if (client_node_ip_address == node_ip_address or
(client_node_ip_address == "127.0.0.1"
and redis_ip_address == ray.services.get_node_ip_address())):
@@ -1352,12 +1357,12 @@ def get_address_info_from_redis_helper(redis_address,
object_store_addresses = [
services.ObjectStoreAddress(
name=raylet.ObjectStoreSocketName().decode("ascii"),
name=ray.utils.decode(raylet.ObjectStoreSocketName()),
manager_name=None,
manager_port=None) for raylet in raylets
]
raylet_socket_names = [
raylet.RayletSocketName().decode("ascii") for raylet in raylets
ray.utils.decode(raylet.RayletSocketName()) for raylet in raylets
]
return {
"node_ip_address": node_ip_address,
@@ -1807,6 +1812,21 @@ def custom_excepthook(type, value, tb):
sys.excepthook = custom_excepthook
def _flush_profile_events(worker):
"""Drivers run this as a thread to flush profile data in the background."""
# Note(rkn): This is run on a background thread in the driver. It uses the
# local scheduler client. This should be ok because it doesn't read from
# the local scheduler client and we have the GIL here. However, if either
# of those things changes, then we could run into issues.
try:
while True:
time.sleep(1)
flush_profile_data(worker=worker)
except AttributeError:
# This is to suppress errors that occur at shutdown.
pass
def print_error_messages_raylet(worker):
"""Print error messages in the background on the driver.
@@ -1858,7 +1878,7 @@ def print_error_messages_raylet(worker):
if job_id not in [worker.task_driver_id.id(), NIL_JOB_ID]:
continue
error_message = error_data.ErrorMessage().decode("ascii")
error_message = ray.utils.decode(error_data.ErrorMessage())
if error_message not in old_error_messages:
logger.error(error_message)
@@ -1900,8 +1920,8 @@ def print_error_messages(worker):
error_keys = worker.redis_client.lrange("ErrorKeys", 0, -1)
for error_key in error_keys:
if error_applies_to_driver(error_key, worker=worker):
error_message = worker.redis_client.hget(
error_key, "message").decode("ascii")
error_message = ray.utils.decode(
worker.redis_client.hget(error_key, "message"))
if error_message not in old_error_messages:
logger.error(error_message)
old_error_messages.add(error_message)
@@ -1915,8 +1935,8 @@ def print_error_messages(worker):
for error_key in worker.redis_client.lrange(
"ErrorKeys", num_errors_received, -1):
if error_applies_to_driver(error_key, worker=worker):
error_message = worker.redis_client.hget(
error_key, "message").decode("ascii")
error_message = ray.utils.decode(
worker.redis_client.hget(error_key, "message"))
if error_message not in old_error_messages:
logger.error(error_message)
old_error_messages.add(error_message)
@@ -1939,9 +1959,9 @@ def fetch_and_register_remote_function(key, worker=global_worker):
"module", "resources", "max_calls"
])
function_id = ray.ObjectID(function_id_str)
function_name = function_name.decode("ascii")
function_name = ray.utils.decode(function_name)
max_calls = int(max_calls)
module = module.decode("ascii")
module = ray.utils.decode(module)
# This is a placeholder in case the function can't be unpickled. This will
# be overwritten if the function is successfully registered.
@@ -2031,15 +2051,18 @@ def import_thread(worker, mode):
# Handle the driver case first.
if mode != WORKER_MODE:
if key.startswith(b"FunctionsToRun"):
fetch_and_execute_function_to_run(key, worker=worker)
with profile("fetch_and_run_function", worker=worker):
fetch_and_execute_function_to_run(key, worker=worker)
# Continue because FunctionsToRun are the only things that the
# driver should import.
continue
if key.startswith(b"RemoteFunction"):
fetch_and_register_remote_function(key, worker=worker)
with profile("register_remote_function", worker=worker):
fetch_and_register_remote_function(key, worker=worker)
elif key.startswith(b"FunctionsToRun"):
fetch_and_execute_function_to_run(key, worker=worker)
with profile("fetch_and_run_function", worker=worker):
fetch_and_execute_function_to_run(key, worker=worker)
elif key.startswith(b"ActorClass"):
# Keep track of the fact that this actor class has been
# exported so that we know it is safe to turn this worker into
@@ -2063,9 +2086,8 @@ def import_thread(worker, mode):
# Handle the driver case first.
if mode != WORKER_MODE:
if key.startswith(b"FunctionsToRun"):
with log_span(
"ray:import_function_to_run",
worker=worker):
with profile(
"fetch_and_run_function", worker=worker):
fetch_and_execute_function_to_run(
key, worker=worker)
# Continue because FunctionsToRun are the only things
@@ -2073,13 +2095,12 @@ def import_thread(worker, mode):
continue
if key.startswith(b"RemoteFunction"):
with log_span(
"ray:import_remote_function", worker=worker):
with profile(
"register_remote_function", worker=worker):
fetch_and_register_remote_function(
key, worker=worker)
elif key.startswith(b"FunctionsToRun"):
with log_span(
"ray:import_function_to_run", worker=worker):
with profile("fetch_and_run_function", worker=worker):
fetch_and_execute_function_to_run(
key, worker=worker)
elif key.startswith(b"ActorClass"):
@@ -2333,6 +2354,13 @@ def connect(info,
t.daemon = True
t.start()
if mode in [SCRIPT_MODE, SILENT_MODE] and worker.use_raylet:
t = threading.Thread(target=_flush_profile_events, args=(worker, ))
# Making the thread a daemon causes it to exit when the main thread
# exits.
t.daemon = True
t.start()
if mode in [SCRIPT_MODE, SILENT_MODE]:
# Add the directory containing the script that is running to the Python
# paths of the workers. Also add the current directory. Note that this
@@ -2526,7 +2554,8 @@ class RayLogSpan(object):
def __enter__(self):
"""Log the beginning of a span event."""
log(event_type=self.event_type,
_log(
event_type=self.event_type,
contents=self.contents,
kind=LOG_SPAN_START,
worker=self.worker)
@@ -2534,11 +2563,13 @@ class RayLogSpan(object):
def __exit__(self, type, value, tb):
"""Log the end of a span event. Log any exception that occurred."""
if type is None:
log(event_type=self.event_type,
_log(
event_type=self.event_type,
kind=LOG_SPAN_END,
worker=self.worker)
else:
log(event_type=self.event_type,
_log(
event_type=self.event_type,
contents={
"type": str(type),
"value": value,
@@ -2548,19 +2579,109 @@ class RayLogSpan(object):
worker=self.worker)
def log_span(event_type, contents=None, worker=global_worker):
return RayLogSpan(event_type, contents=contents, worker=worker)
class RayLogSpanRaylet(object):
"""An object used to enable logging a span of events with a with statement.
Attributes:
event_type (str): The type of the event being logged.
contents: Additional information to log.
"""
def __init__(self, event_type, extra_data=None, worker=global_worker):
"""Initialize a RayLogSpan object."""
self.event_type = event_type
self.extra_data = extra_data if extra_data is not None else {}
self.worker = worker
def set_attribute(self, key, value):
"""Add a key-value pair to the extra_data dict.
This can be used to add attributes that are not available when
ray.profile was called.
Args:
key: The attribute name.
value: The attribute value.
"""
if not isinstance(key, str) or not isinstance(value, str):
raise ValueError("The extra_data argument must be a "
"dictionary mapping strings to strings.")
self.extra_data[key] = value
def __enter__(self):
"""Log the beginning of a span event.
Returns:
The object itself is returned so that if the block is opened using
"with ray.profile(...) as prof:", we can call
"prof.set_attribute" inside the block.
"""
self.start_time = time.time()
return self
def __exit__(self, type, value, tb):
"""Log the end of a span event. Log any exception that occurred."""
for key, value in self.extra_data.items():
if not isinstance(key, str) or not isinstance(value, str):
raise ValueError("The extra_data argument must be a "
"dictionary mapping strings to strings.")
event = {
"event_type": self.event_type,
"start_time": self.start_time,
"end_time": time.time(),
"extra_data": json.dumps(self.extra_data),
}
if type is not None:
event["extra_data"] = json.dumps({
"type": str(type),
"value": str(value),
"traceback": str(traceback.format_exc()),
})
self.worker.events.append(event)
def log_event(event_type, contents=None, worker=global_worker):
log(event_type, kind=LOG_POINT, contents=contents, worker=worker)
def profile(event_type, extra_data=None, worker=global_worker):
"""Profile a span of time so that it appears in the timeline visualization.
This function can be used as follows (both on the driver or within a task).
with ray.profile("custom event", extra_data={'key': 'value'}):
# Do some computation here.
Optionally, a dictionary can be passed as the "extra_data" argument, and
it can have keys "name" and "cname" if you want to override the default
timeline display text and box color. Other values will appear at the bottom
of the chrome tracing GUI when you click on the box corresponding to this
profile span.
Args:
event_type: A string describing the type of the event.
extra_data: This must be a dictionary mapping strings to strings. This
data will be added to the json objects that are used to populate
the timeline, so if you want to set a particular color, you can
simply set the "cname" attribute to an appropriate color.
Similarly, if you set the "name" attribute, then that will set the
text displayed on the box in the timeline.
Returns:
An object that can profile a span of time via a "with" statement.
"""
if not worker.use_raylet:
return RayLogSpan(event_type, contents=extra_data, worker=worker)
else:
return RayLogSpanRaylet(
event_type, extra_data=extra_data, worker=worker)
def log(event_type, kind, contents=None, worker=global_worker):
def _log(event_type, kind, contents=None, worker=global_worker):
"""Log an event to the global state store.
This adds the event to a buffer of events locally. The buffer can be
flushed and written to the global state store by calling flush_log().
flushed and written to the global state store by calling
flush_profile_data().
Args:
event_type (str): The type of the event.
@@ -2571,6 +2692,9 @@ def log(event_type, kind, contents=None, worker=global_worker):
time, and it is LOG_SPAN_END if we are finishing logging a span of
time.
"""
if worker.use_raylet:
raise Exception(
"This method is not supported in the raylet code path.")
# TODO(rkn): This code currently takes around half a microsecond. Since we
# call it tens of times per task, this adds up. We will need to redo the
# logging code, perhaps in C.
@@ -2584,13 +2708,32 @@ def log(event_type, kind, contents=None, worker=global_worker):
worker.events.append((time.time(), event_type, kind, contents))
def flush_log(worker=global_worker):
"""Send the logged worker events to the global state store."""
event_log_key = b"event_log:" + worker.worker_id
event_log_value = json.dumps(worker.events)
# TODO(rkn): Support calling this function in the middle of a task, and also
# call this periodically in the background from the driver.
def flush_profile_data(worker=global_worker):
"""Push the logged profiling data to the global control store.
By default, profiling information for a given task won't appear in the
timeline until after the task has completed. For very long-running tasks,
we may want profiling information to appear more quickly. In such cases,
this function can be called. Note that as an alternative, we could start
a thread in the background on workers that calls this automatically.
"""
if not worker.use_raylet:
event_log_key = b"event_log:" + worker.worker_id
event_log_value = json.dumps(worker.events)
worker.local_scheduler_client.log_event(event_log_key, event_log_value,
time.time())
else:
if worker.mode == WORKER_MODE:
component_type = "worker"
else:
component_type = "driver"
worker.local_scheduler_client.push_profile_events(
component_type, ray.ObjectID(worker.worker_id),
worker.node_ip_address, worker.events)
worker.events = []
@@ -2611,7 +2754,7 @@ def get(object_ids, worker=global_worker):
A Python object or a list of Python objects.
"""
worker.check_connected()
with log_span("ray:get", worker=worker):
with profile("ray.get", worker=worker):
check_main_thread()
if worker.mode == PYTHON_MODE:
@@ -2644,7 +2787,7 @@ def put(value, worker=global_worker):
The object ID assigned to this value.
"""
worker.check_connected()
with log_span("ray:put", worker=worker):
with profile("ray.put", worker=worker):
check_main_thread()
if worker.mode == PYTHON_MODE:
@@ -2702,7 +2845,7 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
type(object_id)))
worker.check_connected()
with log_span("ray:wait", worker=worker):
with profile("ray.wait", worker=worker):
check_main_thread()
# When Ray is run in PYTHON_MODE, all functions are run immediately,