Move global state API out of global_state object. (#4857)

This commit is contained in:
Robert Nishihara
2019-05-26 11:27:53 -07:00
committed by Philipp Moritz
parent ea8d7b4dc0
commit 6703519144
29 changed files with 387 additions and 403 deletions
+12 -4
View File
@@ -66,6 +66,9 @@ from ray._raylet import (
_config = _Config()
from ray.profiling import profile # noqa: E402
from ray.state import (global_state, nodes, tasks, objects, timeline,
object_transfer_timeline, cluster_resources,
available_resources, errors) # noqa: E402
from ray.worker import (
LOCAL_MODE,
PYTHON_MODE,
@@ -73,12 +76,10 @@ from ray.worker import (
WORKER_MODE,
connect,
disconnect,
error_info,
get,
get_gpu_ids,
get_resource_ids,
get_webui_url,
global_state,
init,
is_initialized,
put,
@@ -98,6 +99,15 @@ from ray.runtime_context import _get_runtime_context # noqa: E402
__version__ = "0.8.0.dev0"
__all__ = [
"global_state",
"nodes",
"tasks",
"objects",
"timeline",
"object_transfer_timeline",
"cluster_resources",
"available_resources",
"errors",
"LOCAL_MODE",
"PYTHON_MODE",
"SCRIPT_MODE",
@@ -108,12 +118,10 @@ __all__ = [
"actor",
"connect",
"disconnect",
"error_info",
"get",
"get_gpu_ids",
"get_resource_ids",
"get_webui_url",
"global_state",
"init",
"internal",
"is_initialized",
+2 -2
View File
@@ -811,7 +811,7 @@ def exit_actor():
worker.raylet_client.disconnect()
ray.disconnect()
# Disconnect global state from GCS.
ray.global_state.disconnect()
ray.state.state.disconnect()
sys.exit(0)
assert False, "This process should have terminated."
else:
@@ -931,7 +931,7 @@ def get_checkpoints_for_actor(actor_id):
"""Get the available checkpoints for the given actor ID, return a list
sorted by checkpoint timestamp in descending order.
"""
checkpoint_info = ray.worker.global_state.actor_checkpoint_info(actor_id)
checkpoint_info = ray.state.state.actor_checkpoint_info(actor_id)
if checkpoint_info is None:
return []
checkpoints = [
+3 -10
View File
@@ -2,10 +2,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from .features import (
flush_redis_unsafe, flush_task_and_object_metadata_unsafe,
flush_finished_tasks_unsafe, flush_evicted_objects_unsafe,
_flush_finished_tasks_unsafe_shard, _flush_evicted_objects_unsafe_shard)
from .gcs_flush_policy import (set_flushing_policy, GcsFlushPolicy,
SimpleGcsFlushPolicy)
from .named_actors import get_actor, register_actor
@@ -20,10 +16,7 @@ def TensorFlowVariables(*args, **kwargs):
__all__ = [
"TensorFlowVariables", "flush_redis_unsafe",
"flush_task_and_object_metadata_unsafe", "flush_finished_tasks_unsafe",
"flush_evicted_objects_unsafe", "_flush_finished_tasks_unsafe_shard",
"_flush_evicted_objects_unsafe_shard", "get_actor", "register_actor",
"get", "wait", "set_flushing_policy", "GcsFlushPolicy",
"SimpleGcsFlushPolicy", "set_resource"
"TensorFlowVariables", "get_actor", "register_actor", "get", "wait",
"set_flushing_policy", "GcsFlushPolicy", "SimpleGcsFlushPolicy",
"set_resource"
]
-186
View File
@@ -1,186 +0,0 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ray
from ray.utils import binary_to_hex
OBJECT_INFO_PREFIX = b"OI:"
OBJECT_LOCATION_PREFIX = b"OL:"
TASK_PREFIX = b"TT:"
def flush_redis_unsafe(redis_client=None):
"""This removes some non-critical state from the primary Redis shard.
This removes the log files as well as the event log from Redis. This can
be used to try to address out-of-memory errors caused by the accumulation
of metadata in Redis. However, it will only partially address the issue as
much of the data is in the task table (and object table), which are not
flushed.
Args:
redis_client: optional, if not provided then ray.init() must have been
called.
"""
if redis_client is None:
ray.worker.global_worker.check_connected()
redis_client = ray.worker.global_worker.redis_client
# Delete the log files from the primary Redis shard.
keys = redis_client.keys("LOGFILE:*")
if len(keys) > 0:
num_deleted = redis_client.delete(*keys)
else:
num_deleted = 0
print("Deleted {} log files from Redis.".format(num_deleted))
# Delete the event log from the primary Redis shard.
keys = redis_client.keys("event_log:*")
if len(keys) > 0:
num_deleted = redis_client.delete(*keys)
else:
num_deleted = 0
print("Deleted {} event logs from Redis.".format(num_deleted))
def flush_task_and_object_metadata_unsafe():
"""This removes some critical state from the Redis shards.
In a multitenant environment, this will flush metadata for all jobs, which
may be undesirable.
This removes all of the object and task metadata. This can be used to try
to address out-of-memory errors caused by the accumulation of metadata in
Redis. However, after running this command, fault tolerance will most
likely not work.
"""
ray.worker.global_worker.check_connected()
def flush_shard(redis_client):
# Flush the task table. Note that this also flushes the driver tasks
# which may be undesirable.
num_task_keys_deleted = 0
for key in redis_client.scan_iter(match=TASK_PREFIX + b"*"):
num_task_keys_deleted += redis_client.delete(key)
print("Deleted {} task keys from Redis.".format(num_task_keys_deleted))
# Flush the object information.
num_object_keys_deleted = 0
for key in redis_client.scan_iter(match=OBJECT_INFO_PREFIX + b"*"):
num_object_keys_deleted += redis_client.delete(key)
print("Deleted {} object info keys from Redis.".format(
num_object_keys_deleted))
# Flush the object locations.
num_object_location_keys_deleted = 0
for key in redis_client.scan_iter(match=OBJECT_LOCATION_PREFIX + b"*"):
num_object_location_keys_deleted += redis_client.delete(key)
print("Deleted {} object location keys from Redis.".format(
num_object_location_keys_deleted))
# Loop over the shards and flush all of them.
for redis_client in ray.worker.global_state.redis_clients:
flush_shard(redis_client)
def _task_table_shard(shard_index):
redis_client = ray.global_state.redis_clients[shard_index]
task_table_keys = redis_client.keys(TASK_PREFIX + b"*")
results = {}
for key in task_table_keys:
task_id_binary = key[len(TASK_PREFIX):]
results[binary_to_hex(task_id_binary)] = ray.global_state._task_table(
ray.TaskID(task_id_binary))
return results
def _object_table_shard(shard_index):
redis_client = ray.global_state.redis_clients[shard_index]
object_table_keys = redis_client.keys(OBJECT_LOCATION_PREFIX + b"*")
results = {}
for key in object_table_keys:
object_id_binary = key[len(OBJECT_LOCATION_PREFIX):]
results[binary_to_hex(object_id_binary)] = (
ray.global_state._object_table(ray.ObjectID(object_id_binary)))
return results
def _flush_finished_tasks_unsafe_shard(shard_index):
ray.worker.global_worker.check_connected()
redis_client = ray.global_state.redis_clients[shard_index]
tasks = _task_table_shard(shard_index)
keys_to_delete = []
for task_id, task_info in tasks.items():
if task_info["State"] == ray.experimental.state.TASK_STATUS_DONE:
keys_to_delete.append(TASK_PREFIX +
ray.utils.hex_to_binary(task_id))
num_task_keys_deleted = 0
if len(keys_to_delete) > 0:
num_task_keys_deleted = redis_client.execute_command(
"del", *keys_to_delete)
print("Deleted {} finished tasks from Redis shard."
.format(num_task_keys_deleted))
def _flush_evicted_objects_unsafe_shard(shard_index):
ray.worker.global_worker.check_connected()
redis_client = ray.global_state.redis_clients[shard_index]
objects = _object_table_shard(shard_index)
keys_to_delete = []
for object_id, object_info in objects.items():
if object_info["ManagerIDs"] == []:
keys_to_delete.append(OBJECT_LOCATION_PREFIX +
ray.utils.hex_to_binary(object_id))
keys_to_delete.append(OBJECT_INFO_PREFIX +
ray.utils.hex_to_binary(object_id))
num_object_keys_deleted = 0
if len(keys_to_delete) > 0:
num_object_keys_deleted = redis_client.execute_command(
"del", *keys_to_delete)
print("Deleted {} keys for evicted objects from Redis."
.format(num_object_keys_deleted))
def flush_finished_tasks_unsafe():
"""This removes some critical state from the Redis shards.
In a multitenant environment, this will flush metadata for all jobs, which
may be undesirable.
This removes all of the metadata for finished tasks. This can be used to
try to address out-of-memory errors caused by the accumulation of metadata
in Redis. However, after running this command, fault tolerance will most
likely not work.
"""
ray.worker.global_worker.check_connected()
for shard_index in range(len(ray.global_state.redis_clients)):
_flush_finished_tasks_unsafe_shard(shard_index)
def flush_evicted_objects_unsafe():
"""This removes some critical state from the Redis shards.
In a multitenant environment, this will flush metadata for all jobs, which
may be undesirable.
This removes all of the metadata for objects that have been evicted. This
can be used to try to address out-of-memory errors caused by the
accumulation of metadata in Redis. However, after running this command,
fault tolerance will most likely not work.
"""
ray.worker.global_worker.check_connected()
for shard_index in range(len(ray.global_state.redis_clients)):
_flush_evicted_objects_unsafe_shard(shard_index)
+8 -9
View File
@@ -37,8 +37,7 @@ class Monitor(object):
def __init__(self, redis_address, autoscaling_config, redis_password=None):
# Initialize the Redis clients.
self.state = ray.experimental.state.GlobalState()
self.state._initialize_global_state(
ray.state.state._initialize_global_state(
args.redis_address, redis_password=redis_password)
self.redis = ray.services.create_redis_client(
redis_address, password=redis_password)
@@ -149,7 +148,7 @@ class Monitor(object):
xray_object_table_prefix = (
ray.gcs_utils.TablePrefix_OBJECT_string.encode("ascii"))
task_table_objects = self.state.task_table()
task_table_objects = ray.tasks()
driver_id_hex = binary_to_hex(driver_id)
driver_task_id_bins = set()
for task_id_hex, task_info in task_table_objects.items():
@@ -161,7 +160,7 @@ class Monitor(object):
driver_task_id_bins.add(hex_to_binary(task_id_hex))
# Get objects associated with the driver.
object_table_objects = self.state.object_table()
object_table_objects = ray.objects()
driver_object_id_bins = set()
for object_id, _ in object_table_objects.items():
task_id_bin = ray._raylet.compute_task_id(object_id).binary()
@@ -171,13 +170,13 @@ class Monitor(object):
def to_shard_index(id_bin):
if len(id_bin) == ray.TaskID.size():
return binary_to_task_id(id_bin).redis_shard_hash() % len(
self.state.redis_clients)
ray.state.state.redis_clients)
else:
return binary_to_object_id(id_bin).redis_shard_hash() % len(
self.state.redis_clients)
ray.state.state.redis_clients)
# Form the redis keys to delete.
sharded_keys = [[] for _ in range(len(self.state.redis_clients))]
sharded_keys = [[] for _ in range(len(ray.state.state.redis_clients))]
for task_id_bin in driver_task_id_bins:
sharded_keys[to_shard_index(task_id_bin)].append(
xray_task_table_prefix + task_id_bin)
@@ -190,7 +189,7 @@ class Monitor(object):
keys = sharded_keys[shard_index]
if len(keys) == 0:
continue
redis = self.state.redis_clients[shard_index]
redis = ray.state.state.redis_clients[shard_index]
num_deleted = redis.delete(*keys)
logger.info("Monitor: "
"Removed {} dead redis entries of the "
@@ -256,7 +255,7 @@ class Monitor(object):
message_handler(channel, data)
def update_raylet_map(self):
all_raylet_nodes = self.state.client_table()
all_raylet_nodes = ray.nodes()
self.raylet_id_to_ip_map = {}
for raylet_info in all_raylet_nodes:
client_id = (raylet_info.get("DBClientID")
+1 -1
View File
@@ -746,7 +746,7 @@ def timeline(redis_address):
ray.init(redis_address=redis_address)
time = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
filename = "/tmp/ray-timeline-{}.json".format(time)
ray.global_state.chrome_tracing_dump(filename=filename)
ray.timeline(filename=filename)
size = os.path.getsize(filename)
logger.info("Trace file written to {} ({} bytes).".format(filename, size))
logger.info(
+1 -1
View File
@@ -101,7 +101,7 @@ def get_address_info_from_redis_helper(redis_address,
# Redis) must have run "CONFIG SET protected-mode no".
redis_client = create_redis_client(redis_address, password=redis_password)
client_table = ray.experimental.state.parse_client_table(redis_client)
client_table = ray.state._parse_client_table(redis_client)
if len(client_table) == 0:
raise Exception(
"Redis has started but no raylets have registered yet.")
@@ -4,6 +4,7 @@ from __future__ import print_function
from collections import defaultdict
import json
import logging
import sys
import time
@@ -17,8 +18,10 @@ from ray.core.generated.EntryType import EntryType
from ray.utils import (decode, binary_to_object_id, binary_to_hex,
hex_to_binary)
logger = logging.getLogger(__name__)
def parse_client_table(redis_client):
def _parse_client_table(redis_client):
"""Read the client table.
Args:
@@ -128,11 +131,11 @@ class GlobalState(object):
yet.
"""
if self.redis_client is None:
raise Exception("The ray.global_state API cannot be used before "
raise Exception("The ray global state API cannot be used before "
"ray.init has been called.")
if self.redis_clients is None:
raise Exception("The ray.global_state API cannot be used before "
raise Exception("The ray global state API cannot be used before "
"ray.init has been called.")
def disconnect(self):
@@ -408,7 +411,7 @@ class GlobalState(object):
"""
self._check_connected()
return parse_client_table(self.redis_client)
return _parse_client_table(self.redis_client)
def _profile_table(self, batch_id):
"""Get the profile events for a given batch of profile events.
@@ -461,6 +464,7 @@ class GlobalState(object):
return profile_events
def profile_table(self):
self._check_connected()
profile_table_keys = self._keys(
ray.gcs_utils.TablePrefix_PROFILE_string + "*")
batch_identifiers_binary = [
@@ -561,6 +565,8 @@ class GlobalState(object):
# TODO(rkn): This should support viewing just a window of time or a
# limited number of events.
self._check_connected()
profile_table = self.profile_table()
all_events = []
@@ -626,8 +632,10 @@ class GlobalState(object):
If filename is not provided, this returns a list of profiling
events. Each profile event is a dictionary.
"""
self._check_connected()
client_id_to_address = {}
for client_info in ray.global_state.client_table():
for client_info in self.client_table():
client_id_to_address[client_info["ClientID"]] = "{}:{}".format(
client_info["NodeManagerAddress"],
client_info["ObjectManagerPort"])
@@ -703,6 +711,8 @@ class GlobalState(object):
def workers(self):
"""Get a dictionary mapping worker ID to worker information."""
self._check_connected()
worker_keys = self.redis_client.keys("Worker*")
workers_data = {}
@@ -723,22 +733,6 @@ class GlobalState(object):
worker_info[b"stdout_file"])
return workers_data
def actors(self):
actor_keys = self.redis_client.keys("Actor:*")
actor_info = {}
for key in actor_keys:
info = self.redis_client.hgetall(key)
actor_id = key[len("Actor:"):]
assert len(actor_id) == ID_SIZE
actor_info[binary_to_hex(actor_id)] = {
"class_id": binary_to_hex(info[b"class_id"]),
"driver_id": binary_to_hex(info[b"driver_id"]),
"raylet_id": binary_to_hex(info[b"raylet_id"]),
"num_gpus": int(info[b"num_gpus"]),
"removed": decode(info[b"removed"]) == "True"
}
return actor_info
def _job_length(self):
event_log_sets = self.redis_client.keys("event_log*")
overall_smallest = sys.maxsize
@@ -769,6 +763,8 @@ class GlobalState(object):
A dictionary mapping resource name to the total quantity of that
resource in the cluster.
"""
self._check_connected()
resources = defaultdict(int)
clients = self.client_table()
for client in clients:
@@ -798,6 +794,8 @@ class GlobalState(object):
A dictionary mapping resource name to the total quantity of that
resource in the cluster.
"""
self._check_connected()
available_resources_by_id = {}
subscribe_clients = [
@@ -899,6 +897,8 @@ class GlobalState(object):
A dictionary mapping driver ID to a list of the error messages for
that driver.
"""
self._check_connected()
if driver_id is not None:
assert isinstance(driver_id, ray.DriverID)
return self._error_messages(driver_id)
@@ -954,3 +954,194 @@ class GlobalState(object):
entry.Timestamps(i) for i in range(num_checkpoints)
],
}
class DeprecatedGlobalState(object):
"""A class used to print errors when the old global state API is used."""
def object_table(self, object_id=None):
logger.warning(
"ray.global_state.object_table() is deprecated and will be "
"removed in a subsequent release. Use ray.objects() instead.")
return ray.objects(object_id=object_id)
def task_table(self, task_id=None):
logger.warning(
"ray.global_state.task_table() is deprecated and will be "
"removed in a subsequent release. Use ray.tasks() instead.")
return ray.tasks(task_id=task_id)
def function_table(self, function_id=None):
raise DeprecationWarning(
"ray.global_state.function_table() is deprecated.")
def client_table(self):
logger.warning(
"ray.global_state.client_table() is deprecated and will be "
"removed in a subsequent release. Use ray.nodes() instead.")
return ray.nodes()
def profile_table(self):
raise DeprecationWarning(
"ray.global_state.profile_table() is deprecated.")
def chrome_tracing_dump(self, filename=None):
logger.warning(
"ray.global_state.chrome_tracing_dump() is deprecated and will be "
"removed in a subsequent release. Use ray.timeline() instead.")
return ray.timeline(filename=filename)
def chrome_tracing_object_transfer_dump(self, filename=None):
logger.warning(
"ray.global_state.chrome_tracing_object_transfer_dump() is "
"deprecated and will be removed in a subsequent release. Use "
"ray.object_transfer_timeline() instead.")
return ray.object_transfer_timeline(filename=filename)
def workers(self):
raise DeprecationWarning("ray.global_state.workers() is deprecated.")
def cluster_resources(self):
logger.warning(
"ray.global_state.cluster_resources() is deprecated and will be "
"removed in a subsequent release. Use ray.cluster_resources() "
"instead.")
return ray.cluster_resources()
def available_resources(self):
logger.warning(
"ray.global_state.available_resources() is deprecated and will be "
"removed in a subsequent release. Use ray.available_resources() "
"instead.")
return ray.available_resources()
def error_messages(self, driver_id=None):
logger.warning(
"ray.global_state.error_messages() is deprecated and will be "
"removed in a subsequent release. Use ray.errors() "
"instead.")
return ray.errors(driver_id=driver_id)
state = GlobalState()
"""A global object used to access the cluster's global state."""
global_state = DeprecatedGlobalState()
def nodes():
"""Get a list of the nodes in the cluster.
Returns:
Information about the Ray clients in the cluster.
"""
return state.client_table()
def tasks(task_id=None):
"""Fetch and parse the task table information for one or more task IDs.
Args:
task_id: A hex string of the task ID to fetch information about. If
this is None, then the task object table is fetched.
Returns:
Information from the task table.
"""
return state.task_table(task_id=task_id)
def objects(object_id=None):
"""Fetch and parse the object table info for one or more object IDs.
Args:
object_id: An object ID to fetch information about. If this is None,
then the entire object table is fetched.
Returns:
Information from the object table.
"""
return state.object_table(object_id=object_id)
def timeline(filename=None):
"""Return a list of profiling events that can viewed as a timeline.
To view this information as a timeline, simply dump it as a json file by
passing in "filename" or using using json.dump, and then load go to
chrome://tracing in the Chrome web browser and load the dumped file.
Args:
filename: If a filename is provided, the timeline is dumped to that
file.
Returns:
If filename is not provided, this returns a list of profiling events.
Each profile event is a dictionary.
"""
return state.chrome_tracing_dump(filename=filename)
def object_transfer_timeline(filename=None):
"""Return a list of transfer events that can viewed as a timeline.
To view this information as a timeline, simply dump it as a json file by
passing in "filename" or using using json.dump, and then load go to
chrome://tracing in the Chrome web browser and load the dumped file. Make
sure to enable "Flow events" in the "View Options" menu.
Args:
filename: If a filename is provided, the timeline is dumped to that
file.
Returns:
If filename is not provided, this returns a list of profiling events.
Each profile event is a dictionary.
"""
return state.chrome_tracing_object_transfer_dump(filename=filename)
def cluster_resources():
"""Get the current total cluster resources.
Note that this information can grow stale as nodes are added to or removed
from the cluster.
Returns:
A dictionary mapping resource name to the total quantity of that
resource in the cluster.
"""
return state.cluster_resources()
def available_resources():
"""Get the current available cluster resources.
This is different from `cluster_resources` in that this will return idle
(available) resources rather than total resources.
Note that this information can grow stale as tasks start and finish.
Returns:
A dictionary mapping resource name to the total quantity of that
resource in the cluster.
"""
return state.available_resources()
def errors(include_cluster_errors=True):
"""Get error messages from the cluster.
Args:
include_cluster_errors: True if we should include error messages for
all drivers, and false if we should only include error messages for
this specific driver.
Returns:
Error messages pushed from the cluster.
"""
worker = ray.worker.global_worker
error_messages = state.error_messages(driver_id=worker.task_driver_id)
if include_cluster_errors:
error_messages += state.error_messages(driver_id=ray.DriverID.nil())
return error_messages
+2 -2
View File
@@ -141,7 +141,7 @@ class Cluster(object):
start_time = time.time()
while time.time() - start_time < timeout:
clients = ray.experimental.state.parse_client_table(redis_client)
clients = ray.state._parse_client_table(redis_client)
object_store_socket_names = [
client["ObjectStoreSocketName"] for client in clients
]
@@ -174,7 +174,7 @@ class Cluster(object):
start_time = time.time()
while time.time() - start_time < timeout:
clients = ray.experimental.state.parse_client_table(redis_client)
clients = ray.state._parse_client_table(redis_client)
live_clients = [
client for client in clients
if client["EntryType"] == EntryType.INSERTION
+2 -2
View File
@@ -2439,7 +2439,7 @@ def test_checkpointing_save_exception(ray_start_regular,
assert ray.get(actor.was_resumed_from_checkpoint.remote()) is False
# Check that checkpointing errors were pushed to the driver.
errors = ray.error_info()
errors = ray.errors()
assert len(errors) > 0
for error in errors:
# An error for the actor process dying may also get pushed.
@@ -2483,7 +2483,7 @@ def test_checkpointing_load_exception(ray_start_regular,
assert ray.get(actor.was_resumed_from_checkpoint.remote()) is False
# Check that checkpointing errors were pushed to the driver.
errors = ray.error_info()
errors = ray.errors()
assert len(errors) > 0
for error in errors:
# An error for the actor process dying may also get pushed.
+28 -45
View File
@@ -935,7 +935,7 @@ def test_many_fractional_resources(shutdown_only):
stop_time = time.time() + 10
correct_available_resources = False
while time.time() < stop_time:
if ray.global_state.available_resources() == {
if ray.available_resources() == {
"CPU": 2.0,
"GPU": 2.0,
"Custom": 2.0,
@@ -1176,7 +1176,7 @@ def test_profiling_api(ray_start_2_cpus):
if time.time() - start_time > timeout_seconds:
raise Exception("Timed out while waiting for information in "
"profile table.")
profile_data = ray.global_state.chrome_tracing_dump()
profile_data = ray.timeline()
event_types = {event["cat"] for event in profile_data}
expected_types = [
"worker_idle",
@@ -1252,7 +1252,7 @@ def test_object_transfer_dump(ray_start_cluster):
# The profiling information only flushes once every second.
time.sleep(1.1)
transfer_dump = ray.global_state.chrome_tracing_object_transfer_dump()
transfer_dump = ray.object_transfer_timeline()
# Make sure the transfer dump can be serialized with JSON.
json.loads(json.dumps(transfer_dump))
assert len(transfer_dump) >= num_nodes**2
@@ -1559,12 +1559,12 @@ def test_free_objects_multi_node(ray_start_cluster):
# Case3: These cases test the deleting creating tasks for the object.
(a, b, c) = run_one_test(actors, False, False)
task_table = ray.global_state.task_table()
task_table = ray.tasks()
for obj in [a, b, c]:
assert ray._raylet.compute_task_id(obj).hex() in task_table
(a, b, c) = run_one_test(actors, False, True)
task_table = ray.global_state.task_table()
task_table = ray.tasks()
for obj in [a, b, c]:
assert ray._raylet.compute_task_id(obj).hex() not in task_table
@@ -2026,7 +2026,7 @@ def test_multiple_raylets(ray_start_cluster):
results.append(run_on_0_2.remote())
return names, results
client_table = ray.global_state.client_table()
client_table = ray.nodes()
store_names = []
store_names += [
client["ObjectStoreSocketName"] for client in client_table
@@ -2214,13 +2214,13 @@ def test_zero_capacity_deletion_semantics(shutdown_only):
ray.init(num_cpus=2, num_gpus=1, resources={"test_resource": 1})
def test():
resources = ray.global_state.available_resources()
resources = ray.available_resources()
MAX_RETRY_ATTEMPTS = 5
retry_count = 0
while resources and retry_count < MAX_RETRY_ATTEMPTS:
time.sleep(0.1)
resources = ray.global_state.available_resources()
resources = ray.available_resources()
retry_count += 1
if retry_count >= MAX_RETRY_ATTEMPTS:
@@ -2394,7 +2394,7 @@ def test_load_balancing_with_dependencies(ray_start_cluster):
def wait_for_num_tasks(num_tasks, timeout=10):
start_time = time.time()
while time.time() - start_time < timeout:
if len(ray.global_state.task_table()) >= num_tasks:
if len(ray.tasks()) >= num_tasks:
return
time.sleep(0.1)
raise Exception("Timed out while waiting for global state.")
@@ -2403,7 +2403,7 @@ def wait_for_num_tasks(num_tasks, timeout=10):
def wait_for_num_objects(num_objects, timeout=10):
start_time = time.time()
while time.time() - start_time < timeout:
if len(ray.global_state.object_table()) >= num_objects:
if len(ray.objects()) >= num_objects:
return
time.sleep(0.1)
raise Exception("Timed out while waiting for global state.")
@@ -2414,31 +2414,27 @@ def wait_for_num_objects(num_objects, timeout=10):
reason="New GCS API doesn't have a Python API yet.")
def test_global_state_api(shutdown_only):
with pytest.raises(Exception):
ray.global_state.object_table()
ray.objects()
with pytest.raises(Exception):
ray.global_state.task_table()
ray.tasks()
with pytest.raises(Exception):
ray.global_state.client_table()
with pytest.raises(Exception):
ray.global_state.function_table()
ray.nodes()
ray.init(num_cpus=5, num_gpus=3, resources={"CustomResource": 1})
resources = {"CPU": 5, "GPU": 3, "CustomResource": 1}
assert ray.global_state.cluster_resources() == resources
assert ray.cluster_resources() == resources
assert ray.global_state.object_table() == {}
assert ray.objects() == {}
driver_id = ray.experimental.state.binary_to_hex(
ray.worker.global_worker.worker_id)
driver_id = ray.utils.binary_to_hex(ray.worker.global_worker.worker_id)
driver_task_id = ray.worker.global_worker.current_task_id.hex()
# One task is put in the task table which corresponds to this driver.
wait_for_num_tasks(1)
task_table = ray.global_state.task_table()
task_table = ray.tasks()
assert len(task_table) == 1
assert driver_task_id == list(task_table.keys())[0]
task_spec = task_table[driver_task_id]["TaskSpec"]
@@ -2451,7 +2447,7 @@ def test_global_state_api(shutdown_only):
assert task_spec["FunctionID"] == nil_id_hex
assert task_spec["ReturnObjectIDs"] == []
client_table = ray.global_state.client_table()
client_table = ray.nodes()
node_ip_address = ray.worker.global_worker.node_ip_address
assert len(client_table) == 1
@@ -2466,24 +2462,19 @@ def test_global_state_api(shutdown_only):
# Wait for one additional task to complete.
wait_for_num_tasks(1 + 1)
task_table = ray.global_state.task_table()
task_table = ray.tasks()
assert len(task_table) == 1 + 1
task_id_set = set(task_table.keys())
task_id_set.remove(driver_task_id)
task_id = list(task_id_set)[0]
function_table = ray.global_state.function_table()
task_spec = task_table[task_id]["TaskSpec"]
assert task_spec["ActorID"] == nil_id_hex
assert task_spec["Args"] == [1, "hi", x_id]
assert task_spec["DriverID"] == driver_id
assert task_spec["ReturnObjectIDs"] == [result_id]
function_table_entry = function_table[task_spec["FunctionID"]]
assert function_table_entry["Name"] == "ray.tests.test_basic.f"
assert function_table_entry["DriverID"] == driver_id
assert function_table_entry["Module"] == "ray.tests.test_basic"
assert task_table[task_id] == ray.global_state.task_table(task_id)
assert task_table[task_id] == ray.tasks(task_id)
# Wait for two objects, one for the x_id and one for result_id.
wait_for_num_objects(2)
@@ -2492,7 +2483,7 @@ def test_global_state_api(shutdown_only):
timeout = 10
start_time = time.time()
while time.time() - start_time < timeout:
object_table = ray.global_state.object_table()
object_table = ray.objects()
tables_ready = (object_table[x_id]["ManagerIDs"] is not None and
object_table[result_id]["ManagerIDs"] is not None)
if tables_ready:
@@ -2501,11 +2492,11 @@ def test_global_state_api(shutdown_only):
raise Exception("Timed out while waiting for object table to "
"update.")
object_table = ray.global_state.object_table()
object_table = ray.objects()
assert len(object_table) == 2
assert object_table[x_id] == ray.global_state.object_table(x_id)
object_table_entry = ray.global_state.object_table(result_id)
assert object_table[x_id] == ray.objects(x_id)
object_table_entry = ray.objects(result_id)
assert object_table[result_id] == object_table_entry
@@ -2611,14 +2602,6 @@ def test_workers(shutdown_only):
while len(worker_ids) != num_workers:
worker_ids = set(ray.get([f.remote() for _ in range(10)]))
worker_info = ray.global_state.workers()
assert len(worker_info) >= num_workers
for worker_id, info in worker_info.items():
assert "node_ip_address" in info
assert "plasma_store_socket" in info
assert "stderr_file" in info
assert "stdout_file" in info
def test_specific_driver_id():
dummy_driver_id = ray.DriverID(b"00112233445566778899")
@@ -2816,7 +2799,7 @@ def test_socket_dir_not_existing(shutdown_only):
def test_raylet_is_robust_to_random_messages(ray_start_regular):
node_manager_address = None
node_manager_port = None
for client in ray.global_state.client_table():
for client in ray.nodes():
if "NodeManagerAddress" in client:
node_manager_address = client["NodeManagerAddress"]
node_manager_port = client["NodeManagerPort"]
@@ -2908,7 +2891,7 @@ def test_shutdown_disconnect_global_state():
ray.shutdown()
with pytest.raises(Exception) as e:
ray.global_state.object_table()
ray.objects()
assert str(e.value).endswith("ray.init has been called.")
@@ -2922,8 +2905,8 @@ def test_redis_lru_with_set(ray_start_object_store_memory):
removed = False
start_time = time.time()
while time.time() < start_time + 10:
if ray.global_state.redis_clients[0].delete(b"OBJECT" +
x_id.binary()) == 1:
if ray.state.state.redis_clients[0].delete(b"OBJECT" +
x_id.binary()) == 1:
removed = True
break
assert removed
+29 -44
View File
@@ -23,8 +23,8 @@ def test_dynamic_res_creation(ray_start_regular):
ray.get(set_res.remote(res_name, res_capacity))
available_res = ray.global_state.available_resources()
cluster_res = ray.global_state.cluster_resources()
available_res = ray.available_resources()
cluster_res = ray.cluster_resources()
assert available_res[res_name] == res_capacity
assert cluster_res[res_name] == res_capacity
@@ -43,8 +43,8 @@ def test_dynamic_res_deletion(shutdown_only):
ray.get(delete_res.remote(res_name))
available_res = ray.global_state.available_resources()
cluster_res = ray.global_state.cluster_resources()
available_res = ray.available_resources()
cluster_res = ray.cluster_resources()
assert res_name not in available_res
assert res_name not in cluster_res
@@ -69,7 +69,7 @@ def test_dynamic_res_infeasible_rescheduling(ray_start_regular):
oid = remote_task.remote() # This is infeasible
ray.get(set_res.remote(res_name, res_capacity)) # Now should be feasible
available_res = ray.global_state.available_resources()
available_res = ray.available_resources()
assert available_res[res_name] == res_capacity
successful, unsuccessful = ray.wait([oid], timeout=1)
@@ -88,7 +88,7 @@ def test_dynamic_res_updation_clientid(ray_start_cluster):
ray.init(redis_address=cluster.redis_address)
target_clientid = ray.global_state.client_table()[1]["ClientID"]
target_clientid = ray.nodes()[1]["ClientID"]
@ray.remote
def set_res(resource_name, resource_capacity, client_id):
@@ -102,7 +102,7 @@ def test_dynamic_res_updation_clientid(ray_start_cluster):
new_capacity = res_capacity + 1
ray.get(set_res.remote(res_name, new_capacity, target_clientid))
target_client = next(client for client in ray.global_state.client_table()
target_client = next(client for client in ray.nodes()
if client["ClientID"] == target_clientid)
resources = target_client["Resources"]
@@ -122,7 +122,7 @@ def test_dynamic_res_creation_clientid(ray_start_cluster):
ray.init(redis_address=cluster.redis_address)
target_clientid = ray.global_state.client_table()[1]["ClientID"]
target_clientid = ray.nodes()[1]["ClientID"]
@ray.remote
def set_res(resource_name, resource_capacity, res_client_id):
@@ -130,7 +130,7 @@ def test_dynamic_res_creation_clientid(ray_start_cluster):
resource_name, resource_capacity, client_id=res_client_id)
ray.get(set_res.remote(res_name, res_capacity, target_clientid))
target_client = next(client for client in ray.global_state.client_table()
target_client = next(client for client in ray.nodes()
if client["ClientID"] == target_clientid)
resources = target_client["Resources"]
@@ -152,9 +152,7 @@ def test_dynamic_res_creation_clientid_multiple(ray_start_cluster):
ray.init(redis_address=cluster.redis_address)
target_clientids = [
client["ClientID"] for client in ray.global_state.client_table()
]
target_clientids = [client["ClientID"] for client in ray.nodes()]
@ray.remote
def set_res(resource_name, resource_capacity, res_client_id):
@@ -172,9 +170,8 @@ def test_dynamic_res_creation_clientid_multiple(ray_start_cluster):
while time.time() - start_time < TIMEOUT and not success:
resources_created = []
for cid in target_clientids:
target_client = next(client
for client in ray.global_state.client_table()
if client["ClientID"] == cid)
target_client = next(
client for client in ray.nodes() if client["ClientID"] == cid)
resources = target_client["Resources"]
resources_created.append(resources[res_name] == res_capacity)
success = all(resources_created)
@@ -196,7 +193,7 @@ def test_dynamic_res_deletion_clientid(ray_start_cluster):
ray.init(redis_address=cluster.redis_address)
target_clientid = ray.global_state.client_table()[1]["ClientID"]
target_clientid = ray.nodes()[1]["ClientID"]
# Launch the delete task
@ray.remote
@@ -206,10 +203,10 @@ def test_dynamic_res_deletion_clientid(ray_start_cluster):
ray.get(delete_res.remote(res_name, target_clientid))
target_client = next(client for client in ray.global_state.client_table()
target_client = next(client for client in ray.nodes()
if client["ClientID"] == target_clientid)
resources = target_client["Resources"]
print(ray.global_state.cluster_resources())
print(ray.cluster_resources())
assert res_name not in resources
@@ -228,9 +225,7 @@ def test_dynamic_res_creation_scheduler_consistency(ray_start_cluster):
ray.init(redis_address=cluster.redis_address)
clientids = [
client["ClientID"] for client in ray.global_state.client_table()
]
clientids = [client["ClientID"] for client in ray.nodes()]
@ray.remote
def set_res(resource_name, resource_capacity, res_client_id):
@@ -267,9 +262,7 @@ def test_dynamic_res_deletion_scheduler_consistency(ray_start_cluster):
ray.init(redis_address=cluster.redis_address)
clientids = [
client["ClientID"] for client in ray.global_state.client_table()
]
clientids = [client["ClientID"] for client in ray.nodes()]
@ray.remote
def delete_res(resource_name, res_client_id):
@@ -284,7 +277,7 @@ def test_dynamic_res_deletion_scheduler_consistency(ray_start_cluster):
# Create the resource on node1
target_clientid = clientids[1]
ray.get(set_res.remote(res_name, res_capacity, target_clientid))
assert ray.global_state.cluster_resources()[res_name] == res_capacity
assert ray.cluster_resources()[res_name] == res_capacity
# Delete the resource
ray.get(delete_res.remote(res_name, target_clientid))
@@ -322,9 +315,7 @@ def test_dynamic_res_concurrent_res_increment(ray_start_cluster):
ray.init(redis_address=cluster.redis_address)
clientids = [
client["ClientID"] for client in ray.global_state.client_table()
]
clientids = [client["ClientID"] for client in ray.nodes()]
target_clientid = clientids[1]
@ray.remote
@@ -334,7 +325,7 @@ def test_dynamic_res_concurrent_res_increment(ray_start_cluster):
# Create the resource on node 1
ray.get(set_res.remote(res_name, res_capacity, target_clientid))
assert ray.global_state.cluster_resources()[res_name] == res_capacity
assert ray.cluster_resources()[res_name] == res_capacity
# Task to hold the resource till the driver signals to finish
@ray.remote
@@ -376,7 +367,7 @@ def test_dynamic_res_concurrent_res_increment(ray_start_cluster):
}) # This should be infeasible
successful, unsuccessful = ray.wait([task_3], timeout=TIMEOUT_DURATION)
assert unsuccessful # The task did not complete because it's infeasible
assert ray.global_state.available_resources()[res_name] == updated_capacity
assert ray.available_resources()[res_name] == updated_capacity
def test_dynamic_res_concurrent_res_decrement(ray_start_cluster):
@@ -403,9 +394,7 @@ def test_dynamic_res_concurrent_res_decrement(ray_start_cluster):
ray.init(redis_address=cluster.redis_address)
clientids = [
client["ClientID"] for client in ray.global_state.client_table()
]
clientids = [client["ClientID"] for client in ray.nodes()]
target_clientid = clientids[1]
@ray.remote
@@ -415,7 +404,7 @@ def test_dynamic_res_concurrent_res_decrement(ray_start_cluster):
# Create the resource on node 1
ray.get(set_res.remote(res_name, res_capacity, target_clientid))
assert ray.global_state.cluster_resources()[res_name] == res_capacity
assert ray.cluster_resources()[res_name] == res_capacity
# Task to hold the resource till the driver signals to finish
@ray.remote
@@ -457,7 +446,7 @@ def test_dynamic_res_concurrent_res_decrement(ray_start_cluster):
}) # This should be infeasible
successful, unsuccessful = ray.wait([task_3], timeout=TIMEOUT_DURATION)
assert unsuccessful # The task did not complete because it's infeasible
assert ray.global_state.available_resources()[res_name] == updated_capacity
assert ray.available_resources()[res_name] == updated_capacity
def test_dynamic_res_concurrent_res_delete(ray_start_cluster):
@@ -482,9 +471,7 @@ def test_dynamic_res_concurrent_res_delete(ray_start_cluster):
ray.init(redis_address=cluster.redis_address)
clientids = [
client["ClientID"] for client in ray.global_state.client_table()
]
clientids = [client["ClientID"] for client in ray.nodes()]
target_clientid = clientids[1]
@ray.remote
@@ -499,7 +486,7 @@ def test_dynamic_res_concurrent_res_delete(ray_start_cluster):
# Create the resource on node 1
ray.get(set_res.remote(res_name, res_capacity, target_clientid))
assert ray.global_state.cluster_resources()[res_name] == res_capacity
assert ray.cluster_resources()[res_name] == res_capacity
# Task to hold the resource till the driver signals to finish
@ray.remote
@@ -534,7 +521,7 @@ def test_dynamic_res_concurrent_res_delete(ray_start_cluster):
args=[], resources={res_name: 1}) # This should be infeasible
successful, unsuccessful = ray.wait([task_2], timeout=TIMEOUT_DURATION)
assert unsuccessful # The task did not complete because it's infeasible
assert res_name not in ray.global_state.available_resources()
assert res_name not in ray.available_resources()
def test_dynamic_res_creation_stress(ray_start_cluster):
@@ -553,9 +540,7 @@ def test_dynamic_res_creation_stress(ray_start_cluster):
ray.init(redis_address=cluster.redis_address)
clientids = [
client["ClientID"] for client in ray.global_state.client_table()
]
clientids = [client["ClientID"] for client in ray.nodes()]
target_clientid = clientids[1]
@ray.remote
@@ -578,7 +563,7 @@ def test_dynamic_res_creation_stress(ray_start_cluster):
start_time = time.time()
while time.time() - start_time < TIMEOUT and not success:
resources = ray.global_state.cluster_resources()
resources = ray.cluster_resources()
all_resources_created = []
for i in range(0, NUM_RES_TO_CREATE):
all_resources_created.append(str(i) in resources)
+5 -4
View File
@@ -164,7 +164,7 @@ def temporary_helper_function():
return 1
# There should be no errors yet.
assert len(ray.error_info()) == 0
assert len(ray.errors()) == 0
# Create an actor.
foo = Foo.remote()
@@ -376,8 +376,9 @@ def test_actor_scope_or_intentionally_killed_message(ray_start_regular):
a = Actor.remote()
a.__ray_terminate__.remote()
time.sleep(1)
assert len(ray.error_info()) == 0, (
"Should not have propogated an error - {}".format(ray.error_info()))
assert len(
ray.errors()) == 0, ("Should not have propogated an error - {}".format(
ray.errors()))
@pytest.mark.skip("This test does not work yet.")
@@ -653,7 +654,7 @@ def test_warning_for_dead_node(ray_start_cluster_2_nodes):
cluster = ray_start_cluster_2_nodes
cluster.wait_for_nodes()
client_ids = {item["ClientID"] for item in ray.global_state.client_table()}
client_ids = {item["ClientID"] for item in ray.nodes()}
# Try to make sure that the monitor has received at least one heartbeat
# from the node.
+9 -9
View File
@@ -18,8 +18,8 @@ import ray
reason="Timeout package not installed; skipping test that may hang.")
@pytest.mark.timeout(10)
def test_replenish_resources(ray_start_regular):
cluster_resources = ray.global_state.cluster_resources()
available_resources = ray.global_state.available_resources()
cluster_resources = ray.cluster_resources()
available_resources = ray.available_resources()
assert cluster_resources == available_resources
@ray.remote
@@ -30,7 +30,7 @@ def test_replenish_resources(ray_start_regular):
resources_reset = False
while not resources_reset:
available_resources = ray.global_state.available_resources()
available_resources = ray.available_resources()
resources_reset = (cluster_resources == available_resources)
assert resources_reset
@@ -40,7 +40,7 @@ def test_replenish_resources(ray_start_regular):
reason="Timeout package not installed; skipping test that may hang.")
@pytest.mark.timeout(10)
def test_uses_resources(ray_start_regular):
cluster_resources = ray.global_state.cluster_resources()
cluster_resources = ray.cluster_resources()
@ray.remote
def cpu_task():
@@ -50,7 +50,7 @@ def test_uses_resources(ray_start_regular):
resource_used = False
while not resource_used:
available_resources = ray.global_state.available_resources()
available_resources = ray.available_resources()
resource_used = available_resources.get(
"CPU", 0) == cluster_resources.get("CPU", 0) - 1
@@ -64,17 +64,17 @@ def test_uses_resources(ray_start_regular):
def test_add_remove_cluster_resources(ray_start_cluster_head):
"""Tests that Global State API is consistent with actual cluster."""
cluster = ray_start_cluster_head
assert ray.global_state.cluster_resources()["CPU"] == 1
assert ray.cluster_resources()["CPU"] == 1
nodes = []
nodes += [cluster.add_node(num_cpus=1)]
cluster.wait_for_nodes()
assert ray.global_state.cluster_resources()["CPU"] == 2
assert ray.cluster_resources()["CPU"] == 2
cluster.remove_node(nodes.pop())
cluster.wait_for_nodes()
assert ray.global_state.cluster_resources()["CPU"] == 1
assert ray.cluster_resources()["CPU"] == 1
for i in range(5):
nodes += [cluster.add_node(num_cpus=1)]
cluster.wait_for_nodes()
assert ray.global_state.cluster_resources()["CPU"] == 6
assert ray.cluster_resources()["CPU"] == 6
+8 -9
View File
@@ -30,17 +30,16 @@ def _test_cleanup_on_driver_exit(num_redis_shards):
time.sleep(2)
def StateSummary():
obj_tbl_len = len(ray.global_state.object_table())
task_tbl_len = len(ray.global_state.task_table())
func_tbl_len = len(ray.global_state.function_table())
return obj_tbl_len, task_tbl_len, func_tbl_len
obj_tbl_len = len(ray.objects())
task_tbl_len = len(ray.tasks())
return obj_tbl_len, task_tbl_len
def Driver(success):
success.value = True
# Start driver.
ray.init(redis_address=redis_address)
summary_start = StateSummary()
if (0, 1) != summary_start[:2]:
if (0, 1) != summary_start:
success.value = False
# Two new objects.
@@ -54,7 +53,7 @@ def _test_cleanup_on_driver_exit(num_redis_shards):
# 1 new function.
attempts = 0
while (2, 1, summary_start[2]) != StateSummary():
while (2, 1) != StateSummary():
time.sleep(0.1)
attempts += 1
if attempts == max_attempts_before_failing:
@@ -63,7 +62,7 @@ def _test_cleanup_on_driver_exit(num_redis_shards):
ray.get(f.remote())
attempts = 0
while (4, 2, summary_start[2] + 1) != StateSummary():
while (4, 2) != StateSummary():
time.sleep(0.1)
attempts += 1
if attempts == max_attempts_before_failing:
@@ -83,12 +82,12 @@ def _test_cleanup_on_driver_exit(num_redis_shards):
# Check that objects, tasks, and functions are cleaned up.
ray.init(redis_address=redis_address)
attempts = 0
while (0, 1) != StateSummary()[:2]:
while (0, 1) != StateSummary():
time.sleep(0.1)
attempts += 1
if attempts == max_attempts_before_failing:
break
assert (0, 1) == StateSummary()[:2]
assert (0, 1) == StateSummary()
ray.shutdown()
subprocess.Popen(["ray", "stop"]).wait()
+11 -11
View File
@@ -19,7 +19,7 @@ def test_error_isolation(call_ray_start):
ray.init(redis_address=redis_address)
# There shouldn't be any errors yet.
assert len(ray.error_info()) == 0
assert len(ray.errors()) == 0
error_string1 = "error_string1"
error_string2 = "error_string2"
@@ -33,13 +33,13 @@ def test_error_isolation(call_ray_start):
ray.get(f.remote())
# Wait for the error to appear in Redis.
while len(ray.error_info()) != 1:
while len(ray.errors()) != 1:
time.sleep(0.1)
print("Waiting for error to appear.")
# Make sure we got the error.
assert len(ray.error_info()) == 1
assert error_string1 in ray.error_info()[0]["message"]
assert len(ray.errors()) == 1
assert error_string1 in ray.errors()[0]["message"]
# Start another driver and make sure that it does not receive this
# error. Make the other driver throw an error, and make sure it
@@ -51,7 +51,7 @@ import time
ray.init(redis_address="{}")
time.sleep(1)
assert len(ray.error_info()) == 0
assert len(ray.errors()) == 0
@ray.remote
def f():
@@ -62,12 +62,12 @@ try:
except Exception as e:
pass
while len(ray.error_info()) != 1:
print(len(ray.error_info()))
while len(ray.errors()) != 1:
print(len(ray.errors()))
time.sleep(0.1)
assert len(ray.error_info()) == 1
assert len(ray.errors()) == 1
assert "{}" in ray.error_info()[0]["message"]
assert "{}" in ray.errors()[0]["message"]
print("success")
""".format(redis_address, error_string2, error_string2)
@@ -78,8 +78,8 @@ print("success")
# Make sure that the other error message doesn't show up for this
# driver.
assert len(ray.error_info()) == 1
assert error_string1 in ray.error_info()[0]["message"]
assert len(ray.errors()) == 1
assert error_string1 in ray.errors()[0]["message"]
def test_remote_function_isolation(call_ray_start):
+4 -4
View File
@@ -52,10 +52,10 @@ def test_internal_config(ray_start_cluster_head):
cluster.remove_node(worker)
time.sleep(1)
assert ray.global_state.cluster_resources()["CPU"] == 2
assert ray.cluster_resources()["CPU"] == 2
time.sleep(2)
assert ray.global_state.cluster_resources()["CPU"] == 1
assert ray.cluster_resources()["CPU"] == 1
def test_wait_for_nodes(ray_start_cluster_head):
@@ -70,12 +70,12 @@ def test_wait_for_nodes(ray_start_cluster_head):
[cluster.remove_node(w) for w in workers]
cluster.wait_for_nodes()
assert ray.global_state.cluster_resources()["CPU"] == 1
assert ray.cluster_resources()["CPU"] == 1
worker2 = cluster.add_node()
cluster.wait_for_nodes()
cluster.remove_node(worker2)
cluster.wait_for_nodes()
assert ray.global_state.cluster_resources()["CPU"] == 1
assert ray.cluster_resources()["CPU"] == 1
def test_worker_plasma_store_failure(ray_start_cluster_head):
+2 -2
View File
@@ -80,7 +80,7 @@ def test_object_broadcast(ray_start_cluster_with_resource):
# Wait for profiling information to be pushed to the profile table.
time.sleep(1)
transfer_events = ray.global_state.chrome_tracing_object_transfer_dump()
transfer_events = ray.object_transfer_timeline()
# Make sure that each object was transferred a reasonable number of times.
for x_id in object_ids:
@@ -160,7 +160,7 @@ def test_actor_broadcast(ray_start_cluster_with_resource):
# Wait for profiling information to be pushed to the profile table.
time.sleep(1)
transfer_events = ray.global_state.chrome_tracing_object_transfer_dump()
transfer_events = ray.object_transfer_timeline()
# Make sure that each object was transferred a reasonable number of times.
for x_id in object_ids:
+1 -1
View File
@@ -393,7 +393,7 @@ def wait_for_errors(error_check):
errors = []
time_left = 100
while time_left > 0:
errors = ray.error_info()
errors = ray.errors()
if error_check(errors):
break
time_left -= 1
+1 -1
View File
@@ -84,7 +84,7 @@ def run_string_as_driver_nonblocking(driver_script):
def relevant_errors(error_type):
return [info for info in ray.error_info() if info["type"] == error_type]
return [info for info in ray.errors() if info["type"] == error_type]
def wait_for_errors(error_type, num_errors, timeout=10):
+1 -1
View File
@@ -356,7 +356,7 @@ class RayTrialExecutor(TrialExecutor):
def _update_avail_resources(self, num_retries=5):
for i in range(num_retries):
try:
resources = ray.global_state.cluster_resources()
resources = ray.cluster_resources()
except Exception:
# TODO(rliaw): Remove this when local mode is fixed.
# https://github.com/ray-project/ray/issues/4147
+5 -5
View File
@@ -71,7 +71,7 @@ def test_counting_resources(start_connected_cluster):
"""Tests that Tune accounting is consistent with actual cluster."""
cluster = start_connected_cluster
nodes = []
assert ray.global_state.cluster_resources()["CPU"] == 1
assert ray.cluster_resources()["CPU"] == 1
runner = TrialRunner(BasicVariantGenerator())
kwargs = {"stopping_criterion": {"training_iteration": 10}}
@@ -82,17 +82,17 @@ def test_counting_resources(start_connected_cluster):
runner.step() # run 1
nodes += [cluster.add_node(num_cpus=1)]
cluster.wait_for_nodes()
assert ray.global_state.cluster_resources()["CPU"] == 2
assert ray.cluster_resources()["CPU"] == 2
cluster.remove_node(nodes.pop())
cluster.wait_for_nodes()
assert ray.global_state.cluster_resources()["CPU"] == 1
assert ray.cluster_resources()["CPU"] == 1
runner.step() # run 2
assert sum(t.status == Trial.RUNNING for t in runner.get_trials()) == 1
for i in range(5):
nodes += [cluster.add_node(num_cpus=1)]
cluster.wait_for_nodes()
assert ray.global_state.cluster_resources()["CPU"] == 6
assert ray.cluster_resources()["CPU"] == 6
runner.step() # 1 result
assert sum(t.status == Trial.RUNNING for t in runner.get_trials()) == 2
@@ -120,7 +120,7 @@ def test_remove_node_before_result(start_connected_emptyhead_cluster):
cluster.remove_node(node)
cluster.add_node(num_cpus=1)
cluster.wait_for_nodes()
assert ray.global_state.cluster_resources()["CPU"] == 1
assert ray.cluster_resources()["CPU"] == 1
for i in range(3):
runner.step()
+1 -1
View File
@@ -1532,7 +1532,7 @@ class TrialRunnerTest(unittest.TestCase):
runner.add_trial(Trial("__fake", **kwargs))
trials = runner.get_trials()
with patch("ray.global_state.cluster_resources") as resource_mock:
with patch("ray.cluster_resources") as resource_mock:
resource_mock.return_value = {"CPU": 1, "GPU": 1}
runner.step()
self.assertEqual(trials[0].status, Trial.RUNNING)
+10 -19
View File
@@ -25,7 +25,6 @@ import pyarrow.plasma as plasma
import ray.cloudpickle as pickle
import ray.experimental.signal as ray_signal
import ray.experimental.no_return
import ray.experimental.state as state
import ray.gcs_utils
import ray.memory_monitor as memory_monitor
import ray.node
@@ -35,6 +34,7 @@ import ray.remote_function
import ray.serialization as serialization
import ray.services as services
import ray.signature
import ray.state
from ray import (
ActorHandleID,
@@ -1108,8 +1108,6 @@ We use a global Worker object to ensure that there is a single worker object
per worker process.
"""
global_state = state.GlobalState()
_global_node = None
"""ray.node.Node: The global node object that is created by ray.init()."""
@@ -1134,14 +1132,6 @@ def print_failed_task(task_status):
task_status["error_message"]))
def error_info():
"""Return information about failed tasks."""
worker = global_worker
worker.check_connected()
return (global_state.error_messages(driver_id=worker.task_driver_id) +
global_state.error_messages(driver_id=DriverID.nil()))
def _initialize_serialization(driver_id, worker=global_worker):
"""Initialize the serialization library.
@@ -1488,7 +1478,7 @@ def shutdown(exiting_interpreter=False):
disconnect()
# Disconnect global state from GCS.
global_state.disconnect()
ray.state.state.disconnect()
# Shut down the Ray processes.
global _global_node
@@ -1644,7 +1634,7 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped):
try:
# Get the exports that occurred before the call to subscribe.
error_messages = global_state.error_messages(worker.task_driver_id)
error_messages = ray.errors(include_cluster_errors=False)
for error_message in error_messages:
logger.error(error_message)
@@ -1774,7 +1764,7 @@ def connect(node,
worker.lock = threading.RLock()
# Create an object for interfacing with the global state.
global_state._initialize_global_state(
ray.state.state._initialize_global_state(
node.redis_address, redis_password=node.redis_password)
# Register the worker with Redis.
@@ -1881,11 +1871,12 @@ def connect(node,
)
# Add the driver task to the task table.
global_state._execute_command(driver_task.task_id(), "RAY.TABLE_ADD",
ray.gcs_utils.TablePrefix.RAYLET_TASK,
ray.gcs_utils.TablePubsub.RAYLET_TASK,
driver_task.task_id().binary(),
driver_task._serialized_raylet_task())
ray.state.state._execute_command(driver_task.task_id(),
"RAY.TABLE_ADD",
ray.gcs_utils.TablePrefix.RAYLET_TASK,
ray.gcs_utils.TablePubsub.RAYLET_TASK,
driver_task.task_id().binary(),
driver_task._serialized_raylet_task())
# Set the driver's current task ID to the task ID assigned to the
# driver task.