diff --git a/.style.yapf b/.style.yapf new file mode 100644 index 000000000..c782c475e --- /dev/null +++ b/.style.yapf @@ -0,0 +1,190 @@ +[style] +# Align closing bracket with visual indentation. +align_closing_bracket_with_visual_indent=True + +# Allow dictionary keys to exist on multiple lines. For example: +# +# x = { +# ('this is the first element of a tuple', +# 'this is the second element of a tuple'): +# value, +# } +allow_multiline_dictionary_keys=False + +# Allow lambdas to be formatted on more than one line. +allow_multiline_lambdas=False + +# Insert a blank line before a class-level docstring. +blank_line_before_class_docstring=False + +# Insert a blank line before a 'def' or 'class' immediately nested +# within another 'def' or 'class'. For example: +# +# class Foo: +# # <------ this blank line +# def method(): +# ... +blank_line_before_nested_class_or_def=False + +# Do not split consecutive brackets. Only relevant when +# dedent_closing_brackets is set. For example: +# +# call_func_that_takes_a_dict( +# { +# 'key1': 'value1', +# 'key2': 'value2', +# } +# ) +# +# would reformat to: +# +# call_func_that_takes_a_dict({ +# 'key1': 'value1', +# 'key2': 'value2', +# }) +coalesce_brackets=False + +# The column limit. +column_limit=79 + +# Indent width used for line continuations. +continuation_indent_width=4 + +# Put closing brackets on a separate line, dedented, if the bracketed +# expression can't fit in a single line. Applies to all kinds of brackets, +# including function definitions and calls. For example: +# +# config = { +# 'key1': 'value1', +# 'key2': 'value2', +# } # <--- this bracket is dedented and on a separate line +# +# time_series = self.remote_client.query_entity_counters( +# entity='dev3246.region1', +# key='dns.query_latency_tcp', +# transform=Transformation.AVERAGE(window=timedelta(seconds=60)), +# start_ts=now()-timedelta(days=3), +# end_ts=now(), +# ) # <--- this bracket is dedented and on a separate line +dedent_closing_brackets=False + +# Place each dictionary entry onto its own line. +each_dict_entry_on_separate_line=True + +# The regex for an i18n comment. The presence of this comment stops +# reformatting of that line, because the comments are required to be +# next to the string they translate. +i18n_comment= + +# The i18n function call names. The presence of this function stops +# reformattting on that line, because the string it has cannot be moved +# away from the i18n comment. +i18n_function_call= + +# Indent the dictionary value if it cannot fit on the same line as the +# dictionary key. For example: +# +# config = { +# 'key1': +# 'value1', +# 'key2': value1 + +# value2, +# } +indent_dictionary_value=False + +# The number of columns to use for indentation. +indent_width=4 + +# Join short lines into one line. E.g., single line 'if' statements. +join_multiple_lines=True + +# Do not include spaces around selected binary operators. For example: +# +# 1 + 2 * 3 - 4 / 5 +# +# will be formatted as follows when configured with a value "*,/": +# +# 1 + 2*3 - 4/5 +# +no_spaces_around_selected_binary_operators=set([]) + +# Use spaces around default or named assigns. +spaces_around_default_or_named_assign=False + +# Use spaces around the power operator. +spaces_around_power_operator=False + +# The number of spaces required before a trailing comment. +spaces_before_comment=2 + +# Insert a space between the ending comma and closing bracket of a list, +# etc. +space_between_ending_comma_and_closing_bracket=True + +# Split before arguments if the argument list is terminated by a +# comma. +split_arguments_when_comma_terminated=False + +# Set to True to prefer splitting before '&', '|' or '^' rather than +# after. +split_before_bitwise_operator=True + +# Split before a dictionary or set generator (comp_for). For example, note +# the split before the 'for': +# +# foo = { +# variable: 'Hello world, have a nice day!' +# for variable in bar if variable != 42 +# } +split_before_dict_set_generator=True + +# If an argument / parameter list is going to be split, then split before +# the first argument. +split_before_first_argument=False + +# Set to True to prefer splitting before 'and' or 'or' rather than +# after. +split_before_logical_operator=True + +# Split named assignments onto individual lines. +split_before_named_assigns=True + +# The penalty for splitting right after the opening bracket. +split_penalty_after_opening_bracket=30 + +# The penalty for splitting the line after a unary operator. +split_penalty_after_unary_operator=10000 + +# The penalty for splitting right before an if expression. +split_penalty_before_if_expr=0 + +# The penalty of splitting the line around the '&', '|', and '^' +# operators. +split_penalty_bitwise_operator=300 + +# The penalty for characters over the column limit. +split_penalty_excess_character=4500 + +# The penalty incurred by adding a line split to the unwrapped line. The +# more line splits added the higher the penalty. +split_penalty_for_added_line_split=30 + +# The penalty of splitting a list of "import as" names. For example: +# +# from a_very_long_or_indented_module_name_yada_yad import (long_argument_1, +# long_argument_2, +# long_argument_3) +# +# would reformat to something like: +# +# from a_very_long_or_indented_module_name_yada_yad import ( +# long_argument_1, long_argument_2, long_argument_3) +split_penalty_import_names=0 + +# The penalty of splitting the line around the 'and' and 'or' +# operators. +split_penalty_logical_operator=300 + +# Use the Tab character for indentation. +use_tabs=False + diff --git a/.travis.yml b/.travis.yml index 619284ed7..49d049516 100644 --- a/.travis.yml +++ b/.travis.yml @@ -117,5 +117,6 @@ script: - python test/component_failures_test.py - python test/multi_node_test.py - python test/recursion_test.py + - python test/monitor_test.py - python -m pytest python/ray/rllib/test/test_catalog.py diff --git a/python/ray/experimental/state.py b/python/ray/experimental/state.py index b9fe9a563..e3fd4a4fa 100644 --- a/python/ray/experimental/state.py +++ b/python/ray/experimental/state.py @@ -52,12 +52,18 @@ TASK_STATUS_MAPPING = { class GlobalState(object): """A class used to interface with the Ray control state. + # TODO(zongheng): In the future move this to use Ray's redis module in the + # backend to cut down on # of request RPCs. + Attributes: - redis_client: The redis client used to query the redis server. + redis_client: The redis client used to query the redis server. """ def __init__(self): """Create a GlobalState object.""" + # The redis server storing metadata, such as function table, client + # table, log files, event logs, workers/actions info. self.redis_client = None + # A list of redis shards, storing the object table & task table. self.redis_clients = None def _check_connected(self): diff --git a/python/ray/monitor.py b/python/ray/monitor.py index 2616b2d46..f480ba182 100644 --- a/python/ray/monitor.py +++ b/python/ray/monitor.py @@ -3,22 +3,22 @@ from __future__ import division from __future__ import print_function import argparse -from collections import Counter import json import logging -import redis import time +from collections import Counter, defaultdict import ray import ray.utils -from ray.services import get_ip_address, get_port -from ray.utils import binary_to_object_id, binary_to_hex, hex_to_binary -from ray.worker import NIL_ACTOR_ID - +import redis # Import flatbuffer bindings. -from ray.core.generated.SubscribeToDBClientTableReply \ - import SubscribeToDBClientTableReply from ray.core.generated.DriverTableMessage import DriverTableMessage +from ray.core.generated.SubscribeToDBClientTableReply import \ + SubscribeToDBClientTableReply +from ray.core.generated.TaskInfo import TaskInfo +from ray.services import get_ip_address, get_port +from ray.utils import binary_to_hex, binary_to_object_id, hex_to_binary +from ray.worker import NIL_ACTOR_ID # These variables must be kept in sync with the C codebase. # common/common.h @@ -26,17 +26,24 @@ HEARTBEAT_TIMEOUT_MILLISECONDS = 100 NUM_HEARTBEATS_TIMEOUT = 100 DB_CLIENT_ID_SIZE = 20 NIL_ID = b"\xff" * DB_CLIENT_ID_SIZE + # common/task.h TASK_STATUS_LOST = 32 + # common/state/redis.cc PLASMA_MANAGER_HEARTBEAT_CHANNEL = b"plasma_managers" DRIVER_DEATH_CHANNEL = b"driver_deaths" + # common/redis_module/ray_redis_module.cc -OBJECT_PREFIX = "OL:" -DB_CLIENT_PREFIX = "CL:" +OBJECT_INFO_PREFIX = b"OI:" +OBJECT_LOCATION_PREFIX = b"OL:" +TASK_TABLE_PREFIX = b"TT:" +DB_CLIENT_PREFIX = b"CL:" DB_CLIENT_TABLE_NAME = b"db_clients" + # local_scheduler/local_scheduler.h LOCAL_SCHEDULER_CLIENT_TYPE = b"local_scheduler" + # plasma/plasma_manager.cc PLASMA_MANAGER_CLIENT_TYPE = b"plasma_manager" @@ -69,12 +76,13 @@ class Monitor(object): dead_plasma_managers: A set of the plasma manager IDs of all the plasma managers that were up at one point and have died since then. """ + def __init__(self, redis_address, redis_port): # Initialize the Redis clients. self.state = ray.experimental.state.GlobalState() self.state._initialize_global_state(redis_address, redis_port) - self.redis = redis.StrictRedis(host=redis_address, port=redis_port, - db=0) + self.redis = redis.StrictRedis( + host=redis_address, port=redis_port, db=0) # TODO(swang): Update pubsub client to use ray.experimental.state once # subscriptions are implemented there. self.subscribe_client = self.redis.pubsub() @@ -109,8 +117,9 @@ class Monitor(object): info["local_scheduler_id"] in self.dead_local_schedulers): # Choose a new local scheduler to run the actor. local_scheduler_id = ray.utils.select_local_scheduler( - info["driver_id"], self.state.local_schedulers(), - info["num_gpus"], self.redis) + info["driver_id"], + self.state.local_schedulers(), info["num_gpus"], + self.redis) import sys sys.stdout.flush() # The new local scheduler should not be the same as the old @@ -121,8 +130,9 @@ class Monitor(object): # Announce to all of the local schedulers that the actor should # be recreated on this new local scheduler. ray.utils.publish_actor_creation( - hex_to_binary(actor_id), hex_to_binary(info["driver_id"]), - local_scheduler_id, True, self.redis) + hex_to_binary(actor_id), + hex_to_binary(info["driver_id"]), local_scheduler_id, True, + self.redis) log.info("Actor {} for driver {} was on dead local scheduler " "{}. It is being recreated on local scheduler {}" .format(actor_id, info["driver_id"], @@ -160,7 +170,7 @@ class Monitor(object): # The dummy object should exist on at most one plasma # manager, the manager associated with the local scheduler # that died. - assert(len(manager_ids) <= 1) + assert len(manager_ids) <= 1 # Remove the dummy object from the plasma manager # associated with the dead local scheduler, if any. for manager in manager_ids: @@ -175,7 +185,8 @@ class Monitor(object): # task as lost. key = binary_to_object_id(hex_to_binary(task_id)) ok = self.state._execute_command( - key, "RAY.TASK_TABLE_UPDATE", hex_to_binary(task_id), + key, "RAY.TASK_TABLE_UPDATE", + hex_to_binary(task_id), ray.experimental.state.TASK_STATUS_LOST, NIL_ID) if ok != b"OK": log.warn("Failed to update lost task for dead scheduler.") @@ -238,7 +249,7 @@ class Monitor(object): log.debug("Subscribed to {}, data was {}".format(channel, data)) self.subscribed[channel] = True - def db_client_notification_handler(self, channel, data): + def db_client_notification_handler(self, unused_channel, data): """Handle a notification from the db_client table from Redis. This handler processes notifications from the db_client table. @@ -247,9 +258,8 @@ class Monitor(object): the associated state in the state tables should be handled by the caller. """ - notification_object = (SubscribeToDBClientTableReply - .GetRootAsSubscribeToDBClientTableReply(data, - 0)) + notification_object = (SubscribeToDBClientTableReply. + GetRootAsSubscribeToDBClientTableReply(data, 0)) db_client_id = binary_to_hex(notification_object.DbClientId()) client_type = notification_object.ClientType() is_insertion = notification_object.IsInsertion() @@ -271,7 +281,7 @@ class Monitor(object): # already dead. del self.live_plasma_managers[db_client_id] - def plasma_manager_heartbeat_handler(self, channel, data): + def plasma_manager_heartbeat_handler(self, unused_channel, data): """Handle a plasma manager heartbeat from Redis. This resets the number of heartbeats that we've missed from this plasma @@ -283,7 +293,134 @@ class Monitor(object): # manager. self.live_plasma_managers[db_client_id] = 0 - def driver_removed_handler(self, channel, data): + def _entries_for_driver_in_shard(self, driver_id, redis_shard_index): + """Collect IDs of control-state entries for a driver from a shard. + + Args: + driver_id: The ID of the driver. + redis_shard_index: The index of the Redis shard to query. + + Returns: + Lists of IDs: (returned_object_ids, task_ids, put_objects). The + first two are relevant to the driver and are safe to delete. + The last contains all "put" objects in this redis shard; each + element is an (object_id, corresponding task_id) pair. + """ + # TODO(zongheng): consider adding save & restore functionalities. + redis = self.state.redis_clients[redis_shard_index] + task_table_infos = {} # task id -> TaskInfo messages + + # Scan the task table & filter to get the list of tasks belong to this + # driver. Use a cursor in order not to block the redis shards. + for key in redis.scan_iter(match=TASK_TABLE_PREFIX + b"*"): + entry = redis.hgetall(key) + task_info = TaskInfo.GetRootAsTaskInfo(entry[b"TaskSpec"], 0) + if driver_id != task_info.DriverId(): + # Ignore tasks that aren't from this driver. + continue + task_table_infos[task_info.TaskId()] = task_info + + # Get the list of objects returned by these tasks. Note these might + # not belong to this redis shard. + returned_object_ids = [] + for task_info in task_table_infos.values(): + returned_object_ids.extend([ + task_info.Returns(i) for i in range(task_info.ReturnsLength()) + ]) + + # Also record all the ray.put()'d objects. + put_objects = [] + for key in redis.scan_iter(match=OBJECT_INFO_PREFIX + b"*"): + entry = redis.hgetall(key) + if entry[b"is_put"] == "0": + continue + object_id = key.split(OBJECT_INFO_PREFIX)[1] + task_id = entry[b"task"] + put_objects.append((object_id, task_id)) + + return returned_object_ids, task_table_infos.keys(), put_objects + + def _clean_up_entries_from_shard(self, object_ids, task_ids, shard_index): + redis = self.state.redis_clients[shard_index] + # Clean up (in the future, save) entries for non-empty objects. + object_ids_locs = set() + object_ids_infos = set() + for object_id in object_ids: + # OL. + obj_loc = redis.zrange(OBJECT_LOCATION_PREFIX + object_id, 0, -1) + if obj_loc: + object_ids_locs.add(object_id) + # OI. + obj_info = redis.hgetall(OBJECT_INFO_PREFIX + object_id) + if obj_info: + object_ids_infos.add(object_id) + + # Form the redis keys to delete. + keys = [TASK_TABLE_PREFIX + k for k in task_ids] + keys.extend([OBJECT_LOCATION_PREFIX + k for k in object_ids_locs]) + keys.extend([OBJECT_INFO_PREFIX + k for k in object_ids_infos]) + + if not keys: + return + # Remove with best effort. + num_deleted = redis.delete(*keys) + log.info( + "Removed {} dead redis entries of the driver from redis shard {}.". + format(num_deleted, shard_index)) + if num_deleted != len(keys): + log.warning( + "Failed to remove {} relevant redis entries" + " from redis shard {}.".format(len(keys) - num_deleted)) + + def _clean_up_entries_for_driver(self, driver_id): + """Remove this driver's object/task entries from all redis shards. + + Specifically, removes control-state entries of: + * all objects (OI and OL entries) created by `ray.put()` from the + driver + * all tasks belonging to the driver. + """ + # TODO(zongheng): handle function_table, client_table, log_files -- + # these are in the metadata redis server, not in the shards. + driver_object_ids = [] + driver_task_ids = [] + all_put_objects = [] + + # Collect relevant ids. + # TODO(zongheng): consider parallelizing this loop. + for shard_index in range(len(self.state.redis_clients)): + returned_object_ids, task_ids, put_objects = \ + self._entries_for_driver_in_shard(driver_id, shard_index) + driver_object_ids.extend(returned_object_ids) + driver_task_ids.extend(task_ids) + all_put_objects.extend(put_objects) + + # For the put objects, keep those from relevant tasks. + driver_task_ids_set = set(driver_task_ids) + for object_id, task_id in all_put_objects: + if task_id in driver_task_ids_set: + driver_object_ids.append(object_id) + + # Partition IDs and distribute to shards. + object_ids_per_shard = defaultdict(list) + task_ids_per_shard = defaultdict(list) + + def ToShardIndex(index): + return binary_to_object_id(index).redis_shard_hash() % len( + self.state.redis_clients) + + for object_id in driver_object_ids: + object_ids_per_shard[ToShardIndex(object_id)].append(object_id) + for task_id in driver_task_ids: + task_ids_per_shard[ToShardIndex(task_id)].append(task_id) + + # TODO(zongheng): consider parallelizing this loop. + for shard_index in range(len(self.state.redis_clients)): + self._clean_up_entries_from_shard( + object_ids_per_shard[shard_index], + task_ids_per_shard[shard_index], shard_index) + + def driver_removed_handler(self, unused_channel, data): """Handle a notification that a driver has been removed. This releases any GPU resources that were reserved for that driver in @@ -291,8 +428,8 @@ class Monitor(object): """ message = DriverTableMessage.GetRootAsDriverTableMessage(data, 0) driver_id = message.DriverId() - log.info("Driver {} has been removed." - .format(binary_to_hex(driver_id))) + log.info( + "Driver {} has been removed.".format(binary_to_hex(driver_id))) # Get a list of the local schedulers. client_table = ray.global_state.client_table() @@ -302,6 +439,8 @@ class Monitor(object): if client["ClientType"] == "local_scheduler": local_schedulers.append(client) + self._clean_up_entries_for_driver(driver_id) + # Release any GPU resources that have been reserved for this driver in # Redis. for local_scheduler in local_schedulers: @@ -321,8 +460,8 @@ class Monitor(object): result = pipe.hget(local_scheduler_id, "gpus_in_use") - gpus_in_use = (dict() if result is None - else json.loads(result)) + gpus_in_use = (dict() if result is None else + json.loads(result)) driver_id_hex = binary_to_hex(driver_id) if driver_id_hex in gpus_in_use: @@ -345,9 +484,9 @@ class Monitor(object): continue log.info("Driver {} is returning GPU IDs {} to local " - "scheduler {}.".format(binary_to_hex(driver_id), - num_gpus_returned, - local_scheduler_id)) + "scheduler {}.".format( + binary_to_hex(driver_id), num_gpus_returned, + local_scheduler_id)) def process_messages(self): """Process all messages ready in the subscription channels. @@ -371,22 +510,23 @@ class Monitor(object): # to an initial subscription request. message_handler = self.subscribe_handler elif channel == PLASMA_MANAGER_HEARTBEAT_CHANNEL: - assert(self.subscribed[channel]) + assert self.subscribed[channel] # The message was a heartbeat from a plasma manager. message_handler = self.plasma_manager_heartbeat_handler elif channel == DB_CLIENT_TABLE_NAME: - assert(self.subscribed[channel]) + assert self.subscribed[channel] # The message was a notification from the db_client table. message_handler = self.db_client_notification_handler elif channel == DRIVER_DEATH_CHANNEL: - assert(self.subscribed[channel]) + assert self.subscribed[channel] # The message was a notification that a driver was removed. + log.info("message-handler: driver_removed_handler") message_handler = self.driver_removed_handler else: raise Exception("This code should be unreachable.") # Call the handler. - assert(message_handler is not None) + assert (message_handler is not None) message_handler(channel, data) def run(self): @@ -439,8 +579,8 @@ class Monitor(object): # Handle plasma managers that timed out during this round. plasma_manager_ids = list(self.live_plasma_managers.keys()) for plasma_manager_id in plasma_manager_ids: - if ((self.live_plasma_managers - [plasma_manager_id]) >= NUM_HEARTBEATS_TIMEOUT): + if ((self.live_plasma_managers[plasma_manager_id]) >= + NUM_HEARTBEATS_TIMEOUT): log.warn("Timed out {}".format(PLASMA_MANAGER_CLIENT_TYPE)) # Remove the plasma manager from the managers whose # heartbeats we're tracking. @@ -465,8 +605,11 @@ class Monitor(object): if __name__ == "__main__": parser = argparse.ArgumentParser(description=("Parse Redis server for the " "monitor to connect to.")) - parser.add_argument("--redis-address", required=True, type=str, - help="the address to use for Redis") + parser.add_argument( + "--redis-address", + required=True, + type=str, + help="the address to use for Redis") args = parser.parse_args() redis_ip_address = get_ip_address(args.redis_address) diff --git a/src/common/redis_module/ray_redis_module.cc b/src/common/redis_module/ray_redis_module.cc index 1ee25fa08..cabf8ce4b 100644 --- a/src/common/redis_module/ray_redis_module.cc +++ b/src/common/redis_module/ray_redis_module.cc @@ -8,22 +8,25 @@ #include "common_protocol.h" -/** - * Various tables are maintained in redis: - * - * == OBJECT TABLE == - * - * This consists of two parts: - * - The object location table, indexed by OL:object_id, which is the set of - * plasma manager indices that have access to the object. - * - The object info table, indexed by OI:object_id, which is a hashmap with key - * "hash" for the hash of the object and key "data_size" for the size of the - * object in bytes. - * - * == TASK TABLE == - * - * TODO(pcm): Fill this out. - */ +// Various tables are maintained in redis: +// +// == OBJECT TABLE == +// +// This consists of two parts: +// - The object location table, indexed by OL:object_id, which is the set of +// plasma manager indices that have access to the object. +// (In redis this is represented by a zset (sorted set).) +// +// - The object info table, indexed by OI:object_id, which is a hashmap of: +// "hash" -> the hash of the object, +// "data_size" -> the size of the object in bytes, +// "task" -> the task ID that generated this object. +// "is_put" -> 0 or 1. +// +// == TASK TABLE == +// +// TODO(pcm): Fill this out. +// #define OBJECT_INFO_PREFIX "OI:" #define OBJECT_LOCATION_PREFIX "OL:" diff --git a/test/monitor_test.py b/test/monitor_test.py new file mode 100644 index 000000000..8fb7ae62e --- /dev/null +++ b/test/monitor_test.py @@ -0,0 +1,93 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import multiprocessing +import subprocess +import time +import unittest + +import ray + + +class MonitorTest(unittest.TestCase): + def _testCleanupOnDriverExit(self, num_redis_shards): + stdout = subprocess.check_output([ + "ray", + "start", + "--head", + "--num-redis-shards", + str(num_redis_shards), + ]).decode("ascii") + lines = [m.strip() for m in stdout.split("\n")] + init_cmd = [m for m in lines if m.startswith("ray.init")] + self.assertEqual(1, len(init_cmd)) + redis_address = init_cmd[0].split("redis_address=\"")[-1][:-2] + + def StateSummary(): + obj_tbl_len = len(ray.global_state.object_table()) + task_tbl_len = len(ray.global_state.task_table()) + func_tbl_len = len(ray.global_state.function_table()) + return obj_tbl_len, task_tbl_len, func_tbl_len + + def Driver(success): + success.value = True + # Start driver. + ray.init(redis_address=redis_address) + summary_start = StateSummary() + if (0, 1) != summary_start[:2]: + success.value = False + + # Two new objects. + ray.get(ray.put(1111)) + ray.get(ray.put(1111)) + if (2, 1, summary_start[2]) != StateSummary(): + success.value = False + + @ray.remote + def f(): + ray.put(1111) # Yet another object. + return 1111 # A returned object as well. + + # 1 new function. + if (2, 1, summary_start[2] + 1) != StateSummary(): + success.value = False + + ray.get(f.remote()) + if (4, 2, summary_start[2] + 1) != StateSummary(): + success.value = False + + ray.worker.cleanup() + + success = multiprocessing.Value('b', False) + driver = multiprocessing.Process(target=Driver, args=(success, )) + driver.start() + # Wait for client to exit. + driver.join() + time.sleep(5) + + # Just make sure Driver() is run and succeeded. Note(rkn), if the below + # assertion starts failing, then the issue may be that the summary + # values computed in the Driver function are being updated slowly and + # so the call to StateSummary() is getting outdated values. This could + # be fixed by looping until StateSummary() returns the desired values. + self.assertTrue(success.value) + # Check that objects, tasks, and functions are cleaned up. + ray.init(redis_address=redis_address) + # The assertion below can fail if the monitor is too slow to clean up + # the global state. + self.assertEqual((0, 1), StateSummary()[:2]) + + ray.worker.cleanup() + subprocess.Popen(["ray", "stop"]).wait() + + def testCleanupOnDriverExitSingleRedisShard(self): + self._testCleanupOnDriverExit(num_redis_shards=1) + + def testCleanupOnDriverExitManyRedisShards(self): + self._testCleanupOnDriverExit(num_redis_shards=5) + self._testCleanupOnDriverExit(num_redis_shards=31) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/test/runtest.py b/test/runtest.py index 684832524..b6c1ca367 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -1,18 +1,17 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function -from collections import defaultdict, namedtuple -import numpy as np import os -import ray import re import shutil import string import sys import time import unittest +from collections import defaultdict, namedtuple +import numpy as np + +import ray import ray.test.test_functions as test_functions import ray.test.test_utils @@ -21,11 +20,11 @@ if sys.version_info >= (3, 0): def assert_equal(obj1, obj2): - module_numpy = (type(obj1).__module__ == np.__name__ or - type(obj2).__module__ == np.__name__) + module_numpy = (type(obj1).__module__ == np.__name__ + or type(obj2).__module__ == np.__name__) if module_numpy: - empty_shape = ((hasattr(obj1, "shape") and obj1.shape == ()) or - (hasattr(obj2, "shape") and obj2.shape == ())) + empty_shape = ((hasattr(obj1, "shape") and obj1.shape == ()) + or (hasattr(obj2, "shape") and obj2.shape == ())) if empty_shape: # This is a special case because currently np.testing.assert_equal # fails because we do not properly handle different numerical @@ -36,13 +35,11 @@ def assert_equal(obj1, obj2): np.testing.assert_equal(obj1, obj2) elif hasattr(obj1, "__dict__") and hasattr(obj2, "__dict__"): special_keys = ["_pytype_"] - assert (set(list(obj1.__dict__.keys()) + special_keys) == - set(list(obj2.__dict__.keys()) + special_keys)), ("Objects {} " - "and {} are " - "different." - .format( - obj1, - obj2)) + assert (set(list(obj1.__dict__.keys()) + special_keys) == set( + list(obj2.__dict__.keys()) + special_keys)), ("Objects {} " + "and {} are " + "different.".format( + obj1, obj2)) for key in obj1.__dict__.keys(): if key not in special_keys: assert_equal(obj1.__dict__[key], obj2.__dict__[key]) @@ -52,49 +49,76 @@ def assert_equal(obj1, obj2): assert_equal(obj1[key], obj2[key]) elif type(obj1) is list or type(obj2) is list: assert len(obj1) == len(obj2), ("Objects {} and {} are lists with " - "different lengths." - .format(obj1, obj2)) + "different lengths.".format( + obj1, obj2)) for i in range(len(obj1)): assert_equal(obj1[i], obj2[i]) elif type(obj1) is tuple or type(obj2) is tuple: assert len(obj1) == len(obj2), ("Objects {} and {} are tuples with " - "different lengths." - .format(obj1, obj2)) + "different lengths.".format( + obj1, obj2)) for i in range(len(obj1)): assert_equal(obj1[i], obj2[i]) - elif (ray.serialization.is_named_tuple(type(obj1)) or - ray.serialization.is_named_tuple(type(obj2))): + elif (ray.serialization.is_named_tuple(type(obj1)) + or ray.serialization.is_named_tuple(type(obj2))): assert len(obj1) == len(obj2), ("Objects {} and {} are named tuples " - "with different lengths." - .format(obj1, obj2)) + "with different lengths.".format( + obj1, obj2)) for i in range(len(obj1)): assert_equal(obj1[i], obj2[i]) else: - assert obj1 == obj2, "Objects {} and {} are different.".format(obj1, - obj2) + assert obj1 == obj2, "Objects {} and {} are different.".format( + obj1, obj2) if sys.version_info >= (3, 0): long_extras = [0, np.array([["hi", u"hi"], [1.3, 1]])] else: - long_extras = [long(0), np.array([["hi", u"hi"], [1.3, long(1)]])] # noqa: E501,F821 -PRIMITIVE_OBJECTS = [0, 0.0, 0.9, 1 << 62, 1 << 100, 1 << 999, - [1 << 100, [1 << 100]], "a", string.printable, "\u262F", - u"hello world", u"\xff\xfe\x9c\x001\x000\x00", None, True, - False, [], (), {}, np.int8(3), np.int32(4), np.int64(5), - np.uint8(3), np.uint32(4), np.uint64(5), np.float32(1.9), - np.float64(1.9), np.zeros([100, 100]), - np.random.normal(size=[100, 100]), np.array(["hi", 3]), - np.array(["hi", 3], dtype=object)] + long_extras + long_extras = [ + long(0), # noqa: E501,F821 + np.array([ + ["hi", u"hi"], + [1.3, long(1)] # noqa: E501,F821 + ]) + ] + +PRIMITIVE_OBJECTS = [ + 0, 0.0, 0.9, 1 << 62, 1 << 100, 1 << 999, [1 << 100, [1 << 100]], "a", + string.printable, "\u262F", u"hello world", u"\xff\xfe\x9c\x001\x000\x00", + None, True, False, [], (), {}, + np.int8(3), + np.int32(4), + np.int64(5), + np.uint8(3), + np.uint32(4), + np.uint64(5), + np.float32(1.9), + np.float64(1.9), + np.zeros([100, 100]), + np.random.normal(size=[100, 100]), + np.array(["hi", 3]), + np.array(["hi", 3], dtype=object) +] + long_extras COMPLEX_OBJECTS = [ [[[[[[[[[[[[]]]]]]]]]]]], - {"obj{}".format(i): np.random.normal(size=[100, 100]) for i in range(10)}, + {"obj{}".format(i): np.random.normal(size=[100, 100]) + for i in range(10)}, # {(): {(): {(): {(): {(): {(): {(): {(): {(): {(): { # (): {(): {}}}}}}}}}}}}}, - ((((((((((),),),),),),),),),), - {"a": {"b": {"c": {"d": {}}}}}] + ( + (((((((((), ), ), ), ), ), ), ), ), ), + { + "a": { + "b": { + "c": { + "d": {} + } + } + } + } +] class Foo(object): @@ -141,21 +165,32 @@ Point = namedtuple("Point", ["x", "y"]) NamedTupleExample = namedtuple("Example", "field1, field2, field3, field4, field5") -CUSTOM_OBJECTS = [Exception("Test object."), CustomError(), Point(11, y=22), - Foo(), Bar(), Baz(), # Qux(), SubQux(), - NamedTupleExample(1, 1.0, "hi", np.zeros([3, 5]), [1, 2, 3])] +CUSTOM_OBJECTS = [ + Exception("Test object."), + CustomError(), + Point(11, y=22), + Foo(), + Bar(), + Baz(), # Qux(), SubQux(), + NamedTupleExample(1, 1.0, "hi", np.zeros([3, 5]), [1, 2, 3]) +] BASE_OBJECTS = PRIMITIVE_OBJECTS + COMPLEX_OBJECTS + CUSTOM_OBJECTS LIST_OBJECTS = [[obj] for obj in BASE_OBJECTS] -TUPLE_OBJECTS = [(obj,) for obj in BASE_OBJECTS] +TUPLE_OBJECTS = [(obj, ) for obj in BASE_OBJECTS] # The check that type(obj).__module__ != "numpy" should be unnecessary, but # otherwise this seems to fail on Mac OS X on Travis. -DICT_OBJECTS = ([{obj: obj} for obj in PRIMITIVE_OBJECTS - if (obj.__hash__ is not None and - type(obj).__module__ != "numpy")] + - [{0: obj} for obj in BASE_OBJECTS] + - [{Foo(123): Foo(456)}]) +DICT_OBJECTS = ( + [{ + obj: obj + } for obj in PRIMITIVE_OBJECTS + if (obj.__hash__ is not None and type(obj).__module__ != "numpy")] + [{ + 0: + obj + } for obj in BASE_OBJECTS] + [{ + Foo(123): Foo(456) + }]) RAY_TEST_OBJECTS = BASE_OBJECTS + LIST_OBJECTS + TUPLE_OBJECTS + DICT_OBJECTS @@ -171,7 +206,6 @@ except AttributeError: class SerializationTest(unittest.TestCase): - def testRecursiveObjects(self): ray.init(num_workers=0) @@ -253,15 +287,15 @@ class SerializationTest(unittest.TestCase): class WorkerTest(unittest.TestCase): - def testPythonWorkers(self): # Test the codepath for starting workers from the Python script, # instead of the local scheduler. This codepath is for debugging # purposes only. num_workers = 4 - ray.worker._init(num_workers=num_workers, - start_workers_from_local_scheduler=False, - start_ray_local=True) + ray.worker._init( + num_workers=num_workers, + start_workers_from_local_scheduler=False, + start_ray_local=True) @ray.remote def f(x): @@ -275,13 +309,13 @@ class WorkerTest(unittest.TestCase): ray.init(num_workers=0) for i in range(100): - value_before = i * 10 ** 6 + value_before = i * 10**6 objectid = ray.put(value_before) value_after = ray.get(objectid) self.assertEqual(value_before, value_after) for i in range(100): - value_before = i * 10 ** 6 * 1.0 + value_before = i * 10**6 * 1.0 objectid = ray.put(value_before) value_after = ray.get(objectid) self.assertEqual(value_before, value_after) @@ -302,7 +336,6 @@ class WorkerTest(unittest.TestCase): class APITest(unittest.TestCase): - def init_ray(self, kwargs=None): if kwargs is None: kwargs = {} @@ -318,6 +351,7 @@ class APITest(unittest.TestCase): # throws an exception. class TempClass(object): pass + ray.get(ray.put(TempClass())) # Note that the below actually returns a dictionary and not a @@ -525,14 +559,14 @@ class APITest(unittest.TestCase): return x, y, args self.assertEqual(ray.get(f1.remote()), ()) - self.assertEqual(ray.get(f1.remote(1)), (1,)) + self.assertEqual(ray.get(f1.remote(1)), (1, )) self.assertEqual(ray.get(f1.remote(1, 2, 3)), (1, 2, 3)) with self.assertRaises(Exception): f2.remote() with self.assertRaises(Exception): f2.remote(1) self.assertEqual(ray.get(f2.remote(1, 2)), (1, 2, ())) - self.assertEqual(ray.get(f2.remote(1, 2, 3)), (1, 2, (3,))) + self.assertEqual(ray.get(f2.remote(1, 2, 3)), (1, 2, (3, ))) self.assertEqual(ray.get(f2.remote(1, 2, 3, 4)), (1, 2, (3, 4))) def testNoArgs(self): @@ -548,12 +582,14 @@ class APITest(unittest.TestCase): @ray.remote def f(x): return x + 1 + self.assertEqual(ray.get(f.remote(0)), 1) # Test that we can redefine the remote function. @ray.remote def f(x): return x + 10 + while True: val = ray.get(f.remote(0)) self.assertTrue(val in [1, 10]) @@ -563,23 +599,29 @@ class APITest(unittest.TestCase): print("Still using old definition of f, trying again.") # Test that we can close over plain old data. - data = [np.zeros([3, 5]), (1, 2, "a"), [0.0, 1.0, 1 << 62], 1 << 60, - {"a": np.zeros(3)}] + data = [ + np.zeros([3, 5]), (1, 2, "a"), [0.0, 1.0, 1 << 62], 1 << 60, { + "a": np.zeros(3) + } + ] @ray.remote def g(): return data + ray.get(g.remote()) # Test that we can close over modules. @ray.remote def h(): return np.zeros([3, 5]) + assert_equal(ray.get(h.remote()), np.zeros([3, 5])) @ray.remote def j(): return time.time() + ray.get(j.remote()) # Test that we can define remote functions that call other remote @@ -595,6 +637,7 @@ class APITest(unittest.TestCase): @ray.remote def m(x): return ray.get(l.remote(x)) + self.assertEqual(ray.get(k.remote(1)), 2) self.assertEqual(ray.get(l.remote(1)), 2) self.assertEqual(ray.get(m.remote(1)), 2) @@ -618,8 +661,12 @@ class APITest(unittest.TestCase): time.sleep(delay) return 1 - objectids = [f.remote(1.0), f.remote(0.5), f.remote(0.5), - f.remote(0.5)] + objectids = [ + f.remote(1.0), + f.remote(0.5), + f.remote(0.5), + f.remote(0.5) + ] ready_ids, remaining_ids = ray.wait(objectids) self.assertEqual(len(ready_ids), 1) self.assertEqual(len(remaining_ids), 3) @@ -627,17 +674,25 @@ class APITest(unittest.TestCase): self.assertEqual(set(ready_ids), set(objectids)) self.assertEqual(remaining_ids, []) - objectids = [f.remote(0.5), f.remote(0.5), f.remote(0.5), - f.remote(0.5)] + objectids = [ + f.remote(0.5), + f.remote(0.5), + f.remote(0.5), + f.remote(0.5) + ] start_time = time.time() - ready_ids, remaining_ids = ray.wait(objectids, timeout=1750, - num_returns=4) + ready_ids, remaining_ids = ray.wait( + objectids, timeout=1750, num_returns=4) self.assertLess(time.time() - start_time, 2) self.assertEqual(len(ready_ids), 3) self.assertEqual(len(remaining_ids), 1) ray.wait(objectids) - objectids = [f.remote(1.0), f.remote(0.5), f.remote(0.5), - f.remote(0.5)] + objectids = [ + f.remote(1.0), + f.remote(0.5), + f.remote(0.5), + f.remote(0.5) + ] start_time = time.time() ready_ids, remaining_ids = ray.wait(objectids, timeout=5000) self.assertTrue(time.time() - start_time < 5) @@ -684,18 +739,22 @@ class APITest(unittest.TestCase): # is connected. def f(worker_info): sys.path.append(1) + ray.worker.global_worker.run_function_on_all_workers(f) def f(worker_info): sys.path.append(2) + ray.worker.global_worker.run_function_on_all_workers(f) def g(worker_info): sys.path.append(3) + ray.worker.global_worker.run_function_on_all_workers(g) def f(worker_info): sys.path.append(4) + ray.worker.global_worker.run_function_on_all_workers(f) self.init_ray() @@ -716,6 +775,7 @@ class APITest(unittest.TestCase): sys.path.pop() sys.path.pop() sys.path.pop() + ray.worker.global_worker.run_function_on_all_workers(f) def testRunningFunctionOnAllWorkers(self): @@ -723,15 +783,18 @@ class APITest(unittest.TestCase): def f(worker_info): sys.path.append("fake_directory") + ray.worker.global_worker.run_function_on_all_workers(f) @ray.remote def get_path1(): return sys.path + self.assertEqual("fake_directory", ray.get(get_path1.remote())[-1]) def f(worker_info): sys.path.pop(-1) + ray.worker.global_worker.run_function_on_all_workers(f) # Create a second remote function to guarantee that when we call @@ -740,6 +803,7 @@ class APITest(unittest.TestCase): @ray.remote def get_path2(): return sys.path + self.assertTrue("fake_directory" not in ray.get(get_path2.remote())) def testLoggingAPI(self): @@ -751,8 +815,8 @@ class APITest(unittest.TestCase): keys = ray.worker.global_worker.redis_client.keys("event_log:*") res = [] for key in keys: - res.extend(ray.worker.global_worker.redis_client.zrange(key, 0, - -1)) + res.extend( + ray.worker.global_worker.redis_client.zrange(key, 0, -1)) return res def wait_for_num_events(num_events, timeout=10): @@ -806,26 +870,31 @@ class APITest(unittest.TestCase): @ray.remote def f(): return 1 + results1 = [f.remote() for _ in range(num_calls)] @ray.remote def f(): return 2 + results2 = [f.remote() for _ in range(num_calls)] @ray.remote def f(): return 3 + results3 = [f.remote() for _ in range(num_calls)] @ray.remote def f(): return 4 + results4 = [f.remote() for _ in range(num_calls)] @ray.remote def f(): return 5 + results5 = [f.remote() for _ in range(num_calls)] self.assertEqual(ray.get(results1), num_calls * [1]) @@ -870,7 +939,6 @@ class APITest(unittest.TestCase): class APITestSharded(APITest): - def init_ray(self, kwargs=None): if kwargs is None: kwargs = {} @@ -881,7 +949,6 @@ class APITestSharded(APITest): class PythonModeTest(unittest.TestCase): - def testPythonMode(self): reload(test_functions) ray.init(driver_mode=ray.PYTHON_MODE) @@ -889,6 +956,7 @@ class PythonModeTest(unittest.TestCase): @ray.remote def f(): return np.ones([3, 4, 5]) + xref = f.remote() # Remote functions should return by value. assert_equal(xref, np.ones([3, 4, 5])) @@ -911,8 +979,8 @@ class PythonModeTest(unittest.TestCase): # first list and the remaining values as the second list num_returns = 5 object_ids = [ray.put(i) for i in range(20)] - ready, remaining = ray.wait(object_ids, num_returns=num_returns, - timeout=None) + ready, remaining = ray.wait( + object_ids, num_returns=num_returns, timeout=None) assert_equal(ready, object_ids[:num_returns]) assert_equal(remaining, object_ids[num_returns:]) @@ -949,7 +1017,6 @@ class PythonModeTest(unittest.TestCase): class UtilsTest(unittest.TestCase): - def testCopyingDirectory(self): # The functionality being tested here is really multi-node # functionality, but this test just uses a single node. @@ -999,7 +1066,6 @@ class UtilsTest(unittest.TestCase): class ResourcesTest(unittest.TestCase): - def testResourceConstraints(self): num_workers = 20 ray.init(num_workers=num_workers, num_cpus=10, num_gpus=2) @@ -1012,9 +1078,13 @@ class ResourcesTest(unittest.TestCase): def get_worker_id(): time.sleep(1) return sys.path[-1] + while True: - if len(set(ray.get([get_worker_id.remote() - for _ in range(num_workers)]))) == num_workers: + if len( + set( + ray.get([ + get_worker_id.remote() for _ in range(num_workers) + ]))) == num_workers: break time_buffer = 0.3 @@ -1088,9 +1158,13 @@ class ResourcesTest(unittest.TestCase): def get_worker_id(): time.sleep(1) return sys.path[-1] + while True: - if len(set(ray.get([get_worker_id.remote() - for _ in range(num_workers)]))) == num_workers: + if len( + set( + ray.get([ + get_worker_id.remote() for _ in range(num_workers) + ]))) == num_workers: break @ray.remote(num_cpus=1, num_gpus=9) @@ -1192,7 +1266,7 @@ class ResourcesTest(unittest.TestCase): list_of_ids = ray.get([f1.remote() for _ in range(10)]) set_of_ids = set([tuple(gpu_ids) for gpu_ids in list_of_ids]) - self.assertEqual(set_of_ids, set([(i,) for i in range(10)])) + self.assertEqual(set_of_ids, set([(i, ) for i in range(10)])) list_of_ids = ray.get([f2.remote(), f4.remote(), f4.remote()]) all_ids = [gpu_id for gpu_ids in list_of_ids for gpu_id in gpu_ids] @@ -1218,11 +1292,12 @@ class ResourcesTest(unittest.TestCase): # This test will define a bunch of tasks that can only be assigned to # specific local schedulers, and we will check that they are assigned # to the correct local schedulers. - address_info = ray.worker._init(start_ray_local=True, - num_local_schedulers=3, - num_workers=1, - num_cpus=[100, 5, 10], - num_gpus=[0, 5, 1]) + address_info = ray.worker._init( + start_ray_local=True, + num_local_schedulers=3, + num_workers=1, + num_cpus=[100, 5, 10], + num_gpus=[0, 5, 1]) # Define a bunch of remote functions that all return the socket name of # the plasma store. Since there is a one-to-one correspondence between @@ -1284,8 +1359,10 @@ class ResourcesTest(unittest.TestCase): results.append(run_on_0_2.remote()) return names, results - store_names = [object_store_address.name for object_store_address - in address_info["object_store_addresses"]] + store_names = [ + object_store_address.name + for object_store_address in address_info["object_store_addresses"] + ] def validate_names_and_results(names, results): for name, result in zip(names, ray.get(results)): @@ -1296,8 +1373,9 @@ class ResourcesTest(unittest.TestCase): elif name == "run_on_2": self.assertIn(result, [store_names[2]]) elif name == "run_on_0_1_2": - self.assertIn(result, [store_names[0], store_names[1], - store_names[2]]) + self.assertIn(result, [ + store_names[0], store_names[1], store_names[2] + ]) elif name == "run_on_1_2": self.assertIn(result, [store_names[1], store_names[2]]) elif name == "run_on_0_2": @@ -1327,8 +1405,11 @@ class ResourcesTest(unittest.TestCase): ray.worker.cleanup() def testCustomResources(self): - ray.worker._init(start_ray_local=True, num_local_schedulers=2, - num_cpus=3, num_custom_resource=[0, 1]) + ray.worker._init( + start_ray_local=True, + num_local_schedulers=2, + num_cpus=3, + num_custom_resource=[0, 1]) @ray.remote def f(): @@ -1373,13 +1454,12 @@ class ResourcesTest(unittest.TestCase): ray.get(ray.remote(num_custom_resource=2)(f).remote()) ray.get(ray.remote(num_custom_resource=4)(f).remote()) ray.get(ray.remote(num_custom_resource=8)(f).remote()) - ray.get(ray.remote(num_custom_resource=(10 ** 10))(f).remote()) + ray.get(ray.remote(num_custom_resource=(10**10))(f).remote()) ray.worker.cleanup() class WorkerPoolTests(unittest.TestCase): - def tearDown(self): ray.worker.cleanup() @@ -1450,19 +1530,22 @@ class WorkerPoolTests(unittest.TestCase): class SchedulingAlgorithm(unittest.TestCase): - - def attempt_to_load_balance(self, remote_function, args, total_tasks, - num_local_schedulers, minimum_count, + def attempt_to_load_balance(self, + remote_function, + args, + total_tasks, + num_local_schedulers, + minimum_count, num_attempts=20): attempts = 0 while attempts < num_attempts: - locations = ray.get([remote_function.remote(*args) - for _ in range(total_tasks)]) + locations = ray.get( + [remote_function.remote(*args) for _ in range(total_tasks)]) names = set(locations) counts = [locations.count(name) for name in names] print("Counts are {}.".format(counts)) - if (len(names) == num_local_schedulers and - all([count >= minimum_count for count in counts])): + if (len(names) == num_local_schedulers + and all([count >= minimum_count for count in counts])): break attempts += 1 self.assertLess(attempts, num_attempts) @@ -1472,9 +1555,10 @@ class SchedulingAlgorithm(unittest.TestCase): # schedulers in a roughly equal manner. num_local_schedulers = 3 num_cpus = 7 - ray.worker._init(start_ray_local=True, - num_local_schedulers=num_local_schedulers, - num_cpus=num_cpus) + ray.worker._init( + start_ray_local=True, + num_local_schedulers=num_local_schedulers, + num_cpus=num_cpus) @ray.remote def f(): @@ -1492,8 +1576,10 @@ class SchedulingAlgorithm(unittest.TestCase): # dependencies. num_workers = 3 num_local_schedulers = 3 - ray.worker._init(start_ray_local=True, num_workers=num_workers, - num_local_schedulers=num_local_schedulers) + ray.worker._init( + start_ray_local=True, + num_workers=num_workers, + num_local_schedulers=num_local_schedulers) @ray.remote def f(x): @@ -1528,7 +1614,6 @@ def wait_for_num_objects(num_objects, timeout=10): class GlobalStateAPI(unittest.TestCase): - def testGlobalStateAPI(self): with self.assertRaises(Exception): ray.global_state.object_table() @@ -1572,15 +1657,16 @@ class GlobalStateAPI(unittest.TestCase): driver_id) self.assertEqual(task_table[driver_task_id]["TaskSpec"]["FunctionID"], ID_SIZE * "ff") - self.assertEqual((task_table[driver_task_id]["TaskSpec"] - ["ReturnObjectIDs"]), - []) + self.assertEqual( + (task_table[driver_task_id]["TaskSpec"]["ReturnObjectIDs"]), []) client_table = ray.global_state.client_table() node_ip_address = ray.worker.global_worker.node_ip_address self.assertEqual(len(client_table[node_ip_address]), 3) - manager_client = [c for c in client_table[node_ip_address] - if c["ClientType"] == "plasma_manager"][0] + manager_client = [ + c for c in client_table[node_ip_address] + if c["ClientType"] == "plasma_manager" + ][0] @ray.remote def f(*xs): @@ -1624,8 +1710,8 @@ class GlobalStateAPI(unittest.TestCase): while time.time() - start_time < timeout: object_table = ray.global_state.object_table() tables_ready = ( - object_table[x_id]["ManagerIDs"] is not None and - object_table[result_id]["ManagerIDs"] is not None) + object_table[x_id]["ManagerIDs"] is not None + and object_table[result_id]["ManagerIDs"] is not None) if tables_ready: return time.sleep(0.1) @@ -1701,8 +1787,8 @@ class GlobalStateAPI(unittest.TestCase): while time.time() - start_time < 10: profiles = ray.global_state.task_profiles( 100, start=0, end=time.time()) - limited_profiles = ray.global_state.task_profiles(1, start=0, - end=time.time()) + limited_profiles = ray.global_state.task_profiles( + 1, start=0, end=time.time()) if len(profiles) == num_calls and len(limited_profiles) == 1: break time.sleep(0.1) @@ -1722,8 +1808,10 @@ class GlobalStateAPI(unittest.TestCase): def testWorkers(self): num_workers = 3 - ray.init(redirect_output=True, num_cpus=num_workers, - num_workers=num_workers) + ray.init( + redirect_output=True, + num_cpus=num_workers, + num_workers=num_workers) @ray.remote def f():