Merge branch 'master' into py39

This commit is contained in:
Akash Patel
2020-12-24 13:13:30 -05:00
committed by GitHub
407 changed files with 14044 additions and 12568 deletions
+49
View File
@@ -0,0 +1,49 @@
import os
from contextlib import contextmanager
from functools import wraps
client_mode_enabled = os.environ.get("RAY_CLIENT_MODE", "0") == "1"
_client_hook_enabled = True
def _enable_client_hook(val: bool):
global _client_hook_enabled
_client_hook_enabled = val
def _disable_client_hook():
global _client_hook_enabled
out = _client_hook_enabled
_client_hook_enabled = False
return out
def _explicitly_enable_client_mode():
global client_mode_enabled
client_mode_enabled = True
@contextmanager
def disable_client_hook():
val = _disable_client_hook()
try:
yield None
finally:
_enable_client_hook(val)
def client_mode_hook(func):
"""
Decorator for ray module methods to delegate to ray client
"""
from ray.experimental.client import ray
@wraps(func)
def wrapper(*args, **kwargs):
global _client_hook_enabled
if client_mode_enabled and _client_hook_enabled:
return getattr(ray, func.__name__)(*args, **kwargs)
return func(*args, **kwargs)
return wrapper
@@ -0,0 +1,83 @@
import inspect
import logging
import sys
from ray.experimental.client.ray_client_helpers import ray_start_client_server
from ray._private.ray_microbenchmark_helpers import timeit
from ray._private.ray_microbenchmark_helpers import ray_setup_and_teardown
def benchmark_get_calls(ray):
value = ray.put(0)
def get_small():
ray.get(value)
timeit("client: get calls", get_small)
def benchmark_put_calls(ray):
def put_small():
ray.put(0)
timeit("client: put calls", put_small)
def benchmark_remote_put_calls(ray):
@ray.remote
def do_put_small():
for _ in range(100):
ray.put(0)
def put_multi_small():
ray.get([do_put_small.remote() for _ in range(10)])
timeit("client: remote put calls", put_multi_small, 1000)
def benchmark_simple_actor(ray):
@ray.remote(num_cpus=0)
class Actor:
def small_value(self):
return b"ok"
def small_value_arg(self, x):
return b"ok"
def small_value_batch(self, n):
ray.get([self.small_value.remote() for _ in range(n)])
a = Actor.remote()
def actor_sync():
ray.get(a.small_value.remote())
timeit("client: 1:1 actor calls sync", actor_sync)
def actor_async():
ray.get([a.small_value.remote() for _ in range(1000)])
timeit("client: 1:1 actor calls async", actor_async, 1000)
a = Actor.options(max_concurrency=16).remote()
def actor_concurrent():
ray.get([a.small_value.remote() for _ in range(1000)])
timeit("client: 1:1 actor calls concurrent", actor_concurrent, 1000)
def main():
system_config = {"put_small_object_in_memory_store": True}
with ray_setup_and_teardown(
logging_level=logging.WARNING, _system_config=system_config):
for name, obj in inspect.getmembers(sys.modules[__name__]):
if not name.startswith("benchmark_"):
continue
with ray_start_client_server() as ray:
obj(ray)
if __name__ == "__main__":
main()
@@ -0,0 +1,39 @@
import time
import os
import ray
import numpy as np
from contextlib import contextmanager
# Only run tests matching this filter pattern.
filter_pattern = os.environ.get("TESTS_TO_RUN", "")
def timeit(name, fn, multiplier=1):
if filter_pattern not in name:
return
# warmup
start = time.time()
while time.time() - start < 1:
fn()
# real run
stats = []
for _ in range(4):
start = time.time()
count = 0
while time.time() - start < 2:
fn()
count += 1
end = time.time()
stats.append(multiplier * count / (end - start))
print(name, "per second", round(np.mean(stats), 2), "+-",
round(np.std(stats), 2))
@contextmanager
def ray_setup_and_teardown(**init_args):
ray.init(**init_args)
try:
yield None
finally:
ray.shutdown()
+12 -13
View File
@@ -279,8 +279,7 @@ def get_address_info_from_redis_helper(redis_address,
def get_address_info_from_redis(redis_address,
node_ip_address,
num_retries=5,
redis_password=None,
no_warning=False):
redis_password=None):
counter = 0
while True:
try:
@@ -291,11 +290,10 @@ def get_address_info_from_redis(redis_address,
raise
# Some of the information may not be in Redis yet, so wait a little
# bit.
if not no_warning:
logger.warning(
"Some processes that the driver needs to connect to have "
"not registered with Redis, so retrying. Have you run "
"'ray start' on this node?")
logger.warning(
"Some processes that the driver needs to connect to have "
"not registered with Redis, so retrying. Have you run "
"'ray start' on this node?")
time.sleep(1)
counter += 1
@@ -1618,12 +1616,13 @@ def determine_plasma_store_config(object_store_memory,
logger.warning(
"WARNING: The object store is using {} instead of "
"/dev/shm because /dev/shm has only {} bytes available. "
"This may slow down performance! You may be able to free "
"up space by deleting files in /dev/shm or terminating "
"any running plasma_store_server processes. If you are "
"inside a Docker container, you may need to pass an "
"argument with the flag '--shm-size' to 'docker run'.".
format(ray.utils.get_user_temp_dir(), shm_avail))
"This will harm performance! You may be able to free up "
"space by deleting files in /dev/shm. If you are inside a "
"Docker container, you can increase /dev/shm size by "
"passing '--shm-size=Xgb' to 'docker run' (or add it to "
"the run_options list in a Ray cluster config). Make sure "
"to set this to more than 2gb.".format(
ray.utils.get_user_temp_dir(), shm_avail))
else:
plasma_directory = ray.utils.get_user_temp_dir()
+19 -19
View File
@@ -107,6 +107,10 @@ from ray.exceptions import (
TaskCancelledError
)
from ray.utils import decode
from ray._private.client_mode_hook import (
_enable_client_hook,
_disable_client_hook,
)
import msgpack
cimport cpython
@@ -558,6 +562,7 @@ cdef CRayStatus task_execution_handler(
with gil:
try:
client_was_enabled = _disable_client_hook()
try:
# The call to execute_task should never raise an exception. If
# it does, that indicates that there was an internal error.
@@ -582,6 +587,8 @@ cdef CRayStatus task_execution_handler(
else:
logger.exception("SystemExit was raised from the worker")
return CRayStatus.UnexpectedSystemExit()
finally:
_enable_client_hook(client_was_enabled)
return CRayStatus.OK()
@@ -638,9 +645,11 @@ cdef c_vector[c_string] spill_objects_handler(
return return_urls
cdef void restore_spilled_objects_handler(
cdef int64_t restore_spilled_objects_handler(
const c_vector[CObjectID]& object_ids_to_restore,
const c_vector[c_string]& object_urls) nogil:
cdef:
int64_t bytes_restored = 0
with gil:
urls = []
size = object_urls.size()
@@ -651,7 +660,8 @@ cdef void restore_spilled_objects_handler(
with ray.worker._changeproctitle(
ray_constants.WORKER_PROCESS_TYPE_RESTORE_WORKER,
ray_constants.WORKER_PROCESS_TYPE_RESTORE_WORKER_IDLE):
external_storage.restore_spilled_objects(object_refs, urls)
bytes_restored = external_storage.restore_spilled_objects(
object_refs, urls)
except Exception:
exception_str = (
"An unexpected internal error occurred while the IO worker "
@@ -662,6 +672,7 @@ cdef void restore_spilled_objects_handler(
"restore_spilled_objects_error",
traceback.format_exc() + exception_str,
job_id=None)
return bytes_restored
cdef void delete_spilled_objects_handler(
@@ -873,7 +884,8 @@ cdef class CoreWorker:
return self.plasma_event_handler
def get_objects(self, object_refs, TaskID current_task_id,
int64_t timeout_ms=-1, plasma_objects_only=False):
int64_t timeout_ms=-1,
plasma_objects_only=False):
cdef:
c_vector[shared_ptr[CRayObject]] results
CTaskID c_task_id = current_task_id.native()
@@ -1004,7 +1016,7 @@ cdef class CoreWorker:
return c_object_id.Binary()
def wait(self, object_refs, int num_returns, int64_t timeout_ms,
TaskID current_task_id):
TaskID current_task_id, c_bool fetch_local):
cdef:
c_vector[CObjectID] wait_ids
c_vector[c_bool] results
@@ -1013,7 +1025,7 @@ cdef class CoreWorker:
wait_ids = ObjectRefsToVector(object_refs)
with nogil:
check_status(CCoreWorkerProcess.GetCoreWorker().Wait(
wait_ids, num_returns, timeout_ms, &results))
wait_ids, num_returns, timeout_ms, &results, fetch_local))
assert len(results) == len(object_refs)
@@ -1026,14 +1038,13 @@ cdef class CoreWorker:
return ready, not_ready
def free_objects(self, object_refs, c_bool local_only,
c_bool delete_creating_tasks):
def free_objects(self, object_refs, c_bool local_only):
cdef:
c_vector[CObjectID] free_ids = ObjectRefsToVector(object_refs)
with nogil:
check_status(CCoreWorkerProcess.GetCoreWorker().Delete(
free_ids, local_only, delete_creating_tasks))
free_ids, local_only))
def global_gc(self):
with nogil:
@@ -1573,17 +1584,6 @@ cdef class CoreWorker:
resource_name.encode("ascii"), capacity,
CNodeID.FromBinary(client_id.binary()))
def force_spill_objects(self, object_refs):
cdef c_vector[CObjectID] object_ids
object_ids = ObjectRefsToVector(object_refs)
assert not RayConfig.instance().automatic_object_deletion_enabled(), (
"Automatic object deletion is not supported for"
"force_spill_objects yet. Please set"
"automatic_object_deletion_enabled: False in Ray's system config.")
with nogil:
check_status(CCoreWorkerProcess.GetCoreWorker()
.SpillObjects(object_ids))
cdef void async_set_result(shared_ptr[CRayObject] obj,
CObjectID object_ref,
void *future) with gil:
+103 -74
View File
@@ -1,4 +1,4 @@
from collections import defaultdict, namedtuple
from collections import defaultdict, namedtuple, Counter
from typing import Any, Optional, Dict, List
from urllib3.exceptions import MaxRetryError
import copy
@@ -16,8 +16,10 @@ from ray.experimental.internal_kv import _internal_kv_put, \
from ray.autoscaler.tags import (
TAG_RAY_LAUNCH_CONFIG, TAG_RAY_RUNTIME_CONFIG,
TAG_RAY_FILE_MOUNTS_CONTENTS, TAG_RAY_NODE_STATUS, TAG_RAY_NODE_KIND,
TAG_RAY_USER_NODE_TYPE, STATUS_UP_TO_DATE, NODE_KIND_WORKER,
NODE_KIND_UNMANAGED, NODE_KIND_HEAD)
TAG_RAY_USER_NODE_TYPE, STATUS_UNINITIALIZED, STATUS_WAITING_FOR_SSH,
STATUS_SYNCING_FILES, STATUS_SETTING_UP, STATUS_UP_TO_DATE,
NODE_KIND_WORKER, NODE_KIND_UNMANAGED, NODE_KIND_HEAD)
from ray.autoscaler._private.legacy_info_string import legacy_log_info_string
from ray.autoscaler._private.providers import _get_node_provider
from ray.autoscaler._private.updater import NodeUpdaterThread
from ray.autoscaler._private.node_launcher import NodeLauncher
@@ -25,8 +27,8 @@ from ray.autoscaler._private.resource_demand_scheduler import \
get_bin_pack_residual, ResourceDemandScheduler, NodeType, NodeID, NodeIP, \
ResourceDict
from ray.autoscaler._private.util import ConcurrentCounter, validate_config, \
with_head_node_ip, hash_launch_conf, hash_runtime_conf, add_prefix, \
DEBUG_AUTOSCALING_STATUS, DEBUG_AUTOSCALING_ERROR
with_head_node_ip, hash_launch_conf, hash_runtime_conf, \
DEBUG_AUTOSCALING_ERROR, format_info_string
from ray.autoscaler._private.constants import \
AUTOSCALER_MAX_NUM_FAILURES, AUTOSCALER_MAX_LAUNCH_BATCH, \
AUTOSCALER_MAX_CONCURRENT_LAUNCHES, AUTOSCALER_UPDATE_INTERVAL_S, \
@@ -41,20 +43,23 @@ UpdateInstructions = namedtuple(
"UpdateInstructions",
["node_id", "init_commands", "start_ray_commands", "docker_config"])
AutoscalerSummary = namedtuple(
"AutoscalerSummary",
["active_nodes", "pending_nodes", "pending_launches", "failed_nodes"])
class StandardAutoscaler:
"""The autoscaling control loop for a Ray cluster.
There are two ways to start an autoscaling cluster: manually by running
`ray start --head --autoscaling-config=/path/to/config.yaml` on a
instance that has permission to launch other instances, or you can also use
`ray up /path/to/config.yaml` from your laptop, which will
configure the right AWS/Cloud roles automatically.
StandardAutoscaler's `update` method is periodically called by `monitor.py`
to add and remove nodes as necessary. Currently, load-based autoscaling is
not implemented, so all this class does is try to maintain a constant
cluster size.
`ray start --head --autoscaling-config=/path/to/config.yaml` on a instance
that has permission to launch other instances, or you can also use `ray up
/path/to/config.yaml` from your laptop, which will configure the right
AWS/Cloud roles automatically. See the documentation for a full definition
of autoscaling behavior:
https://docs.ray.io/en/master/cluster/autoscaling.html
StandardAutoscaler's `update` method is periodically called in
`monitor.py`'s monitoring loop.
StandardAutoscaler is also used to bootstrap clusters (by adding workers
until the cluster size that can handle the resource demand is met).
@@ -120,9 +125,6 @@ class StandardAutoscaler:
for local_path in self.config["file_mounts"].values():
assert os.path.exists(local_path)
# List of resource bundles the user is requesting of the cluster.
self.resource_demand_vector = []
logger.info("StandardAutoscaler: {}".format(self.config))
def update(self):
@@ -149,7 +151,6 @@ class StandardAutoscaler:
def _update(self):
now = time.time()
# Throttle autoscaling updates to this interval to avoid exceeding
# rate limits on API calls.
if now - self.last_update_time < self.update_interval_s:
@@ -162,7 +163,6 @@ class StandardAutoscaler:
self.provider.internal_ip(node_id)
for node_id in self.all_workers()
])
self.log_info_string(nodes)
# Terminate any idle or out of date nodes
last_used = self.load_metrics.last_used_time_by_ip
@@ -176,7 +176,7 @@ class StandardAutoscaler:
sorted_node_ids = self._sort_based_on_last_used(nodes, last_used)
# Don't terminate nodes needed by request_resources()
nodes_allowed_to_terminate: Dict[NodeID, bool] = {}
if self.resource_demand_vector:
if self.load_metrics.get_resource_requests():
nodes_allowed_to_terminate = self._get_nodes_allowed_to_terminate(
sorted_node_ids)
@@ -202,7 +202,6 @@ class StandardAutoscaler:
if nodes_to_terminate:
self.provider.terminate_nodes(nodes_to_terminate)
nodes = self.workers()
self.log_info_string(nodes)
# Terminate nodes if there are too many
nodes_to_terminate = []
@@ -217,8 +216,6 @@ class StandardAutoscaler:
self.provider.terminate_nodes(nodes_to_terminate)
nodes = self.workers()
self.log_info_string(nodes)
to_launch = self.resource_demand_scheduler.get_nodes_to_launch(
self.provider.non_terminated_nodes(tag_filters={}),
self.pending_launches.breakdown(),
@@ -226,7 +223,7 @@ class StandardAutoscaler:
self.load_metrics.get_resource_utilization(),
self.load_metrics.get_pending_placement_groups(),
self.load_metrics.get_static_node_resources_by_ip(),
ensure_min_cluster_size=self.resource_demand_vector)
ensure_min_cluster_size=self.load_metrics.get_resource_requests())
for node_type, count in to_launch.items():
self.launch_new_node(count, node_type=node_type)
@@ -256,7 +253,6 @@ class StandardAutoscaler:
self.provider.terminate_nodes(nodes_to_terminate)
nodes = self.workers()
self.log_info_string(nodes)
# Update nodes with out-of-date files.
# TODO(edoakes): Spawning these threads directly seems to cause
@@ -282,6 +278,9 @@ class StandardAutoscaler:
for node_id in nodes:
self.recover_if_needed(node_id, now)
logger.info(self.info_string())
legacy_log_info_string(self, nodes)
def _sort_based_on_last_used(self, nodes: List[NodeID],
last_used: Dict[str, float]) -> List[NodeID]:
"""Sort the nodes based on the last time they were used.
@@ -333,7 +332,7 @@ class StandardAutoscaler:
NodeIP,
ResourceDict] = \
self.load_metrics.get_static_node_resources_by_ip()
head_node_resources = static_nodes[head_ip]
head_node_resources = static_nodes.get(head_ip, {})
else:
head_node_resources = {}
@@ -362,7 +361,7 @@ class StandardAutoscaler:
used_resource_requests: List[ResourceDict]
_, used_resource_requests = \
get_bin_pack_residual(max_node_resources,
self.resource_demand_vector)
self.load_metrics.get_resource_requests())
# Remove the first entry (the head node).
max_node_resources.pop(0)
# Remove the first entry (the head node).
@@ -482,11 +481,13 @@ class StandardAutoscaler:
# for legacy yamls.
self.resource_demand_scheduler.reset_config(
self.provider, self.available_node_types,
self.config["max_workers"], upscaling_speed)
self.config["max_workers"], self.config["head_node_type"],
upscaling_speed)
else:
self.resource_demand_scheduler = ResourceDemandScheduler(
self.provider, self.available_node_types,
self.config["max_workers"], upscaling_speed)
self.config["max_workers"], self.config["head_node_type"],
upscaling_speed)
except Exception as e:
if errors_fatal:
@@ -532,15 +533,17 @@ class StandardAutoscaler:
if not self.can_update(node_id):
return
key = self.provider.internal_ip(node_id)
if key not in self.load_metrics.last_heartbeat_time_by_ip:
self.load_metrics.last_heartbeat_time_by_ip[key] = now
last_heartbeat_time = self.load_metrics.last_heartbeat_time_by_ip[key]
delta = now - last_heartbeat_time
if delta < AUTOSCALER_HEARTBEAT_TIMEOUT_S:
return
if key in self.load_metrics.last_heartbeat_time_by_ip:
last_heartbeat_time = self.load_metrics.last_heartbeat_time_by_ip[
key]
delta = now - last_heartbeat_time
if delta < AUTOSCALER_HEARTBEAT_TIMEOUT_S:
return
logger.warning("StandardAutoscaler: "
"{}: No heartbeat in {}s, "
"restarting Ray to recover...".format(node_id, delta))
"{}: No recent heartbeat, "
"restarting Ray to recover...".format(node_id))
updater = NodeUpdaterThread(
node_id=node_id,
provider_config=self.config["provider"],
@@ -677,43 +680,6 @@ class StandardAutoscaler:
return self.provider.non_terminated_nodes(
tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_UNMANAGED})
def log_info_string(self, nodes):
tmp = "Cluster status: "
tmp += self.info_string(nodes)
tmp += "\n"
tmp += self.load_metrics.info_string()
tmp += "\n"
tmp += self.resource_demand_scheduler.debug_string(
nodes, self.pending_launches.breakdown(),
self.load_metrics.get_resource_utilization())
if _internal_kv_initialized():
_internal_kv_put(DEBUG_AUTOSCALING_STATUS, tmp, overwrite=True)
if self.prefix_cluster_info:
tmp = add_prefix(tmp, self.config["cluster_name"])
logger.debug(tmp)
def info_string(self, nodes):
suffix = ""
if self.updaters:
suffix += " ({} updating)".format(len(self.updaters))
if self.num_failed_updates:
suffix += " ({} failed to update)".format(
len(self.num_failed_updates))
return "{} nodes{}".format(len(nodes), suffix)
def request_resources(self, resources: List[dict]):
"""Called by monitor to request resources.
Args:
resources: A list of resource bundles.
"""
if resources:
logger.info(
"StandardAutoscaler: resource_requests={}".format(resources))
assert isinstance(resources, list), resources
self.resource_demand_vector = resources
def kill_workers(self):
logger.error("StandardAutoscaler: kill_workers triggered")
nodes = self.workers()
@@ -721,3 +687,66 @@ class StandardAutoscaler:
self.provider.terminate_nodes(nodes)
logger.error("StandardAutoscaler: terminated {} node(s)".format(
len(nodes)))
def summary(self):
"""Summarizes the active, pending, and failed node launches.
An active node is a node whose raylet is actively reporting heartbeats.
A pending node is non-active node whose node tag is uninitialized,
waiting for ssh, syncing files, or setting up.
If a node is not pending or active, it is failed.
Returns:
AutoscalerSummary: The summary.
"""
all_node_ids = self.provider.non_terminated_nodes(tag_filters={})
active_nodes = Counter()
pending_nodes = []
failed_nodes = []
for node_id in all_node_ids:
ip = self.provider.internal_ip(node_id)
node_tags = self.provider.node_tags(node_id)
if node_tags[TAG_RAY_NODE_KIND] == NODE_KIND_UNMANAGED:
continue
node_type = node_tags[TAG_RAY_USER_NODE_TYPE]
# TODO (Alex): If a node's raylet has died, it shouldn't be marked
# as active.
is_active = self.load_metrics.is_active(ip)
if is_active:
active_nodes[node_type] += 1
else:
status = node_tags[TAG_RAY_NODE_STATUS]
pending_states = [
STATUS_UNINITIALIZED, STATUS_WAITING_FOR_SSH,
STATUS_SYNCING_FILES, STATUS_SETTING_UP
]
is_pending = status in pending_states
if is_pending:
pending_nodes.append((ip, node_type))
else:
# TODO (Alex): Failed nodes are now immediately killed, so
# this list will almost always be empty. We should ideally
# keep a cache of recently failed nodes and their startup
# logs.
failed_nodes.append((ip, node_type))
# The concurrent counter leaves some 0 counts in, so we need to
# manually filter those out.
pending_launches = {}
for node_type, count in self.pending_launches.breakdown().items():
if count:
pending_launches[node_type] = count
return AutoscalerSummary(
active_nodes=active_nodes,
pending_nodes=pending_nodes,
pending_launches=pending_launches,
failed_nodes=failed_nodes)
def info_string(self):
lm_summary = self.load_metrics.summary()
autoscaler_summary = self.summary()
return "\n" + format_info_string(lm_summary, autoscaler_summary)
@@ -35,6 +35,8 @@ logger = logging.getLogger(__name__)
HASH_MAX_LENGTH = 10
KUBECTL_RSYNC = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "kubernetes/kubectl-rsync.sh")
MAX_HOME_RETRIES = 3
HOME_RETRY_DELAY_S = 5
_config = {"use_login_shells": True, "silent_rsync": True}
@@ -248,16 +250,31 @@ class KubernetesCommandRunner(CommandRunnerInterface):
@property
def _home(self):
if self._home_cached is not None:
return self._home_cached
for _ in range(MAX_HOME_RETRIES - 1):
try:
self._home_cached = self._try_to_get_home()
return self._home_cached
except Exception:
# TODO (Dmitri): Identify the exception we're trying to avoid.
logger.info("Error reading container's home directory. "
f"Retrying in {HOME_RETRY_DELAY_S} seconds.")
time.sleep(HOME_RETRY_DELAY_S)
# Last try
self._home_cached = self._try_to_get_home()
return self._home_cached
def _try_to_get_home(self):
# TODO (Dmitri): Think about how to use the node's HOME variable
# without making an extra kubectl exec call.
if self._home_cached is None:
cmd = self.kubectl + [
"exec", "-it", self.node_id, "--", "printenv", "HOME"
]
joined_cmd = " ".join(cmd)
raw_out = self.process_runner.check_output(joined_cmd, shell=True)
self._home_cached = raw_out.decode().strip("\n\r")
return self._home_cached
cmd = self.kubectl + [
"exec", "-it", self.node_id, "--", "printenv", "HOME"
]
joined_cmd = " ".join(cmd)
raw_out = self.process_runner.check_output(joined_cmd, shell=True)
home = raw_out.decode().strip("\n\r")
return home
class SSHOptions:
+16 -3
View File
@@ -43,6 +43,10 @@ from ray.worker import global_worker # type: ignore
from ray.util.debug import log_once
import ray.autoscaler._private.subprocess_output_util as cmd_output_util
from ray.autoscaler._private.load_metrics import LoadMetricsSummary
from ray.autoscaler._private.autoscaler import AutoscalerSummary
from ray.autoscaler._private.util import format_info_string, \
format_info_string_no_node_types
logger = logging.getLogger(__name__)
@@ -94,6 +98,14 @@ def debug_status() -> str:
status = "No cluster status."
else:
status = status.decode("utf-8")
as_dict = json.loads(status)
lm_summary = LoadMetricsSummary(**as_dict["load_metrics_report"])
if "autoscaler_report" in as_dict:
autoscaler_summary = AutoscalerSummary(
**as_dict["autoscaler_report"])
status = format_info_string(lm_summary, autoscaler_summary)
else:
status = format_info_string_no_node_types(lm_summary)
if error:
status += "\n"
status += error.decode("utf-8")
@@ -280,9 +292,10 @@ def _bootstrap_config(config: Dict[str, Any],
f"Failed to autodetect node resources: {str(exc)}. "
"You can see full stack trace with higher verbosity.")
# NOTE: if `resources` field is missing, validate_config for non-AWS will
# fail (the schema error will ask the user to manually fill the resources)
# as we currently support autofilling resources for AWS instances only.
# NOTE: if `resources` field is missing, validate_config for providers
# other than AWS and Kubernetes will fail (the schema error will ask the
# user to manually fill the resources) as we currently support autofilling
# resources for AWS and Kubernetes only.
validate_config(config)
resolved_config = provider_cls.bootstrap_config(config)
@@ -60,6 +60,13 @@ def bootstrap_kubernetes(config):
def fillout_resources_kubernetes(config):
"""Fills CPU and GPU resources by reading pod spec of each available node
type.
For each node type and each of CPU/GPU, looks at container's resources
and limits, takes min of the two. The result is rounded up, as Ray does
not currently support fractional CPU.
"""
if "available_node_types" not in config:
return config["available_node_types"]
node_types = copy.deepcopy(config["available_node_types"])
@@ -96,20 +103,47 @@ def get_resource(container_resources, resource_name):
limit = _get_resource(
container_resources, resource_name, field_name="limits")
resource = min(request, limit)
# float("inf") value means the resource wasn't detected in either
# requests or limits
return 0 if resource == float("inf") else int(resource)
def _get_resource(container_resources, resource_name, field_name):
if (field_name in container_resources
and resource_name in container_resources[field_name]):
return _parse_resource(container_resources[field_name][resource_name])
else:
"""Returns the resource quantity.
The amount of resource is rounded up to nearest integer.
Returns float("inf") if the resource is not present.
Args:
container_resources (dict): Container's resource field.
resource_name (str): One of 'cpu' or 'gpu'.
field_name (str): One of 'requests' or 'limits'.
Returns:
Union[int, float]: Detected resource quantity.
"""
if field_name not in container_resources:
# No limit/resource field.
return float("inf")
resources = container_resources[field_name]
# Look for keys containing the resource_name. For example,
# the key 'nvidia.com/gpu' contains the key 'gpu'.
matching_keys = [key for key in resources if resource_name in key.lower()]
if len(matching_keys) == 0:
return float("inf")
if len(matching_keys) > 1:
# Should have only one match -- mostly relevant for gpu.
raise ValueError(f"Multiple {resource_name} types not supported.")
# E.g. 'nvidia.com/gpu' or 'cpu'.
resource_key = matching_keys.pop()
resource_quantity = resources[resource_key]
return _parse_resource(resource_quantity)
def _parse_resource(resource):
resource_str = str(resource)
if resource_str[-1] == "m":
# For example, '500m' rounds up to 1.
return math.ceil(int(resource_str[:-1]) / 1000)
else:
return int(resource_str)
@@ -0,0 +1,35 @@
import logging
from ray.autoscaler._private.util import DEBUG_AUTOSCALING_STATUS_LEGACY
from ray.experimental.internal_kv import _internal_kv_put, \
_internal_kv_initialized
"""This file provides legacy support for the old info string in order to
ensure the dashboard's `api/cluster_status` does not break backwards
compatibilty.
"""
logger = logging.getLogger(__name__)
def legacy_log_info_string(autoscaler, nodes):
tmp = "Cluster status: "
tmp += info_string(autoscaler, nodes)
tmp += "\n"
tmp += autoscaler.load_metrics.info_string()
tmp += "\n"
tmp += autoscaler.resource_demand_scheduler.debug_string(
nodes, autoscaler.pending_launches.breakdown(),
autoscaler.load_metrics.get_resource_utilization())
if _internal_kv_initialized():
_internal_kv_put(DEBUG_AUTOSCALING_STATUS_LEGACY, tmp, overwrite=True)
logger.debug(tmp)
def info_string(autoscaler, nodes):
suffix = ""
if autoscaler.updaters:
suffix += " ({} updating)".format(len(autoscaler.updaters))
if autoscaler.num_failed_updates:
suffix += " ({} failed to update)".format(
len(autoscaler.num_failed_updates))
return "{} nodes{}".format(len(nodes), suffix)
+100 -16
View File
@@ -1,16 +1,26 @@
from collections import namedtuple
from functools import reduce
import logging
import time
from typing import Dict, List
import numpy as np
import ray._private.services as services
from ray.autoscaler._private.constants import MEMORY_RESOURCE_UNIT_BYTES
from ray.autoscaler._private.constants import MEMORY_RESOURCE_UNIT_BYTES,\
AUTOSCALER_MAX_RESOURCE_DEMAND_VECTOR_SIZE
from ray.autoscaler._private.util import add_resources, freq_of_dicts
from ray.gcs_utils import PlacementGroupTableData
from ray.autoscaler._private.resource_demand_scheduler import \
NodeIP, ResourceDict
from ray.core.generated.common_pb2 import PlacementStrategy
logger = logging.getLogger(__name__)
LoadMetricsSummary = namedtuple("LoadMetricsSummary", [
"head_ip", "usage", "resource_demand", "pg_demand", "request_demand",
"node_types"
])
class LoadMetrics:
"""Container for cluster load metrics.
@@ -31,6 +41,7 @@ class LoadMetrics:
self.waiting_bundles = []
self.infeasible_bundles = []
self.pending_placement_groups = []
self.resource_requests = []
def update(self,
ip: str,
@@ -72,34 +83,37 @@ class LoadMetrics:
def mark_active(self, ip):
assert ip is not None, "IP should be known at this time"
logger.info("Node {} is newly setup, treating as active".format(ip))
logger.debug("Node {} is newly setup, treating as active".format(ip))
self.last_heartbeat_time_by_ip[ip] = time.time()
def is_active(self, ip):
return ip in self.last_heartbeat_time_by_ip
def prune_active_ips(self, active_ips):
active_ips = set(active_ips)
active_ips.add(self.local_ip)
def prune(mapping):
def prune(mapping, should_log):
unwanted = set(mapping) - active_ips
for unwanted_key in unwanted:
# TODO (Alex): Change this back to info after #12138.
logger.debug("LoadMetrics: "
"Removed mapping: {} - {}".format(
unwanted_key, mapping[unwanted_key]))
if should_log:
logger.info("LoadMetrics: "
"Removed mapping: {} - {}".format(
unwanted_key, mapping[unwanted_key]))
del mapping[unwanted_key]
if unwanted:
if unwanted and should_log:
# TODO (Alex): Change this back to info after #12138.
logger.debug(
logger.info(
"LoadMetrics: "
"Removed {} stale ip mappings: {} not in {}".format(
len(unwanted), unwanted, active_ips))
assert not (unwanted & set(mapping))
prune(self.last_used_time_by_ip)
prune(self.static_resources_by_ip)
prune(self.dynamic_resources_by_ip)
prune(self.resource_load_by_ip)
prune(self.last_heartbeat_time_by_ip)
prune(self.last_used_time_by_ip, should_log=True)
prune(self.static_resources_by_ip, should_log=False)
prune(self.dynamic_resources_by_ip, should_log=False)
prune(self.resource_load_by_ip, should_log=False)
prune(self.last_heartbeat_time_by_ip, should_log=False)
def get_node_resources(self):
"""Return a list of node resources (static resource sizes).
@@ -155,12 +169,82 @@ class LoadMetrics:
return resources_used, resources_total
def get_resource_demand_vector(self):
return self.waiting_bundles + self.infeasible_bundles
def get_resource_demand_vector(self, clip=True):
if clip:
# Bound the total number of bundles to
# 2xMAX_RESOURCE_DEMAND_VECTOR_SIZE. This guarantees the resource
# demand scheduler bin packing algorithm takes a reasonable amount
# of time to run.
return (
self.
waiting_bundles[:AUTOSCALER_MAX_RESOURCE_DEMAND_VECTOR_SIZE] +
self.
infeasible_bundles[:AUTOSCALER_MAX_RESOURCE_DEMAND_VECTOR_SIZE]
)
else:
return self.waiting_bundles + self.infeasible_bundles
def get_resource_requests(self):
return self.resource_requests
def get_pending_placement_groups(self):
return self.pending_placement_groups
def summary(self):
available_resources = reduce(add_resources,
self.dynamic_resources_by_ip.values()
) if self.dynamic_resources_by_ip else {}
total_resources = reduce(add_resources,
self.static_resources_by_ip.values()
) if self.static_resources_by_ip else {}
usage_dict = {}
for key in total_resources:
total = total_resources[key]
usage_dict[key] = (total - available_resources[key], total)
summarized_demand_vector = freq_of_dicts(
self.get_resource_demand_vector(clip=False))
summarized_resource_requests = freq_of_dicts(
self.get_resource_requests())
def placement_group_serializer(pg):
bundles = tuple(
frozenset(bundle.unit_resources.items())
for bundle in pg.bundles)
return (bundles, pg.strategy)
def placement_group_deserializer(pg_tuple):
# We marshal this as a dictionary so that we can easily json.dumps
# it later.
# TODO (Alex): Would there be a benefit to properly
# marshalling this (into a protobuf)?
bundles = list(map(dict, pg_tuple[0]))
return {
"bundles": freq_of_dicts(bundles),
"strategy": PlacementStrategy.Name(pg_tuple[1])
}
summarized_placement_groups = freq_of_dicts(
self.get_pending_placement_groups(),
serializer=placement_group_serializer,
deserializer=placement_group_deserializer)
nodes_summary = freq_of_dicts(self.static_resources_by_ip.values())
return LoadMetricsSummary(
head_ip=self.local_ip,
usage=usage_dict,
resource_demand=summarized_demand_vector,
pg_demand=summarized_placement_groups,
request_demand=summarized_resource_requests,
node_types=nodes_summary)
def set_resource_requests(self, requested_resources):
if requested_resources is not None:
assert isinstance(requested_resources, list), requested_resources
self.resource_requests = [
request for request in requested_resources if len(request) > 0
]
def info_string(self):
return " - " + "\n - ".join(
["{}: {}".format(k, v) for k, v in sorted(self._info().items())])
@@ -47,16 +47,19 @@ class ResourceDemandScheduler:
provider: NodeProvider,
node_types: Dict[NodeType, NodeTypeConfigDict],
max_workers: int,
head_node_type: NodeType,
upscaling_speed: float = 1) -> None:
self.provider = provider
self.node_types = copy.deepcopy(node_types)
self.max_workers = max_workers
self.head_node_type = head_node_type
self.upscaling_speed = upscaling_speed
def reset_config(self,
provider: NodeProvider,
node_types: Dict[NodeType, NodeTypeConfigDict],
max_workers: int,
head_node_type: NodeType,
upscaling_speed: float = 1) -> None:
"""Updates the class state variables.
@@ -89,6 +92,7 @@ class ResourceDemandScheduler:
self.provider = provider
self.node_types = copy.deepcopy(final_node_types)
self.max_workers = max_workers
self.head_node_type = head_node_type
self.upscaling_speed = upscaling_speed
def is_legacy_yaml(self,
@@ -145,18 +149,18 @@ class ResourceDemandScheduler:
node_resources, node_type_counts = self.calculate_node_resources(
nodes, launching_nodes, unused_resources_by_ip)
logger.info("Cluster resources: {}".format(node_resources))
logger.info("Node counts: {}".format(node_type_counts))
logger.debug("Cluster resources: {}".format(node_resources))
logger.debug("Node counts: {}".format(node_type_counts))
# Step 2: add nodes to add to satisfy min_workers for each type
(node_resources,
node_type_counts,
adjusted_min_workers) = \
_add_min_workers_nodes(
node_resources, node_type_counts, self.node_types,
self.max_workers, ensure_min_cluster_size)
self.max_workers, self.head_node_type, ensure_min_cluster_size)
# Step 3: add nodes for strict spread groups
logger.info(f"Placement group demands: {pending_placement_groups}")
logger.debug(f"Placement group demands: {pending_placement_groups}")
placement_group_demand_vector, strict_spreads = \
placement_groups_to_resource_demands(pending_placement_groups)
resource_demands.extend(placement_group_demand_vector)
@@ -183,12 +187,13 @@ class ResourceDemandScheduler:
# groups
unfulfilled, _ = get_bin_pack_residual(node_resources,
resource_demands)
logger.info("Resource demands: {}".format(resource_demands))
logger.info("Unfulfilled demands: {}".format(unfulfilled))
logger.debug("Resource demands: {}".format(resource_demands))
logger.debug("Unfulfilled demands: {}".format(unfulfilled))
# Add 1 to account for the head node.
max_to_add = self.max_workers + 1 - sum(node_type_counts.values())
nodes_to_add_based_on_demand = get_nodes_for(
self.node_types, node_type_counts, max_to_add, unfulfilled)
self.node_types, node_type_counts, self.head_node_type, max_to_add,
unfulfilled)
# Merge nodes to add based on demand and nodes to add based on
# min_workers constraint. We add them because nodes to add based on
# demand was calculated after the min_workers constraint was respected.
@@ -206,7 +211,7 @@ class ResourceDemandScheduler:
total_nodes_to_add, unused_resources_by_ip.keys(), nodes,
launching_nodes, adjusted_min_workers)
logger.info("Node requests: {}".format(total_nodes_to_add))
logger.debug("Node requests: {}".format(total_nodes_to_add))
return total_nodes_to_add
def _legacy_worker_node_to_launch(
@@ -443,6 +448,7 @@ class ResourceDemandScheduler:
to_launch = get_nodes_for(
self.node_types,
node_type_counts,
self.head_node_type,
max_to_add,
unfulfilled,
strict_spread=True)
@@ -490,7 +496,7 @@ def _add_min_workers_nodes(
node_resources: List[ResourceDict],
node_type_counts: Dict[NodeType, int],
node_types: Dict[NodeType, NodeTypeConfigDict], max_workers: int,
ensure_min_cluster_size: List[ResourceDict]
head_node_type: NodeType, ensure_min_cluster_size: List[ResourceDict]
) -> (List[ResourceDict], Dict[NodeType, int], Dict[NodeType, int]):
"""Updates resource demands to respect the min_workers and
request_resources() constraints.
@@ -515,6 +521,9 @@ def _add_min_workers_nodes(
existing = node_type_counts.get(node_type, 0)
target = min(
config.get("min_workers", 0), config.get("max_workers", 0))
if node_type == head_node_type:
# Add 1 to account for head node.
target = target + 1
if existing < target:
total_nodes_to_add_dict[node_type] = target - existing
node_type_counts[node_type] = target
@@ -537,7 +546,7 @@ def _add_min_workers_nodes(
max_node_resources, ensure_min_cluster_size)
# Get the nodes to meet the unfulfilled.
nodes_to_add_request_resources = get_nodes_for(
node_types, node_type_counts, max_to_add,
node_types, node_type_counts, head_node_type, max_to_add,
resource_requests_unfulfilled)
# Update the resources, counts and total nodes to add.
for node_type in nodes_to_add_request_resources:
@@ -558,6 +567,7 @@ def _add_min_workers_nodes(
def get_nodes_for(node_types: Dict[NodeType, NodeTypeConfigDict],
existing_nodes: Dict[NodeType, int],
head_node_type: NodeType,
max_to_add: int,
resources: List[ResourceDict],
strict_spread: bool = False) -> Dict[NodeType, int]:
@@ -581,9 +591,13 @@ def get_nodes_for(node_types: Dict[NodeType, NodeTypeConfigDict],
while resources and sum(nodes_to_add.values()) < max_to_add:
utilization_scores = []
for node_type in node_types:
max_workers_of_node_type = node_types[node_type].get(
"max_workers", 0)
if head_node_type == node_type:
# Add 1 to account for head node.
max_workers_of_node_type = max_workers_of_node_type + 1
if (existing_nodes.get(node_type, 0) + nodes_to_add.get(
node_type, 0) >= node_types[node_type].get(
"max_workers", 0)):
node_type, 0) >= max_workers_of_node_type):
continue
node_resources = node_types[node_type]["resources"]
if strict_spread:
@@ -601,8 +615,14 @@ def get_nodes_for(node_types: Dict[NodeType, NodeTypeConfigDict],
# starts up because placement groups are scheduled via custom
# resources. This will behave properly with the current utilization
# score heuristic, but it's a little dangerous and misleading.
logger.info(
"No feasible node type to add for {}".format(resources))
logger.warning(
f"The autoscaler could not find a node type to satisfy the"
f"request: {resources}. If this request is related to "
f"placement groups the resource request will resolve itself, "
f"otherwise please specify a node type with the necessary "
f"resource "
f"https://docs.ray.io/en/master/cluster/autoscaling.html#multiple-node-type-autoscaling." # noqa: E501
)
break
utilization_scores = sorted(utilization_scores, reverse=True)
+174 -1
View File
@@ -1,13 +1,15 @@
import collections
from datetime import datetime
import logging
import hashlib
import json
import jsonschema
import os
import threading
from typing import Any, Dict
from typing import Any, Dict, List
import ray
import ray.ray_constants
import ray._private.services as services
from ray.autoscaler._private.providers import _get_default_config
from ray.autoscaler._private.docker import validate_docker_config
@@ -20,6 +22,7 @@ RAY_SCHEMA_PATH = os.path.join(
# Internal kv keys for storing debug status.
DEBUG_AUTOSCALING_ERROR = "__autoscaling_error"
DEBUG_AUTOSCALING_STATUS = "__autoscaling_status"
DEBUG_AUTOSCALING_STATUS_LEGACY = "__autoscaling_status_legacy"
logger = logging.getLogger(__name__)
@@ -246,6 +249,47 @@ def hash_runtime_conf(file_mounts,
return (_hash_cache[conf_str], file_mounts_contents_hash)
def add_resources(dict1: Dict[str, float],
dict2: Dict[str, float]) -> Dict[str, float]:
"""Add the values in two dictionaries.
Returns:
dict: A new dictionary (inputs remain unmodified).
"""
new_dict = dict1.copy()
for k, v in dict2.items():
new_dict[k] = v + new_dict.get(k, 0)
return new_dict
def freq_of_dicts(dicts: List[Dict],
serializer=lambda d: frozenset(d.items()),
deserializer=dict):
"""Count a list of dictionaries (or unhashable types).
This is somewhat annoying because mutable data structures aren't hashable,
and set/dict keys must be hashable.
Args:
dicts (List[D]): A list of dictionaries to be counted.
serializer (D -> S): A custom serailization function. The output type S
must be hashable. The default serializer converts a dictionary into
a frozenset of KV pairs.
deserializer (S -> U): A custom deserialization function. See the
serializer for information about type S. For dictionaries U := D.
Returns:
List[Tuple[U, int]]: Returns a list of tuples. Each entry in the list
is a tuple containing a unique entry from `dicts` and its
corresponding frequency count.
"""
freqs = collections.Counter(map(lambda d: serializer(d), dicts))
as_list = []
for as_set, count in freqs.items():
as_list.append((deserializer(as_set), count))
return as_list
def add_prefix(info_string, prefix):
"""Prefixes each line of info_string, except the first, by prefix."""
lines = info_string.split("\n")
@@ -255,3 +299,132 @@ def add_prefix(info_string, prefix):
prefixed_lines.append(prefixed_line)
prefixed_info_string = "\n".join(prefixed_lines)
return prefixed_info_string
def format_pg(pg):
strategy = pg["strategy"]
bundles = pg["bundles"]
shape_strs = []
for bundle, count in bundles:
shape_strs.append(f"{bundle} * {count}")
bundles_str = ", ".join(shape_strs)
return f"{bundles_str} ({strategy})"
def get_usage_report(lm_summary):
usage_lines = []
for resource, (used, total) in lm_summary.usage.items():
line = f" {used}/{total} {resource}"
if resource in ["memory", "object_store_memory"]:
to_GiB = ray.ray_constants.MEMORY_RESOURCE_UNIT_BYTES / 2**30
used *= to_GiB
total *= to_GiB
line = f" {used:.2f}/{total:.3f} GiB {resource}"
usage_lines.append(line)
usage_report = "\n".join(usage_lines)
return usage_report
def get_demand_report(lm_summary):
demand_lines = []
for bundle, count in lm_summary.resource_demand:
line = f" {bundle}: {count}+ pending tasks/actors"
demand_lines.append(line)
for entry in lm_summary.pg_demand:
pg, count = entry
pg_str = format_pg(pg)
line = f" {pg_str}: {count}+ pending placement groups"
demand_lines.append(line)
for bundle, count in lm_summary.request_demand:
line = f" {bundle}: {count}+ from request_resources()"
demand_lines.append(line)
if len(demand_lines) > 0:
demand_report = "\n".join(demand_lines)
else:
demand_report = " (no resource demands)"
return demand_report
def format_info_string(lm_summary, autoscaler_summary, time=None):
if time is None:
time = datetime.now()
header = "=" * 8 + f" Autoscaler status: {time} " + "=" * 8
separator = "-" * len(header)
available_node_report_lines = []
for node_type, count in autoscaler_summary.active_nodes.items():
line = f" {count} {node_type}"
available_node_report_lines.append(line)
available_node_report = "\n".join(available_node_report_lines)
pending_lines = []
for node_type, count in autoscaler_summary.pending_launches.items():
line = f" {node_type}, {count} launching"
pending_lines.append(line)
for ip, node_type in autoscaler_summary.pending_nodes:
line = f" {ip}: {node_type}, setting up"
pending_lines.append(line)
if pending_lines:
pending_report = "\n".join(pending_lines)
else:
pending_report = " (no pending nodes)"
failure_lines = []
for ip, node_type in autoscaler_summary.failed_nodes:
line = f" {ip}: {node_type}"
failure_report = "Recent failures:\n"
if failure_lines:
failure_report += "\n".join(failure_lines)
else:
failure_report += " (no failures)"
usage_report = get_usage_report(lm_summary)
demand_report = get_demand_report(lm_summary)
formatted_output = f"""{header}
Node status
{separator}
Healthy:
{available_node_report}
Pending:
{pending_report}
{failure_report}
Resources
{separator}
Usage:
{usage_report}
Demands:
{demand_report}"""
return formatted_output
def format_info_string_no_node_types(lm_summary, time=None):
if time is None:
time = datetime.now()
header = "=" * 8 + f" Cluster status: {time} " + "=" * 8
separator = "-" * len(header)
node_lines = []
for node_type, count in lm_summary.node_types:
line = f" {count} node(s) with resources: {node_type}"
node_lines.append(line)
node_report = "\n".join(node_lines)
usage_report = get_usage_report(lm_summary)
demand_report = get_demand_report(lm_summary)
formatted_output = f"""{header}
Node status
{separator}
{node_report}
Resources
{separator}
Usage:
{usage_report}
Demands:
{demand_report}"""
return formatted_output
@@ -250,9 +250,11 @@ worker_nodes:
# Files or directories to copy to the head and worker nodes. The format is a
# dictionary from REMOTE_PATH: LOCAL_PATH, e.g.
file_mounts: {
# "/path1/on/remote/machine": "/path1/on/local/machine",
# "/path2/on/remote/machine": "/path2/on/local/machine",
# "~/path1/on/remote/machine": "/path1/on/local/machine",
# "~/path2/on/remote/machine": "/path2/on/local/machine",
}
# Note that the container images in this example have a non-root user.
# To avoid permissions issues, we recommend mounting into a subdirectory of home (~).
# Files or directories to copy from the head node to the worker nodes. The format is a
# list of paths. The same path on the head node will be copied to the worker node.
@@ -250,9 +250,11 @@ worker_nodes:
# Files or directories to copy to the head and worker nodes. The format is a
# dictionary from REMOTE_PATH: LOCAL_PATH, e.g.
file_mounts: {
# "/path1/on/remote/machine": "/path1/on/local/machine",
# "/path2/on/remote/machine": "/path2/on/local/machine",
# "~/path1/on/remote/machine": "/path1/on/local/machine",
# "~/path2/on/remote/machine": "/path2/on/local/machine",
}
# Note that the container images in this example have a non-root user.
# To avoid permissions issues, we recommend mounting into a subdirectory of home (~).
# Files or directories to copy from the head node to the worker nodes. The format is a
# list of paths. The same path on the head node will be copied to the worker node.
@@ -286,9 +286,11 @@ worker_nodes:
# Files or directories to copy to the head and worker nodes. The format is a
# dictionary from REMOTE_PATH: LOCAL_PATH, e.g.
file_mounts: {
# "/path1/on/remote/machine": "/path1/on/local/machine",
# "/path2/on/remote/machine": "/path2/on/local/machine",
# "~/path1/on/remote/machine": "/path1/on/local/machine",
# "~/path2/on/remote/machine": "/path2/on/local/machine",
}
# Note that the container images in this example have a non-root user.
# To avoid permissions issues, we recommend mounting into a subdirectory of home (~).
# List of commands that will be run before `setup_commands`. If docker is
# enabled, these commands will run outside the container and before docker
+4 -2
View File
@@ -142,7 +142,8 @@ class WorkerCrashedError(RayError):
"""Indicates that the worker died unexpectedly while executing a task."""
def __str__(self):
return "The worker died unexpectedly while executing this task."
return ("The worker died unexpectedly while executing this task. "
"Check python-core-worker-*.log files for more information.")
class RayActorError(RayError):
@@ -153,7 +154,8 @@ class RayActorError(RayError):
"""
def __str__(self):
return "The actor died unexpectedly before finishing this task."
return ("The actor died unexpectedly before finishing this task. "
"Check python-core-worker-*.log files for more information.")
class RaySystemError(RayError):
-2
View File
@@ -1,6 +1,4 @@
from .dynamic_resources import set_resource
from .object_spilling import force_spill_objects
__all__ = [
"set_resource",
"force_spill_objects",
]
+82 -99
View File
@@ -1,117 +1,100 @@
from ray.experimental.client.api import ClientAPI
from ray.experimental.client.api import APIImpl
from typing import Optional, List, Tuple
from contextlib import contextmanager
from typing import List, Tuple
import logging
logger = logging.getLogger(__name__)
# About these global variables: Ray 1.0 uses exported module functions to
# provide its API, and we need to match that. However, we want different
# behaviors depending on where, exactly, in the client stack this is running.
#
# The reason for these differences depends on what's being pickled and passed
# to functions, or functions inside functions. So there are three cases to care
# about
#
# (Python Client)-->(Python ClientServer)-->(Internal Raylet Process)
#
# * _client_api should be set if we're inside the client
# * _server_api should be set if we're inside the clientserver
# * Both will be set if we're running both (as in a test)
# * Neither should be set if we're inside the raylet (but we still need to shim
# from the client API surface to the Ray API)
#
# The job of RayAPIStub (below) delegates to the appropriate one of these
# depending on what's set or not. Then, all users importing the ray object
# from this package get the stub which routes them to the appropriate APIImpl.
_client_api: Optional[APIImpl] = None
_server_api: Optional[APIImpl] = None
# The reason for _is_server is a hack around the above comment while running
# tests. If we have both a client and a server trying to control these static
# variables then we need a way to decide which to use. In this case, both
# _client_api and _server_api are set.
# This boolean flips between the two
_is_server: bool = False
@contextmanager
def stash_api_for_tests(in_test: bool):
global _is_server
is_server = _is_server
if in_test:
_is_server = True
yield _server_api
if in_test:
_is_server = is_server
def _set_client_api(val: Optional[APIImpl]):
global _client_api
global _is_server
if _client_api is not None:
raise Exception("Trying to set more than one client API")
_client_api = val
_is_server = False
def _set_server_api(val: Optional[APIImpl]):
global _server_api
global _is_server
if _server_api is not None:
raise Exception("Trying to set more than one server API")
_server_api = val
_is_server = True
def reset_api():
global _client_api
global _server_api
global _is_server
_client_api = None
_server_api = None
_is_server = False
def _get_client_api() -> APIImpl:
global _client_api
global _server_api
global _is_server
api = None
if _is_server:
api = _server_api
else:
api = _client_api
if api is None:
# We're inside a raylet worker
from ray.experimental.client.server.core_ray_api import CoreRayAPI
return CoreRayAPI()
return api
class RayAPIStub:
"""This class stands in as the replacement API for the `import ray` module.
Much like the ray module, this mostly delegates the work to the
_client_worker. As parts of the ray API are covered, they are piped through
here or on the client worker API.
"""
def __init__(self):
from ray.experimental.client.api import ClientAPI
self.api = ClientAPI()
self.client_worker = None
self._server = None
self._connected_with_init = False
self._inside_client_test = False
def connect(self,
conn_str: str,
secure: bool = False,
metadata: List[Tuple[str, str]] = None,
stub=None):
metadata: List[Tuple[str, str]] = None) -> None:
"""Connect the Ray Client to a server.
Args:
conn_str: Connection string, in the form "[host]:port"
secure: Whether to use a TLS secured gRPC channel
metadata: gRPC metadata to send on connect
"""
# Delay imports until connect to avoid circular imports.
from ray.experimental.client.worker import Worker
_client_worker = Worker(
conn_str, secure=secure, metadata=metadata, stub=stub)
_set_client_api(ClientAPI(_client_worker))
import ray._private.client_mode_hook
if self.client_worker is not None:
if self._connected_with_init:
return
raise Exception(
"ray.connect() called, but ray client is already connected")
if not self._inside_client_test:
# If we're calling a client connect specifically and we're not
# currently in client mode, ensure we are.
ray._private.client_mode_hook._explicitly_enable_client_mode()
self.client_worker = Worker(conn_str, secure=secure, metadata=metadata)
self.api.worker = self.client_worker
def disconnect(self):
global _client_api
if _client_api is not None:
_client_api.close()
_client_api = None
"""Disconnect the Ray Client.
"""
if self.client_worker is not None:
self.client_worker.close()
self.client_worker = None
# remote can be called outside of a connection, which is why it
# exists on the same API layer as connect() itself.
def remote(self, *args, **kwargs):
"""remote is the hook stub passed on to replace `ray.remote`.
This sets up remote functions or actors, as the decorator,
but does not execute them.
Args:
args: opaque arguments
kwargs: opaque keyword arguments
"""
return self.api.remote(*args, **kwargs)
def __getattr__(self, key: str):
global _get_client_api
api = _get_client_api()
return getattr(api, key)
if not self.is_connected():
raise Exception("Ray Client is not connected. "
"Please connect by calling `ray.connect`.")
return getattr(self.api, key)
def is_connected(self) -> bool:
return self.api is not None
def init(self, *args, **kwargs):
if self._server is not None:
raise Exception("Trying to start two instances of ray via client")
import ray.experimental.client.server.server as ray_client_server
self._server, address_info = ray_client_server.init_and_serve(
"localhost:50051", *args, **kwargs)
self.connect("localhost:50051")
self._connected_with_init = True
return address_info
def shutdown(self, _exiting_interpreter=False):
self.disconnect()
import ray.experimental.client.server.server as ray_client_server
if self._server is None:
return
ray_client_server.shutdown_with_server(self._server,
_exiting_interpreter)
self._server = None
ray = RayAPIStub()
+82 -96
View File
@@ -1,74 +1,51 @@
# This file defines an interface and client-side API stub
# for referring either to the core Ray API or the same interface
# from the Ray client.
#
# In tandem with __init__.py, we want to expose an API that's
# close to `python/ray/__init__.py` but with more than one implementation.
# The stubs in __init__ should call into a well-defined interface.
# Only the core Ray API implementation should actually `import ray`
# (and thus import all the raylet worker C bindings and such).
# But to make sure that we're matching these calls, we define this API.
from abc import ABC
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, Union, Optional
import ray.core.generated.ray_client_pb2 as ray_client_pb2
"""This file defines the interface between the ray client worker
and the overall ray module API.
"""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from ray.experimental.client.common import ClientActorHandle
from ray.experimental.client.common import ClientStub
from ray.experimental.client.common import ClientActorHandle
from ray.experimental.client.common import ClientObjectRef
from ray._raylet import ObjectRef
# Use the imports for type checking. This is a python 3.6 limitation.
# See https://www.python.org/dev/peps/pep-0563/
PutType = Union[ClientObjectRef, ObjectRef]
class APIImpl(ABC):
"""
APIImpl is the interface to implement for whichever version of the core
Ray API that needs abstracting when run in client mode.
class ClientAPI:
"""The Client-side methods corresponding to the ray API. Delegates
to the Client Worker that contains the connection to the ClientServer.
"""
@abstractmethod
def get(self, vals, *, timeout: Optional[float] = None) -> Any:
"""
get is the hook stub passed on to replace `ray.get`
def __init__(self, worker=None):
self.worker = worker
def get(self, vals, *, timeout=None):
"""get is the hook stub passed on to replace `ray.get`
Args:
vals: [Client]ObjectRef or list of these refs to retrieve.
timeout: Optional timeout in milliseconds
"""
pass
return self.worker.get(vals, timeout=timeout)
@abstractmethod
def put(self, vals: Any, *args,
**kwargs) -> Union["ClientObjectRef", "ObjectRef"]:
"""
put is the hook stub passed on to replace `ray.put`
def put(self, *args, **kwargs):
"""put is the hook stub passed on to replace `ray.put`
Args:
vals: The value or list of values to `put`.
args: opaque arguments
kwargs: opaque keyword arguments
"""
pass
return self.worker.put(*args, **kwargs)
@abstractmethod
def wait(self, *args, **kwargs):
"""
wait is the hook stub passed on to replace `ray.wait`
"""wait is the hook stub passed on to replace `ray.wait`
Args:
args: opaque arguments
kwargs: opaque keyword arguments
"""
pass
return self.worker.wait(*args, **kwargs)
@abstractmethod
def remote(self, *args, **kwargs):
"""
remote is the hook stub passed on to replace `ray.remote`.
"""remote is the hook stub passed on to replace `ray.remote`.
This sets up remote functions or actors, as the decorator,
but does not execute them.
@@ -77,12 +54,24 @@ class APIImpl(ABC):
args: opaque arguments
kwargs: opaque keyword arguments
"""
pass
# Delayed import to avoid a cyclic import
from ray.experimental.client.common import remote_decorator
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
# This is the case where the decorator is just @ray.remote.
return remote_decorator(options=None)(args[0])
error_string = ("The @ray.remote decorator must be applied either "
"with no arguments and no parentheses, for example "
"'@ray.remote', or it must be applied using some of "
"the arguments 'num_returns', 'num_cpus', 'num_gpus', "
"'memory', 'object_store_memory', 'resources', "
"'max_calls', or 'max_restarts', like "
"'@ray.remote(num_returns=2, "
"resources={\"CustomResource\": 1})'.")
assert len(args) == 0 and len(kwargs) > 0, error_string
return remote_decorator(options=kwargs)
@abstractmethod
def call_remote(self, instance: "ClientStub", *args, **kwargs):
"""
call_remote is called by stub objects to execute them remotely.
"""call_remote is called by stub objects to execute them remotely.
This is used by stub objects in situations where they're called
with .remote, eg, `f.remote()` or `actor_cls.remote()`.
@@ -95,31 +84,57 @@ class APIImpl(ABC):
args: opaque arguments
kwargs: opaque keyword arguments
"""
pass
return self.worker.call_remote(instance, *args, **kwargs)
@abstractmethod
def close(self) -> None:
def call_release(self, id: bytes) -> None:
"""Attempts to release an object reference.
When client references are destructed, they release their reference,
which can opportunistically send a notification through the datachannel
to release the reference being held for that object on the server.
Args:
id: The id of the reference to release on the server side.
"""
close cleans up an API connection by closing any channels or
return self.worker.call_release(id)
def call_retain(self, id: bytes) -> None:
"""Attempts to retain a client object reference.
Increments the reference count on the client side, to prevent
the client worker from attempting to release the server reference.
Args:
id: The id of the reference to retain on the client side.
"""
return self.worker.call_retain(id)
def close(self) -> None:
"""close cleans up an API connection by closing any channels or
shutting down any servers gracefully.
"""
pass
return self.worker.close()
@abstractmethod
def kill(self, actor, *, no_restart=True):
def get_actor(self, name: str) -> "ClientActorHandle":
"""Returns a handle to an actor by name.
Args:
name: The name passed to this actor by
Actor.options(name="name").remote()
"""
kill forcibly stops an actor running in the cluster
return self.worker.get_actor(name)
def kill(self, actor: "ClientActorHandle", *, no_restart=True):
"""kill forcibly stops an actor running in the cluster
Args:
no_restart: Whether this actor should be restarted if it's a
restartable actor.
"""
pass
return self.worker.terminate_actor(actor, no_restart)
@abstractmethod
def cancel(self, obj, *, force=False, recursive=True):
"""
Cancels a task on the cluster.
def cancel(self, obj: "ClientObjectRef", *, force=False, recursive=True):
"""Cancels a task on the cluster.
If the specified task is pending execution, it will not be executed. If
the task is currently executing, the behavior depends on the ``force``
@@ -136,46 +151,11 @@ class APIImpl(ABC):
recursive (boolean): Whether to try to cancel tasks submitted by
the task specified.
"""
pass
class ClientAPI(APIImpl):
"""
The Client-side methods corresponding to the ray API. Delegates
to the Client Worker that contains the connection to the ClientServer.
"""
def __init__(self, worker):
self.worker = worker
def get(self, vals, *, timeout=None):
return self.worker.get(vals, timeout=timeout)
def put(self, *args, **kwargs):
return self.worker.put(*args, **kwargs)
def wait(self, *args, **kwargs):
return self.worker.wait(*args, **kwargs)
def remote(self, *args, **kwargs):
return self.worker.remote(*args, **kwargs)
def call_remote(self, instance: "ClientStub", *args, **kwargs):
return self.worker.call_remote(instance, *args, **kwargs)
def close(self) -> None:
return self.worker.close()
def kill(self, actor: "ClientActorHandle", *, no_restart=True):
return self.worker.terminate_actor(actor, no_restart)
def cancel(self, obj: "ClientObjectRef", *, force=False, recursive=True):
return self.worker.terminate_task(obj, force, recursive)
# Various metadata methods for the client that are defined in the protocol.
def is_initialized(self) -> bool:
""" True if our client is connected, and if the server is initialized.
"""True if our client is connected, and if the server is initialized.
Returns:
A boolean determining if the client is connected and
server initialized.
@@ -188,6 +168,8 @@ class ClientAPI(APIImpl):
Returns:
Information about the Ray clients in the cluster.
"""
# This should be imported here, otherwise, it will error doc build.
import ray.core.generated.ray_client_pb2 as ray_client_pb2
return self.worker.get_cluster_info(
ray_client_pb2.ClusterInfoType.NODES)
@@ -201,6 +183,8 @@ class ClientAPI(APIImpl):
A dictionary mapping resource name to the total quantity of that
resource in the cluster.
"""
# This should be imported here, otherwise, it will error doc build.
import ray.core.generated.ray_client_pb2 as ray_client_pb2
return self.worker.get_cluster_info(
ray_client_pb2.ClusterInfoType.CLUSTER_RESOURCES)
@@ -216,6 +200,8 @@ class ClientAPI(APIImpl):
A dictionary mapping resource name to the total quantity of that
resource in the cluster.
"""
# This should be imported here, otherwise, it will error doc build.
import ray.core.generated.ray_client_pb2 as ray_client_pb2
return self.worker.get_cluster_info(
ray_client_pb2.ClusterInfoType.AVAILABLE_RESOURCES)
@@ -0,0 +1,182 @@
"""Implements the client side of the client/server pickling protocol.
All ray client client/server data transfer happens through this pickling
protocol. The model is as follows:
* All Client objects (eg ClientObjectRef) always live on the client and
are never represented in the server
* All Ray objects (eg, ray.ObjectRef) always live on the server and are
never returned to the client
* In order to translate between these two references, PickleStub tuples
are generated as persistent ids in the data blobs during the pickling
and unpickling of these objects.
The PickleStubs have just enough information to find or generate their
associated partner object on either side.
This also has the advantage of avoiding predefined pickle behavior for ray
objects, which may include ray internal reference counting.
ClientPickler dumps things from the client into the appropriate stubs
ServerUnpickler loads stubs from the server into their client counterparts.
"""
import cloudpickle
import io
import sys
from typing import NamedTuple
from typing import Any
from typing import Dict
from typing import Optional
from ray.experimental.client import RayAPIStub
from ray.experimental.client.common import ClientObjectRef
from ray.experimental.client.common import ClientActorHandle
from ray.experimental.client.common import ClientActorRef
from ray.experimental.client.common import ClientActorClass
from ray.experimental.client.common import ClientRemoteFunc
from ray.experimental.client.common import ClientRemoteMethod
from ray.experimental.client.common import OptionWrapper
from ray.experimental.client.common import SelfReferenceSentinel
import ray.core.generated.ray_client_pb2 as ray_client_pb2
from ray._private.client_mode_hook import disable_client_hook
if sys.version_info < (3, 8):
try:
import pickle5 as pickle # noqa: F401
except ImportError:
import pickle # noqa: F401
else:
import pickle # noqa: F401
# NOTE(barakmich): These PickleStubs are really close to
# the data for an exectuion, with no arguments. Combine the two?
PickleStub = NamedTuple("PickleStub",
[("type", str), ("client_id", str), ("ref_id", bytes),
("name", Optional[str]),
("baseline_options", Optional[Dict])])
class ClientPickler(cloudpickle.CloudPickler):
def __init__(self, client_id, *args, **kwargs):
super().__init__(*args, **kwargs)
self.client_id = client_id
def persistent_id(self, obj):
if isinstance(obj, RayAPIStub):
return PickleStub(
type="Ray",
client_id=self.client_id,
ref_id=b"",
name=None,
baseline_options=None,
)
elif isinstance(obj, ClientObjectRef):
return PickleStub(
type="Object",
client_id=self.client_id,
ref_id=obj.id,
name=None,
baseline_options=None,
)
elif isinstance(obj, ClientActorHandle):
return PickleStub(
type="Actor",
client_id=self.client_id,
ref_id=obj._actor_id,
name=None,
baseline_options=None,
)
elif isinstance(obj, ClientRemoteFunc):
# TODO(barakmich): This is going to have trouble with mutually
# recursive functions that haven't, as yet, been executed. It's
# relatively doable (keep track of intermediate refs in progress
# with ensure_ref and return appropriately) But punting for now.
if obj._ref is None:
obj._ensure_ref()
if type(obj._ref) == SelfReferenceSentinel:
return PickleStub(
type="RemoteFuncSelfReference",
client_id=self.client_id,
ref_id=b"",
name=None,
baseline_options=None,
)
return PickleStub(
type="RemoteFunc",
client_id=self.client_id,
ref_id=obj._ref.id,
name=None,
baseline_options=obj._options,
)
elif isinstance(obj, ClientActorClass):
# TODO(barakmich): Mutual recursion, as above.
if obj._ref is None:
obj._ensure_ref()
if type(obj._ref) == SelfReferenceSentinel:
return PickleStub(
type="RemoteActorSelfReference",
client_id=self.client_id,
ref_id=b"",
name=None,
baseline_options=None,
)
return PickleStub(
type="RemoteActor",
client_id=self.client_id,
ref_id=obj._ref.id,
name=None,
baseline_options=obj._options,
)
elif isinstance(obj, ClientRemoteMethod):
return PickleStub(
type="RemoteMethod",
client_id=self.client_id,
ref_id=obj.actor_handle.actor_ref.id,
name=obj.method_name,
baseline_options=None,
)
elif isinstance(obj, OptionWrapper):
raise NotImplementedError(
"Sending a partial option is unimplemented")
return None
class ServerUnpickler(pickle.Unpickler):
def persistent_load(self, pid):
assert isinstance(pid, PickleStub)
if pid.type == "Object":
return ClientObjectRef(id=pid.ref_id)
elif pid.type == "Actor":
return ClientActorHandle(ClientActorRef(id=pid.ref_id))
else:
raise NotImplementedError("Being passed back an unknown stub")
def dumps_from_client(obj: Any, client_id: str, protocol=None) -> bytes:
with disable_client_hook():
with io.BytesIO() as file:
cp = ClientPickler(client_id, file, protocol=protocol)
cp.dump(obj)
return file.getvalue()
def loads_from_server(data: bytes,
*,
fix_imports=True,
encoding="ASCII",
errors="strict") -> Any:
if isinstance(data, str):
raise TypeError("Can't load pickle from unicode string")
file = io.BytesIO(data)
return ServerUnpickler(
file, fix_imports=fix_imports, encoding=encoding,
errors=errors).load()
def convert_to_arg(val: Any, client_id: str) -> ray_client_pb2.Arg:
out = ray_client_pb2.Arg()
out.local = ray_client_pb2.Arg.Locality.INTERNED
out.data = dumps_from_client(val, client_id)
return out
+151 -130
View File
@@ -1,16 +1,31 @@
import ray.core.generated.ray_client_pb2 as ray_client_pb2
from ray.experimental.client import ray
from typing import Any
from typing import Dict
from ray import cloudpickle
from ray.experimental.client.options import validate_options
import base64
import inspect
from ray.util.inspect import is_cython
import json
import threading
from typing import Any
from typing import List
from typing import Dict
from typing import Optional
from typing import Union
class ClientBaseRef:
def __init__(self, id, handle=None):
self.id = id
self.handle = handle
def __init__(self, id: bytes):
self.id = None
if not isinstance(id, bytes):
raise TypeError("ClientRefs must be created with bytes IDs")
self.id: bytes = id
ray.call_retain(id)
def binary(self):
return self.id
def __eq__(self, other):
return self.id == other.id
def __repr__(self):
return "%s(%s)" % (
@@ -18,20 +33,16 @@ class ClientBaseRef:
self.id.hex(),
)
def __eq__(self, other):
return self.id == other.id
def __hash__(self):
return hash(self.id)
def binary(self):
return self.id
@classmethod
def from_remote_ref(cls, ref: ray_client_pb2.RemoteRef):
return cls(id=ref.id, handle=ref.handle)
def __del__(self):
if ray.is_connected() and self.id is not None:
ray.call_release(self.id)
class ClientObjectRef(ClientBaseRef):
def _unpack_ref(self):
return cloudpickle.loads(self.handle)
pass
class ClientActorRef(ClientBaseRef):
@@ -43,8 +54,7 @@ class ClientStub:
class ClientRemoteFunc(ClientStub):
"""
A stub created on the Ray Client to represent a remote
"""A stub created on the Ray Client to represent a remote
function that can be exectued on the cluster.
This class is allowed to be passed around between remote functions.
@@ -53,55 +63,57 @@ class ClientRemoteFunc(ClientStub):
_func: The actual function to execute remotely
_name: The original name of the function
_ref: The ClientObjectRef of the pickled code of the function, _func
_raylet_remote: The Raylet-side ray.remote_function.RemoteFunction
for this object
"""
def __init__(self, f):
def __init__(self, f, options=None):
self._lock = threading.Lock()
self._func = f
self._name = f.__name__
self.id = None
# self._ref can be lazily instantiated. Rather than eagerly creating
# function data objects in the server we can put them just before we
# execute the function, especially in cases where many @ray.remote
# functions exist in a library and only a handful are ever executed by
# a user of the library.
#
# TODO(barakmich): This ref might actually be better as a serialized
# ObjectRef. This requires being able to serialize the ref without
# pinning it (as the lifetime of the ref is tied with the server, not
# the client)
self._ref = None
self._raylet_remote = None
self._options = validate_options(options)
def __call__(self, *args, **kwargs):
raise TypeError(f"Remote function cannot be called directly. "
"Use {self._name}.remote method instead")
def remote(self, *args, **kwargs):
return ray.call_remote(self, *args, **kwargs)
return return_refs(ray.call_remote(self, *args, **kwargs))
def _get_ray_remote_impl(self):
if self._raylet_remote is None:
self._raylet_remote = ray.remote(self._func)
return self._raylet_remote
def options(self, **kwargs):
return OptionWrapper(self, kwargs)
def _remote(self, args=[], kwargs={}, **option_args):
return self.options(**option_args).remote(*args, **kwargs)
def __repr__(self):
return "ClientRemoteFunc(%s, %s)" % (self._name, self._ref)
def _ensure_ref(self):
with self._lock:
if self._ref is None:
# While calling ray.put() on our function, if
# our function is recursive, it will attempt to
# encode the ClientRemoteFunc -- itself -- and
# infinitely recurse on _ensure_ref.
#
# So we set the state of the reference to be an
# in-progress self reference value, which
# the encoding can detect and handle correctly.
self._ref = SelfReferenceSentinel()
self._ref = ray.put(self._func)
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
if self._ref is None:
self._ref = ray.put(self._func)
self._ensure_ref()
task = ray_client_pb2.ClientTask()
task.type = ray_client_pb2.ClientTask.FUNCTION
task.name = self._name
task.payload_id = self._ref.handle
task.payload_id = self._ref.id
set_task_options(task, self._options, "baseline_options")
return task
class ClientActorClass(ClientStub):
""" A stub created on the Ray Client to represent an actor class.
"""A stub created on the Ray Client to represent an actor class.
It is wrapped by ray.remote and can be executed on the cluster.
@@ -109,39 +121,40 @@ class ClientActorClass(ClientStub):
actor_cls: The actual class to execute remotely
_name: The original name of the class
_ref: The ClientObjectRef of the pickled `actor_cls`
_raylet_remote: The Raylet-side ray.ActorClass for this object
"""
def __init__(self, actor_cls):
def __init__(self, actor_cls, options=None):
self.actor_cls = actor_cls
self._name = actor_cls.__name__
self._ref = None
self._raylet_remote = None
self._options = validate_options(options)
def __call__(self, *args, **kwargs):
raise TypeError(f"Remote actor cannot be instantiated directly. "
"Use {self._name}.remote() instead")
def __getstate__(self) -> Dict:
state = {
"actor_cls": self.actor_cls,
"_name": self._name,
"_ref": self._ref,
}
return state
def _ensure_ref(self):
if self._ref is None:
# As before, set the state of the reference to be an
# in-progress self reference value, which
# the encoding can detect and handle correctly.
self._ref = SelfReferenceSentinel()
self._ref = ray.put(self.actor_cls)
def __setstate__(self, state: Dict) -> None:
self.actor_cls = state["actor_cls"]
self._name = state["_name"]
self._ref = state["_ref"]
def remote(self, *args, **kwargs):
def remote(self, *args, **kwargs) -> "ClientActorHandle":
# Actually instantiate the actor
ref = ray.call_remote(self, *args, **kwargs)
return ClientActorHandle(ClientActorRef(ref.id, ref.handle), self)
ref_ids = ray.call_remote(self, *args, **kwargs)
assert len(ref_ids) == 1
return ClientActorHandle(ClientActorRef(ref_ids[0]))
def options(self, **kwargs):
return ActorOptionWrapper(self, kwargs)
def _remote(self, args=[], kwargs={}, **option_args):
return self.options(**option_args).remote(*args, **kwargs)
def __repr__(self):
return "ClientRemoteActor(%s, %s)" % (self._name, self._ref)
return "ClientActorClass(%s, %s)" % (self._name, self._ref)
def __getattr__(self, key):
if key not in self.__dict__:
@@ -149,12 +162,12 @@ class ClientActorClass(ClientStub):
raise NotImplementedError("static methods")
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
if self._ref is None:
self._ref = ray.put(self.actor_cls)
self._ensure_ref()
task = ray_client_pb2.ClientTask()
task.type = ray_client_pb2.ClientTask.ACTOR
task.name = self._name
task.payload_id = self._ref.handle
task.payload_id = self._ref.id
set_task_options(task, self._options, "baseline_options")
return task
@@ -174,29 +187,12 @@ class ClientActorHandle(ClientStub):
ray.actor.ActorHandle contained in the actor_id ref.
"""
def __init__(self, actor_ref: ClientActorRef,
actor_class: ClientActorClass):
def __init__(self, actor_ref: ClientActorRef):
self.actor_ref = actor_ref
self.actor_class = actor_class
self._real_actor_handle = None
def _get_ray_remote_impl(self):
if self._real_actor_handle is None:
self._real_actor_handle = cloudpickle.loads(self.actor_ref.handle)
return self._real_actor_handle
def __getstate__(self) -> Dict:
state = {
"actor_ref": self.actor_ref,
"actor_class": self.actor_class,
"_real_actor_handle": self._real_actor_handle,
}
return state
def __setstate__(self, state: Dict) -> None:
self.actor_ref = state["actor_ref"]
self.actor_class = state["actor_class"]
self._real_actor_handle = state["_real_actor_handle"]
def __del__(self) -> None:
if ray.is_connected():
ray.call_release(self.actor_ref.id)
@property
def _actor_id(self):
@@ -226,65 +222,90 @@ class ClientRemoteMethod(ClientStub):
def __call__(self, *args, **kwargs):
raise TypeError(f"Remote method cannot be called directly. "
"Use {self._name}.remote() instead")
def _get_ray_remote_impl(self):
return getattr(self.actor_handle._get_ray_remote_impl(),
self.method_name)
def __getstate__(self) -> Dict:
state = {
"actor_handle": self.actor_handle,
"method_name": self.method_name,
}
return state
def __setstate__(self, state: Dict) -> None:
self.actor_handle = state["actor_handle"]
self.method_name = state["method_name"]
f"Use {self._name}.remote() instead")
def remote(self, *args, **kwargs):
return ray.call_remote(self, *args, **kwargs)
return return_refs(ray.call_remote(self, *args, **kwargs))
def __repr__(self):
name = "%s.%s" % (self.actor_handle.actor_class._name,
self.method_name)
return "ClientRemoteMethod(%s, %s)" % (name,
self.actor_handle.actor_id)
return "ClientRemoteMethod(%s, %s)" % (self.method_name,
self.actor_handle)
def options(self, **kwargs):
return OptionWrapper(self, kwargs)
def _remote(self, args=[], kwargs={}, **option_args):
return self.options(**option_args).remote(*args, **kwargs)
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
task = ray_client_pb2.ClientTask()
task.type = ray_client_pb2.ClientTask.METHOD
task.name = self.method_name
task.payload_id = self.actor_handle.actor_ref.handle
task.payload_id = self.actor_handle.actor_ref.id
return task
def convert_from_arg(pb) -> Any:
if pb.local == ray_client_pb2.Arg.Locality.REFERENCE:
return ClientObjectRef(pb.reference_id)
elif pb.local == ray_client_pb2.Arg.Locality.INTERNED:
return cloudpickle.loads(pb.data)
class OptionWrapper:
def __init__(self, stub: ClientStub, options: Optional[Dict[str, Any]]):
self.remote_stub = stub
self.options = validate_options(options)
raise Exception("convert_from_arg: Uncovered locality enum")
def remote(self, *args, **kwargs):
return return_refs(ray.call_remote(self, *args, **kwargs))
def __getattr__(self, key):
return getattr(self.remote_stub, key)
def _prepare_client_task(self):
task = self.remote_stub._prepare_client_task()
set_task_options(task, self.options)
return task
def convert_to_arg(val):
out = ray_client_pb2.Arg()
if isinstance(val, ClientObjectRef):
out.local = ray_client_pb2.Arg.Locality.REFERENCE
out.reference_id = val.id
else:
out.local = ray_client_pb2.Arg.Locality.INTERNED
out.data = cloudpickle.dumps(val)
return out
class ActorOptionWrapper(OptionWrapper):
def remote(self, *args, **kwargs):
ref_ids = ray.call_remote(self, *args, **kwargs)
assert len(ref_ids) == 1
return ClientActorHandle(ClientActorRef(ref_ids[0]))
def encode_exception(exception) -> str:
data = cloudpickle.dumps(exception)
return base64.standard_b64encode(data).decode()
def set_task_options(task: ray_client_pb2.ClientTask,
options: Optional[Dict[str, Any]],
field: str = "options") -> None:
if options is None:
task.ClearField(field)
return
options_str = json.dumps(options)
getattr(task, field).json_options = options_str
def decode_exception(data) -> Exception:
data = base64.standard_b64decode(data)
return cloudpickle.loads(data)
def return_refs(ids: List[bytes]
) -> Union[None, ClientObjectRef, List[ClientObjectRef]]:
if len(ids) == 1:
return ClientObjectRef(ids[0])
if len(ids) == 0:
return None
return [ClientObjectRef(id) for id in ids]
class DataEncodingSentinel:
def __repr__(self) -> str:
return self.__class__.__name__
class SelfReferenceSentinel(DataEncodingSentinel):
pass
def remote_decorator(options: Optional[Dict[str, Any]]):
def decorator(function_or_class) -> ClientStub:
if (inspect.isfunction(function_or_class)
or is_cython(function_or_class)):
return ClientRemoteFunc(function_or_class, options=options)
elif inspect.isclass(function_or_class):
return ClientActorClass(function_or_class, options=options)
else:
raise TypeError("The @ray.remote decorator must be applied to "
"either a function or to a class.")
return decorator
@@ -0,0 +1,108 @@
"""This file implements a threaded stream controller to abstract a data stream
back to the ray clientserver.
"""
import logging
import queue
import threading
import grpc
from typing import Any
from typing import Dict
import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
logger = logging.getLogger(__name__)
# The maximum field value for request_id -- which is also the maximum
# number of simultaneous in-flight requests.
INT32_MAX = (2**31) - 1
class DataClient:
def __init__(self, channel: "grpc._channel.Channel", client_id: str):
"""Initializes a thread-safe datapath over a Ray Client gRPC channel.
Args:
channel: connected gRPC channel
client_id: the generated ID representing this client
"""
self.channel = channel
self.request_queue = queue.Queue()
self.data_thread = self._start_datathread()
self.ready_data: Dict[int, Any] = {}
self.cv = threading.Condition()
self._req_id = 0
self._client_id = client_id
self.data_thread.start()
def _next_id(self) -> int:
self._req_id += 1
if self._req_id > INT32_MAX:
self._req_id = 1
# Responses that aren't tracked (like opportunistic releases)
# have req_id=0, so make sure we never mint such an id.
assert self._req_id != 0
return self._req_id
def _start_datathread(self) -> threading.Thread:
return threading.Thread(target=self._data_main, args=(), daemon=True)
def _data_main(self) -> None:
stub = ray_client_pb2_grpc.RayletDataStreamerStub(self.channel)
resp_stream = stub.Datapath(
iter(self.request_queue.get, None),
metadata=(("client_id", self._client_id), ))
try:
for response in resp_stream:
if response.req_id == 0:
# This is not being waited for.
logger.debug(f"Got unawaited response {response}")
continue
with self.cv:
self.ready_data[response.req_id] = response
self.cv.notify_all()
except grpc.RpcError as e:
if grpc.StatusCode.CANCELLED == e.code():
# Gracefully shutting down
logger.info("Cancelling data channel")
else:
logger.error(
f"Got Error from data channel -- shutting down: {e}")
raise e
def close(self) -> None:
if self.request_queue is not None:
self.request_queue.put(None)
if self.data_thread is not None:
self.data_thread.join()
def _blocking_send(self, req: ray_client_pb2.DataRequest
) -> ray_client_pb2.DataResponse:
req_id = self._next_id()
req.req_id = req_id
self.request_queue.put(req)
data = None
with self.cv:
self.cv.wait_for(lambda: req_id in self.ready_data)
data = self.ready_data[req_id]
del self.ready_data[req_id]
return data
def GetObject(self, request: ray_client_pb2.GetRequest,
context=None) -> ray_client_pb2.GetResponse:
datareq = ray_client_pb2.DataRequest(get=request, )
resp = self._blocking_send(datareq)
return resp.get
def PutObject(self, request: ray_client_pb2.PutRequest,
context=None) -> ray_client_pb2.PutResponse:
datareq = ray_client_pb2.DataRequest(put=request, )
resp = self._blocking_send(datareq)
return resp.put
def ReleaseObject(self,
request: ray_client_pb2.ReleaseRequest,
context=None) -> None:
datareq = ray_client_pb2.DataRequest(release=request, )
self.request_queue.put(datareq)
@@ -0,0 +1,7 @@
from ray.experimental.client import ray
from ray.tune import tune
ray.connect("localhost:50051")
tune.run("PG", config={"env": "CartPole-v0"})
@@ -0,0 +1,86 @@
"""This file implements a threaded stream controller to return logs back from
the ray clientserver.
"""
import sys
import logging
import queue
import threading
import grpc
import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
logger = logging.getLogger(__name__)
# TODO(barakmich): Running a logger in a logger causes loopback.
# The client logger need its own root -- possibly this one.
# For the moment, let's just not propogate beyond this point.
logger.propagate = False
class LogstreamClient:
def __init__(self, channel: "grpc._channel.Channel"):
"""Initializes a thread-safe log stream over a Ray Client gRPC channel.
Args:
channel: connected gRPC channel
"""
self.channel = channel
self.request_queue = queue.Queue()
self.log_thread = self._start_logthread()
self.log_thread.start()
def _start_logthread(self) -> threading.Thread:
return threading.Thread(target=self._log_main, args=(), daemon=True)
def _log_main(self) -> None:
stub = ray_client_pb2_grpc.RayletLogStreamerStub(self.channel)
log_stream = stub.Logstream(iter(self.request_queue.get, None))
try:
for record in log_stream:
if record.level < 0:
self.stdstream(level=record.level, msg=record.msg)
self.log(level=record.level, msg=record.msg)
except grpc.RpcError as e:
if grpc.StatusCode.CANCELLED != e.code():
# Not just shutting down normally
logger.error(
f"Got Error from logger channel -- shutting down: {e}")
raise e
def log(self, level: int, msg: str):
"""Log the message from the log stream.
By default, calls logger.log but this can be overridden.
Args:
level: The loglevel of the received log message
msg: The content of the message
"""
logger.log(level=level, msg=msg)
def stdstream(self, level: int, msg: str):
"""Log the stdout/stderr entry from the log stream.
By default, calls print but this can be overridden.
Args:
level: The loglevel of the received log message
msg: The content of the message
"""
print_file = sys.stderr if level == -2 else sys.stdout
print(msg, file=print_file)
def set_logstream_level(self, level: int):
logger.setLevel(level)
req = ray_client_pb2.LogSettingsRequest()
req.enabled = True
req.loglevel = level
self.request_queue.put(req)
def close(self) -> None:
self.request_queue.put(None)
if self.log_thread is not None:
self.log_thread.join()
def disable_logs(self) -> None:
req = ray_client_pb2.LogSettingsRequest()
req.enabled = False
self.request_queue.put(req)
+54
View File
@@ -0,0 +1,54 @@
from typing import Any
from typing import Dict
from typing import Optional
options = {
"num_returns": (int, lambda x: x >= 0,
"The keyword 'num_returns' only accepts 0 "
"or a positive integer"),
"num_cpus": (),
"num_gpus": (),
"resources": (),
"accelerator_type": (),
"max_calls": (int, lambda x: x >= 0,
"The keyword 'max_calls' only accepts 0 "
"or a positive integer"),
"max_restarts": (int, lambda x: x >= -1,
"The keyword 'max_restarts' only accepts -1, 0 "
"or a positive integer"),
"max_task_retries": (int, lambda x: x >= -1,
"The keyword 'max_task_retries' only accepts -1, 0 "
"or a positive integer"),
"max_retries": (int, lambda x: x >= -1,
"The keyword 'max_retries' only accepts 0, -1 "
"or a positive integer"),
"max_concurrency": (),
"name": (),
"lifetime": (),
"memory": (),
"object_store_memory": (),
"placement_group": (),
"placement_group_bundle_index": (),
"placement_group_capture_child_tasks": (),
"override_environment_variables": (),
}
def validate_options(
kwargs_dict: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
if kwargs_dict is None:
return None
if len(kwargs_dict) == 0:
return None
out = {}
for k, v in kwargs_dict.items():
if k not in options.keys():
raise TypeError(f"Invalid option passed to remote(): {k}")
validator = options[k]
if len(validator) != 0:
if not isinstance(v, validator[0]):
raise ValueError(validator[2])
if not validator[1](v):
raise ValueError(validator[2])
out[k] = v
return out
@@ -0,0 +1,17 @@
from contextlib import contextmanager
import ray.experimental.client.server.server as ray_client_server
from ray.experimental.client import ray
@contextmanager
def ray_start_client_server():
ray._inside_client_test = True
server = ray_client_server.serve("localhost:50051")
ray.connect("localhost:50051")
try:
yield ray
finally:
ray._inside_client_test = False
ray.disconnect()
server.stop(0)
@@ -1,101 +0,0 @@
# Along with `api.py` this is the stub that interfaces with
# the real (C-binding, raylet) ray core.
#
# Ideally, the first import line is the only time we actually
# import ray in this library (excluding the main function for the server)
#
# While the stub is trivial, it allows us to check that the calls we're
# making into the core-ray module are contained and well-defined.
from typing import Any
from typing import Optional
from typing import Union
import ray
from ray.experimental.client.api import APIImpl
from ray.experimental.client.common import ClientObjectRef
from ray.experimental.client.common import ClientStub
class CoreRayAPI(APIImpl):
"""
Implements the equivalent client-side Ray API by simply passing along to
the Core Ray API. Primarily used inside of Ray Workers as a trampoline back
to core ray when passed client stubs.
"""
def get(self, vals, *, timeout: Optional[float] = None) -> Any:
if isinstance(vals, list):
if isinstance(vals[0], ClientObjectRef):
return ray.get(
[val._unpack_ref() for val in vals], timeout=timeout)
elif isinstance(vals, ClientObjectRef):
return ray.get(vals._unpack_ref(), timeout=timeout)
return ray.get(vals, timeout=timeout)
def put(self, vals: Any, *args,
**kwargs) -> Union[ClientObjectRef, ray._raylet.ObjectRef]:
return ray.put(vals, *args, **kwargs)
def wait(self, *args, **kwargs):
return ray.wait(*args, **kwargs)
def remote(self, *args, **kwargs):
return ray.remote(*args, **kwargs)
def call_remote(self, instance: ClientStub, *args, **kwargs):
return instance._get_ray_remote_impl().remote(*args, **kwargs)
def close(self) -> None:
return None
def kill(self, actor, *, no_restart=True):
return ray.kill(actor, no_restart=no_restart)
def cancel(self, obj, *, force=False, recursive=True):
return ray.cancel(obj, force=force, recursive=recursive)
def is_initialized(self) -> bool:
return ray.is_initialized()
# Allow for generic fallback to ray.* in remote methods. This allows calls
# like ray.nodes() to be run in remote functions even though the client
# doesn't currently support them.
def __getattr__(self, key: str):
return getattr(ray, key)
class RayServerAPI(CoreRayAPI):
"""
Ray Client server-side API shim. By default, simply calls the default Core
Ray API calls, but also accepts scheduling calls from functions running
inside of other remote functions that need to create more work.
"""
def __init__(self, server_instance):
self.server = server_instance
# Wrap single item into list if needed before calling server put.
def put(self, vals: Any, *args, **kwargs) -> ClientObjectRef:
to_put = []
single = False
if isinstance(vals, list):
to_put = vals
else:
single = True
to_put.append(vals)
out = [self._put(x) for x in to_put]
if single:
out = out[0]
return out
def _put(self, val: Any):
resp = self.server._put_and_retain_obj(val)
return ClientObjectRef(resp.id)
def call_remote(self, instance: ClientStub, *args, **kwargs):
task = instance._prepare_client_task()
ticket = self.server.Schedule(task, prepared_args=args)
return ClientObjectRef(ticket.return_id)
@@ -0,0 +1,54 @@
import logging
import grpc
from typing import TYPE_CHECKING
import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
if TYPE_CHECKING:
from ray.experimental.client.server.server import RayletServicer
logger = logging.getLogger(__name__)
class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
def __init__(self, basic_service: "RayletServicer"):
self.basic_service = basic_service
def Datapath(self, request_iterator, context):
metadata = {k: v for k, v in context.invocation_metadata()}
client_id = metadata["client_id"]
if client_id == "":
logger.error("Client connecting with no client_id")
return
logger.info(f"New data connection from client {client_id}")
try:
for req in request_iterator:
resp = None
req_type = req.WhichOneof("type")
if req_type == "get":
get_resp = self.basic_service._get_object(
req.get, client_id)
resp = ray_client_pb2.DataResponse(get=get_resp)
elif req_type == "put":
put_resp = self.basic_service._put_object(
req.put, client_id)
resp = ray_client_pb2.DataResponse(put=put_resp)
elif req_type == "release":
released = []
for rel_id in req.release.ids:
rel = self.basic_service.release(client_id, rel_id)
released.append(rel)
resp = ray_client_pb2.DataResponse(
release=ray_client_pb2.ReleaseResponse(ok=released))
else:
raise Exception(f"Unreachable code: Request type "
f"{req_type} not handled in Datapath")
resp.req_id = req.req_id
yield resp
except grpc.RpcError as e:
logger.debug(f"Closing data channel: {e}")
finally:
logger.info(f"Lost data connection from client {client_id}")
self.basic_service.release_all(client_id)
@@ -0,0 +1,101 @@
"""This file responds to log stream requests and forwards logs
with its handler.
"""
import io
import threading
import queue
import logging
import grpc
import uuid
from ray.worker import print_worker_logs
from ray.ray_logging import global_worker_stdstream_dispatcher
import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
logger = logging.getLogger(__name__)
class LogstreamHandler(logging.Handler):
def __init__(self, queue, level):
super().__init__()
self.queue = queue
self.level = level
def emit(self, record: logging.LogRecord):
logdata = ray_client_pb2.LogData()
logdata.msg = record.getMessage()
logdata.level = record.levelno
logdata.name = record.name
self.queue.put(logdata)
class StdStreamHandler:
def __init__(self, queue):
self.queue = queue
self.id = str(uuid.uuid4())
def handle(self, data):
logdata = ray_client_pb2.LogData()
logdata.level = -2 if data["is_err"] else -1
logdata.name = "stderr" if data["is_err"] else "stdout"
with io.StringIO() as file:
print_worker_logs(data, file)
logdata.msg = file.getvalue()
self.queue.put(logdata)
def register_global(self):
global_worker_stdstream_dispatcher.add_handler(self.id, self.handle)
def unregister_global(self):
global_worker_stdstream_dispatcher.remove_handler(self.id)
def log_status_change_thread(log_queue, request_iterator):
std_handler = StdStreamHandler(log_queue)
current_handler = None
root_logger = logging.getLogger("ray")
default_level = root_logger.getEffectiveLevel()
try:
for req in request_iterator:
if current_handler is not None:
root_logger.setLevel(default_level)
root_logger.removeHandler(current_handler)
std_handler.unregister_global()
if not req.enabled:
current_handler = None
continue
current_handler = LogstreamHandler(log_queue, req.loglevel)
std_handler.register_global()
root_logger.addHandler(current_handler)
root_logger.setLevel(req.loglevel)
except grpc.RpcError as e:
logger.debug(f"closing log thread "
f"grpc error reading request_iterator: {e}")
finally:
if current_handler is not None:
root_logger.setLevel(default_level)
root_logger.removeHandler(current_handler)
std_handler.unregister_global()
log_queue.put(None)
class LogstreamServicer(ray_client_pb2_grpc.RayletLogStreamerServicer):
def Logstream(self, request_iterator, context):
logger.info("New logs connection")
log_queue = queue.Queue()
thread = threading.Thread(
target=log_status_change_thread,
args=(log_queue, request_iterator),
daemon=True)
thread.start()
try:
queue_iter = iter(log_queue.get, None)
for record in queue_iter:
if record is None:
break
yield record
except grpc.RpcError as e:
logger.debug(f"Closing log channel: {e}")
finally:
thread.join()
+305 -150
View File
@@ -1,6 +1,14 @@
import logging
from concurrent import futures
import grpc
import base64
from collections import defaultdict
from typing import Any
from typing import Dict
from typing import Set
from typing import Optional
from ray import cloudpickle
import ray
import ray.state
@@ -9,29 +17,34 @@ import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
import time
import inspect
import json
from ray.experimental.client import stash_api_for_tests, _set_server_api
from ray.experimental.client.common import convert_from_arg
from ray.experimental.client.common import encode_exception
from ray.experimental.client.common import ClientObjectRef
from ray.experimental.client.server.core_ray_api import RayServerAPI
from ray.experimental.client.server.server_pickler import convert_from_arg
from ray.experimental.client.server.server_pickler import dumps_from_server
from ray.experimental.client.server.server_pickler import loads_from_client
from ray.experimental.client.server.dataservicer import DataServicer
from ray.experimental.client.server.logservicer import LogstreamServicer
from ray.experimental.client.server.server_stubs import current_remote
from ray._private.client_mode_hook import disable_client_hook
logger = logging.getLogger(__name__)
class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
def __init__(self, test_mode=False):
self.object_refs = {}
def __init__(self):
self.object_refs: Dict[str, Dict[bytes, ray.ObjectRef]] = defaultdict(
dict)
self.function_refs = {}
self.actor_refs = {}
self.actor_refs: Dict[bytes, ray.ActorHandle] = {}
self.actor_owners: Dict[str, Set[bytes]] = defaultdict(set)
self.registered_actor_classes = {}
self._test_mode = test_mode
self._current_function_stub = None
def ClusterInfo(self, request,
context=None) -> ray_client_pb2.ClusterInfoResponse:
resp = ray_client_pb2.ClusterInfoResponse()
resp.type = request.type
if request.type == ray_client_pb2.ClusterInfoType.CLUSTER_RESOURCES:
resources = ray.cluster_resources()
with disable_client_hook():
resources = ray.cluster_resources()
# Normalize resources into floats
# (the function may return values that are ints)
float_resources = {k: float(v) for k, v in resources.items()}
@@ -40,7 +53,8 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
table=float_resources))
elif request.type == \
ray_client_pb2.ClusterInfoType.AVAILABLE_RESOURCES:
resources = ray.available_resources()
with disable_client_hook():
resources = ray.available_resources()
# Normalize resources into floats
# (the function may return values that are ints)
float_resources = {k: float(v) for k, v in resources.items()}
@@ -48,7 +62,8 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
ray_client_pb2.ClusterInfoResponse.ResourceTable(
table=float_resources))
else:
resp.json = self._return_debug_cluster_info(request, context)
with disable_client_hook():
resp.json = self._return_debug_cluster_info(request, context)
return resp
def _return_debug_cluster_info(self, request, context=None) -> str:
@@ -61,20 +76,61 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
raise TypeError("Unsupported cluster info type")
return json.dumps(data)
def Terminate(self, request, context=None):
if request.WhichOneof("terminate_type") == "task_object":
def release(self, client_id: str, id: bytes) -> bool:
if client_id in self.object_refs:
if id in self.object_refs[client_id]:
logger.debug(f"Releasing object {id.hex()} for {client_id}")
del self.object_refs[client_id][id]
return True
if client_id in self.actor_owners:
if id in self.actor_owners[client_id]:
logger.debug(f"Releasing actor {id.hex()} for {client_id}")
del self.actor_refs[id]
self.actor_owners[client_id].remove(id)
return True
return False
def release_all(self, client_id):
self._release_objects(client_id)
self._release_actors(client_id)
def _release_objects(self, client_id):
if client_id not in self.object_refs:
logger.debug(f"Releasing client with no references: {client_id}")
return
count = len(self.object_refs[client_id])
del self.object_refs[client_id]
logger.debug(f"Released all {count} objects for client {client_id}")
def _release_actors(self, client_id):
if client_id not in self.actor_owners:
logger.debug(f"Releasing client with no actors: {client_id}")
count = 0
for id_bytes in self.actor_owners[client_id]:
count += 1
del self.actor_refs[id_bytes]
del self.actor_owners[client_id]
logger.debug(f"Released all {count} actors for client: {client_id}")
def Terminate(self, req, context=None):
if req.WhichOneof("terminate_type") == "task_object":
try:
object_ref = cloudpickle.loads(request.task_object.handle)
ray.cancel(
object_ref,
force=request.task_object.force,
recursive=request.task_object.recursive)
object_ref = \
self.object_refs[req.client_id][req.task_object.id]
with disable_client_hook():
ray.cancel(
object_ref,
force=req.task_object.force,
recursive=req.task_object.recursive)
except Exception as e:
return_exception_in_context(e, context)
elif request.WhichOneof("terminate_type") == "actor":
elif req.WhichOneof("terminate_type") == "actor":
try:
actor_ref = cloudpickle.loads(request.actor.handle)
ray.kill(actor_ref, no_restart=request.actor.no_restart)
actor_ref = self.actor_refs[req.actor.id]
with disable_client_hook():
ray.kill(actor_ref, no_restart=req.actor.no_restart)
except Exception as e:
return_exception_in_context(e, context)
else:
@@ -84,166 +140,221 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
return ray_client_pb2.TerminateResponse(ok=True)
def GetObject(self, request, context=None):
request_ref = cloudpickle.loads(request.handle)
if request_ref.binary() not in self.object_refs:
return self._get_object(request, "", context)
def _get_object(self, request, client_id: str, context=None):
if request.id not in self.object_refs[client_id]:
return ray_client_pb2.GetResponse(valid=False)
objectref = self.object_refs[request_ref.binary()]
logger.info("get: %s" % objectref)
objectref = self.object_refs[client_id][request.id]
logger.debug("get: %s" % objectref)
try:
item = ray.get(objectref, timeout=request.timeout)
with disable_client_hook():
item = ray.get(objectref, timeout=request.timeout)
except Exception as e:
return_exception_in_context(e, context)
item_ser = cloudpickle.dumps(item)
return ray_client_pb2.GetResponse(
valid=False, error=cloudpickle.dumps(e))
item_ser = dumps_from_server(item, client_id, self)
return ray_client_pb2.GetResponse(valid=True, data=item_ser)
def PutObject(self, request, context=None) -> ray_client_pb2.PutResponse:
obj = cloudpickle.loads(request.data)
objectref = self._put_and_retain_obj(obj)
pickled_ref = cloudpickle.dumps(objectref)
return ray_client_pb2.PutResponse(
ref=make_remote_ref(objectref.binary(), pickled_ref))
def PutObject(self, request: ray_client_pb2.PutRequest,
context=None) -> ray_client_pb2.PutResponse:
"""gRPC entrypoint for unary PutObject
"""
return self._put_object(request, "", context)
def _put_and_retain_obj(self, obj) -> ray.ObjectRef:
objectref = ray.put(obj)
self.object_refs[objectref.binary()] = objectref
logger.info("put: %s" % objectref)
return objectref
def _put_object(self,
request: ray_client_pb2.PutRequest,
client_id: str,
context=None):
"""Put an object in the cluster with ray.put() via gRPC.
Args:
request: PutRequest with pickled data.
client_id: The client who owns this data, for tracking when to
delete this reference.
context: gRPC context.
"""
obj = loads_from_client(request.data, self)
with disable_client_hook():
objectref = ray.put(obj)
self.object_refs[client_id][objectref.binary()] = objectref
logger.debug("put: %s" % objectref)
return ray_client_pb2.PutResponse(id=objectref.binary())
def WaitObject(self, request, context=None) -> ray_client_pb2.WaitResponse:
object_refs = [cloudpickle.loads(o) for o in request.object_handles]
object_refs = []
for id in request.object_ids:
if id not in self.object_refs[request.client_id]:
raise Exception(
"Asking for a ref not associated with this client: %s" %
str(id))
object_refs.append(self.object_refs[request.client_id][id])
num_returns = request.num_returns
timeout = request.timeout
object_refs_ids = []
for object_ref in object_refs:
if object_ref.binary() not in self.object_refs:
return ray_client_pb2.WaitResponse(valid=False)
object_refs_ids.append(self.object_refs[object_ref.binary()])
try:
ready_object_refs, remaining_object_refs = ray.wait(
object_refs_ids,
num_returns=num_returns,
timeout=timeout if timeout != -1 else None)
except Exception:
with disable_client_hook():
ready_object_refs, remaining_object_refs = ray.wait(
object_refs,
num_returns=num_returns,
timeout=timeout if timeout != -1 else None,
)
except Exception as e:
# TODO(ameer): improve exception messages.
logger.error(f"Exception {e}")
return ray_client_pb2.WaitResponse(valid=False)
logger.info("wait: %s %s" % (str(ready_object_refs),
str(remaining_object_refs)))
logger.debug("wait: %s %s" % (str(ready_object_refs),
str(remaining_object_refs)))
ready_object_ids = [
make_remote_ref(
id=ready_object_ref.binary(),
handle=cloudpickle.dumps(ready_object_ref),
) for ready_object_ref in ready_object_refs
ready_object_ref.binary() for ready_object_ref in ready_object_refs
]
remaining_object_ids = [
make_remote_ref(
id=remaining_object_ref.binary(),
handle=cloudpickle.dumps(remaining_object_ref),
) for remaining_object_ref in remaining_object_refs
remaining_object_ref.binary()
for remaining_object_ref in remaining_object_refs
]
return ray_client_pb2.WaitResponse(
valid=True,
ready_object_ids=ready_object_ids,
remaining_object_ids=remaining_object_ids)
def Schedule(self, task, context=None,
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
logger.info("schedule: %s %s" %
(task.name,
ray_client_pb2.ClientTask.RemoteExecType.Name(task.type)))
if task.type == ray_client_pb2.ClientTask.FUNCTION:
return self._schedule_function(task, context, prepared_args)
elif task.type == ray_client_pb2.ClientTask.ACTOR:
return self._schedule_actor(task, context, prepared_args)
elif task.type == ray_client_pb2.ClientTask.METHOD:
return self._schedule_method(task, context, prepared_args)
else:
raise NotImplementedError(
"Unimplemented Schedule task type: %s" %
ray_client_pb2.ClientTask.RemoteExecType.Name(task.type))
def Schedule(self, task, context=None) -> ray_client_pb2.ClientTaskTicket:
logger.debug(
"schedule: %s %s" % (task.name,
ray_client_pb2.ClientTask.RemoteExecType.Name(
task.type)))
try:
with disable_client_hook():
if task.type == ray_client_pb2.ClientTask.FUNCTION:
result = self._schedule_function(task, context)
elif task.type == ray_client_pb2.ClientTask.ACTOR:
result = self._schedule_actor(task, context)
elif task.type == ray_client_pb2.ClientTask.METHOD:
result = self._schedule_method(task, context)
elif task.type == ray_client_pb2.ClientTask.NAMED_ACTOR:
result = self._schedule_named_actor(task, context)
else:
raise NotImplementedError(
"Unimplemented Schedule task type: %s" %
ray_client_pb2.ClientTask.RemoteExecType.Name(
task.type))
result.valid = True
return result
except Exception as e:
logger.error(f"Caught schedule exception {e}")
raise e
return ray_client_pb2.ClientTaskTicket(
valid=False, error=cloudpickle.dumps(e))
def _schedule_method(
self,
task: ray_client_pb2.ClientTask,
context=None,
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
def _schedule_method(self, task: ray_client_pb2.ClientTask,
context=None) -> ray_client_pb2.ClientTaskTicket:
actor_handle = self.actor_refs.get(task.payload_id)
if actor_handle is None:
raise Exception(
"Can't run an actor the server doesn't have a handle for")
arglist = _convert_args(task.args, prepared_args)
with stash_api_for_tests(self._test_mode):
output = getattr(actor_handle, task.name).remote(*arglist)
self.object_refs[output.binary()] = output
pickled_ref = cloudpickle.dumps(output)
return ray_client_pb2.ClientTaskTicket(
return_ref=make_remote_ref(output.binary(), pickled_ref))
arglist, kwargs = self._convert_args(task.args, task.kwargs)
method = getattr(actor_handle, task.name)
opts = decode_options(task.options)
if opts is not None:
method = method.options(**opts)
output = method.remote(*arglist, **kwargs)
ids = self.unify_and_track_outputs(output, task.client_id)
return ray_client_pb2.ClientTaskTicket(return_ids=ids)
def _schedule_actor(self,
task: ray_client_pb2.ClientTask,
context=None,
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
with stash_api_for_tests(self._test_mode):
payload_ref = cloudpickle.loads(task.payload_id)
if payload_ref.binary() not in self.registered_actor_classes:
actor_class_ref = self.object_refs[payload_ref.binary()]
def _schedule_actor(self, task: ray_client_pb2.ClientTask,
context=None) -> ray_client_pb2.ClientTaskTicket:
remote_class = self.lookup_or_register_actor(
task.payload_id, task.client_id,
decode_options(task.baseline_options))
arglist, kwargs = self._convert_args(task.args, task.kwargs)
opts = decode_options(task.options)
if opts is not None:
remote_class = remote_class.options(**opts)
with current_remote(remote_class):
actor = remote_class.remote(*arglist, **kwargs)
self.actor_refs[actor._actor_id.binary()] = actor
self.actor_owners[task.client_id].add(actor._actor_id.binary())
return ray_client_pb2.ClientTaskTicket(
return_ids=[actor._actor_id.binary()])
def _schedule_function(self, task: ray_client_pb2.ClientTask,
context=None) -> ray_client_pb2.ClientTaskTicket:
remote_func = self.lookup_or_register_func(
task.payload_id, task.client_id,
decode_options(task.baseline_options))
arglist, kwargs = self._convert_args(task.args, task.kwargs)
opts = decode_options(task.options)
if opts is not None:
remote_func = remote_func.options(**opts)
with current_remote(remote_func):
output = remote_func.remote(*arglist, **kwargs)
ids = self.unify_and_track_outputs(output, task.client_id)
return ray_client_pb2.ClientTaskTicket(return_ids=ids)
def _schedule_named_actor(self,
task: ray_client_pb2.ClientTask,
context=None) -> ray_client_pb2.ClientTaskTicket:
assert len(task.payload_id) == 0
actor = ray.get_actor(task.name)
self.actor_refs[actor._actor_id.binary()] = actor
self.actor_owners[task.client_id].add(actor._actor_id.binary())
return ray_client_pb2.ClientTaskTicket(
return_ids=[actor._actor_id.binary()])
def _convert_args(self, arg_list, kwarg_map):
argout = []
for arg in arg_list:
t = convert_from_arg(arg, self)
argout.append(t)
kwargout = {}
for k in kwarg_map:
kwargout[k] = convert_from_arg(kwarg_map[k], self)
return argout, kwargout
def lookup_or_register_func(
self, id: bytes, client_id: str,
options: Optional[Dict]) -> ray.remote_function.RemoteFunction:
with disable_client_hook():
if id not in self.function_refs:
funcref = self.object_refs[client_id][id]
func = ray.get(funcref)
if not inspect.isfunction(func):
raise Exception("Attempting to register function that "
"isn't a function.")
if options is None or len(options) == 0:
self.function_refs[id] = ray.remote(func)
else:
self.function_refs[id] = ray.remote(**options)(func)
return self.function_refs[id]
def lookup_or_register_actor(self, id: bytes, client_id: str,
options: Optional[Dict]):
with disable_client_hook():
if id not in self.registered_actor_classes:
actor_class_ref = self.object_refs[client_id][id]
actor_class = ray.get(actor_class_ref)
if not inspect.isclass(actor_class):
raise Exception("Attempting to schedule actor that "
"isn't a class.")
reg_class = ray.remote(actor_class)
self.registered_actor_classes[payload_ref.binary()] = reg_class
remote_class = self.registered_actor_classes[payload_ref.binary()]
arglist = _convert_args(task.args, prepared_args)
actor = remote_class.remote(*arglist)
actorhandle = cloudpickle.dumps(actor)
self.actor_refs[actorhandle] = actor
return ray_client_pb2.ClientTaskTicket(
return_ref=make_remote_ref(actor._actor_id.binary(), actorhandle))
if options is None or len(options) == 0:
reg_class = ray.remote(actor_class)
else:
reg_class = ray.remote(**options)(actor_class)
self.registered_actor_classes[id] = reg_class
def _schedule_function(
self,
task: ray_client_pb2.ClientTask,
context=None,
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
payload_ref = cloudpickle.loads(task.payload_id)
if payload_ref.binary() not in self.function_refs:
funcref = self.object_refs[payload_ref.binary()]
func = ray.get(funcref)
if not inspect.isfunction(func):
raise Exception("Attempting to schedule function that "
"isn't a function.")
self.function_refs[payload_ref.binary()] = ray.remote(func)
remote_func = self.function_refs[payload_ref.binary()]
arglist = _convert_args(task.args, prepared_args)
# Prepare call if we're in a test
with stash_api_for_tests(self._test_mode):
output = remote_func.remote(*arglist)
if output.binary() in self.object_refs:
raise Exception("already found it")
self.object_refs[output.binary()] = output
pickled_output = cloudpickle.dumps(output)
return ray_client_pb2.ClientTaskTicket(
return_ref=make_remote_ref(output.binary(), pickled_output))
return self.registered_actor_classes[id]
def _convert_args(arg_list, prepared_args=None):
if prepared_args is not None:
return prepared_args
out = []
for arg in arg_list:
t = convert_from_arg(arg)
if isinstance(t, ClientObjectRef):
out.append(t._unpack_ref())
def unify_and_track_outputs(self, output, client_id):
if output is None:
outputs = []
elif isinstance(output, list):
outputs = output
else:
out.append(t)
return out
def make_remote_ref(id: bytes, handle: bytes) -> ray_client_pb2.RemoteRef:
return ray_client_pb2.RemoteRef(
id=id,
handle=handle,
)
outputs = [output]
for out in outputs:
if out.binary() in self.object_refs[client_id]:
logger.warning(f"Already saw object_ref {out}")
self.object_refs[client_id][out.binary()] = out
return [out.binary() for out in outputs]
def return_exception_in_context(err, context):
@@ -252,17 +363,61 @@ def return_exception_in_context(err, context):
context.set_code(grpc.StatusCode.INTERNAL)
def serve(connection_str, test_mode=False):
def encode_exception(exception) -> str:
data = cloudpickle.dumps(exception)
return base64.standard_b64encode(data).decode()
def decode_options(
options: ray_client_pb2.TaskOptions) -> Optional[Dict[str, Any]]:
if options.json_options == "":
return None
opts = json.loads(options.json_options)
assert isinstance(opts, dict)
return opts
_current_servicer: Optional[RayletServicer] = None
# Used by tests to peek inside the servicer
def _get_current_servicer():
global _current_servicer
return _current_servicer
def serve(connection_str):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
task_servicer = RayletServicer(test_mode=test_mode)
_set_server_api(RayServerAPI(task_servicer))
task_servicer = RayletServicer()
data_servicer = DataServicer(task_servicer)
logs_servicer = LogstreamServicer()
global _current_servicer
_current_servicer = task_servicer
ray_client_pb2_grpc.add_RayletDriverServicer_to_server(
task_servicer, server)
ray_client_pb2_grpc.add_RayletDataStreamerServicer_to_server(
data_servicer, server)
ray_client_pb2_grpc.add_RayletLogStreamerServicer_to_server(
logs_servicer, server)
server.add_insecure_port(connection_str)
server.start()
return server
def init_and_serve(connection_str, *args, **kwargs):
with disable_client_hook():
# Disable client mode inside the worker's environment
info = ray.init(*args, **kwargs)
server = serve(connection_str)
return (server, info)
def shutdown_with_server(server, _exiting_interpreter=False):
server.stop(1)
with disable_client_hook():
ray.shutdown(_exiting_interpreter)
if __name__ == "__main__":
logging.basicConfig(level="INFO")
# TODO(barakmich): Perhaps wrap ray init
@@ -0,0 +1,135 @@
"""Implements the client side of the client/server pickling protocol.
These picklers are aware of the server internals and can find the
references held for the client within the server.
More discussion about the client/server pickling protocol can be found in:
ray/experimental/client/client_pickler.py
ServerPickler dumps ray objects from the server into the appropriate stubs.
ClientUnpickler loads stubs from the client and finds their associated handle
in the server instance.
"""
import cloudpickle
import io
import sys
import ray
from typing import Any
from typing import TYPE_CHECKING
from ray._private.client_mode_hook import disable_client_hook
from ray.experimental.client.client_pickler import PickleStub
from ray.experimental.client.server.server_stubs import (
ServerSelfReferenceSentinel)
if TYPE_CHECKING:
from ray.experimental.client.server.server import RayletServicer
import ray.core.generated.ray_client_pb2 as ray_client_pb2
if sys.version_info < (3, 8):
try:
import pickle5 as pickle # noqa: F401
except ImportError:
import pickle # noqa: F401
else:
import pickle # noqa: F401
class ServerPickler(cloudpickle.CloudPickler):
def __init__(self, client_id: str, server: "RayletServicer", *args,
**kwargs):
super().__init__(*args, **kwargs)
self.client_id = client_id
self.server = server
def persistent_id(self, obj):
if isinstance(obj, ray.ObjectRef):
obj_id = obj.binary()
if obj_id not in self.server.object_refs[self.client_id]:
# We're passing back a reference, probably inside a reference.
# Let's hold onto it.
self.server.object_refs[self.client_id][obj_id] = obj
return PickleStub(
type="Object",
client_id=self.client_id,
ref_id=obj_id,
name=None,
baseline_options=None,
)
elif isinstance(obj, ray.actor.ActorHandle):
actor_id = obj._actor_id.binary()
if actor_id not in self.server.actor_refs:
# We're passing back a handle, probably inside a reference.
self.actor_refs[actor_id] = obj
if actor_id not in self.actor_owners[self.client_id]:
self.actor_owners[self.client_id].add(actor_id)
return PickleStub(
type="Actor",
client_id=self.client_id,
ref_id=obj._actor_id.binary(),
name=None,
baseline_options=None,
)
return None
class ClientUnpickler(pickle.Unpickler):
def __init__(self, server, *args, **kwargs):
super().__init__(*args, **kwargs)
self.server = server
def persistent_load(self, pid):
assert isinstance(pid, PickleStub)
if pid.type == "Ray":
return ray
elif pid.type == "Object":
return self.server.object_refs[pid.client_id][pid.ref_id]
elif pid.type == "Actor":
return self.server.actor_refs[pid.ref_id]
elif pid.type == "RemoteFuncSelfReference":
return ServerSelfReferenceSentinel()
elif pid.type == "RemoteFunc":
return self.server.lookup_or_register_func(
pid.ref_id, pid.client_id, pid.baseline_options)
elif pid.type == "RemoteActorSelfReference":
return ServerSelfReferenceSentinel()
elif pid.type == "RemoteActor":
return self.server.lookup_or_register_actor(
pid.ref_id, pid.client_id, pid.baseline_options)
elif pid.type == "RemoteMethod":
actor = self.server.actor_refs[pid.ref_id]
return getattr(actor, pid.name)
else:
raise NotImplementedError("Uncovered client data type")
def dumps_from_server(obj: Any,
client_id: str,
server_instance: "RayletServicer",
protocol=None) -> bytes:
with io.BytesIO() as file:
sp = ServerPickler(client_id, server_instance, file, protocol=protocol)
sp.dump(obj)
return file.getvalue()
def loads_from_client(data: bytes,
server_instance: "RayletServicer",
*,
fix_imports=True,
encoding="ASCII",
errors="strict") -> Any:
with disable_client_hook():
if isinstance(data, str):
raise TypeError("Can't load pickle from unicode string")
file = io.BytesIO(data)
return ClientUnpickler(
server_instance, file, fix_imports=fix_imports,
encoding=encoding).load()
def convert_from_arg(pb: "ray_client_pb2.Arg",
server: "RayletServicer") -> Any:
return loads_from_client(pb.data, server)
@@ -0,0 +1,29 @@
from contextlib import contextmanager
_current_remote_obj = None
@contextmanager
def current_remote(r):
global _current_remote_obj
remote = _current_remote_obj
_current_remote_obj = r
try:
yield
finally:
_current_remote_obj = remote
class ServerSelfReferenceSentinel:
def __init__(self):
pass
def __reduce__(self):
global _current_remote_obj
if _current_remote_obj is None:
return (ServerSelfReferenceSentinel, tuple())
return (identity, (_current_remote_obj, ))
def identity(x):
return x
+124 -63
View File
@@ -2,27 +2,30 @@
It implements the Ray API functions that are forwarded through grpc calls
to the server.
"""
import inspect
import base64
import json
import logging
import uuid
from collections import defaultdict
from typing import Any
from typing import Dict
from typing import List
from typing import Tuple
from typing import Optional
import ray.cloudpickle as cloudpickle
from ray.util.inspect import is_cython
import grpc
from ray.exceptions import TaskCancelledError
import ray.cloudpickle as cloudpickle
import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
from ray.experimental.client.common import convert_to_arg
from ray.experimental.client.common import decode_exception
from ray.experimental.client.common import ClientObjectRef
from ray.experimental.client.common import ClientActorClass
from ray.experimental.client.client_pickler import convert_to_arg
from ray.experimental.client.client_pickler import dumps_from_client
from ray.experimental.client.client_pickler import loads_from_server
from ray.experimental.client.common import ClientActorHandle
from ray.experimental.client.common import ClientRemoteFunc
from ray.experimental.client.common import ClientActorRef
from ray.experimental.client.common import ClientObjectRef
from ray.experimental.client.dataclient import DataClient
from ray.experimental.client.logsclient import LogstreamClient
logger = logging.getLogger(__name__)
@@ -31,34 +34,37 @@ class Worker:
def __init__(self,
conn_str: str = "",
secure: bool = False,
metadata: List[Tuple[str, str]] = None,
stub=None):
metadata: List[Tuple[str, str]] = None):
"""Initializes the worker side grpc client.
Args:
stub: custom grpc stub.
secure: whether to use SSL secure channel or not.
metadata: additional metadata passed in the grpc request headers.
"""
self.metadata = metadata
self.channel = None
if stub is None:
if secure:
credentials = grpc.ssl_channel_credentials()
self.channel = grpc.secure_channel(conn_str, credentials)
else:
self.channel = grpc.insecure_channel(conn_str)
self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel)
self._client_id = make_client_id()
if secure:
credentials = grpc.ssl_channel_credentials()
self.channel = grpc.secure_channel(conn_str, credentials)
else:
self.server = stub
self.channel = grpc.insecure_channel(conn_str)
self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel)
self.data_client = DataClient(self.channel, self._client_id)
self.reference_count: Dict[bytes, int] = defaultdict(int)
self.log_client = LogstreamClient(self.channel)
self.log_client.set_logstream_level(logging.INFO)
self.closed = False
def get(self, vals, *, timeout: Optional[float] = None) -> Any:
to_get = []
single = False
if isinstance(vals, list):
to_get = [x.handle for x in vals]
to_get = vals
elif isinstance(vals, ClientObjectRef):
to_get = [vals.handle]
to_get = [vals]
single = True
else:
raise Exception("Can't get something that's not a "
@@ -70,15 +76,17 @@ class Worker:
out = out[0]
return out
def _get(self, handle: bytes, timeout: float):
req = ray_client_pb2.GetRequest(handle=handle, timeout=timeout)
def _get(self, ref: ClientObjectRef, timeout: float):
req = ray_client_pb2.GetRequest(id=ref.id, timeout=timeout)
try:
data = self.server.GetObject(req, metadata=self.metadata)
data = self.data_client.GetObject(req)
except grpc.RpcError as e:
raise decode_exception(e.details())
raise e.details()
if not data.valid:
raise TaskCancelledError(handle)
return cloudpickle.loads(data.data)
err = cloudpickle.loads(data.error)
logger.error(err)
raise err
return loads_from_server(data.data)
def put(self, vals):
to_put = []
@@ -95,26 +103,37 @@ class Worker:
return out
def _put(self, val):
data = cloudpickle.dumps(val)
if isinstance(val, ClientObjectRef):
raise TypeError(
"Calling 'put' on an ObjectRef is not allowed "
"(similarly, returning an ObjectRef from a remote "
"function is not allowed). If you really want to "
"do this, you can wrap the ObjectRef in a list and "
"call 'put' on it (or return it).")
data = dumps_from_client(val, self._client_id)
req = ray_client_pb2.PutRequest(data=data)
resp = self.server.PutObject(req, metadata=self.metadata)
return ClientObjectRef.from_remote_ref(resp.ref)
resp = self.data_client.PutObject(req)
return ClientObjectRef(resp.id)
def wait(self,
object_refs: List[ClientObjectRef],
*,
num_returns: int = 1,
timeout: float = None
timeout: float = None,
fetch_local: bool = True
) -> Tuple[List[ClientObjectRef], List[ClientObjectRef]]:
assert isinstance(object_refs, list)
if not isinstance(object_refs, list):
raise TypeError("wait() expected a list of ClientObjectRef, "
f"got {type(object_refs)}")
for ref in object_refs:
assert isinstance(ref, ClientObjectRef)
if not isinstance(ref, ClientObjectRef):
raise TypeError("wait() expected a list of ClientObjectRef, "
f"got list containing {type(ref)}")
data = {
"object_handles": [
object_ref.handle for object_ref in object_refs
],
"object_ids": [object_ref.id for object_ref in object_refs],
"num_returns": num_returns,
"timeout": timeout if timeout else -1
"timeout": timeout if timeout else -1,
"client_id": self._client_id,
}
req = ray_client_pb2.WaitRequest(**data)
resp = self.server.WaitObject(req, metadata=self.metadata)
@@ -122,41 +141,69 @@ class Worker:
# TODO(ameer): improve error/exceptions messages.
raise Exception("Client Wait request failed. Reference invalid?")
client_ready_object_ids = [
ClientObjectRef.from_remote_ref(ref)
for ref in resp.ready_object_ids
ClientObjectRef(ref) for ref in resp.ready_object_ids
]
client_remaining_object_ids = [
ClientObjectRef.from_remote_ref(ref)
for ref in resp.remaining_object_ids
ClientObjectRef(ref) for ref in resp.remaining_object_ids
]
return (client_ready_object_ids, client_remaining_object_ids)
def remote(self, function_or_class, *args, **kwargs):
# TODO(barakmich): Arguments to ray.remote
# get captured here.
if (inspect.isfunction(function_or_class)
or is_cython(function_or_class)):
return ClientRemoteFunc(function_or_class)
elif inspect.isclass(function_or_class):
return ClientActorClass(function_or_class)
else:
raise TypeError("The @ray.remote decorator must be applied to "
"either a function or to a class.")
def call_remote(self, instance, *args, **kwargs):
def call_remote(self, instance, *args, **kwargs) -> List[bytes]:
task = instance._prepare_client_task()
for arg in args:
pb_arg = convert_to_arg(arg)
pb_arg = convert_to_arg(arg, self._client_id)
task.args.append(pb_arg)
logging.debug("Scheduling %s" % task)
ticket = self.server.Schedule(task, metadata=self.metadata)
return ClientObjectRef.from_remote_ref(ticket.return_ref)
for k, v in kwargs.items():
task.kwargs[k].CopyFrom(convert_to_arg(v, self._client_id))
return self._call_schedule_for_task(task)
def _call_schedule_for_task(
self, task: ray_client_pb2.ClientTask) -> List[bytes]:
logger.debug("Scheduling %s" % task)
task.client_id = self._client_id
try:
ticket = self.server.Schedule(task, metadata=self.metadata)
except grpc.RpcError as e:
raise decode_exception(e.details)
if not ticket.valid:
raise cloudpickle.loads(ticket.error)
return ticket.return_ids
def call_release(self, id: bytes) -> None:
if self.closed:
return
self.reference_count[id] -= 1
if self.reference_count[id] == 0:
self._release_server(id)
del self.reference_count[id]
def _release_server(self, id: bytes) -> None:
if self.data_client is not None:
logger.debug(f"Releasing {id}")
self.data_client.ReleaseObject(
ray_client_pb2.ReleaseRequest(ids=[id]))
def call_retain(self, id: bytes) -> None:
logger.debug(f"Retaining {id.hex()}")
self.reference_count[id] += 1
def close(self):
self.server = None
self.log_client.close()
self.data_client.close()
if self.channel:
self.channel.close()
self.channel = None
self.server = None
self.closed = True
def get_actor(self, name: str) -> ClientActorHandle:
task = ray_client_pb2.ClientTask()
task.type = ray_client_pb2.ClientTask.NAMED_ACTOR
task.name = name
ids = self._call_schedule_for_task(task)
assert len(ids) == 1
return ClientActorHandle(ClientActorRef(ids[0]))
def terminate_actor(self, actor: ClientActorHandle,
no_restart: bool) -> None:
@@ -164,10 +211,11 @@ class Worker:
raise ValueError("ray.kill() only supported for actors. "
"Got: {}.".format(type(actor)))
term_actor = ray_client_pb2.TerminateRequest.ActorTerminate()
term_actor.handle = actor.actor_ref.handle
term_actor.id = actor.actor_ref.id
term_actor.no_restart = no_restart
try:
term = ray_client_pb2.TerminateRequest(actor=term_actor)
term.client_id = self._client_id
self.server.Terminate(term)
except grpc.RpcError as e:
raise decode_exception(e.details())
@@ -179,11 +227,12 @@ class Worker:
"ray.cancel() only supported for non-actor object refs. "
f"Got: {type(obj)}.")
term_object = ray_client_pb2.TerminateRequest.TaskObjectTerminate()
term_object.handle = obj.handle
term_object.id = obj.id
term_object.force = force
term_object.recursive = recursive
try:
term = ray_client_pb2.TerminateRequest(task_object=term_object)
term.client_id = self._client_id
self.server.Terminate(term)
except grpc.RpcError as e:
raise decode_exception(e.details())
@@ -193,7 +242,9 @@ class Worker:
req.type = type
resp = self.server.ClusterInfo(req)
if resp.WhichOneof("response_type") == "resource_table":
return resp.resource_table.table
# translate from a proto map to a python dict
output_dict = {k: v for k, v in resp.resource_table.table.items()}
return output_dict
return json.loads(resp.json)
def is_initialized(self) -> bool:
@@ -201,3 +252,13 @@ class Worker:
return self.get_cluster_info(
ray_client_pb2.ClusterInfoType.IS_INITIALIZED)
return False
def make_client_id() -> str:
id = uuid.uuid4()
return id.hex
def decode_exception(data) -> Exception:
data = base64.standard_b64decode(data)
return loads_from_server(data)
@@ -1,18 +0,0 @@
import ray
def force_spill_objects(object_refs):
"""Force spilling objects to external storage.
Args:
object_refs: Object refs of the objects to be
spilled.
"""
core_worker = ray.worker.global_worker.core_worker
# Make sure that the values are object refs.
for object_ref in object_refs:
if not isinstance(object_ref, ray.ObjectRef):
raise TypeError(
f"Attempting to call `force_spill_objects` on the "
f"value {object_ref}, which is not an ray.ObjectRef.")
return core_worker.force_spill_objects(object_refs)
+12 -3
View File
@@ -157,12 +157,15 @@ class ExternalStorage(metaclass=abc.ABCMeta):
@abc.abstractmethod
def restore_spilled_objects(self, object_refs: List[ObjectRef],
url_with_offset_list: List[str]):
url_with_offset_list: List[str]) -> int:
"""Restore objects from the external storage.
Args:
object_refs: List of object IDs (note that it is not ref).
url_with_offset_list: List of url_with_offset.
Returns:
The total number of bytes restored.
"""
@abc.abstractmethod
@@ -215,6 +218,7 @@ class FileSystemStorage(ExternalStorage):
def restore_spilled_objects(self, object_refs: List[ObjectRef],
url_with_offset_list: List[str]):
total = 0
for i in range(len(object_refs)):
object_ref = object_refs[i]
url_with_offset = url_with_offset_list[i].decode()
@@ -228,9 +232,11 @@ class FileSystemStorage(ExternalStorage):
metadata_len = int.from_bytes(f.read(8), byteorder="little")
buf_len = int.from_bytes(f.read(8), byteorder="little")
self._size_check(metadata_len, buf_len, parsed_result.size)
total += buf_len
metadata = f.read(metadata_len)
# read remaining data to our buffer
self._put_object_to_store(metadata, buf_len, f, object_ref)
return total
def delete_spilled_objects(self, urls: List[str]):
for url in urls:
@@ -297,6 +303,7 @@ class ExternalStorageSmartOpenImpl(ExternalStorage):
def restore_spilled_objects(self, object_refs: List[ObjectRef],
url_with_offset_list: List[str]):
from smart_open import open
total = 0
for i in range(len(object_refs)):
object_ref = object_refs[i]
url_with_offset = url_with_offset_list[i].decode()
@@ -315,9 +322,11 @@ class ExternalStorageSmartOpenImpl(ExternalStorage):
metadata_len = int.from_bytes(f.read(8), byteorder="little")
buf_len = int.from_bytes(f.read(8), byteorder="little")
self._size_check(metadata_len, buf_len, parsed_result.size)
total += buf_len
metadata = f.read(metadata_len)
# read remaining data to our buffer
self._put_object_to_store(metadata, buf_len, f, object_ref)
return total
def delete_spilled_objects(self, urls: List[str]):
pass
@@ -367,8 +376,8 @@ def restore_spilled_objects(object_refs: List[ObjectRef],
object_refs: List of object IDs (note that it is not ref).
url_with_offset_list: List of url_with_offset.
"""
_external_storage.restore_spilled_objects(object_refs,
url_with_offset_list)
return _external_storage.restore_spilled_objects(object_refs,
url_with_offset_list)
def delete_spilled_objects(urls: List[str]):
+9 -4
View File
@@ -12,6 +12,7 @@ import hashlib
import cython
import inspect
import uuid
import ray.ray_constants as ray_constants
ctypedef object (*FunctionDescriptor_from_cpp)(const CFunctionDescriptor &)
@@ -188,7 +189,8 @@ cdef class PythonFunctionDescriptor(FunctionDescriptor):
function_name = function.__name__
class_name = ""
pickled_function_hash = hashlib.sha1(pickled_function).hexdigest()
pickled_function_hash = hashlib.shake_128(pickled_function).hexdigest(
ray_constants.ID_SIZE)
return cls(module_name, function_name, class_name,
pickled_function_hash)
@@ -208,7 +210,10 @@ cdef class PythonFunctionDescriptor(FunctionDescriptor):
module_name = target_class.__module__
class_name = target_class.__name__
# Use a random uuid as function hash to solve actor name conflict.
return cls(module_name, "__init__", class_name, str(uuid.uuid4()))
return cls(
module_name, "__init__", class_name,
hashlib.shake_128(
uuid.uuid4().bytes).hexdigest(ray_constants.ID_SIZE))
@property
def module_name(self):
@@ -268,14 +273,14 @@ cdef class PythonFunctionDescriptor(FunctionDescriptor):
Returns:
ray.ObjectRef to represent the function descriptor.
"""
function_id_hash = hashlib.sha1()
function_id_hash = hashlib.shake_128()
# Include the function module and name in the hash.
function_id_hash.update(self.typed_descriptor.ModuleName())
function_id_hash.update(self.typed_descriptor.FunctionName())
function_id_hash.update(self.typed_descriptor.ClassName())
function_id_hash.update(self.typed_descriptor.FunctionHash())
# Compute the function ID.
function_id = function_id_hash.digest()
function_id = function_id_hash.digest(ray_constants.ID_SIZE)
return ray.FunctionID(function_id)
def is_actor_method(self):
+4 -3
View File
@@ -179,9 +179,10 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
c_bool plasma_objects_only)
CRayStatus Contains(const CObjectID &object_id, c_bool *has_object)
CRayStatus Wait(const c_vector[CObjectID] &object_ids, int num_objects,
int64_t timeout_ms, c_vector[c_bool] *results)
int64_t timeout_ms, c_vector[c_bool] *results,
c_bool fetch_local)
CRayStatus Delete(const c_vector[CObjectID] &object_ids,
c_bool local_only, c_bool delete_creating_tasks)
c_bool local_only)
CRayStatus TriggerGlobalGC()
c_string MemoryUsageString()
@@ -232,7 +233,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
(CRayStatus() nogil) check_signals
(void() nogil) gc_collect
(c_vector[c_string](const c_vector[CObjectID] &) nogil) spill_objects
(void(
(int64_t(
const c_vector[CObjectID] &,
const c_vector[c_string] &) nogil) restore_spilled_objects
(void(
+1 -1
View File
@@ -31,7 +31,7 @@ def check_id(b, size=kUniqueIDSize):
raise TypeError("Unsupported type: " + str(type(b)))
if len(b) != size:
raise ValueError("ID string needs to have length " +
str(size))
str(size) + ", got " + str(len(b)))
cdef extern from "ray/common/constants.h" nogil:
+2 -5
View File
@@ -37,7 +37,7 @@ def memory_summary():
return reply.memory_summary
def free(object_refs, local_only=False, delete_creating_tasks=False):
def free(object_refs, local_only=False):
"""Free a list of IDs from the in-process and plasma object stores.
This function is a low-level API which should be used in restricted
@@ -59,8 +59,6 @@ def free(object_refs, local_only=False, delete_creating_tasks=False):
object_refs (List[ObjectRef]): List of object refs to delete.
local_only (bool): Whether only deleting the list of objects in local
object store or all object stores.
delete_creating_tasks (bool): Whether also delete the object creating
tasks.
"""
worker = ray.worker.global_worker
@@ -83,5 +81,4 @@ def free(object_refs, local_only=False, delete_creating_tasks=False):
if len(object_refs) == 0:
return
worker.core_worker.free_objects(object_refs, local_only,
delete_creating_tasks)
worker.core_worker.free_objects(object_refs, local_only)
+1 -1
View File
@@ -22,7 +22,7 @@ from ray.ray_logging import setup_component_logger
logger = logging.getLogger(__name__)
# The groups are worker id, job id, and pid.
JOB_LOG_PATTERN = re.compile(".*worker-([0-9a-f]{40})-(\d+)-(\d+)")
JOB_LOG_PATTERN = re.compile(".*worker-([0-9a-f]+)-(\d+)-(\d+)")
class LogFileInfo:
+19 -15
View File
@@ -15,11 +15,14 @@ from ray.autoscaler._private.constants import AUTOSCALER_UPDATE_INTERVAL_S
from ray.autoscaler._private.load_metrics import LoadMetrics
from ray.autoscaler._private.constants import \
AUTOSCALER_MAX_RESOURCE_DEMAND_VECTOR_SIZE
from ray.autoscaler._private.util import DEBUG_AUTOSCALING_STATUS
import ray.gcs_utils
import ray.utils
import ray.ray_constants as ray_constants
from ray.ray_logging import setup_component_logger
from ray._raylet import GlobalStateAccessor
from ray.experimental.internal_kv import _internal_kv_put, \
_internal_kv_initialized
import redis
@@ -65,11 +68,7 @@ def parse_resource_demands(resource_load_by_shape):
except Exception:
logger.exception("Failed to parse resource demands.")
# Bound the total number of bundles to 2xMAX_RESOURCE_DEMAND_VECTOR_SIZE.
# This guarantees the resource demand scheduler bin packing algorithm takes
# a reasonable amount of time to run.
return waiting_bundles[:AUTOSCALER_MAX_RESOURCE_DEMAND_VECTOR_SIZE], \
infeasible_bundles[:AUTOSCALER_MAX_RESOURCE_DEMAND_VECTOR_SIZE]
return waiting_bundles, infeasible_bundles
class Monitor:
@@ -184,14 +183,8 @@ class Monitor:
data: a resource request as JSON, e.g. {"CPU": 1}
"""
if not self.autoscaler:
return
try:
self.autoscaler.request_resources(json.loads(data))
except Exception:
# We don't want this to kill the monitor.
traceback.print_exc()
resource_request = json.loads(data)
self.load_metrics.set_resource_requests(resource_request)
def process_messages(self, max_messages=10000):
"""Process all messages ready in the subscription channels.
@@ -257,12 +250,23 @@ class Monitor:
# Handle messages from the subscription channels.
while True:
self.update_raylet_map()
self.update_load_metrics()
status = {
"load_metrics_report": self.load_metrics.summary()._asdict()
}
# Process autoscaling actions
if self.autoscaler:
# Only used to update the load metrics for the autoscaler.
self.update_raylet_map()
self.update_load_metrics()
self.autoscaler.update()
status[
"autoscaler_report"] = self.autoscaler.summary()._asdict()
as_json = json.dumps(status)
if _internal_kv_initialized():
_internal_kv_put(
DEBUG_AUTOSCALING_STATUS, as_json, overwrite=True)
# Process a round of messages.
self.process_messages()
+1 -1
View File
@@ -19,7 +19,7 @@ def env_bool(key, default):
return default
ID_SIZE = 20
ID_SIZE = 28
# The default maximum number of bytes to allocate to the object store unless
# overridden by the user.
+27
View File
@@ -1,8 +1,11 @@
import logging
import os
import sys
import threading
from logging.handlers import RotatingFileHandler
from typing import Callable
import ray
from ray.utils import binary_to_hex
@@ -258,3 +261,27 @@ def setup_and_get_worker_interceptor_logger(args,
# logger to add a newline at the end of string.
handler.terminator = ""
return logger
class WorkerStandardStreamDispatcher:
def __init__(self):
self.handlers = []
self._lock = threading.Lock()
def add_handler(self, name: str, handler: Callable) -> None:
with self._lock:
self.handlers.append((name, handler))
def remove_handler(self, name: str) -> None:
with self._lock:
new_handlers = [pair for pair in self.handlers if pair[0] != name]
self.handlers = new_handlers
def emit(self, data):
with self._lock:
for pair in self.handlers:
_, handle = pair
handle(data)
global_worker_stdstream_dispatcher = WorkerStandardStreamDispatcher()
+6 -26
View File
@@ -2,17 +2,15 @@
import asyncio
import logging
import os
import time
from ray._private.ray_microbenchmark_helpers import timeit
from ray._private.ray_client_microbenchmark import (main as
client_microbenchmark_main)
import numpy as np
import multiprocessing
import ray
logger = logging.getLogger(__name__)
# Only run tests matching this filter pattern.
filter_pattern = os.environ.get("TESTS_TO_RUN", "")
@ray.remote(num_cpus=0)
class Actor:
@@ -71,27 +69,6 @@ def small_value_batch(n):
return 0
def timeit(name, fn, multiplier=1):
if filter_pattern not in name:
return
# warmup
start = time.time()
while time.time() - start < 1:
fn()
# real run
stats = []
for _ in range(4):
start = time.time()
count = 0
while time.time() - start < 2:
fn()
count += 1
end = time.time()
stats.append(multiplier * count / (end - start))
print(name, "per second", round(np.mean(stats), 2), "+-",
round(np.std(stats), 2))
def check_optimized_build():
if not ray._raylet.OPTIMIZED:
msg = ("WARNING: Unoptimized build! "
@@ -277,6 +254,9 @@ def main():
ray.get([async_actor_work.remote(a) for _ in range(m)])
timeit("n:n async-actor calls async", async_actor_multi, m * n)
ray.shutdown()
client_microbenchmark_main()
if __name__ == "__main__":
+2 -5
View File
@@ -6,7 +6,6 @@ import logging
import os
import subprocess
import sys
from telnetlib import Telnet
import time
import urllib
import urllib.parse
@@ -172,8 +171,7 @@ def continue_debug_session():
ray.experimental.internal_kv._internal_kv_del(key)
return
host, port = session["pdb_address"].split(":")
with Telnet(host, int(port)) as tn:
tn.interact()
ray.util.rpdb.connect_pdb_client(host, int(port))
ray.experimental.internal_kv._internal_kv_del(key)
continue_debug_session()
return
@@ -215,8 +213,7 @@ def debug(address):
ray.experimental.internal_kv._internal_kv_get(
active_sessions[index]))
host, port = session["pdb_address"].split(":")
with Telnet(host, int(port)) as tn:
tn.interact()
ray.util.rpdb.connect_pdb_client(host, int(port))
@cli.command()
+3 -2
View File
@@ -74,7 +74,8 @@ def _try_to_compute_deterministic_class_id(cls, depth=5):
new_class_id = pickle.dumps(pickle.loads(class_id))
if new_class_id == class_id:
# We appear to have reached a fix point, so use this as the ID.
return hashlib.sha1(new_class_id).digest()
return hashlib.shake_128(new_class_id).digest(
ray_constants.ID_SIZE)
class_id = new_class_id
# We have not reached a fixed point, so we may end up with a different
@@ -82,7 +83,7 @@ def _try_to_compute_deterministic_class_id(cls, depth=5):
# same class definition being exported many many times.
logger.warning(
f"WARNING: Could not produce a deterministic class ID for class {cls}")
return hashlib.sha1(new_class_id).digest()
return hashlib.shake_128(new_class_id).digest(ray_constants.ID_SIZE)
def object_ref_deserializer(reduced_obj_ref, owner_address):
+8
View File
@@ -119,6 +119,14 @@ py_test(
deps = [":serve_lib"],
)
py_test(
name = "test_imported_backend",
size = "small",
srcs = serve_tests_srcs,
tags = ["exclusive"],
deps = [":serve_lib"],
)
# Runs test_api and test_failure with injected failures in the controller.
# TODO(simon): Tests are disabled until #11683 is fixed.
+93 -28
View File
@@ -4,22 +4,41 @@ import time
from functools import wraps
import os
from uuid import UUID
import threading
from typing import Any, Callable, Coroutine, Dict, List, Optional, Type, Union
from ray.serve.context import TaskContext
import ray
from ray.serve.constants import (DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT,
SERVE_CONTROLLER_NAME, HTTP_PROXY_TIMEOUT)
from ray.serve.controller import ServeController
from ray.serve.handle import RayServeHandle
from ray.serve.handle import RayServeHandle, RayServeSyncHandle
from ray.serve.utils import (block_until_http_ready, format_actor_name,
get_random_letters, logger, get_conda_env_dir)
from ray.serve.exceptions import RayServeException
from ray.serve.config import BackendConfig, ReplicaConfig, BackendMetadata
from ray.serve.config import (BackendConfig, ReplicaConfig, BackendMetadata,
HTTPConfig)
from ray.serve.env import CondaEnv
from ray.serve.router import RequestMetadata, Router
from ray.actor import ActorHandle
from typing import Any, Callable, Dict, List, Optional, Type, Union
_INTERNAL_CONTROLLER_NAME = None
global_async_loop = None
def create_or_get_async_loop_in_thread():
global global_async_loop
if global_async_loop is None:
global_async_loop = asyncio.new_event_loop()
thread = threading.Thread(
daemon=True,
target=global_async_loop.run_forever,
)
thread.start()
return global_async_loop
def _set_internal_controller_name(name):
global _INTERNAL_CONTROLLER_NAME
@@ -36,6 +55,36 @@ def _ensure_connected(f: Callable) -> Callable:
return check
class ThreadProxiedRouter:
def __init__(self, controller_handle, sync: bool):
self.router = Router(controller_handle)
if sync:
self.async_loop = create_or_get_async_loop_in_thread()
asyncio.run_coroutine_threadsafe(
self.router.setup_in_async_loop(),
self.async_loop,
)
else:
self.async_loop = asyncio.get_event_loop()
self.async_loop.create_task(self.router.setup_in_async_loop())
def _remote(self, endpoint_name, handle_options, request_data,
kwargs) -> Coroutine:
request_metadata = RequestMetadata(
get_random_letters(10), # Used for debugging.
endpoint_name,
TaskContext.Python,
call_method=handle_options.method_name,
shard_key=handle_options.shard_key,
http_method=handle_options.http_method,
http_headers=handle_options.http_headers,
)
coro = self.router.assign_request(request_metadata, request_data,
**kwargs)
return coro
class Client:
def __init__(self,
controller: ActorHandle,
@@ -45,15 +94,10 @@ class Client:
self._controller_name = controller_name
self._detached = detached
self._shutdown = False
self._http_host, self._http_port = ray.get(
controller.get_http_config.remote())
self._http_config = ray.get(controller.get_http_config.remote())
# NOTE(simon): Used to cache client.get_handle(endpoint) call. It will
# mostly grow in size, it will only shrink when user calls the
# .remove_endpoint method. This is fine because we expect the number of
# endpoints to be fairly small. However, in case this dictionary does
# grow very big, we can replace it with a LRU cache instead.
self._handle_cache: Dict[str, ActorHandle] = dict()
self._sync_proxied_router = None
self._async_proxied_router = None
# NOTE(edoakes): Need this because the shutdown order isn't guaranteed
# when the interpreter is exiting so we can't rely on __del__ (it
@@ -65,6 +109,18 @@ class Client:
atexit.register(shutdown_serve_client)
def _get_proxied_router(self, sync: bool):
if sync:
if self._sync_proxied_router is None:
self._sync_proxied_router = ThreadProxiedRouter(
self._controller, sync=True)
return self._sync_proxied_router
else:
if self._async_proxied_router is None:
self._async_proxied_router = ThreadProxiedRouter(
self._controller, sync=False)
return self._async_proxied_router
def __del__(self):
if not self._detached:
logger.debug("Shutting down Ray Serve because client went out of "
@@ -181,8 +237,8 @@ class Client:
num_cpus=0, resources={
node_id: 0.01
}).remote(
"http://{}:{}/-/routes".format(self._http_host,
self._http_port),
"http://{}:{}/-/routes".format(self._http_config.host,
self._http_config.port),
check_ready=check_ready,
timeout=HTTP_PROXY_TIMEOUT)
futures.append(future)
@@ -198,8 +254,6 @@ class Client:
Does not delete any associated backends.
"""
if endpoint in self._handle_cache:
del self._handle_cache[endpoint]
self._get_result(self._controller.delete_endpoint.remote(endpoint))
@_ensure_connected
@@ -410,10 +464,11 @@ class Client:
proportion))
@_ensure_connected
def get_handle(self,
endpoint_name: str,
missing_ok: Optional[bool] = False,
sync: bool = True) -> RayServeHandle:
def get_handle(
self,
endpoint_name: str,
missing_ok: Optional[bool] = False,
sync: bool = True) -> Union[RayServeHandle, RayServeSyncHandle]:
"""Retrieve RayServeHandle for service endpoint to invoke it from Python.
Args:
@@ -433,14 +488,26 @@ class Client:
if asyncio.get_event_loop().is_running() and sync:
logger.warning(
"You are retrieving a ServeHandle inside an asyncio loop. "
"You are retrieving a sync handle inside an asyncio loop. "
"Try getting client.get_handle(.., sync=False) to get better "
"performance.")
"performance. Learn more at https://docs.ray.io/en/master/"
"serve/advanced.html#sync-and-async-handles")
if endpoint_name not in self._handle_cache:
handle = RayServeHandle(self._controller, endpoint_name, sync=sync)
self._handle_cache[endpoint_name] = handle
return self._handle_cache[endpoint_name]
if not asyncio.get_event_loop().is_running() and not sync:
logger.warning(
"You are retrieving an async handle outside an asyncio loop. "
"You should make sure client.get_handle is called inside a "
"running event loop. Or call client.get_handle(.., sync=True) "
"to create sync handle. Learn more at https://docs.ray.io/en/"
"master/serve/advanced.html#sync-and-async-handles")
if sync:
handle = RayServeSyncHandle(
self._get_proxied_router(sync=sync), endpoint_name)
else:
handle = RayServeHandle(
self._get_proxied_router(sync=sync), endpoint_name)
return handle
def start(detached: bool = False,
@@ -492,9 +559,7 @@ def start(detached: bool = False,
max_task_retries=-1,
).remote(
controller_name,
http_host,
http_port,
http_middlewares,
HTTPConfig(http_host, http_port, http_middlewares),
detached=detached)
if http_host is not None:
+10 -10
View File
@@ -186,10 +186,10 @@ class RayServeReplica:
"backend_replica_starts",
description=("The number of time this replica "
"has been restarted due to failure."),
tag_keys=("backend", "replica_tag"))
tag_keys=("backend", "replica"))
self.restart_counter.set_default_tags({
"backend": self.backend_tag,
"replica_tag": self.replica_tag
"replica": self.replica_tag
})
self.queuing_latency_tracker = metrics.Histogram(
@@ -198,39 +198,39 @@ class RayServeReplica:
"The latency for queries waiting in the replica's queue "
"waiting to be processed or batched."),
boundaries=DEFAULT_LATENCY_BUCKET_MS,
tag_keys=("backend", "replica_tag"))
tag_keys=("backend", "replica"))
self.queuing_latency_tracker.set_default_tags({
"backend": self.backend_tag,
"replica_tag": self.replica_tag
"replica": self.replica_tag
})
self.processing_latency_tracker = metrics.Histogram(
"backend_processing_latency_ms",
description="The latency for queries to be processed",
boundaries=DEFAULT_LATENCY_BUCKET_MS,
tag_keys=("backend", "replica_tag", "batch_size"))
tag_keys=("backend", "replica", "batch_size"))
self.processing_latency_tracker.set_default_tags({
"backend": self.backend_tag,
"replica_tag": self.replica_tag
"replica": self.replica_tag
})
self.num_queued_items = metrics.Gauge(
"replica_queued_queries",
description=("Current number of queries queued in the "
"the backend replicas"),
tag_keys=("backend", "replica_tag"))
tag_keys=("backend", "replica"))
self.num_queued_items.set_default_tags({
"backend": self.backend_tag,
"replica_tag": self.replica_tag
"replica": self.replica_tag
})
self.num_processing_items = metrics.Gauge(
"replica_processing_queries",
description="Current number of queries being processed",
tag_keys=("backend", "replica_tag"))
tag_keys=("backend", "replica"))
self.num_processing_items.set_default_tags({
"backend": self.backend_tag,
"replica_tag": self.replica_tag
"replica": self.replica_tag
})
self.restart_counter.record(1)
+33
View File
@@ -0,0 +1,33 @@
from ray.serve.utils import import_class
class ImportedBackend:
"""Factory for a class that will dynamically import a backend class.
This is intended to be used when the source code for a backend is
installed in the worker environment but not the driver.
Intended usage:
>>> client = serve.connect()
>>> client.create_backend("b", ImportedBackend("module.Class"), *args)
This will import module.Class on the worker and proxy all relevant methods
to it.
"""
def __new__(cls, class_path):
class ImportedBackend:
def __init__(self, *args, **kwargs):
self.wrapped = import_class(class_path)(*args, **kwargs)
def reconfigure(self, *args, **kwargs):
# NOTE(edoakes): we check that the reconfigure method is
# present if the user specifies a user_config, so we need to
# proxy it manually.
return self.wrapped.reconfigure(*args, **kwargs)
def __getattr__(self, attr):
"""Proxy all other methods to the wrapper class."""
return getattr(self.wrapped, attr)
return ImportedBackend
+9 -2
View File
@@ -2,8 +2,8 @@ import inspect
from pydantic import BaseModel, PositiveInt, validator
from ray.serve.constants import ASYNC_CONCURRENCY
from typing import Optional, Dict, Any
from dataclasses import dataclass
from typing import Optional, Dict, Any, List
from dataclasses import dataclass, field
def _callable_accepts_batch(func_or_class):
@@ -191,3 +191,10 @@ class ReplicaConfig:
raise TypeError(
"resources in ray_actor_options must be a dictionary.")
self.resource_dict.update(custom_resources)
@dataclass
class HTTPConfig:
host: str = field(init=True)
port: int = field(init=True)
middlewares: List[Any] = field(init=True)
+283 -224
View File
@@ -20,7 +20,7 @@ from ray.serve.kv_store import RayInternalKVStore
from ray.serve.exceptions import RayServeException
from ray.serve.utils import (format_actor_name, get_random_letters, logger,
try_schedule_resources_on_nodes, get_all_node_ids)
from ray.serve.config import BackendConfig, ReplicaConfig
from ray.serve.config import BackendConfig, ReplicaConfig, HTTPConfig
from ray.serve.long_poll import LongPollHost
from ray.actor import ActorHandle
@@ -80,6 +80,77 @@ class TrafficPolicy:
return f"<Traffic {self.traffic_dict}; Shadow {self.shadow_dict}>"
class HTTPState:
def __init__(self, controller_name: str, detached: bool,
config: HTTPConfig):
self._controller_name = controller_name
self._detached = detached
self._config = config
self._proxy_actors: Dict[NodeId, ActorHandle] = dict()
# Will populate self.proxy_actors with existing actors.
self._start_proxies_if_needed()
def get_config(self):
return self._config
def get_http_proxy_handles(self) -> Dict[NodeId, ActorHandle]:
return self._proxy_actors
def update(self):
self._start_proxies_if_needed()
self._stop_proxies_if_needed()
def _start_proxies_if_needed(self) -> None:
"""Start a proxy on every node if it doesn't already exist."""
if self._config.host is None:
return
for node_id, node_resource in get_all_node_ids():
if node_id in self._proxy_actors:
continue
name = format_actor_name(SERVE_PROXY_NAME, self._controller_name,
node_id)
try:
proxy = ray.get_actor(name)
except ValueError:
logger.info("Starting HTTP proxy with name '{}' on node '{}' "
"listening on '{}:{}'".format(
name, node_id, self._config.host,
self._config.port))
proxy = HTTPProxyActor.options(
name=name,
lifetime="detached" if self._detached else None,
max_concurrency=ASYNC_CONCURRENCY,
max_restarts=-1,
max_task_retries=-1,
resources={
node_resource: 0.01
},
).remote(
self._config.host,
self._config.port,
controller_name=self._controller_name,
http_middlewares=self._config.middlewares)
self._proxy_actors[node_id] = proxy
def _stop_proxies_if_needed(self) -> bool:
"""Removes proxy actors from any nodes that no longer exist."""
all_node_ids = {node_id for node_id, _ in get_all_node_ids()}
to_stop = []
for node_id in self._proxy_actors:
if node_id not in all_node_ids:
logger.info("Removing HTTP proxy on removed node '{}'.".format(
node_id))
to_stop.append(node_id)
for node_id in to_stop:
proxy = self._proxy_actors.pop(node_id)
ray.kill(proxy, no_restart=True)
class BackendInfo(BaseModel):
# TODO(architkulkarni): Add type hint for worker_class after upgrading
# cloudpickle and adding types to RayServeWrappedReplica
@@ -93,17 +164,15 @@ class BackendInfo(BaseModel):
arbitrary_types_allowed = True
@dataclass
class SystemState:
backends: Dict[BackendTag, BackendInfo] = field(default_factory=dict)
traffic_policies: Dict[EndpointTag, TrafficPolicy] = field(
default_factory=dict)
routes: Dict[BackendTag, Tuple[EndpointTag, Any]] = field(
default_factory=dict)
class BackendState:
def __init__(self, checkpoint: bytes = None):
self.backends: Dict[BackendTag, BackendInfo] = dict()
backend_goal_ids: Dict[BackendTag, GoalId] = field(default_factory=dict)
traffic_goal_ids: Dict[EndpointTag, GoalId] = field(default_factory=dict)
route_goal_ids: Dict[BackendTag, GoalId] = field(default_factory=dict)
if checkpoint is not None:
self.backends = pickle.loads(checkpoint)
def checkpoint(self):
return pickle.dumps(self.backends)
def get_backend_configs(self) -> Dict[BackendTag, BackendConfig]:
return {
@@ -119,7 +188,18 @@ class SystemState:
backend_info: BackendInfo,
goal_id: GoalId = 0) -> None:
self.backends[backend_tag] = backend_info
self.backend_goal_ids = goal_id
class EndpointState:
def __init__(self, checkpoint: bytes = None):
self.routes: Dict[BackendTag, Tuple[EndpointTag, Any]] = dict()
self.traffic_policies: Dict[EndpointTag, TrafficPolicy] = dict()
if checkpoint is not None:
self.routes, self.traffic_policies = pickle.loads(checkpoint)
def checkpoint(self):
return pickle.dumps((self.routes, self.traffic_policies))
def get_endpoints(self) -> Dict[EndpointTag, Dict[str, Any]]:
endpoints = {}
@@ -146,7 +226,6 @@ class ActorStateReconciler:
controller_name: str = field(init=True)
detached: bool = field(init=True)
http_proxy_cache: Dict[NodeId, ActorHandle] = field(default_factory=dict)
backend_replicas: Dict[BackendTag, Dict[ReplicaTag, ActorHandle]] = field(
default_factory=lambda: defaultdict(dict))
backend_replicas_to_start: Dict[BackendTag, List[ReplicaTag]] = field(
@@ -154,13 +233,27 @@ class ActorStateReconciler:
backend_replicas_to_stop: Dict[BackendTag, List[ReplicaTag]] = field(
default_factory=lambda: defaultdict(list))
backends_to_remove: List[BackendTag] = field(default_factory=list)
endpoints_to_remove: List[EndpointTag] = field(default_factory=list)
# NOTE(ilr): These are not checkpointed, but will be recreated by
# `_enqueue_pending_scale_changes_loop`.
currently_starting_replicas: Dict[asyncio.Future, Tuple[
BackendTag, ReplicaTag, ActorHandle]] = field(default_factory=dict)
currently_stopping_replicas: Dict[asyncio.Future, Tuple[
BackendTag, ReplicaTag]] = field(default_factory=dict)
def __getstate__(self):
state = self.__dict__.copy()
del state["currently_stopping_replicas"]
del state["currently_starting_replicas"]
return state
def __setstate__(self, state):
self.__dict__.update(state)
self.currently_stopping_replicas = {}
self.currently_starting_replicas = {}
# TODO(edoakes): consider removing this and just using the names.
def http_proxy_handles(self) -> List[ActorHandle]:
return list(self.http_proxy_cache.values())
def get_replica_handles(self) -> List[ActorHandle]:
return list(
chain.from_iterable([
@@ -175,43 +268,7 @@ class ActorStateReconciler:
for replica_dict in self.backend_replicas.values()
]))
async def _start_pending_backend_replicas(
self, current_state: SystemState) -> None:
"""Starts the pending backend replicas in self.backend_replicas_to_start.
Waits for replicas to start up, then removes them from
self.backend_replicas_to_start.
"""
fut_to_replica_info = {}
for backend_tag, replicas_to_create in self.backend_replicas_to_start.\
items():
for replica_tag in replicas_to_create:
replica_handle = await self._start_backend_replica(
current_state, backend_tag, replica_tag)
ready_future = replica_handle.ready.remote().as_future()
fut_to_replica_info[ready_future] = (backend_tag, replica_tag,
replica_handle)
start = time.time()
prev_warning = start
while fut_to_replica_info:
if time.time() - prev_warning > REPLICA_STARTUP_TIME_WARNING_S:
prev_warning = time.time()
logger.warning("Waited {:.2f}s for replicas to start up. Make "
"sure there are enough resources to create the "
"replicas.".format(time.time() - start))
done, pending = await asyncio.wait(
list(fut_to_replica_info.keys()), timeout=1)
for fut in done:
(backend_tag, replica_tag,
replica_handle) = fut_to_replica_info.pop(fut)
self.backend_replicas[backend_tag][
replica_tag] = replica_handle
self.backend_replicas_to_start.clear()
async def _start_backend_replica(self, current_state: SystemState,
async def _start_backend_replica(self, backend_state: BackendState,
backend_tag: BackendTag,
replica_tag: ReplicaTag) -> ActorHandle:
"""Start a replica and return its actor handle.
@@ -229,7 +286,7 @@ class ActorStateReconciler:
except ValueError:
logger.debug("Starting replica '{}' for backend '{}'.".format(
replica_tag, backend_tag))
backend_info = current_state.get_backend(backend_tag)
backend_info = backend_state.get_backend(backend_tag)
replica_handle = ray.remote(backend_info.worker_class).options(
name=replica_name,
@@ -255,6 +312,7 @@ class ActorStateReconciler:
intended replicas. This avoids inconsistencies with starting/stopping a
replica and then crashing before writing a checkpoint.
"""
logger.debug("Scaling backend '{}' to {} replicas".format(
backend_tag, num_replicas))
assert (backend_tag in backends
@@ -301,97 +359,104 @@ class ActorStateReconciler:
self.backend_replicas_to_stop[backend_tag].append(replica_tag)
async def _stop_pending_backend_replicas(self) -> None:
"""Stops the pending backend replicas in self.backend_replicas_to_stop.
async def _enqueue_pending_scale_changes_loop(self,
backend_state: BackendState):
for backend_tag, replicas_to_create in self.backend_replicas_to_start.\
items():
for replica_tag in replicas_to_create:
replica_handle = await self._start_backend_replica(
backend_state, backend_tag, replica_tag)
ready_future = replica_handle.ready.remote().as_future()
self.currently_starting_replicas[ready_future] = (
backend_tag, replica_tag, replica_handle)
Removes backend_replicas from the http_proxy, kills them, and clears
self.backend_replicas_to_stop.
"""
for backend_tag, replicas_list in self.backend_replicas_to_stop.items(
):
for replica_tag in replicas_list:
# NOTE(edoakes): the replicas may already be stopped if we
# failed after stopping them but before writing a checkpoint.
for backend_tag, replicas_to_stop in self.backend_replicas_to_stop.\
items():
for replica_tag in replicas_to_stop:
replica_name = format_actor_name(replica_tag,
self.controller_name)
try:
replica = ray.get_actor(replica_name)
except ValueError:
continue
# TODO(edoakes): this logic isn't ideal because there may be
# pending tasks still executing on the replica. However, if we
# use replica.__ray_terminate__, we may send it while the
# replica is being restarted and there's no way to tell if it
# successfully killed the worker or not.
ray.kill(replica, no_restart=True)
async def kill_actor(replica_name_to_use):
# NOTE: the replicas may already be stopped if we failed
# after stopping them but before writing a checkpoint.
try:
replica = ray.get_actor(replica_name_to_use)
except ValueError:
return
self.backend_replicas_to_stop.clear()
# TODO(edoakes): this logic isn't ideal because there may
# be pending tasks still executing on the replica. However,
# if we use replica.__ray_terminate__, we may send it while
# the replica is being restarted and there's no way to tell
# if it successfully killed the worker or not.
ray.kill(replica, no_restart=True)
def _start_http_proxies_if_needed(self, http_host: str, http_port: str,
http_middlewares: List[Any]) -> None:
"""Start an HTTP proxy on every node if it doesn't already exist."""
if http_host is None:
return
self.currently_stopping_replicas[asyncio.ensure_future(
kill_actor(replica_name))] = (backend_tag, replica_tag)
for node_id, node_resource in get_all_node_ids():
if node_id in self.http_proxy_cache:
continue
async def _check_currently_starting_replicas(self) -> bool:
"""Returns a boolean specifying if there are more replicas to start"""
in_flight = list()
name = format_actor_name(SERVE_PROXY_NAME, self.controller_name,
node_id)
try:
proxy = ray.get_actor(name)
except ValueError:
logger.info("Starting HTTP proxy with name '{}' on node '{}' "
"listening on '{}:{}'".format(
name, node_id, http_host, http_port))
proxy = HTTPProxyActor.options(
name=name,
lifetime="detached" if self.detached else None,
max_concurrency=ASYNC_CONCURRENCY,
max_restarts=-1,
max_task_retries=-1,
resources={
node_resource: 0.01
},
).remote(
http_host,
http_port,
controller_name=self.controller_name,
http_middlewares=http_middlewares)
if self.currently_starting_replicas:
done, in_flight = await asyncio.wait(
list(self.currently_starting_replicas.keys()), timeout=0)
for fut in done:
(backend_tag, replica_tag,
replica_handle) = self.currently_starting_replicas.pop(fut)
self.backend_replicas[backend_tag][
replica_tag] = replica_handle
self.http_proxy_cache[node_id] = proxy
backend = self.backend_replicas_to_start.get(backend_tag)
if backend:
try:
backend.remove(replica_tag)
except ValueError:
pass
if len(backend) == 0:
del self.backend_replicas_to_start[backend_tag]
return len(in_flight) > 0
def _stop_http_proxies_if_needed(self) -> bool:
"""Removes HTTP proxy actors from any nodes that no longer exist.
async def _check_currently_stopping_replicas(self) -> bool:
"""Returns a boolean specifying if there are more replicas to stop"""
in_flight = list()
if self.currently_stopping_replicas:
done_stoppping, in_flight = await asyncio.wait(
list(self.currently_stopping_replicas.keys()), timeout=0)
for fut in done_stoppping:
(backend_tag,
replica_tag) = self.currently_stopping_replicas.pop(fut)
Returns whether or not any actors were removed (a checkpoint should
be taken).
"""
actor_stopped = False
all_node_ids = {node_id for node_id, _ in get_all_node_ids()}
to_stop = []
for node_id in self.http_proxy_cache:
if node_id not in all_node_ids:
logger.info("Removing HTTP proxy on removed node '{}'.".format(
node_id))
to_stop.append(node_id)
backend = self.backend_replicas_to_stop.get(backend_tag)
for node_id in to_stop:
proxy = self.http_proxy_cache.pop(node_id)
ray.kill(proxy, no_restart=True)
actor_stopped = True
if backend:
try:
backend.remove(replica_tag)
except ValueError:
pass
if len(backend) == 0:
del self.backend_replicas_to_stop[backend_tag]
return actor_stopped
return len(in_flight) > 0
async def backend_control_loop(self):
start = time.time()
prev_warning = start
need_to_continue = True
while need_to_continue:
if time.time() - prev_warning > REPLICA_STARTUP_TIME_WARNING_S:
prev_warning = time.time()
logger.warning("Waited {:.2f}s for replicas to start up. Make "
"sure there are enough resources to create the "
"replicas.".format(time.time() - start))
need_to_continue = (
await self._check_currently_starting_replicas()
or await self._check_currently_stopping_replicas())
asyncio.sleep(1)
def _recover_actor_handles(self) -> None:
# Refresh the RouterCache
for node_id in self.http_proxy_cache.keys():
name = format_actor_name(SERVE_PROXY_NAME, self.controller_name,
node_id)
self.http_proxy_cache[node_id] = ray.get_actor(name)
# Fetch actor handles for all of the backend replicas in the system.
# All of these backend_replicas are guaranteed to already exist because
# they would not be written to a checkpoint in self.backend_replicas
@@ -404,20 +469,20 @@ class ActorStateReconciler:
replica_tag] = ray.get_actor(replica_name)
async def _recover_from_checkpoint(
self, current_state: SystemState, controller: "ServeController"
self, backend_state: BackendState, controller: "ServeController"
) -> Dict[BackendTag, BasicAutoscalingPolicy]:
self._recover_actor_handles()
autoscaling_policies = dict()
for backend, info in current_state.backends.items():
for backend, info in backend_state.backends.items():
metadata = info.backend_config.internal_metadata
if metadata.autoscaling_config is not None:
autoscaling_policies[backend] = BasicAutoscalingPolicy(
backend, metadata.autoscaling_config)
# Start/stop any pending backend replicas.
await self._start_pending_backend_replicas(current_state)
await self._stop_pending_backend_replicas()
await self._enqueue_pending_scale_changes_loop(backend_state)
await self.backend_control_loop()
return autoscaling_policies
@@ -430,8 +495,8 @@ class FutureResult:
@dataclass
class Checkpoint:
goal_state: SystemState
current_state: SystemState
endpoint_state_checkpoint: bytes
backend_state_checkpoint: bytes
reconciler: ActorStateReconciler
# TODO(ilr) Rename reconciler to PendingState
inflight_reqs: Dict[uuid4, FutureResult]
@@ -465,19 +530,10 @@ class ServeController:
async def __init__(self,
controller_name: str,
http_host: str,
http_port: str,
http_middlewares: List[Any],
http_config: HTTPConfig,
detached: bool = False):
# Used to read/write checkpoints.
self.kv_store = RayInternalKVStore(namespace=controller_name)
# Current State
self.current_state = SystemState()
# Goal State
# TODO(ilr) This is currently *unused* until the refactor of the serve
# controller.
self.goal_state = SystemState()
# ActorStateReconciler
self.actor_reconciler = ActorStateReconciler(controller_name, detached)
# backend -> AutoscalingPolicy
@@ -490,24 +546,25 @@ class ServeController:
# at any given time.
self.write_lock = asyncio.Lock()
self.http_host = http_host
self.http_port = http_port
self.http_middlewares = http_middlewares
# If starting the actor for the first time, starts up the other system
# components. If recovering, fetches their actor handles.
self.actor_reconciler._start_http_proxies_if_needed(
self.http_host, self.http_port, self.http_middlewares)
# Map of awaiting results
# TODO(ilr): Checkpoint this once this becomes asynchronous
self.inflight_results: Dict[UUID, asyncio.Event] = dict()
self._serializable_inflight_results: Dict[UUID, FutureResult] = dict()
checkpoint = self.kv_store.get(CHECKPOINT_KEY)
if checkpoint is None:
# HTTP state doesn't currently require a checkpoint.
self.http_state = HTTPState(controller_name, detached, http_config)
checkpoint_bytes = self.kv_store.get(CHECKPOINT_KEY)
if checkpoint_bytes is None:
logger.debug("No checkpoint found")
self.backend_state = BackendState()
self.endpoint_state = EndpointState()
else:
checkpoint: Checkpoint = pickle.loads(checkpoint_bytes)
self.backend_state = BackendState(
checkpoint=checkpoint.backend_state_checkpoint)
self.endpoint_state = EndpointState(
checkpoint=checkpoint.endpoint_state_checkpoint)
await self._recover_from_checkpoint(checkpoint)
# NOTE(simon): Currently we do all-to-all broadcast. This means
@@ -566,17 +623,17 @@ class ServeController:
def notify_traffic_policies_changed(self):
self.long_poll_host.notify_changed(
LongPollKey.TRAFFIC_POLICIES,
self.current_state.traffic_policies,
self.endpoint_state.traffic_policies,
)
def notify_backend_configs_changed(self):
self.long_poll_host.notify_changed(
LongPollKey.BACKEND_CONFIGS,
self.current_state.get_backend_configs())
self.backend_state.get_backend_configs())
def notify_route_table_changed(self):
self.long_poll_host.notify_changed(LongPollKey.ROUTE_TABLE,
self.current_state.routes)
self.endpoint_state.routes)
async def listen_for_change(self, keys_to_snapshot_ids: Dict[str, int]):
"""Proxy long pull client's listen request.
@@ -589,9 +646,9 @@ class ServeController:
return await (
self.long_poll_host.listen_for_change(keys_to_snapshot_ids))
def get_http_proxies(self) -> Dict[str, ActorHandle]:
def get_http_proxies(self) -> Dict[NodeId, ActorHandle]:
"""Returns a dictionary of node ID to http_proxy actor handles."""
return self.actor_reconciler.http_proxy_cache
return self.http_state.get_http_proxy_handles()
def _checkpoint(self) -> None:
"""Checkpoint internal state and write it to the KV store."""
@@ -600,19 +657,19 @@ class ServeController:
start = time.time()
checkpoint = pickle.dumps(
Checkpoint(self.goal_state, self.current_state,
self.actor_reconciler,
Checkpoint(self.endpoint_state.checkpoint(),
self.backend_state.checkpoint(), self.actor_reconciler,
self._serializable_inflight_results))
self.kv_store.put(CHECKPOINT_KEY, checkpoint)
logger.debug("Wrote checkpoint in {:.2f}".format(time.time() - start))
logger.debug("Wrote checkpoint in {:.3f}s".format(time.time() - start))
if random.random(
) < _CRASH_AFTER_CHECKPOINT_PROBABILITY and self.detached:
logger.warning("Intentionally crashing after checkpoint")
os._exit(0)
async def _recover_from_checkpoint(self, checkpoint_bytes: bytes) -> None:
async def _recover_from_checkpoint(self, checkpoint: Checkpoint) -> None:
"""Recover the instance state from the provided checkpoint.
This should be called in the constructor to ensure that the internal
@@ -627,12 +684,9 @@ class ServeController:
start = time.time()
logger.info("Recovering from checkpoint")
restored_checkpoint: Checkpoint = pickle.loads(checkpoint_bytes)
self.current_state = restored_checkpoint.current_state
self.actor_reconciler = checkpoint.reconciler
self.actor_reconciler = restored_checkpoint.reconciler
self._serializable_inflight_results = restored_checkpoint.inflight_reqs
self._serializable_inflight_results = checkpoint.inflight_reqs
for uuid, fut_result in self._serializable_inflight_results.items():
self._create_event_with_result(fut_result.requested_goal, uuid)
@@ -652,7 +706,7 @@ class ServeController:
async def finish_recover_from_checkpoint():
assert self.write_lock.locked()
self.autoscaling_policies = await self.actor_reconciler.\
_recover_from_checkpoint(self.current_state, self)
_recover_from_checkpoint(self.backend_state, self)
self.write_lock.release()
logger.info(
"Recovered from checkpoint in {:.3f}s".format(time.time() -
@@ -662,7 +716,7 @@ class ServeController:
asyncio.get_event_loop().create_task(finish_recover_from_checkpoint())
async def do_autoscale(self) -> None:
for backend, info in self.current_state.backends.items():
for backend, info in self.backend_state.backends.items():
if backend not in self.autoscaling_policies:
continue
@@ -672,16 +726,14 @@ class ServeController:
await self.update_backend_config(
backend, BackendConfig(num_replicas=new_num_replicas))
async def reconcile_current_and_goal_backends(self):
pass
async def run_control_loop(self) -> None:
while True:
await self.do_autoscale()
async with self.write_lock:
self.actor_reconciler._start_http_proxies_if_needed(
self.http_host, self.http_port, self.http_middlewares)
checkpoint_required = self.actor_reconciler.\
_stop_http_proxies_if_needed()
if checkpoint_required:
self._checkpoint()
self.http_state.update()
await asyncio.sleep(CONTROL_LOOP_PERIOD_S)
@@ -692,15 +744,15 @@ class ServeController:
def get_all_backends(self) -> Dict[BackendTag, BackendConfig]:
"""Returns a dictionary of backend tag to backend config."""
return self.current_state.get_backend_configs()
return self.backend_state.get_backend_configs()
def get_all_endpoints(self) -> Dict[EndpointTag, Dict[BackendTag, Any]]:
"""Returns a dictionary of backend tag to backend config."""
return self.current_state.get_endpoints()
return self.endpoint_state.get_endpoints()
async def _set_traffic(self, endpoint_name: str,
traffic_dict: Dict[str, float]) -> UUID:
if endpoint_name not in self.current_state.get_endpoints():
if endpoint_name not in self.endpoint_state.get_endpoints():
raise ValueError("Attempted to assign traffic for an endpoint '{}'"
" that is not registered.".format(endpoint_name))
@@ -708,13 +760,13 @@ class ServeController:
dict), "Traffic policy must be a dictionary."
for backend in traffic_dict:
if self.current_state.get_backend(backend) is None:
if self.backend_state.get_backend(backend) is None:
raise ValueError(
"Attempted to assign traffic to a backend '{}' that "
"is not registered.".format(backend))
traffic_policy = TrafficPolicy(traffic_dict)
self.current_state.traffic_policies[endpoint_name] = traffic_policy
self.endpoint_state.traffic_policies[endpoint_name] = traffic_policy
return_uuid = self._create_event_with_result({
endpoint_name: traffic_policy
@@ -737,20 +789,21 @@ class ServeController:
proportion: float) -> UUID:
"""Shadow traffic from the endpoint to the backend."""
async with self.write_lock:
if endpoint_name not in self.current_state.get_endpoints():
if endpoint_name not in self.endpoint_state.get_endpoints():
raise ValueError("Attempted to shadow traffic from an "
"endpoint '{}' that is not registered."
.format(endpoint_name))
if self.current_state.get_backend(backend_tag) is None:
if self.backend_state.get_backend(backend_tag) is None:
raise ValueError(
"Attempted to shadow traffic to a backend '{}' that "
"is not registered.".format(backend_tag))
self.current_state.traffic_policies[endpoint_name].set_shadow(
self.endpoint_state.traffic_policies[endpoint_name].set_shadow(
backend_tag, proportion)
traffic_policy = self.current_state.traffic_policies[endpoint_name]
traffic_policy = self.endpoint_state.traffic_policies[
endpoint_name]
return_uuid = self._create_event_with_result({
endpoint_name: traffic_policy
@@ -781,10 +834,10 @@ class ServeController:
# TODO(edoakes): move this to client side.
err_prefix = "Cannot create endpoint."
if route in self.current_state.routes:
if route in self.endpoint_state.routes:
# Ensures this method is idempotent
if self.current_state.routes[route] == (endpoint, methods):
if self.endpoint_state.routes[route] == (endpoint, methods):
return
else:
@@ -792,7 +845,7 @@ class ServeController:
"{} Route '{}' is already registered.".format(
err_prefix, route))
if endpoint in self.current_state.get_endpoints():
if endpoint in self.endpoint_state.get_endpoints():
raise ValueError(
"{} Endpoint '{}' is already registered.".format(
err_prefix, endpoint))
@@ -801,7 +854,7 @@ class ServeController:
"Registering route '{}' to endpoint '{}' with methods '{}'.".
format(route, endpoint, methods))
self.current_state.routes[route] = (endpoint, methods)
self.endpoint_state.routes[route] = (endpoint, methods)
# NOTE(edoakes): checkpoint is written in self._set_traffic.
return_uuid = await self._set_traffic(endpoint, traffic_dict)
@@ -818,7 +871,7 @@ class ServeController:
# This method must be idempotent. We should validate that the
# specified endpoint exists on the client.
for route, (route_endpoint,
_) in self.current_state.routes.items():
_) in self.endpoint_state.routes.items():
if route_endpoint == endpoint:
route_to_delete = route
break
@@ -827,13 +880,11 @@ class ServeController:
return
# Remove the routing entry.
del self.current_state.routes[route_to_delete]
del self.endpoint_state.routes[route_to_delete]
# Remove the traffic policy entry if it exists.
if endpoint in self.current_state.traffic_policies:
del self.current_state.traffic_policies[endpoint]
self.actor_reconciler.endpoints_to_remove.append(endpoint)
if endpoint in self.endpoint_state.traffic_policies:
del self.endpoint_state.traffic_policies[endpoint]
return_uuid = self._create_event_with_result({
route_to_delete: None,
@@ -852,7 +903,7 @@ class ServeController:
"""Register a new backend under the specified tag."""
async with self.write_lock:
# Ensures this method is idempotent.
backend_info = self.current_state.get_backend(backend_tag)
backend_info = self.backend_state.get_backend(backend_tag)
if backend_info is not None:
if (backend_info.backend_config == backend_config
and backend_info.replica_config == replica_config):
@@ -867,7 +918,7 @@ class ServeController:
worker_class=backend_replica,
backend_config=backend_config,
replica_config=replica_config)
self.current_state.add_backend(backend_tag, backend_info)
self.backend_state.add_backend(backend_tag, backend_info)
metadata = backend_config.internal_metadata
if metadata.autoscaling_config is not None:
self.autoscaling_policies[
@@ -875,11 +926,12 @@ class ServeController:
backend_tag, metadata.autoscaling_config)
try:
# This call should be to run control loop
self.actor_reconciler._scale_backend_replicas(
self.current_state.backends, backend_tag,
self.backend_state.backends, backend_tag,
backend_config.num_replicas)
except RayServeException as e:
del self.current_state.backends[backend_tag]
del self.backend_state.backends[backend_tag]
raise e
return_uuid = self._create_event_with_result({
@@ -889,8 +941,9 @@ class ServeController:
# or pushing the updated config to avoid inconsistent state if we
# crash while making the change.
self._checkpoint()
await self.actor_reconciler._start_pending_backend_replicas(
self.current_state)
await self.actor_reconciler._enqueue_pending_scale_changes_loop(
self.backend_state)
await self.actor_reconciler.backend_control_loop()
self.notify_replica_handles_changed()
@@ -903,11 +956,11 @@ class ServeController:
async with self.write_lock:
# This method must be idempotent. We should validate that the
# specified backend exists on the client.
if self.current_state.get_backend(backend_tag) is None:
if self.backend_state.get_backend(backend_tag) is None:
return
# Check that the specified backend isn't used by any endpoints.
for endpoint, traffic_policy in self.current_state.\
for endpoint, traffic_policy in self.endpoint_state.\
traffic_policies.items():
if (backend_tag in traffic_policy.traffic_dict
or backend_tag in traffic_policy.shadow_dict):
@@ -917,13 +970,15 @@ class ServeController:
"again.".format(backend_tag, endpoint))
# Scale its replicas down to 0. This will also remove the backend
# from self.current_state.backends and
# from self.backend_state.backends and
# self.actor_reconciler.backend_replicas.
# This should be a call to the control loop
self.actor_reconciler._scale_backend_replicas(
self.current_state.backends, backend_tag, 0)
self.backend_state.backends, backend_tag, 0)
# Remove the backend's metadata.
del self.current_state.backends[backend_tag]
del self.backend_state.backends[backend_tag]
if backend_tag in self.autoscaling_policies:
del self.autoscaling_policies[backend_tag]
@@ -935,7 +990,9 @@ class ServeController:
# backend from the routers to avoid inconsistent state if we crash
# after pushing the update.
self._checkpoint()
await self.actor_reconciler._stop_pending_backend_replicas()
await self.actor_reconciler._enqueue_pending_scale_changes_loop(
self.backend_state)
await self.actor_reconciler.backend_control_loop()
self.notify_replica_handles_changed()
return return_uuid
@@ -944,22 +1001,24 @@ class ServeController:
config_options: BackendConfig) -> UUID:
"""Set the config for the specified backend."""
async with self.write_lock:
assert (self.current_state.get_backend(backend_tag)
assert (self.backend_state.get_backend(backend_tag)
), "Backend {} is not registered.".format(backend_tag)
assert isinstance(config_options, BackendConfig)
stored_backend_config = self.current_state.get_backend(
stored_backend_config = self.backend_state.get_backend(
backend_tag).backend_config
backend_config = stored_backend_config.copy(
update=config_options.dict(exclude_unset=True))
backend_config._validate_complete()
self.current_state.get_backend(
self.backend_state.get_backend(
backend_tag).backend_config = backend_config
backend_info = self.current_state.get_backend(backend_tag)
backend_info = self.backend_state.get_backend(backend_tag)
# Scale the replicas with the new configuration.
# This should be to run the control loop
self.actor_reconciler._scale_backend_replicas(
self.current_state.backends, backend_tag,
self.backend_state.backends, backend_tag,
backend_config.num_replicas)
return_uuid = self._create_event_with_result({
@@ -973,9 +1032,9 @@ class ServeController:
# Inform the routers about change in configuration
# (particularly for setting max_batch_size).
await self.actor_reconciler._start_pending_backend_replicas(
self.current_state)
await self.actor_reconciler._stop_pending_backend_replicas()
await self.actor_reconciler._enqueue_pending_scale_changes_loop(
self.backend_state)
await self.actor_reconciler.backend_control_loop()
self.notify_replica_handles_changed()
self.notify_backend_configs_changed()
@@ -983,19 +1042,19 @@ class ServeController:
def get_backend_config(self, backend_tag: BackendTag) -> BackendConfig:
"""Get the current config for the specified backend."""
assert (self.current_state.get_backend(backend_tag)
assert (self.backend_state.get_backend(backend_tag)
), "Backend {} is not registered.".format(backend_tag)
return self.current_state.get_backend(backend_tag).backend_config
return self.backend_state.get_backend(backend_tag).backend_config
def get_http_config(self):
"""Return the HTTP proxy configuration."""
return self.http_host, self.http_port
return self.http_state.get_config()
async def shutdown(self) -> None:
"""Shuts down the serve instance completely."""
async with self.write_lock:
for http_proxy in self.actor_reconciler.http_proxy_handles():
ray.kill(http_proxy, no_restart=True)
for proxy in self.http_state.get_http_proxy_handles().values():
ray.kill(proxy, no_restart=True)
for replica in self.actor_reconciler.get_replica_handles():
ray.kill(replica, no_restart=True)
self.kv_store.delete(CHECKPOINT_KEY)
@@ -1,10 +1,8 @@
import requests
import ray
from ray import serve
from ray.serve import CondaEnv
import tensorflow as tf
ray.init()
client = serve.start()
@@ -0,0 +1,12 @@
import requests
from ray import serve
from ray.serve.backends import ImportedBackend
client = serve.start()
backend_class = ImportedBackend("ray.serve.utils.MockImportedBackend")
client.create_backend("imported", backend_class, "input_arg")
client.create_endpoint("imported", backend="imported", route="/imported")
print(requests.get("http://127.0.0.1:8000/imported").text)
@@ -10,7 +10,7 @@ class Counter:
def __init__(self):
self.count = 0
def __call__(self, flask_request):
def __call__(self, starlette_request):
self.count += 1
return {"current_counter": self.count}
@@ -6,8 +6,8 @@ ray.init(num_cpus=8)
client = serve.start()
def echo(flask_request):
return "hello " + flask_request.args.get("name", "serve!")
def echo(starlette_request):
return "hello " + starlette_request.query_params.get("name", "serve!")
client.create_backend("hello", echo)
@@ -16,13 +16,13 @@ client = serve.start()
def model_one(request):
print("Model 1 called with data ", request.args.get("data"))
print("Model 1 called with data ", request.query_params.get("data"))
return random()
def model_two(request):
print("Model 2 called with data ", request.args.get("data"))
return request.args.get("data")
print("Model 2 called with data ", request.query_params.get("data"))
return request.query_params.get("data")
class ComposedModel:
@@ -32,8 +32,8 @@ class ComposedModel:
self.model_two = client.get_handle("model_two")
# This method can be called concurrently!
async def __call__(self, flask_request):
data = flask_request.data
async def __call__(self, starlette_request):
data = await starlette_request.body()
score = await self.model_one.remote(data=data)
if score > 0.5:
@@ -14,8 +14,10 @@ import requests
# __doc_define_servable_v0_begin__
@serve.accept_batch
def batch_adder_v0(flask_requests: List):
numbers = [int(request.args["number"]) for request in flask_requests]
def batch_adder_v0(starlette_requests: List):
numbers = [
int(request.query_params["number"]) for request in starlette_requests
]
input_array = np.array(numbers)
print("Our input array has shape:", input_array.shape)
@@ -58,7 +60,7 @@ print("Result returned:", results)
# __doc_define_servable_v1_begin__
@serve.accept_batch
def batch_adder_v1(requests: List):
numbers = [int(request.args["number"]) for request in requests]
numbers = [int(request.query_params["number"]) for request in requests]
input_array = np.array(numbers)
print("Our input array has shape:", input_array.shape)
# Sleep for 200ms, this could be performing CPU intensive computation
@@ -48,9 +48,9 @@ class BoostingModel:
with open("/tmp/iris_labels.json") as f:
self.label_list = json.load(f)
def __call__(self, flask_request):
payload = flask_request.json
print("Worker: received flask request with data", payload)
async def __call__(self, starlette_request):
payload = await starlette_request.json()
print("Worker: received starlette request with data", payload)
input_vector = [
payload["sepal length"],
@@ -143,9 +143,9 @@ class BoostingModelv2:
with open("/tmp/iris_labels_2.json") as f:
self.label_list = json.load(f)
def __call__(self, flask_request):
payload = flask_request.json
print("Worker: received flask request with data", payload)
async def __call__(self, starlette_request):
payload = await starlette_request.json()
print("Worker: received starlette request with data", payload)
input_vector = [
payload["sepal length"],
@@ -27,8 +27,8 @@ class ImageModel:
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def __call__(self, flask_request):
image_payload_bytes = flask_request.data
async def __call__(self, starlette_request):
image_payload_bytes = await starlette_request.body()
pil_image = Image.open(BytesIO(image_payload_bytes))
print("[1/3] Parsed image data: {}".format(pil_image))
@@ -54,9 +54,9 @@ class BoostingModel:
with open(LABEL_PATH) as f:
self.label_list = json.load(f)
def __call__(self, flask_request):
payload = flask_request.json
print("Worker: received flask request with data", payload)
async def __call__(self, starlette_request):
payload = await starlette_request.json()
print("Worker: received starlette request with data", payload)
input_vector = [
payload["sepal length"],
@@ -51,10 +51,10 @@ class TFMnistModel:
self.model_path = model_path
self.model = tf.keras.models.load_model(model_path)
def __call__(self, flask_request):
async def __call__(self, starlette_request):
# Step 1: transform HTTP request -> tensorflow input
# Here we define the request schema to be a json array.
input_array = np.array(flask_request.json["array"])
input_array = np.array((await starlette_request.json())["array"])
reshaped_array = input_array.reshape((1, 28, 28))
# Step 2: tensorflow input -> tensorflow output
+2 -2
View File
@@ -9,8 +9,8 @@ import requests
from ray import serve
def echo(flask_request):
return ["hello " + flask_request.args.get("name", "serve!")]
def echo(starlette_request):
return ["hello " + starlette_request.query_params.get("name", "serve!")]
client = serve.start()
+4 -3
View File
@@ -1,6 +1,6 @@
"""
Example actor that adds an increment to a number. This number can
come from either web (parsing Flask request) or python call.
come from either web (parsing Starlette request) or python call.
This actor can be called from HTTP as well as from Python.
"""
@@ -30,9 +30,10 @@ class MagicCounter:
def __init__(self, increment):
self.increment = increment
def __call__(self, flask_request, base_number=None):
def __call__(self, starlette_request, base_number=None):
if serve.context.web:
base_number = int(flask_request.args.get("base_number", "0"))
base_number = int(
starlette_request.query_params.get("base_number", "0"))
return base_number + self.increment
@@ -1,6 +1,6 @@
"""
Example actor that adds an increment to a number. This number can
come from either web (parsing Flask request) or python call.
come from either web (parsing Starlette request) or python call.
The queries incoming to this actor are batched.
This actor can be called from HTTP as well as from Python.
"""
@@ -31,12 +31,13 @@ class MagicCounter:
self.increment = increment
@serve.accept_batch
def __call__(self, flask_request_list, base_number=None):
def __call__(self, starlette_request_list, base_number=None):
# batch_size = serve.context.batch_size
if serve.context.web:
result = []
for flask_request in flask_request_list:
base_number = int(flask_request.args.get("base_number", "0"))
for starlette_request in starlette_request_list:
base_number = int(
starlette_request.query_params.get("base_number", "0"))
result.append(base_number)
return list(map(lambda x: x + self.increment, result))
else:
+1 -1
View File
@@ -11,7 +11,7 @@ class MagicCounter:
self.increment = increment
@serve.accept_batch
def __call__(self, flask_request, base_number=None):
def __call__(self, starlette_request, base_number=None):
# __call__ fn should preserve the batch size
# base_number is a python list
+3 -3
View File
@@ -12,8 +12,8 @@ client = serve.start()
# a backend can be a function or class.
# it can be made to be invoked from web as well as python.
def echo_v1(flask_request):
response = flask_request.args.get("response", "web")
def echo_v1(starlette_request):
response = starlette_request.query_params.get("response", "web")
return response
@@ -32,7 +32,7 @@ print(ray.get(client.get_handle("my_endpoint").remote(response="hello")))
# We can also add a new backend and split the traffic.
def echo_v2(flask_request):
def echo_v2(starlette_request):
# magic, only from web.
return "something new"
+80 -115
View File
@@ -1,27 +1,23 @@
import asyncio
import concurrent.futures
import threading
from typing import Any, Coroutine, Dict, Optional, Union
import ray
from ray.serve.context import TaskContext
from ray.serve.router import RequestMetadata, Router
from ray.serve.utils import get_random_letters
from ray.serve.exceptions import RayServeException
global_async_loop = None
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Union
from enum import Enum
def create_or_get_async_loop_in_thread():
global global_async_loop
if global_async_loop is None:
global_async_loop = asyncio.new_event_loop()
thread = threading.Thread(
daemon=True,
target=global_async_loop.run_forever,
)
thread.start()
return global_async_loop
@dataclass(frozen=True)
class HandleOptions:
"""Options for each ServeHandle instances. These fields are immutable."""
method_name: str = "__call__"
shard_key: Optional[str] = None
http_method: str = "GET"
http_headers: Dict[str, str] = field(default_factory=dict)
# Use a global singleton enum to emulate default options. We cannot use None
# for those option because None is a valid new value.
class DEFAULT(Enum):
VALUE = 1
class RayServeHandle:
@@ -31,75 +27,83 @@ class RayServeHandle:
an HTTP endpoint.
Example:
>>> handle = serve.get_handle("my_endpoint")
>>> handle = serve_client.get_handle("my_endpoint")
>>> handle
RayServeHandle(
Endpoint="my_endpoint",
Traffic=...
)
>>> handle.remote(my_request_content)
RayServeHandle(endpoint="my_endpoint")
>>> await handle.remote(my_request_content)
ObjectRef(...)
>>> ray.get(handle.remote(...))
>>> ray.get(await handle.remote(...))
# result
>>> ray.get(handle.remote(let_it_crash_request))
>>> ray.get(await handle.remote(let_it_crash_request))
# raises RayTaskError Exception
"""
def __init__(
self,
controller_handle,
endpoint_name,
sync: bool,
*,
method_name=None,
shard_key=None,
http_method=None,
http_headers=None,
):
self.controller_handle = controller_handle
def __init__(self,
router,
endpoint_name,
handle_options: Optional[HandleOptions] = None):
self.router = router
self.endpoint_name = endpoint_name
self.handle_options = handle_options or HandleOptions()
self.method_name = method_name
self.shard_key = shard_key
self.http_method = http_method
self.http_headers = http_headers
def options(self,
*,
method_name: Union[str, DEFAULT] = DEFAULT.VALUE,
shard_key: Union[str, DEFAULT] = DEFAULT.VALUE,
http_method: Union[str, DEFAULT] = DEFAULT.VALUE,
http_headers: Union[Dict[str, str], DEFAULT] = DEFAULT.VALUE):
"""Set options for this handle.
self.router = Router(self.controller_handle)
self.sync = sync
# In the synchrounous mode, we create a new event loop in a separate
# thread and run the Router.setup in that loop. In the async mode, we
# can just use the current loop we are in right now.
if self.sync:
self.async_loop = create_or_get_async_loop_in_thread()
asyncio.run_coroutine_threadsafe(
self.router.setup_in_async_loop(),
self.async_loop,
)
else: # async
self.async_loop = asyncio.get_event_loop()
# create_task is not threadsafe.
self.async_loop.create_task(self.router.setup_in_async_loop())
Args:
method_name(str): The method to invoke on the backend.
http_method(str): The HTTP method to use for the request.
shard_key(str): A string to use to deterministically map this
request to a backend if there are multiple for this endpoint.
"""
new_options_dict = self.handle_options.__dict__.copy()
user_modified_options_dict = {
key: value
for key, value in
zip(["method_name", "shard_key", "http_method", "http_headers"],
[method_name, shard_key, http_method, http_headers])
if value != DEFAULT.VALUE
}
new_options_dict.update(user_modified_options_dict)
new_options = HandleOptions(**new_options_dict)
def _remote(self, request_data, kwargs) -> Coroutine:
request_metadata = RequestMetadata(
get_random_letters(10), # Used for debugging.
self.endpoint_name,
TaskContext.Python,
call_method=self.method_name or "__call__",
shard_key=self.shard_key,
http_method=self.http_method or "GET",
http_headers=self.http_headers or dict(),
)
coro = self.router.assign_request(request_metadata, request_data,
**kwargs)
return coro
return self.__class__(self.router, self.endpoint_name, new_options)
async def remote(self,
request_data: Optional[Union[Dict, Any]] = None,
**kwargs):
"""Issue an asynchrounous request to the endpoint.
Returns a Ray ObjectRef whose results can be waited for or retrieved
using ray.wait or ray.get (or ``await object_ref``), respectively.
Returns:
ray.ObjectRef
Args:
request_data(dict, Any): If it's a dictionary, the data will be
available in ``request.json()`` or ``request.form()``.
Otherwise, it will be available in ``request.body()``.
``**kwargs``: All keyword arguments will be available in
``request.query_params``.
"""
return await self.router._remote(
self.endpoint_name, self.handle_options, request_data, kwargs)
def __repr__(self):
return f"{self.__class__.__name__}(endpoint='{self.endpoint_name}')"
class RayServeSyncHandle(RayServeHandle):
def remote(self, request_data: Optional[Union[Dict, Any]] = None,
**kwargs):
"""Issue an asynchrounous request to the endpoint.
Returns a Ray ObjectRef whose results can be waited for or retrieved
using ray.wait or ray.get, respectively.
using ray.wait or ray.get (or ``await object_ref``), respectively.
Returns:
ray.ObjectRef
@@ -110,47 +114,8 @@ class RayServeHandle:
``**kwargs``: All keyword arguments will be available in
``request.args``.
"""
if not self.sync:
raise RayServeException(
"You are trying to call handle.remote() with async handle. "
"Please use `await handle.remote_async()` instead.")
coro = self._remote(request_data, kwargs)
coro = self.router._remote(self.endpoint_name, self.handle_options,
request_data, kwargs)
future: concurrent.futures.Future = asyncio.run_coroutine_threadsafe(
coro, self.async_loop)
# Block until the result is ready.
coro, self.router.async_loop)
return future.result()
async def remote_async(self,
request_data: Optional[Union[Dict, Any]] = None,
**kwargs) -> ray.ObjectRef:
"""Experimental API for enqueue a request in async context."""
if not asyncio.get_event_loop().is_running():
raise RayServeException(
"remote_async must be called from a running event loop.")
return await self._remote(request_data, kwargs)
def options(self,
method_name: Optional[str] = None,
*,
shard_key: Optional[str] = None,
http_method: Optional[str] = None,
http_headers: Optional[Dict[str, str]] = None):
"""Set options for this handle.
Args:
method_name(str): The method to invoke on the backend.
http_method(str): The HTTP method to use for the request.
shard_key(str): A string to use to deterministically map this
request to a backend if there are multiple for this endpoint.
"""
# Don't override default non-null values.
self.method_name = self.method_name or method_name
self.shard_key = self.shard_key or shard_key
self.http_method = self.http_method or http_method
self.http_headers = self.http_headers or http_headers
return self
def __repr__(self):
return f"RayServeHandle(endpoint='{self.endpoint_name}')"
+14 -65
View File
@@ -1,76 +1,25 @@
import io
import json
import flask
import starlette.requests
def build_flask_request(asgi_scope_dict, request_body):
"""Build and return a flask request from ASGI payload
def build_starlette_request(scope, serialized_body: bytes):
"""Build and return a Starlette Request from ASGI payload.
This function is indented to be used immediately before task invocation
happen.
This function is intended to be used immediately before task invocation
happens.
"""
wsgi_environ = build_wsgi_environ(asgi_scope_dict, request_body)
# We set populate_request=False to prevent self reference, which can lead
# to objects tracked by python garbage collector and memory growth. See
# https://github.com/ray-project/ray/issues/12395.
return flask.Request(wsgi_environ, populate_request=False)
# Simulates receiving HTTP body from TCP socket. In reality, the body has
# already been streamed in chunks and stored in serialized_body.
async def mock_receive():
return {
"body": serialized_body,
"type": "http.request",
"more_body": False
}
def build_wsgi_environ(scope, body):
"""
Builds a scope and request body into a WSGI environ object.
This code snippet is taken from https://github.com/django/asgiref/blob
/36c3e8dc70bf38fe2db87ac20b514f21aaf5ea9d/asgiref/wsgi.py#L52
WSGI specification can be found at
https://www.python.org/dev/peps/pep-0333/
This function helps translate ASGI scope and body into a flask request.
"""
environ = {
"REQUEST_METHOD": scope["method"],
"SCRIPT_NAME": scope.get("root_path", ""),
"PATH_INFO": scope["path"],
"QUERY_STRING": scope["query_string"].decode("ascii"),
"SERVER_PROTOCOL": "HTTP/{}".format(scope["http_version"]),
"wsgi.version": (1, 0),
"wsgi.url_scheme": scope.get("scheme", "http"),
"wsgi.input": body,
"wsgi.errors": io.BytesIO(),
"wsgi.multithread": True,
"wsgi.multiprocess": True,
"wsgi.run_once": False,
}
# Get server name and port - required in WSGI, not in ASGI
environ["SERVER_NAME"] = scope["server"][0]
environ["SERVER_PORT"] = str(scope["server"][1])
environ["REMOTE_ADDR"] = scope["client"][0]
# Transforms headers into environ entries.
for name, value in scope.get("headers", []):
# name, values are both bytes, we need to decode them to string
name = name.decode("latin1")
value = value.decode("latin1")
# Handle name correction to conform to WSGI spec
# https://www.python.org/dev/peps/pep-0333/#environ-variables
if name == "content-length":
corrected_name = "CONTENT_LENGTH"
elif name == "content-type":
corrected_name = "CONTENT_TYPE"
else:
corrected_name = "HTTP_%s" % name.upper().replace("-", "_")
# If the header value repeated,
# we will just concatenate it to the field.
if corrected_name in environ:
value = environ[corrected_name] + "," + value
environ[corrected_name] = value
return environ
return starlette.requests.Request(scope, mock_receive)
class Response:
+26 -6
View File
@@ -19,8 +19,8 @@ from ray.serve.utils import (block_until_http_ready, format_actor_name,
def test_e2e(serve_instance):
client = serve_instance
def function(flask_request):
return {"method": flask_request.method}
def function(starlette_request):
return {"method": starlette_request.method}
client.create_backend("echo:v1", function)
client.create_endpoint(
@@ -97,7 +97,7 @@ def test_backend_user_config(serve_instance):
def __init__(self):
self.count = 10
def __call__(self, flask_request):
def __call__(self, starlette_request):
return self.count, os.getpid()
def reconfigure(self, config):
@@ -146,7 +146,7 @@ def test_call_method(serve_instance):
# Test serve handle path.
handle = client.get_handle("endpoint")
assert ray.get(handle.options("method").remote()) == "hello"
assert ray.get(handle.options(method_name="method").remote()) == "hello"
def test_no_route(serve_instance):
@@ -820,8 +820,8 @@ def test_serve_metrics(serve_instance):
client = serve_instance
@serve.accept_batch
def batcher(flask_requests):
return ["hello"] * len(flask_requests)
def batcher(starlette_requests):
return ["hello"] * len(starlette_requests)
client.create_backend("metrics", batcher)
client.create_endpoint("metrics", backend="metrics", route="/metrics")
@@ -871,6 +871,26 @@ def test_serve_metrics(serve_instance):
verify_metrics()
def test_starlette_request(serve_instance):
client = serve_instance
async def echo_body(starlette_request):
data = await starlette_request.body()
return data
UVICORN_HIGH_WATER_MARK = 65536 # max bytes in one message
# Long string to test serialization of multiple messages.
long_string = "x" * 10 * UVICORN_HIGH_WATER_MARK
client.create_backend("echo:v1", echo_body)
client.create_endpoint(
"endpoint", backend="echo:v1", route="/api", methods=["GET", "POST"])
resp = requests.post("http://127.0.0.1:8000/api", data=long_string).text
assert resp == long_string
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", "-s", __file__]))
@@ -85,7 +85,7 @@ async def test_runner_wraps_error():
async def test_servable_function(serve_instance, router,
mock_controller_with_name):
def echo(request):
return request.args["i"]
return request.query_params["i"]
await add_servable_to_router(echo, router, mock_controller_with_name[0])
@@ -103,7 +103,7 @@ async def test_servable_class(serve_instance, router,
self.increment = inc
def __call__(self, request):
return request.args["i"] + self.increment
return request.query_params["i"] + self.increment
await add_servable_to_router(
MyAdder, router, mock_controller_with_name[0], init_args=(3, ))
@@ -277,7 +277,7 @@ async def test_user_config_update(serve_instance, router,
def __init__(self):
self.reval = ""
def __call__(self, flask_request):
def __call__(self, starlette_request):
return self.retval
def reconfigure(self, config):
+38 -8
View File
@@ -8,7 +8,7 @@ def test_handle_in_endpoint(serve_instance):
client = serve_instance
class Endpoint1:
def __call__(self, flask_request):
def __call__(self, starlette_request):
return "hello"
class Endpoint2:
@@ -40,12 +40,12 @@ def test_handle_http_args(serve_instance):
client = serve_instance
class Endpoint:
def __call__(self, request):
async def __call__(self, request):
return {
"args": dict(request.args),
"args": dict(request.query_params),
"headers": dict(request.headers),
"method": request.method,
"json": request.json
"json": await request.json()
}
client.create_backend("backend", Endpoint)
@@ -58,7 +58,7 @@ def test_handle_http_args(serve_instance):
"arg2": "2"
},
"headers": {
"X-Custom-Header": "value"
"x-custom-header": "value"
},
"method": "POST",
"json": {
@@ -81,10 +81,10 @@ def test_handle_http_args(serve_instance):
for resp in [resp_web, resp_handle]:
for field in ["args", "method", "json"]:
assert resp[field] == ground_truth[field]
resp["headers"]["X-Custom-Header"] == "value"
resp["headers"]["x-custom-header"] == "value"
def test_handle_inject_flask_request(serve_instance):
def test_handle_inject_starlette_request(serve_instance):
client = serve_instance
def echo_request_type(request):
@@ -103,7 +103,37 @@ def test_handle_inject_flask_request(serve_instance):
for route in ["/echo", "/wrapper"]:
resp = requests.get(f"http://127.0.0.1:8000{route}")
request_type = resp.text
assert request_type == "<class 'flask.wrappers.Request'>"
assert request_type == "<class 'starlette.requests.Request'>"
def test_handle_option_chaining(serve_instance):
# https://github.com/ray-project/ray/issues/12802
# https://github.com/ray-project/ray/issues/12798
client = serve_instance
class MultiMethod:
def method_a(self, _):
return "method_a"
def method_b(self, _):
return "method_b"
def __call__(self, _):
return "__call__"
client.create_backend("m", MultiMethod)
client.create_endpoint("m", backend="m")
# get_handle should give you a clean handle
handle1 = client.get_handle("m").options(method_name="method_a")
handle2 = client.get_handle("m")
# options().options() override should work
handle3 = handle1.options(method_name="method_b")
assert ray.get(handle1.remote()) == "method_a"
assert ray.get(handle2.remote()) == "__call__"
assert ray.get(handle3.remote()) == "method_b"
if __name__ == "__main__":
@@ -0,0 +1,29 @@
import ray
from ray.serve.backends import ImportedBackend
from ray.serve.config import BackendConfig
def test_imported_backend(serve_instance):
client = serve_instance
backend_class = ImportedBackend("ray.serve.utils.MockImportedBackend")
config = BackendConfig(user_config="config")
client.create_backend(
"imported", backend_class, "input_arg", config=config)
client.create_endpoint("imported", backend="imported")
# Basic sanity check.
handle = client.get_handle("imported")
assert ray.get(handle.remote()) == {"arg": "input_arg", "config": "config"}
# Check that updating backend config works.
client.update_backend_config(
"imported", BackendConfig(user_config="new_config"))
assert ray.get(handle.remote()) == {
"arg": "input_arg",
"config": "new_config"
}
# Check that other call methods work.
handle = handle.options(method_name="other_method")
assert ray.get(handle.remote("hello")) == "hello"
+1 -1
View File
@@ -12,7 +12,7 @@ ray.init(address="{}")
from ray import serve
client = serve.connect()
def driver(flask_request):
def driver(starlette_request):
return "OK!"
client.create_backend("driver", driver)
+2 -2
View File
@@ -15,7 +15,7 @@ def test_np_in_composed_model(serve_instance):
# in cloudpickle _from_numpy_buffer
def sum_model(request):
return np.sum(request.args["data"])
return np.sum(request.query_params["data"])
class ComposedModel:
def __init__(self):
@@ -42,7 +42,7 @@ def test_backend_worker_memory_growth(serve_instance):
# https://github.com/ray-project/ray/issues/12395
client = serve_instance
def gc_unreachable_objects(flask_request):
def gc_unreachable_objects(starlette_request):
gc.set_debug(gc.DEBUG_SAVEALL)
gc.collect()
return len(gc.garbage)
+17 -1
View File
@@ -6,9 +6,10 @@ from copy import deepcopy
import numpy as np
import pytest
import ray
from ray.serve.utils import (ServeEncoder, chain_future, unpack_future,
try_schedule_resources_on_nodes,
get_conda_env_dir)
get_conda_env_dir, import_class)
def test_bytes_encoder():
@@ -125,6 +126,21 @@ def test_get_conda_env_dir(tmp_path):
os.environ["CONDA_PREFIX"] = ""
def test_import_class():
assert import_class("ray.serve.Client") == ray.serve.api.Client
assert import_class("ray.serve.api.Client") == ray.serve.api.Client
policy_cls = import_class("ray.serve.controller.TrafficPolicy")
assert policy_cls == ray.serve.controller.TrafficPolicy
policy = policy_cls({"endpoint1": 0.5, "endpoint2": 0.5})
with pytest.raises(ValueError):
policy.set_traffic_dict({"endpoint1": 0.5, "endpoint2": 0.6})
policy.set_traffic_dict({"endpoint1": 0.4, "endpoint2": 0.6})
print(repr(policy))
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", "-s", __file__]))
+57 -17
View File
@@ -1,5 +1,6 @@
import asyncio
from functools import singledispatch
import importlib
from itertools import groupby
import json
import logging
@@ -7,7 +8,6 @@ import random
import string
import time
from typing import List, Dict
import io
import os
from ray.serve.exceptions import RayServeException
from collections import UserDict
@@ -15,18 +15,18 @@ from collections import UserDict
import requests
import numpy as np
import pydantic
import flask
import starlette.requests
import ray
from ray.serve.constants import HTTP_PROXY_TIMEOUT
from ray.serve.context import TaskContext
from ray.serve.http_util import build_flask_request
from ray.serve.http_util import build_starlette_request
ACTOR_FAILURE_RETRY_TIMEOUT_S = 60
class ServeMultiDict(UserDict):
"""Compatible data structure to simulate Flask.Request.args API."""
"""Compatible data structure to simulate Starlette Request query_args."""
def getlist(self, key):
"""Return the list of items for a given key."""
@@ -34,11 +34,14 @@ class ServeMultiDict(UserDict):
class ServeRequest:
"""The request object used in Python context.
"""The request object used when passing arguments via ServeHandle.
ServeRequest is built to have similar API as Flask.Request. You only need
to write your model serving code once; it can be queried by both HTTP and
Python.
ServeRequest partially implements the API of Starlette Request. You only
need to write your model serving code once; it can be queried by both HTTP
and Python.
To use the full Starlette Request interface with ServeHandle, you may
instead directly pass in a Starlette Request object to the ServeHandle.
"""
def __init__(self, data, kwargs, headers, method):
@@ -58,28 +61,25 @@ class ServeRequest:
return self._method
@property
def args(self):
def query_params(self):
"""The keyword arguments from ``handle.remote(**kwargs)``."""
return self._kwargs
@property
def json(self):
async def json(self):
"""The request dictionary, from ``handle.remote(dict)``."""
if not isinstance(self._data, dict):
raise RayServeException("Request data is not a dictionary. "
f"It is {type(self._data)}.")
return self._data
@property
def form(self):
async def form(self):
"""The request dictionary, from ``handle.remote(dict)``."""
if not isinstance(self._data, dict):
raise RayServeException("Request data is not a dictionary. "
f"It is {type(self._data)}.")
return self._data
@property
def data(self):
async def body(self):
"""The request data from ``handle.remote(obj)``."""
return self._data
@@ -87,13 +87,13 @@ class ServeRequest:
def parse_request_item(request_item):
if request_item.metadata.request_context == TaskContext.Web:
asgi_scope, body_bytes = request_item.args
return build_flask_request(asgi_scope, io.BytesIO(body_bytes))
return build_starlette_request(asgi_scope, body_bytes)
else:
arg = request_item.args[0] if len(request_item.args) == 1 else None
# If the input data from handle is web request, we don't need to wrap
# it in ServeRequest.
if isinstance(arg, flask.Request):
if isinstance(arg, starlette.requests.Request):
return arg
return ServeRequest(
@@ -342,3 +342,43 @@ def get_node_id_for_actor(actor_handle):
"""Given an actor handle, return the node id it's placed on."""
return ray.actors()[actor_handle._actor_id.hex()]["Address"]["NodeID"]
def import_class(full_path: str):
"""Given a full import path to a class name, return the imported class.
For example, the following are equivalent:
MyClass = import_class("module.submodule.MyClass")
from module.submodule import MyClass
Returns:
Imported class
"""
last_period_idx = full_path.rfind(".")
class_name = full_path[last_period_idx + 1:]
module_name = full_path[:last_period_idx]
module = importlib.import_module(module_name)
return getattr(module, class_name)
class MockImportedBackend:
"""Used for testing backends.ImportedBackend.
This is necessary because we need the class to be installed in the worker
processes. We could instead mock out importlib but doing so is messier and
reduces confidence in the test (it isn't truly end-to-end).
"""
def __init__(self, arg):
self.arg = arg
self.config = None
def reconfigure(self, config):
self.config = config
def __call__(self, *args):
return {"arg": self.arg, "config": self.config}
def other_method(self, request):
return request.data
+4
View File
@@ -9,6 +9,7 @@ import ray
from ray import gcs_utils
from google.protobuf.json_format import MessageToDict
from ray._private import services
from ray._private.client_mode_hook import client_mode_hook
from ray.utils import (decode, binary_to_hex, hex_to_binary)
from ray._raylet import GlobalStateAccessor
@@ -851,6 +852,7 @@ def jobs():
return state.job_table()
@client_mode_hook
def nodes():
"""Get a list of the nodes in the cluster (for debugging only).
@@ -964,6 +966,7 @@ def object_transfer_timeline(filename=None):
return state.chrome_tracing_object_transfer_dump(filename=filename)
@client_mode_hook
def cluster_resources():
"""Get the current total cluster resources.
@@ -977,6 +980,7 @@ def cluster_resources():
return state.cluster_resources()
@client_mode_hook
def available_resources():
"""Get the current available cluster resources.
+6 -2
View File
@@ -414,7 +414,7 @@ def init_error_pubsub():
return p
def get_error_message(pub_sub, num, error_type=None, timeout=5):
def get_error_message(pub_sub, num, error_type=None, timeout=20):
"""Get errors through pub/sub."""
start_time = time.time()
msgs = []
@@ -442,4 +442,8 @@ def format_web_url(url):
def new_scheduler_enabled():
return os.environ.get("RAY_ENABLE_NEW_SCHEDULER") == "1"
return os.environ.get("RAY_ENABLE_NEW_SCHEDULER", "1") == "1"
def client_test_enabled() -> bool:
return os.environ.get("RAY_CLIENT_MODE") == "1"
+21 -12
View File
@@ -10,6 +10,7 @@ SRCS = [] + select({
py_test_module_list(
files = [
# "test_dynres.py", # dyn res not implemented
"test_async.py",
"test_actor.py",
"test_actor_advanced.py",
@@ -40,16 +41,6 @@ py_test_module_list(
deps = ["//:ray_lib"],
)
py_test_module_list(
files = [
"test_dynres.py", # dyn res not implemented
],
size = "medium",
extra_srcs = SRCS,
tags = ["exclusive", "medium_size_python_tests_a_to_j", "new_scheduler_broken"],
deps = ["//:ray_lib"],
)
py_test_module_list(
files = [
"test_memory_limits.py",
@@ -96,6 +87,7 @@ py_test_module_list(
"test_debug_tools.py",
"test_experimental_client.py",
"test_experimental_client_metadata.py",
"test_experimental_client_references.py",
"test_experimental_client_terminate.py",
"test_job.py",
"test_memstat.py",
@@ -129,11 +121,10 @@ py_test_module_list(
py_test_module_list(
files = [
"test_placement_group.py", # placement groups not implemented
"test_placement_group.py",
],
size = "large",
extra_srcs = SRCS,
tags = ["exclusive", "new_scheduler_broken"],
deps = ["//:ray_lib"],
)
@@ -162,3 +153,21 @@ py_test(
tags = ["exclusive"],
deps = ["//:ray_lib"],
)
py_test_module_list(
files = [
"test_actor.py",
"test_advanced.py",
"test_basic.py",
"test_basic_2.py",
],
size = "medium",
extra_srcs = SRCS,
name_suffix = "_client_mode",
# TODO(barakmich): py_test will support env in Bazel 4.0.0...
# Until then, we can use tags.
#env = {"RAY_CLIENT_MODE": "1"},
tags = ["exclusive", "client_tests"],
deps = ["//:ray_lib"],
)
+20
View File
@@ -0,0 +1,20 @@
import asyncio
def create_remote_signal_actor(ray):
# TODO(barakmich): num_cpus=0
@ray.remote
class SignalActor:
def __init__(self):
self.ready_event = asyncio.Event()
def send(self, clear=False):
self.ready_event.set()
if clear:
self.ready_event.clear()
async def wait(self, should_wait=True):
if should_wait:
await self.ready_event.wait()
return SignalActor
+2 -1
View File
@@ -23,7 +23,7 @@ def get_default_fixure_system_config():
"object_timeout_milliseconds": 200,
"num_heartbeats_timeout": 10,
"object_store_full_max_retries": 3,
"object_store_full_initial_delay_ms": 100,
"object_store_full_delay_ms": 100,
}
return system_config
@@ -44,6 +44,7 @@ def _ray_start(**kwargs):
init_kwargs.update(kwargs)
# Start the Ray processes.
address_info = ray.init(**init_kwargs)
yield address_info
# The code after the yield will run as teardown code.
ray.shutdown()
Binary file not shown.
+16 -10
View File
@@ -12,9 +12,10 @@ import tempfile
import datetime
import setproctitle
import ray
import ray.test_utils
import ray.cluster_utils
from ray.test_utils import client_test_enabled
from ray.test_utils import wait_for_condition
from ray.test_utils import wait_for_pid_to_exit
from ray.tests.client_test_utils import create_remote_signal_actor
def test_caching_actors(shutdown_only):
@@ -235,6 +236,7 @@ def test_actor_import_counter(ray_start_10_cpus):
assert ray.get(g.remote()) == num_remote_functions - 1
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
def test_actor_method_metadata_cache(ray_start_regular):
class Actor(object):
pass
@@ -254,6 +256,7 @@ def test_actor_method_metadata_cache(ray_start_regular):
assert [id(x) for x in list(cache.items())[0]] == cached_data_id
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
def test_actor_class_name(ray_start_regular):
@ray.remote
class Foo:
@@ -615,6 +618,8 @@ def test_random_id_generation(ray_start_regular_shared):
assert f1._actor_id != f2._actor_id
@pytest.mark.skipif(
client_test_enabled(), reason="differing inheritence structure")
def test_actor_inheritance(ray_start_regular_shared):
class NonActorBase:
def __init__(self):
@@ -627,8 +632,7 @@ def test_actor_inheritance(ray_start_regular_shared):
pass
# Test that you can't instantiate an actor class directly.
with pytest.raises(
Exception, match="Actors cannot be instantiated directly."):
with pytest.raises(Exception, match="cannot be instantiated directly"):
ActorBase()
# Test that you can't inherit from an actor class.
@@ -642,6 +646,7 @@ def test_actor_inheritance(ray_start_regular_shared):
pass
@pytest.mark.skipif(client_test_enabled(), reason="ray.method unimplemented")
def test_multiple_return_values(ray_start_regular_shared):
@ray.remote
class Foo:
@@ -731,13 +736,13 @@ def test_actor_deletion(ray_start_regular_shared):
a = Actor.remote()
pid = ray.get(a.getpid.remote())
a = None
ray.test_utils.wait_for_pid_to_exit(pid)
wait_for_pid_to_exit(pid)
actors = [Actor.remote() for _ in range(10)]
pids = ray.get([a.getpid.remote() for a in actors])
a = None
actors = None
[ray.test_utils.wait_for_pid_to_exit(pid) for pid in pids]
[wait_for_pid_to_exit(pid) for pid in pids]
def test_actor_method_deletion(ray_start_regular_shared):
@@ -766,7 +771,8 @@ def test_distributed_actor_handle_deletion(ray_start_regular_shared):
ray.get(signal.wait.remote())
return ray.get(actor.method.remote())
signal = ray.test_utils.SignalActor.remote()
SignalActor = create_remote_signal_actor(ray)
signal = SignalActor.remote()
a = Actor.remote()
pid = ray.get(a.getpid.remote())
# Pass the handle to another task that cannot run yet.
@@ -777,7 +783,7 @@ def test_distributed_actor_handle_deletion(ray_start_regular_shared):
# Once the task finishes, the actor process should get killed.
ray.get(signal.send.remote())
assert ray.get(x_id) == 1
ray.test_utils.wait_for_pid_to_exit(pid)
wait_for_pid_to_exit(pid)
def test_multiple_actors(ray_start_regular_shared):
@@ -918,7 +924,7 @@ def test_atexit_handler(ray_start_regular_shared, exit_condition):
if exit_condition == "ray.kill":
assert not check_file_written()
else:
ray.test_utils.wait_for_condition(check_file_written)
wait_for_condition(check_file_written)
if __name__ == "__main__":
+19 -7
View File
@@ -10,16 +10,22 @@ import time
import numpy as np
import pytest
import ray
import ray.cluster_utils
import ray.test_utils
from ray.test_utils import client_test_enabled
from ray.test_utils import RayTestTimeoutException
if client_test_enabled():
from ray.experimental.client import ray
else:
import ray
logger = logging.getLogger(__name__)
# issue https://github.com/ray-project/ray/issues/7105
@pytest.mark.skipif(client_test_enabled(), reason="message size")
def test_internal_free(shutdown_only):
ray.init(num_cpus=1)
@@ -60,14 +66,14 @@ def test_multiple_waits_and_gets(shutdown_only):
return 1
@ray.remote
def g(l):
# The argument l should be a list containing one object ref.
ray.wait([l[0]])
def g(input_list):
# The argument input_list should be a list containing one object ref.
ray.wait([input_list[0]])
@ray.remote
def h(l):
# The argument l should be a list containing one object ref.
ray.get(l[0])
def h(input_list):
# The argument input_list should be a list containing one object ref.
ray.get(input_list[0])
# Make sure that multiple wait requests involving the same object ref
# all return.
@@ -80,6 +86,7 @@ def test_multiple_waits_and_gets(shutdown_only):
ray.get([h.remote([x]), h.remote([x])])
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
def test_caching_functions_to_run(shutdown_only):
# Test that we export functions to run on all workers before the driver
# is connected.
@@ -125,6 +132,7 @@ def test_caching_functions_to_run(shutdown_only):
ray.worker.global_worker.run_function_on_all_workers(f)
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
def test_running_function_on_all_workers(ray_start_regular):
def f(worker_info):
sys.path.append("fake_directory")
@@ -152,6 +160,7 @@ def test_running_function_on_all_workers(ray_start_regular):
assert "fake_directory" not in ray.get(get_path2.remote())
@pytest.mark.skipif(client_test_enabled(), reason="ray.timeline")
def test_profiling_api(ray_start_2_cpus):
@ray.remote
def f():
@@ -345,6 +354,8 @@ def test_illegal_api_calls(ray_start_regular):
ray.get(3)
@pytest.mark.skipif(
client_test_enabled(), reason="grpc interaction with releasing resources")
def test_multithreading(ray_start_2_cpus):
# This test requires at least 2 CPUs to finish since the worker does not
# release resources when joining the threads.
@@ -482,6 +493,7 @@ def test_multithreading(ray_start_2_cpus):
ray.get(actor.join.remote()) == "ok"
@pytest.mark.skipif(client_test_enabled(), reason="message size")
def test_wait_makes_object_local(ray_start_cluster):
cluster = ray_start_cluster
cluster.add_node(num_cpus=0)
+3 -3
View File
@@ -284,14 +284,14 @@ def test_workers(shutdown_only):
def test_object_ref_properties():
id_bytes = b"00112233445566778899"
id_bytes = b"0011223344556677889900001111"
object_ref = ray.ObjectRef(id_bytes)
assert object_ref.binary() == id_bytes
object_ref = ray.ObjectRef.nil()
assert object_ref.is_nil()
with pytest.raises(ValueError, match=r".*needs to have length 20.*"):
with pytest.raises(ValueError, match=r".*needs to have length.*"):
ray.ObjectRef(id_bytes + b"1234")
with pytest.raises(ValueError, match=r".*needs to have length 20.*"):
with pytest.raises(ValueError, match=r".*needs to have length.*"):
ray.ObjectRef(b"0123456789")
object_ref = ray.ObjectRef.from_random()
assert not object_ref.is_nil()
+109 -48
View File
@@ -12,7 +12,6 @@ import sys
from jsonschema.exceptions import ValidationError
import ray
import ray._private.services as services
from ray.autoscaler._private.util import prepare_config, validate_config
from ray.autoscaler._private import commands
from ray.autoscaler.sdk import get_docker_host_mount_location
@@ -55,8 +54,11 @@ class MockProcessRunner:
self.calls = []
self.fail_cmds = fail_cmds or []
self.call_response = {}
self.ready_to_run = threading.Event()
self.ready_to_run.set()
def check_call(self, cmd, *args, **kwargs):
self.ready_to_run.wait()
for token in self.fail_cmds:
if token in str(cmd):
raise CalledProcessError(1, token,
@@ -166,22 +168,28 @@ class MockProvider(NodeProvider):
]
def is_running(self, node_id):
return self.mock_nodes[node_id].state == "running"
with self.lock:
return self.mock_nodes[node_id].state == "running"
def is_terminated(self, node_id):
return self.mock_nodes[node_id].state in ["stopped", "terminated"]
with self.lock:
return self.mock_nodes[node_id].state in ["stopped", "terminated"]
def node_tags(self, node_id):
return self.mock_nodes[node_id].tags
with self.lock:
return self.mock_nodes[node_id].tags
def internal_ip(self, node_id):
return self.mock_nodes[node_id].internal_ip
with self.lock:
return self.mock_nodes[node_id].internal_ip
def external_ip(self, node_id):
return self.mock_nodes[node_id].external_ip
with self.lock:
return self.mock_nodes[node_id].external_ip
def create_node(self, node_config, tags, count):
self.ready_to_create.wait()
def create_node(self, node_config, tags, count, _skip_wait=False):
if not _skip_wait:
self.ready_to_create.wait()
if self.fail_creates:
return
with self.lock:
@@ -201,7 +209,8 @@ class MockProvider(NodeProvider):
self.next_id += 1
def set_node_tags(self, node_id, tags):
self.mock_nodes[node_id].tags.update(tags)
with self.lock:
self.mock_nodes[node_id].tags.update(tags)
def terminate_node(self, node_id):
with self.lock:
@@ -535,7 +544,11 @@ class AutoscalingTest(unittest.TestCase):
config["max_workers"] = 5
config_path = self.write_config(config)
self.provider = MockProvider()
self.provider.create_node({}, {TAG_RAY_NODE_KIND: "worker"}, 10)
self.provider.create_node({}, {
TAG_RAY_NODE_KIND: "worker",
TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE,
TAG_RAY_USER_NODE_TYPE: NODE_TYPE_LEGACY_WORKER
}, 10)
runner = MockProcessRunner()
runner.respond_to_call("json .Config.Env", ["[]" for i in range(10)])
autoscaler = StandardAutoscaler(
@@ -559,8 +572,14 @@ class AutoscalingTest(unittest.TestCase):
config_path = self.write_config(SMALL_CLUSTER)
self.provider = MockProvider()
runner = MockProcessRunner()
runner.respond_to_call("json .Config.Env", ["[]" for i in range(11)])
runner.respond_to_call("json .Config.Env", ["[]" for i in range(12)])
lm = LoadMetrics()
self.provider.create_node({}, {
TAG_RAY_NODE_KIND: NODE_KIND_HEAD,
TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE,
TAG_RAY_USER_NODE_TYPE: NODE_TYPE_LEGACY_HEAD
}, 1)
lm.update("172.0.0.0", {"CPU": 1}, {"CPU": 0}, {})
autoscaler = StandardAutoscaler(
config_path,
lm,
@@ -569,16 +588,16 @@ class AutoscalingTest(unittest.TestCase):
max_failures=0,
process_runner=runner,
update_interval_s=0)
self.waitForNodes(0)
self.waitForNodes(0, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER})
autoscaler.update()
self.waitForNodes(2)
self.waitForNodes(2, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER})
# Update the config to reduce the cluster size
new_config = SMALL_CLUSTER.copy()
new_config["max_workers"] = 1
self.write_config(new_config)
autoscaler.update()
self.waitForNodes(1)
self.waitForNodes(1, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER})
# Update the config to reduce the cluster size
new_config["min_workers"] = 10
@@ -587,12 +606,13 @@ class AutoscalingTest(unittest.TestCase):
autoscaler.update()
# Because one worker already started, the scheduler waits for its
# resources to be updated before it launches the remaining min_workers.
self.waitForNodes(1)
self.waitForNodes(1, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER})
worker_ip = self.provider.non_terminated_node_ips(
tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}, )[0]
lm.update(worker_ip, {"CPU": 1}, {"CPU": 1}, {})
autoscaler.update()
self.waitForNodes(10)
self.waitForNodes(
10, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER})
def testInitialWorkers(self):
"""initial_workers is deprecated, this tests that it is ignored."""
@@ -653,11 +673,15 @@ class AutoscalingTest(unittest.TestCase):
autoscaler.update()
# 1 head node.
self.waitForNodes(1)
autoscaler.request_resources([{"CPU": 1}])
autoscaler.load_metrics.set_resource_requests([{"CPU": 1}])
autoscaler.update()
# still 1 head node because request_resources fits in the headnode.
self.waitForNodes(1)
autoscaler.request_resources([{"CPU": 1}] + [{"CPU": 2}] * 9)
autoscaler.load_metrics.set_resource_requests([{
"CPU": 1
}] + [{
"CPU": 2
}] * 9)
autoscaler.update()
self.waitForNodes(2) # Adds a single worker to get its resources.
autoscaler.update()
@@ -760,7 +784,11 @@ class AutoscalingTest(unittest.TestCase):
config_path = self.write_config(config)
self.provider = MockProvider()
self.provider.create_node({}, {TAG_RAY_NODE_KIND: "head"}, 1)
self.provider.create_node({}, {
TAG_RAY_NODE_KIND: "head",
TAG_RAY_USER_NODE_TYPE: NODE_TYPE_LEGACY_HEAD,
TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE
}, 1)
head_ip = self.provider.non_terminated_node_ips(
tag_filters={TAG_RAY_NODE_KIND: "head"}, )[0]
@@ -809,7 +837,11 @@ class AutoscalingTest(unittest.TestCase):
config_path = self.write_config(config)
self.provider = MockProvider()
self.provider.create_node({}, {TAG_RAY_NODE_KIND: "head"}, 1)
self.provider.create_node({}, {
TAG_RAY_NODE_KIND: "head",
TAG_RAY_USER_NODE_TYPE: NODE_TYPE_LEGACY_HEAD,
TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE
}, 1)
head_ip = self.provider.non_terminated_node_ips(
tag_filters={TAG_RAY_NODE_KIND: "head"}, )[0]
@@ -964,8 +996,14 @@ class AutoscalingTest(unittest.TestCase):
config_path = self.write_config(SMALL_CLUSTER)
self.provider = MockProvider()
runner = MockProcessRunner()
runner.respond_to_call("json .Config.Env", ["[]" for i in range(10)])
runner.respond_to_call("json .Config.Env", ["[]" for i in range(11)])
self.provider.create_node({}, {
TAG_RAY_NODE_KIND: NODE_KIND_HEAD,
TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE,
TAG_RAY_USER_NODE_TYPE: NODE_TYPE_LEGACY_HEAD
}, 1)
lm = LoadMetrics()
lm.update("172.0.0.0", {"CPU": 1}, {"CPU": 0}, {})
autoscaler = StandardAutoscaler(
config_path,
lm,
@@ -975,7 +1013,7 @@ class AutoscalingTest(unittest.TestCase):
max_failures=0,
update_interval_s=0)
autoscaler.update()
self.waitForNodes(2)
self.waitForNodes(2, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER})
# Write a corrupted config
self.write_config("asdf", call_prepare_config=False)
@@ -983,7 +1021,10 @@ class AutoscalingTest(unittest.TestCase):
autoscaler.update()
time.sleep(0.1)
assert autoscaler.pending_launches.value == 0
assert len(self.provider.non_terminated_nodes({})) == 2
assert len(
self.provider.non_terminated_nodes({
TAG_RAY_NODE_KIND: NODE_KIND_WORKER
})) == 2
# New a good config again
new_config = SMALL_CLUSTER.copy()
@@ -996,7 +1037,8 @@ class AutoscalingTest(unittest.TestCase):
# resources to be updated before it launches the remaining min_workers.
lm.update(worker_ip, {"CPU": 1}, {"CPU": 1}, {})
autoscaler.update()
self.waitForNodes(10)
self.waitForNodes(
10, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER})
def testMaxFailures(self):
config_path = self.write_config(SMALL_CLUSTER)
@@ -1076,9 +1118,17 @@ class AutoscalingTest(unittest.TestCase):
2, tag_filters={TAG_RAY_NODE_STATUS: STATUS_UPDATE_FAILED})
except AssertionError:
# The failed nodes might have been already terminated by autoscaler
assert len(self.provider.non_terminated_nodes({})) == 0
assert len(self.provider.non_terminated_nodes({})) < 2
def testConfiguresOutdatedNodes(self):
from ray.autoscaler._private.cli_logger import cli_logger
def do_nothing(*args, **kwargs):
pass
cli_logger._print = type(cli_logger._print)(do_nothing,
type(cli_logger))
config_path = self.write_config(SMALL_CLUSTER)
self.provider = MockProvider()
runner = MockProcessRunner()
@@ -1113,53 +1163,61 @@ class AutoscalingTest(unittest.TestCase):
self.provider = MockProvider()
lm = LoadMetrics()
runner = MockProcessRunner()
runner.respond_to_call("json .Config.Env", ["[]" for i in range(5)])
runner.respond_to_call("json .Config.Env", ["[]" for i in range(6)])
self.provider.create_node({}, {
TAG_RAY_NODE_KIND: NODE_KIND_HEAD,
TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE,
TAG_RAY_USER_NODE_TYPE: NODE_TYPE_LEGACY_HEAD
}, 1)
lm.update("172.0.0.0", {"CPU": 1}, {"CPU": 0}, {})
autoscaler = StandardAutoscaler(
config_path,
lm,
max_failures=0,
process_runner=runner,
update_interval_s=0)
assert len(self.provider.non_terminated_nodes({})) == 0
assert len(
self.provider.non_terminated_nodes({
TAG_RAY_NODE_KIND: NODE_KIND_WORKER
})) == 0
autoscaler.update()
self.waitForNodes(1)
self.waitForNodes(1, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER})
autoscaler.update()
assert autoscaler.pending_launches.value == 0
assert len(self.provider.non_terminated_nodes({})) == 1
assert len(
self.provider.non_terminated_nodes({
TAG_RAY_NODE_KIND: NODE_KIND_WORKER
})) == 1
# Scales up as nodes are reported as used
local_ip = services.get_node_ip_address()
lm.update(
local_ip, {"CPU": 2}, {"CPU": 0}, {},
waiting_bundles=2 * [{
"CPU": 2
}]) # head
autoscaler.update()
lm.update(
"172.0.0.0", {"CPU": 2}, {"CPU": 0}, {},
"172.0.0.1", {"CPU": 2}, {"CPU": 0}, {},
waiting_bundles=2 * [{
"CPU": 2
}])
autoscaler.update()
self.waitForNodes(3)
self.waitForNodes(3, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER})
lm.update(
"172.0.0.1", {"CPU": 2}, {"CPU": 0}, {},
"172.0.0.2", {"CPU": 2}, {"CPU": 0}, {},
waiting_bundles=3 * [{
"CPU": 2
}])
autoscaler.update()
self.waitForNodes(5)
self.waitForNodes(5, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER})
# Holds steady when load is removed
lm.update("172.0.0.0", {"CPU": 2}, {"CPU": 2}, {})
lm.update("172.0.0.1", {"CPU": 2}, {"CPU": 2}, {})
lm.update("172.0.0.2", {"CPU": 2}, {"CPU": 2}, {})
autoscaler.update()
assert autoscaler.pending_launches.value == 0
assert len(self.provider.non_terminated_nodes({})) == 5
assert len(
self.provider.non_terminated_nodes({
TAG_RAY_NODE_KIND: NODE_KIND_WORKER
})) == 5
# Scales down as nodes become unused
lm.last_used_time_by_ip["172.0.0.0"] = 0
lm.last_used_time_by_ip["172.0.0.1"] = 0
lm.last_used_time_by_ip["172.0.0.2"] = 0
autoscaler.update()
assert autoscaler.pending_launches.value == 0
@@ -1167,18 +1225,21 @@ class AutoscalingTest(unittest.TestCase):
# are not connected and hence we rely more on connected nodes for
# min_workers. When the "pending" nodes show up as connected,
# then we can terminate the ones connected before.
assert len(self.provider.non_terminated_nodes({})) == 4
lm.last_used_time_by_ip["172.0.0.2"] = 0
assert len(
self.provider.non_terminated_nodes({
TAG_RAY_NODE_KIND: NODE_KIND_WORKER
})) == 4
lm.last_used_time_by_ip["172.0.0.3"] = 0
lm.last_used_time_by_ip["172.0.0.4"] = 0
autoscaler.update()
assert autoscaler.pending_launches.value == 0
# 2 nodes and not 1 because 1 is needed for min_worker and the other 1
# is still not connected.
self.waitForNodes(2)
self.waitForNodes(2, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER})
# when we connect it, we will see 1 node.
lm.last_used_time_by_ip["172.0.0.4"] = 0
lm.last_used_time_by_ip["172.0.0.5"] = 0
autoscaler.update()
self.waitForNodes(1)
self.waitForNodes(1, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER})
def testTargetUtilizationFraction(self):
config = SMALL_CLUSTER.copy()
+56 -7
View File
@@ -8,14 +8,20 @@ import time
import numpy as np
import pytest
import ray
import ray.cluster_utils
import ray.test_utils
from ray.test_utils import (
client_test_enabled,
dicts_equal,
wait_for_pid_to_exit,
)
import ray
logger = logging.getLogger(__name__)
# https://github.com/ray-project/ray/issues/6662
@pytest.mark.skipif(client_test_enabled(), reason="interferes with grpc")
def test_ignore_http_proxy(shutdown_only):
ray.init(num_cpus=1)
os.environ["http_proxy"] = "http://example.com"
@@ -29,6 +35,7 @@ def test_ignore_http_proxy(shutdown_only):
# https://github.com/ray-project/ray/issues/7263
@pytest.mark.skipif(client_test_enabled(), reason="message size")
def test_grpc_message_size(shutdown_only):
ray.init(num_cpus=1)
@@ -178,7 +185,7 @@ def test_many_fractional_resources(shutdown_only):
}
if block:
ray.get(g.remote())
return ray.test_utils.dicts_equal(true_resources, accepted_resources)
return dicts_equal(true_resources, accepted_resources)
# Check that the resource are assigned correctly.
result_ids = []
@@ -257,7 +264,7 @@ def test_background_tasks_with_max_calls(shutdown_only):
pid, g_id = nested.pop(0)
ray.get(g_id)
del g_id
ray.test_utils.wait_for_pid_to_exit(pid)
wait_for_pid_to_exit(pid)
def test_fair_queueing(shutdown_only):
@@ -327,6 +334,7 @@ def test_wait_timing(shutdown_only):
assert len(not_ready) == 1
@pytest.mark.skipif(client_test_enabled(), reason="internal _raylet")
def test_function_descriptor():
python_descriptor = ray._raylet.PythonFunctionDescriptor(
"module_name", "function_name", "class_name", "function_hash")
@@ -345,6 +353,8 @@ def test_function_descriptor():
def test_ray_options(shutdown_only):
ray.init(num_cpus=10, num_gpus=10, resources={"custom1": 2})
@ray.remote(
num_cpus=2, num_gpus=3, memory=150 * 2**20, resources={"custom1": 1})
def foo():
@@ -353,8 +363,6 @@ def test_ray_options(shutdown_only):
time.sleep(0.1)
return ray.available_resources()
ray.init(num_cpus=10, num_gpus=10, resources={"custom1": 2})
without_options = ray.get(foo.remote())
with_options = ray.get(
foo.options(
@@ -371,6 +379,43 @@ def test_ray_options(shutdown_only):
assert without_options != with_options
@pytest.mark.skipif(client_test_enabled(), reason="message size")
@pytest.mark.parametrize(
"ray_start_cluster_head", [{
"num_cpus": 0,
"object_store_memory": 75 * 1024 * 1024,
}],
indirect=True)
def test_fetch_local(ray_start_cluster_head):
cluster = ray_start_cluster_head
cluster.add_node(num_cpus=2, object_store_memory=75 * 1024 * 1024)
signal_actor = ray.test_utils.SignalActor.remote()
@ray.remote
def put():
ray.wait([signal_actor.wait.remote()])
return np.random.rand(5 * 1024 * 1024) # 40 MB data
local_ref = ray.put(np.random.rand(5 * 1024 * 1024))
remote_ref = put.remote()
# Data is not ready in any node
(ready_ref, remaining_ref) = ray.wait(
[remote_ref], timeout=2, fetch_local=False)
assert (0, 1) == (len(ready_ref), len(remaining_ref))
ray.wait([signal_actor.send.remote()])
# Data is ready in some node, but not local node.
(ready_ref, remaining_ref) = ray.wait([remote_ref], fetch_local=False)
assert (1, 0) == (len(ready_ref), len(remaining_ref))
(ready_ref, remaining_ref) = ray.wait(
[remote_ref], timeout=2, fetch_local=True)
assert (0, 1) == (len(ready_ref), len(remaining_ref))
del local_ref
(ready_ref, remaining_ref) = ray.wait([remote_ref], fetch_local=True)
assert (1, 0) == (len(ready_ref), len(remaining_ref))
def test_nested_functions(ray_start_shared_local_modes):
# Make sure that remote functions can use other values that are defined
# after the remote function but before the first function invocation.
@@ -402,8 +447,11 @@ def test_nested_functions(ray_start_shared_local_modes):
assert ray.get(factorial.remote(4)) == 24
assert ray.get(factorial.remote(5)) == 120
# Test remote functions that recursively call each other.
@pytest.mark.skipif(
client_test_enabled(), reason="mutual recursion is a known issue")
def test_mutually_recursive_functions(ray_start_shared_local_modes):
# Test remote functions that recursively call each other.
@ray.remote
def factorial_even(n):
assert n % 2 == 0
@@ -674,6 +722,7 @@ def test_args_stars_after(ray_start_shared_local_modes):
ray.get(remote_test_function.remote(local_method, actor_method))
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
def test_object_id_backward_compatibility(ray_start_shared_local_modes):
# We've renamed Python's `ObjectID` to `ObjectRef`, and added a type
# alias for backward compatibility.
+20 -6
View File
@@ -9,10 +9,16 @@ import pytest
from unittest.mock import MagicMock, patch
import ray
import ray.cluster_utils
import ray.test_utils
from ray.test_utils import client_test_enabled
from ray.tests.client_test_utils import create_remote_signal_actor
from ray.exceptions import GetTimeoutError
from ray.exceptions import RayTaskError
if client_test_enabled():
from ray.experimental.client import ray
else:
import ray
logger = logging.getLogger(__name__)
@@ -25,6 +31,8 @@ logger = logging.getLogger(__name__)
}],
indirect=True)
def test_variable_number_of_args(shutdown_only):
ray.init(num_cpus=1)
@ray.remote
def varargs_fct1(*a):
return " ".join(map(str, a))
@@ -33,8 +41,6 @@ def test_variable_number_of_args(shutdown_only):
def varargs_fct2(a, *b):
return " ".join(map(str, b))
ray.init(num_cpus=1)
x = varargs_fct1.remote(0, 1, 2)
assert ray.get(x) == "0 1 2"
x = varargs_fct2.remote(0, 1, 2)
@@ -160,7 +166,7 @@ def test_redefining_remote_functions(shutdown_only):
def g():
return nonexistent()
with pytest.raises(ray.exceptions.RayTaskError, match="nonexistent"):
with pytest.raises(RayTaskError, match="nonexistent"):
ray.get(g.remote())
def nonexistent():
@@ -187,6 +193,7 @@ def test_redefining_remote_functions(shutdown_only):
assert ray.get(ray.get(h.remote(i))) == i
@pytest.mark.skipif(client_test_enabled(), reason="message size")
def test_call_matrix(shutdown_only):
ray.init(object_store_memory=1000 * 1024 * 1024)
@@ -312,6 +319,7 @@ def test_actor_pass_by_ref_order_optimization(shutdown_only):
assert delta < 10, "did not skip slow value"
@pytest.mark.skipif(client_test_enabled(), reason="message size")
@pytest.mark.parametrize(
"ray_start_cluster", [{
"num_cpus": 1,
@@ -332,6 +340,7 @@ def test_call_chain(ray_start_cluster):
assert ray.get(x) == 100
@pytest.mark.skipif(client_test_enabled(), reason="message size")
def test_system_config_when_connecting(ray_start_cluster):
config = {"object_pinning_enabled": 0, "object_timeout_milliseconds": 200}
cluster = ray.cluster_utils.Cluster()
@@ -368,7 +377,8 @@ def test_get_multiple(ray_start_regular_shared):
def test_get_with_timeout(ray_start_regular_shared):
signal = ray.test_utils.SignalActor.remote()
SignalActor = create_remote_signal_actor(ray)
signal = SignalActor.remote()
# Check that get() returns early if object is ready.
start = time.time()
@@ -438,6 +448,7 @@ def test_inline_arg_memory_corruption(ray_start_regular_shared):
ray.get(a.add.remote(f.remote()))
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
def test_skip_plasma(ray_start_regular_shared):
@ray.remote
class Actor:
@@ -454,6 +465,8 @@ def test_skip_plasma(ray_start_regular_shared):
assert ray.get(obj_ref) == 2
@pytest.mark.skipif(
client_test_enabled(), reason="internal api and message size")
def test_actor_large_objects(ray_start_regular_shared):
@ray.remote
class Actor:
@@ -626,6 +639,7 @@ def test_duplicate_args(ray_start_regular_shared):
arg1, arg2, arg1, kwarg1=arg1, kwarg2=arg2, kwarg1_duplicate=arg1))
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
def test_get_correct_node_ip():
with patch("ray.worker") as worker_mock:
node_mock = MagicMock()
+102 -22
View File
@@ -1,21 +1,10 @@
import pytest
from contextlib import contextmanager
import time
import sys
import logging
import ray.experimental.client.server.server as ray_client_server
from ray.experimental.client import ray, reset_api
from ray.experimental.client.common import ClientObjectRef
@contextmanager
def ray_start_client_server():
server = ray_client_server.serve("localhost:50051", test_mode=True)
ray.connect("localhost:50051")
try:
yield ray
finally:
ray.disconnect()
server.stop(0)
reset_api()
from ray.experimental.client.ray_client_helpers import ray_start_client_server
def test_real_ray_fallback(ray_start_regular_shared):
@@ -81,11 +70,11 @@ def test_wait(ray_start_regular_shared):
with pytest.raises(Exception):
# Reference not in the object store.
ray.wait([ClientObjectRef("blabla")])
with pytest.raises(AssertionError):
with pytest.raises(TypeError):
ray.wait("blabla")
with pytest.raises(AssertionError):
with pytest.raises(TypeError):
ray.wait(ClientObjectRef("blabla"))
with pytest.raises(AssertionError):
with pytest.raises(TypeError):
ray.wait(["blabla"])
@@ -142,7 +131,7 @@ def test_function_calling_function(ray_start_regular_shared):
@ray.remote
def f():
print(f, f._name, g._name, g)
print(f, g)
return ray.get(g.remote())
print(f, type(f))
@@ -171,8 +160,7 @@ def test_basic_actor(ray_start_regular_shared):
def test_pass_handles(ray_start_regular_shared):
"""
Test that passing client handles to actors and functions to remote actors
"""Test that passing client handles to actors and functions to remote actors
in functions (on the server or raylet side) works transparently to the
caller.
"""
@@ -234,6 +222,98 @@ def test_pass_handles(ray_start_regular_shared):
4)) == local_fact(4)
def test_basic_log_stream(ray_start_regular_shared):
with ray_start_client_server() as ray:
log_msgs = []
def test_log(level, msg):
log_msgs.append(msg)
ray.worker.log_client.log = test_log
ray.worker.log_client.set_logstream_level(logging.DEBUG)
# Allow some time to propogate
time.sleep(1)
x = ray.put("Foo")
assert ray.get(x) == "Foo"
time.sleep(1)
logs_with_id = [msg for msg in log_msgs if msg.find(x.id.hex()) >= 0]
assert len(logs_with_id) >= 2
assert any((msg.find("get") >= 0 for msg in logs_with_id))
assert any((msg.find("put") >= 0 for msg in logs_with_id))
def test_stdout_log_stream(ray_start_regular_shared):
with ray_start_client_server() as ray:
log_msgs = []
def test_log(level, msg):
log_msgs.append(msg)
ray.worker.log_client.stdstream = test_log
@ray.remote
def print_on_stderr_and_stdout(s):
print(s)
print(s, file=sys.stderr)
time.sleep(1)
print_on_stderr_and_stdout.remote("Hello world")
time.sleep(1)
assert len(log_msgs) == 2
assert all((msg.find("Hello world") for msg in log_msgs))
def test_create_remote_before_start(ray_start_regular_shared):
"""Creates remote objects (as though in a library) before
starting the client.
"""
from ray.experimental.client import ray
@ray.remote
class Returner:
def doit(self):
return "foo"
@ray.remote
def f(x):
return x + 20
# Prints in verbose tests
print("Created remote functions")
with ray_start_client_server() as ray:
assert ray.get(f.remote(3)) == 23
a = Returner.remote()
assert ray.get(a.doit.remote()) == "foo"
def test_basic_named_actor(ray_start_regular_shared):
"""Test that ray.get_actor() can create and return a detached actor.
"""
with ray_start_client_server() as ray:
@ray.remote
class Accumulator:
def __init__(self):
self.x = 0
def inc(self):
self.x += 1
def get(self):
return self.x
# Create the actor
actor = Accumulator.options(name="test_acc").remote()
actor.inc.remote()
actor.inc.remote()
del actor
new_actor = ray.get_actor("test_acc")
new_actor.inc.remote()
assert ray.get(new_actor.get.remote()) == 3
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", __file__]))
@@ -1,9 +1,8 @@
from ray.tests.test_experimental_client import ray_start_client_server
from ray.experimental.client.ray_client_helpers import ray_start_client_server
def test_get_ray_metadata(ray_start_regular_shared):
"""
Test the ClusterInfo client data pathway and API surface
"""Test the ClusterInfo client data pathway and API surface
"""
with ray_start_client_server() as ray:
ip_address = ray_start_regular_shared["node_ip_address"]
@@ -0,0 +1,152 @@
from ray.experimental.client.ray_client_helpers import ray_start_client_server
from ray.test_utils import wait_for_condition
import ray as real_ray
from ray.core.generated.gcs_pb2 import ActorTableData
from ray.experimental.client.server.server import _get_current_servicer
def server_object_ref_count(n):
server = _get_current_servicer()
assert server is not None
def test_cond():
if len(server.object_refs) == 0:
# No open clients
return n == 0
client_id = list(server.object_refs.keys())[0]
return len(server.object_refs[client_id]) == n
return test_cond
def server_actor_ref_count(n):
server = _get_current_servicer()
assert server is not None
def test_cond():
if len(server.actor_refs) == 0:
# No running actors
return n == 0
return len(server.actor_refs) == n
return test_cond
def test_delete_refs_on_disconnect(ray_start_regular):
with ray_start_client_server() as ray:
@ray.remote
def f(x):
return x + 2
thing1 = f.remote(6) # noqa
thing2 = ray.put("Hello World") # noqa
# One put, one function -- the function result thing1 is
# in a different category, according to the raylet.
assert len(real_ray.objects()) == 2
# But we're maintaining the reference
assert server_object_ref_count(3)()
# And can get the data
assert ray.get(thing1) == 8
# Close the client
ray.close()
wait_for_condition(server_object_ref_count(0), timeout=5)
def test_cond():
return len(real_ray.objects()) == 0
wait_for_condition(test_cond, timeout=5)
def test_delete_ref_on_object_deletion(ray_start_regular):
with ray_start_client_server() as ray:
vals = {
"ref": ray.put("Hello World"),
"ref2": ray.put("This value stays"),
}
del vals["ref"]
wait_for_condition(server_object_ref_count(1), timeout=5)
def test_delete_actor_on_disconnect(ray_start_regular):
with ray_start_client_server() as ray:
@ray.remote
class Accumulator:
def __init__(self):
self.acc = 0
def inc(self):
self.acc += 1
def get(self):
return self.acc
actor = Accumulator.remote()
actor.inc.remote()
assert server_actor_ref_count(1)()
assert ray.get(actor.get.remote()) == 1
ray.close()
wait_for_condition(server_actor_ref_count(0), timeout=5)
def test_cond():
alive_actors = [
v for v in real_ray.actors().values()
if v["State"] != ActorTableData.DEAD
]
return len(alive_actors) == 0
wait_for_condition(test_cond, timeout=10)
def test_delete_actor(ray_start_regular):
with ray_start_client_server() as ray:
@ray.remote
class Accumulator:
def __init__(self):
self.acc = 0
def inc(self):
self.acc += 1
actor = Accumulator.remote()
actor.inc.remote()
actor2 = Accumulator.remote()
actor2.inc.remote()
assert server_actor_ref_count(2)()
del actor
wait_for_condition(server_actor_ref_count(1), timeout=5)
def test_simple_multiple_references(ray_start_regular):
with ray_start_client_server() as ray:
@ray.remote
class A:
def __init__(self):
self.x = ray.put("hi")
def get(self):
return [self.x]
a = A.remote()
ref1 = ray.get(a.get.remote())[0]
ref2 = ray.get(a.get.remote())[0]
del a
assert ray.get(ref1) == "hi"
del ref1
assert ray.get(ref2) == "hi"
del ref2
@@ -1,6 +1,6 @@
import pytest
import asyncio
from ray.tests.test_experimental_client import ray_start_client_server
from ray.experimental.client.ray_client_helpers import ray_start_client_server
from ray.tests.client_test_utils import create_remote_signal_actor
from ray.test_utils import wait_for_condition
from ray.exceptions import TaskCancelledError
from ray.exceptions import RayTaskError
@@ -45,21 +45,7 @@ def test_kill_actor_immediately_after_creation(ray_start_regular):
@pytest.mark.parametrize("use_force", [True, False])
def test_cancel_chain(ray_start_regular, use_force):
with ray_start_client_server() as ray:
@ray.remote
class SignalActor:
def __init__(self):
self.ready_event = asyncio.Event()
def send(self, clear=False):
self.ready_event.set()
if clear:
self.ready_event.clear()
async def wait(self, should_wait=True):
if should_wait:
await self.ready_event.wait()
SignalActor = create_remote_signal_actor(ray)
signaler = SignalActor.remote()
@ray.remote
+47 -17
View File
@@ -16,14 +16,8 @@ import ray.utils
import ray.ray_constants as ray_constants
from ray.exceptions import RayTaskError
from ray.cluster_utils import Cluster
from ray.test_utils import (
wait_for_condition,
SignalActor,
init_error_pubsub,
get_error_message,
Semaphore,
new_scheduler_enabled,
)
from ray.test_utils import (wait_for_condition, SignalActor, init_error_pubsub,
get_error_message, Semaphore)
def test_failed_task(ray_start_regular, error_pubsub):
@@ -638,11 +632,10 @@ def test_export_large_objects(ray_start_regular, error_pubsub):
assert errors[0].type == ray_constants.PICKLING_LARGE_OBJECT_PUSH_ERROR
@pytest.mark.skip(reason="TODO detect resource deadlock")
def test_warning_for_resource_deadlock(error_pubsub, shutdown_only):
p = error_pubsub
# Check that we get warning messages for infeasible tasks.
ray.init(num_cpus=1)
def test_warning_all_tasks_blocked(shutdown_only):
ray.init(
num_cpus=1, _system_config={"debug_dump_period_milliseconds": 500})
p = init_error_pubsub()
@ray.remote(num_cpus=1)
class Foo:
@@ -652,7 +645,7 @@ def test_warning_for_resource_deadlock(error_pubsub, shutdown_only):
@ray.remote
def f():
# Creating both actors is not possible.
actors = [Foo.remote() for _ in range(2)]
actors = [Foo.remote() for _ in range(3)]
for a in actors:
ray.get(a.f.remote())
@@ -663,7 +656,46 @@ def test_warning_for_resource_deadlock(error_pubsub, shutdown_only):
assert errors[0].type == ray_constants.RESOURCE_DEADLOCK_ERROR
@pytest.mark.skipif(new_scheduler_enabled(), reason="broken")
def test_warning_actor_waiting_on_actor(shutdown_only):
ray.init(
num_cpus=1, _system_config={"debug_dump_period_milliseconds": 500})
p = init_error_pubsub()
@ray.remote(num_cpus=1)
class Actor:
pass
a = Actor.remote() # noqa
b = Actor.remote() # noqa
errors = get_error_message(p, 1, ray_constants.RESOURCE_DEADLOCK_ERROR)
assert len(errors) == 1
assert errors[0].type == ray_constants.RESOURCE_DEADLOCK_ERROR
def test_warning_task_waiting_on_actor(shutdown_only):
ray.init(
num_cpus=1, _system_config={"debug_dump_period_milliseconds": 500})
p = init_error_pubsub()
@ray.remote(num_cpus=1)
class Actor:
pass
a = Actor.remote() # noqa
@ray.remote(num_cpus=1)
def f():
print("f running")
time.sleep(999)
ids = [f.remote()] # noqa
errors = get_error_message(p, 1, ray_constants.RESOURCE_DEADLOCK_ERROR)
assert len(errors) == 1
assert errors[0].type == ray_constants.RESOURCE_DEADLOCK_ERROR
def test_warning_for_infeasible_tasks(ray_start_regular, error_pubsub):
p = error_pubsub
# Check that we get warning messages for infeasible tasks.
@@ -689,7 +721,6 @@ def test_warning_for_infeasible_tasks(ray_start_regular, error_pubsub):
assert errors[0].type == ray_constants.INFEASIBLE_TASK_ERROR
@pytest.mark.skipif(new_scheduler_enabled(), reason="broken")
def test_warning_for_infeasible_zero_cpu_actor(shutdown_only):
# Check that we cannot place an actor on a 0 CPU machine and that we get an
# infeasibility warning (even though the actor creation task itself
@@ -956,7 +987,6 @@ def test_raylet_crash_when_get(ray_start_regular):
thread.join()
@pytest.mark.skipif(new_scheduler_enabled(), reason="broken")
def test_connect_with_disconnected_node(shutdown_only):
config = {
"num_heartbeats_timeout": 50,
@@ -6,7 +6,6 @@ from ray.test_utils import (
generate_system_config_map,
wait_for_condition,
wait_for_pid_to_exit,
new_scheduler_enabled,
)
@@ -21,7 +20,6 @@ def increase(x):
return x + 1
@pytest.mark.skipif(new_scheduler_enabled(), reason="notimpl")
@pytest.mark.parametrize(
"ray_start_regular", [
generate_system_config_map(
+4 -5
View File
@@ -9,7 +9,7 @@ import pytest
import ray
import ray.cluster_utils
from ray.test_utils import wait_for_condition, new_scheduler_enabled
from ray.test_utils import wait_for_condition
from ray.internal.internal_api import global_gc
logger = logging.getLogger(__name__)
@@ -166,9 +166,9 @@ def test_global_gc_when_full(shutdown_only):
gc.enable()
@pytest.mark.skipif(new_scheduler_enabled(), reason="hangs")
def test_global_gc_actors(shutdown_only):
ray.init(num_cpus=1)
ray.init(
num_cpus=1, _system_config={"debug_dump_period_milliseconds": 500})
try:
gc.disable()
@@ -179,8 +179,7 @@ def test_global_gc_actors(shutdown_only):
return "Ok"
# Try creating 3 actors. Unless python GC is triggered to break
# reference cycles, this won't be possible. Note this test takes 20s
# to run due to the 10s delay before checking of infeasible tasks.
# reference cycles, this won't be possible.
for i in range(3):
a = A.remote()
cycle = [a]
-3
View File
@@ -8,7 +8,6 @@ import time
import ray
import ray.ray_constants
import ray.test_utils
from ray.test_utils import new_scheduler_enabled
from ray._raylet import GlobalStateAccessor
@@ -217,8 +216,6 @@ def test_load_report(shutdown_only, max_shapes):
global_state_accessor.disconnect()
@pytest.mark.skipif(
new_scheduler_enabled(), reason="requires placement groups")
def test_placement_group_load_report(ray_start_cluster):
cluster = ray_start_cluster
# Add a head node that doesn't have gpu resource.
+7 -13
View File
@@ -1,12 +1,13 @@
import joblib
import sys
import time
import os
import pickle
import numpy as np
from sklearn.datasets import load_digits, load_iris
from sklearn.model_selection import RandomizedSearchCV
from sklearn.datasets import fetch_openml
from sklearn.ensemble import ExtraTreesClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.kernel_approximation import Nystroem
@@ -14,7 +15,6 @@ from sklearn.kernel_approximation import RBFSampler
from sklearn.pipeline import make_pipeline
from sklearn.svm import LinearSVC, SVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.utils import check_array
from sklearn.linear_model import LogisticRegression
from sklearn.neural_network import MLPClassifier
from sklearn.model_selection import cross_val_score
@@ -112,20 +112,14 @@ def test_sklearn_benchmarks(ray_start_cluster_2_nodes):
}
# Load dataset.
print("Loading dataset...")
data = fetch_openml("mnist_784")
X = check_array(data["data"], dtype=np.float32, order="C")
y = data["target"]
unnormalized_X_train, y_train = pickle.load(
open(
os.path.join(
os.path.dirname(__file__), "mnist_784_100_samples.pkl"), "rb"))
# Normalize features.
X = X / 255
X_train = unnormalized_X_train / 255
# Create train-test split.
print("Creating train-test split...")
n_train = 100
X_train = X[:n_train]
y_train = y[:n_train]
register_ray()
train_time = {}
random_seed = 0
# Use two workers per classifier.
+4 -4
View File
@@ -741,10 +741,10 @@ ray.get(main_wait.release.remote())
driver1_out_split = driver1_out.split("\n")
driver2_out_split = driver2_out.split("\n")
assert driver1_out_split[0][-1] == "1"
assert driver1_out_split[1][-1] == "2"
assert driver2_out_split[0][-1] == "3"
assert driver2_out_split[1][-1] == "4"
assert driver1_out_split[0][-1] == "1", driver1_out_split
assert driver1_out_split[1][-1] == "2", driver1_out_split
assert driver2_out_split[0][-1] == "3", driver2_out_split
assert driver2_out_split[1][-1] == "4", driver2_out_split
if __name__ == "__main__":

Some files were not shown because too many files have changed in this diff Show More