mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 22:53:20 +08:00
Merge branch 'master' into py39
This commit is contained in:
@@ -0,0 +1,49 @@
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
|
||||
client_mode_enabled = os.environ.get("RAY_CLIENT_MODE", "0") == "1"
|
||||
|
||||
_client_hook_enabled = True
|
||||
|
||||
|
||||
def _enable_client_hook(val: bool):
|
||||
global _client_hook_enabled
|
||||
_client_hook_enabled = val
|
||||
|
||||
|
||||
def _disable_client_hook():
|
||||
global _client_hook_enabled
|
||||
out = _client_hook_enabled
|
||||
_client_hook_enabled = False
|
||||
return out
|
||||
|
||||
|
||||
def _explicitly_enable_client_mode():
|
||||
global client_mode_enabled
|
||||
client_mode_enabled = True
|
||||
|
||||
|
||||
@contextmanager
|
||||
def disable_client_hook():
|
||||
val = _disable_client_hook()
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
_enable_client_hook(val)
|
||||
|
||||
|
||||
def client_mode_hook(func):
|
||||
"""
|
||||
Decorator for ray module methods to delegate to ray client
|
||||
"""
|
||||
from ray.experimental.client import ray
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
global _client_hook_enabled
|
||||
if client_mode_enabled and _client_hook_enabled:
|
||||
return getattr(ray, func.__name__)(*args, **kwargs)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
@@ -0,0 +1,83 @@
|
||||
import inspect
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from ray.experimental.client.ray_client_helpers import ray_start_client_server
|
||||
|
||||
from ray._private.ray_microbenchmark_helpers import timeit
|
||||
from ray._private.ray_microbenchmark_helpers import ray_setup_and_teardown
|
||||
|
||||
|
||||
def benchmark_get_calls(ray):
|
||||
value = ray.put(0)
|
||||
|
||||
def get_small():
|
||||
ray.get(value)
|
||||
|
||||
timeit("client: get calls", get_small)
|
||||
|
||||
|
||||
def benchmark_put_calls(ray):
|
||||
def put_small():
|
||||
ray.put(0)
|
||||
|
||||
timeit("client: put calls", put_small)
|
||||
|
||||
|
||||
def benchmark_remote_put_calls(ray):
|
||||
@ray.remote
|
||||
def do_put_small():
|
||||
for _ in range(100):
|
||||
ray.put(0)
|
||||
|
||||
def put_multi_small():
|
||||
ray.get([do_put_small.remote() for _ in range(10)])
|
||||
|
||||
timeit("client: remote put calls", put_multi_small, 1000)
|
||||
|
||||
|
||||
def benchmark_simple_actor(ray):
|
||||
@ray.remote(num_cpus=0)
|
||||
class Actor:
|
||||
def small_value(self):
|
||||
return b"ok"
|
||||
|
||||
def small_value_arg(self, x):
|
||||
return b"ok"
|
||||
|
||||
def small_value_batch(self, n):
|
||||
ray.get([self.small_value.remote() for _ in range(n)])
|
||||
|
||||
a = Actor.remote()
|
||||
|
||||
def actor_sync():
|
||||
ray.get(a.small_value.remote())
|
||||
|
||||
timeit("client: 1:1 actor calls sync", actor_sync)
|
||||
|
||||
def actor_async():
|
||||
ray.get([a.small_value.remote() for _ in range(1000)])
|
||||
|
||||
timeit("client: 1:1 actor calls async", actor_async, 1000)
|
||||
|
||||
a = Actor.options(max_concurrency=16).remote()
|
||||
|
||||
def actor_concurrent():
|
||||
ray.get([a.small_value.remote() for _ in range(1000)])
|
||||
|
||||
timeit("client: 1:1 actor calls concurrent", actor_concurrent, 1000)
|
||||
|
||||
|
||||
def main():
|
||||
system_config = {"put_small_object_in_memory_store": True}
|
||||
with ray_setup_and_teardown(
|
||||
logging_level=logging.WARNING, _system_config=system_config):
|
||||
for name, obj in inspect.getmembers(sys.modules[__name__]):
|
||||
if not name.startswith("benchmark_"):
|
||||
continue
|
||||
with ray_start_client_server() as ray:
|
||||
obj(ray)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,39 @@
|
||||
import time
|
||||
import os
|
||||
import ray
|
||||
import numpy as np
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
# Only run tests matching this filter pattern.
|
||||
filter_pattern = os.environ.get("TESTS_TO_RUN", "")
|
||||
|
||||
|
||||
def timeit(name, fn, multiplier=1):
|
||||
if filter_pattern not in name:
|
||||
return
|
||||
# warmup
|
||||
start = time.time()
|
||||
while time.time() - start < 1:
|
||||
fn()
|
||||
# real run
|
||||
stats = []
|
||||
for _ in range(4):
|
||||
start = time.time()
|
||||
count = 0
|
||||
while time.time() - start < 2:
|
||||
fn()
|
||||
count += 1
|
||||
end = time.time()
|
||||
stats.append(multiplier * count / (end - start))
|
||||
print(name, "per second", round(np.mean(stats), 2), "+-",
|
||||
round(np.std(stats), 2))
|
||||
|
||||
|
||||
@contextmanager
|
||||
def ray_setup_and_teardown(**init_args):
|
||||
ray.init(**init_args)
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
ray.shutdown()
|
||||
@@ -279,8 +279,7 @@ def get_address_info_from_redis_helper(redis_address,
|
||||
def get_address_info_from_redis(redis_address,
|
||||
node_ip_address,
|
||||
num_retries=5,
|
||||
redis_password=None,
|
||||
no_warning=False):
|
||||
redis_password=None):
|
||||
counter = 0
|
||||
while True:
|
||||
try:
|
||||
@@ -291,11 +290,10 @@ def get_address_info_from_redis(redis_address,
|
||||
raise
|
||||
# Some of the information may not be in Redis yet, so wait a little
|
||||
# bit.
|
||||
if not no_warning:
|
||||
logger.warning(
|
||||
"Some processes that the driver needs to connect to have "
|
||||
"not registered with Redis, so retrying. Have you run "
|
||||
"'ray start' on this node?")
|
||||
logger.warning(
|
||||
"Some processes that the driver needs to connect to have "
|
||||
"not registered with Redis, so retrying. Have you run "
|
||||
"'ray start' on this node?")
|
||||
time.sleep(1)
|
||||
counter += 1
|
||||
|
||||
@@ -1618,12 +1616,13 @@ def determine_plasma_store_config(object_store_memory,
|
||||
logger.warning(
|
||||
"WARNING: The object store is using {} instead of "
|
||||
"/dev/shm because /dev/shm has only {} bytes available. "
|
||||
"This may slow down performance! You may be able to free "
|
||||
"up space by deleting files in /dev/shm or terminating "
|
||||
"any running plasma_store_server processes. If you are "
|
||||
"inside a Docker container, you may need to pass an "
|
||||
"argument with the flag '--shm-size' to 'docker run'.".
|
||||
format(ray.utils.get_user_temp_dir(), shm_avail))
|
||||
"This will harm performance! You may be able to free up "
|
||||
"space by deleting files in /dev/shm. If you are inside a "
|
||||
"Docker container, you can increase /dev/shm size by "
|
||||
"passing '--shm-size=Xgb' to 'docker run' (or add it to "
|
||||
"the run_options list in a Ray cluster config). Make sure "
|
||||
"to set this to more than 2gb.".format(
|
||||
ray.utils.get_user_temp_dir(), shm_avail))
|
||||
else:
|
||||
plasma_directory = ray.utils.get_user_temp_dir()
|
||||
|
||||
|
||||
+19
-19
@@ -107,6 +107,10 @@ from ray.exceptions import (
|
||||
TaskCancelledError
|
||||
)
|
||||
from ray.utils import decode
|
||||
from ray._private.client_mode_hook import (
|
||||
_enable_client_hook,
|
||||
_disable_client_hook,
|
||||
)
|
||||
import msgpack
|
||||
|
||||
cimport cpython
|
||||
@@ -558,6 +562,7 @@ cdef CRayStatus task_execution_handler(
|
||||
|
||||
with gil:
|
||||
try:
|
||||
client_was_enabled = _disable_client_hook()
|
||||
try:
|
||||
# The call to execute_task should never raise an exception. If
|
||||
# it does, that indicates that there was an internal error.
|
||||
@@ -582,6 +587,8 @@ cdef CRayStatus task_execution_handler(
|
||||
else:
|
||||
logger.exception("SystemExit was raised from the worker")
|
||||
return CRayStatus.UnexpectedSystemExit()
|
||||
finally:
|
||||
_enable_client_hook(client_was_enabled)
|
||||
|
||||
return CRayStatus.OK()
|
||||
|
||||
@@ -638,9 +645,11 @@ cdef c_vector[c_string] spill_objects_handler(
|
||||
return return_urls
|
||||
|
||||
|
||||
cdef void restore_spilled_objects_handler(
|
||||
cdef int64_t restore_spilled_objects_handler(
|
||||
const c_vector[CObjectID]& object_ids_to_restore,
|
||||
const c_vector[c_string]& object_urls) nogil:
|
||||
cdef:
|
||||
int64_t bytes_restored = 0
|
||||
with gil:
|
||||
urls = []
|
||||
size = object_urls.size()
|
||||
@@ -651,7 +660,8 @@ cdef void restore_spilled_objects_handler(
|
||||
with ray.worker._changeproctitle(
|
||||
ray_constants.WORKER_PROCESS_TYPE_RESTORE_WORKER,
|
||||
ray_constants.WORKER_PROCESS_TYPE_RESTORE_WORKER_IDLE):
|
||||
external_storage.restore_spilled_objects(object_refs, urls)
|
||||
bytes_restored = external_storage.restore_spilled_objects(
|
||||
object_refs, urls)
|
||||
except Exception:
|
||||
exception_str = (
|
||||
"An unexpected internal error occurred while the IO worker "
|
||||
@@ -662,6 +672,7 @@ cdef void restore_spilled_objects_handler(
|
||||
"restore_spilled_objects_error",
|
||||
traceback.format_exc() + exception_str,
|
||||
job_id=None)
|
||||
return bytes_restored
|
||||
|
||||
|
||||
cdef void delete_spilled_objects_handler(
|
||||
@@ -873,7 +884,8 @@ cdef class CoreWorker:
|
||||
return self.plasma_event_handler
|
||||
|
||||
def get_objects(self, object_refs, TaskID current_task_id,
|
||||
int64_t timeout_ms=-1, plasma_objects_only=False):
|
||||
int64_t timeout_ms=-1,
|
||||
plasma_objects_only=False):
|
||||
cdef:
|
||||
c_vector[shared_ptr[CRayObject]] results
|
||||
CTaskID c_task_id = current_task_id.native()
|
||||
@@ -1004,7 +1016,7 @@ cdef class CoreWorker:
|
||||
return c_object_id.Binary()
|
||||
|
||||
def wait(self, object_refs, int num_returns, int64_t timeout_ms,
|
||||
TaskID current_task_id):
|
||||
TaskID current_task_id, c_bool fetch_local):
|
||||
cdef:
|
||||
c_vector[CObjectID] wait_ids
|
||||
c_vector[c_bool] results
|
||||
@@ -1013,7 +1025,7 @@ cdef class CoreWorker:
|
||||
wait_ids = ObjectRefsToVector(object_refs)
|
||||
with nogil:
|
||||
check_status(CCoreWorkerProcess.GetCoreWorker().Wait(
|
||||
wait_ids, num_returns, timeout_ms, &results))
|
||||
wait_ids, num_returns, timeout_ms, &results, fetch_local))
|
||||
|
||||
assert len(results) == len(object_refs)
|
||||
|
||||
@@ -1026,14 +1038,13 @@ cdef class CoreWorker:
|
||||
|
||||
return ready, not_ready
|
||||
|
||||
def free_objects(self, object_refs, c_bool local_only,
|
||||
c_bool delete_creating_tasks):
|
||||
def free_objects(self, object_refs, c_bool local_only):
|
||||
cdef:
|
||||
c_vector[CObjectID] free_ids = ObjectRefsToVector(object_refs)
|
||||
|
||||
with nogil:
|
||||
check_status(CCoreWorkerProcess.GetCoreWorker().Delete(
|
||||
free_ids, local_only, delete_creating_tasks))
|
||||
free_ids, local_only))
|
||||
|
||||
def global_gc(self):
|
||||
with nogil:
|
||||
@@ -1573,17 +1584,6 @@ cdef class CoreWorker:
|
||||
resource_name.encode("ascii"), capacity,
|
||||
CNodeID.FromBinary(client_id.binary()))
|
||||
|
||||
def force_spill_objects(self, object_refs):
|
||||
cdef c_vector[CObjectID] object_ids
|
||||
object_ids = ObjectRefsToVector(object_refs)
|
||||
assert not RayConfig.instance().automatic_object_deletion_enabled(), (
|
||||
"Automatic object deletion is not supported for"
|
||||
"force_spill_objects yet. Please set"
|
||||
"automatic_object_deletion_enabled: False in Ray's system config.")
|
||||
with nogil:
|
||||
check_status(CCoreWorkerProcess.GetCoreWorker()
|
||||
.SpillObjects(object_ids))
|
||||
|
||||
cdef void async_set_result(shared_ptr[CRayObject] obj,
|
||||
CObjectID object_ref,
|
||||
void *future) with gil:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from collections import defaultdict, namedtuple
|
||||
from collections import defaultdict, namedtuple, Counter
|
||||
from typing import Any, Optional, Dict, List
|
||||
from urllib3.exceptions import MaxRetryError
|
||||
import copy
|
||||
@@ -16,8 +16,10 @@ from ray.experimental.internal_kv import _internal_kv_put, \
|
||||
from ray.autoscaler.tags import (
|
||||
TAG_RAY_LAUNCH_CONFIG, TAG_RAY_RUNTIME_CONFIG,
|
||||
TAG_RAY_FILE_MOUNTS_CONTENTS, TAG_RAY_NODE_STATUS, TAG_RAY_NODE_KIND,
|
||||
TAG_RAY_USER_NODE_TYPE, STATUS_UP_TO_DATE, NODE_KIND_WORKER,
|
||||
NODE_KIND_UNMANAGED, NODE_KIND_HEAD)
|
||||
TAG_RAY_USER_NODE_TYPE, STATUS_UNINITIALIZED, STATUS_WAITING_FOR_SSH,
|
||||
STATUS_SYNCING_FILES, STATUS_SETTING_UP, STATUS_UP_TO_DATE,
|
||||
NODE_KIND_WORKER, NODE_KIND_UNMANAGED, NODE_KIND_HEAD)
|
||||
from ray.autoscaler._private.legacy_info_string import legacy_log_info_string
|
||||
from ray.autoscaler._private.providers import _get_node_provider
|
||||
from ray.autoscaler._private.updater import NodeUpdaterThread
|
||||
from ray.autoscaler._private.node_launcher import NodeLauncher
|
||||
@@ -25,8 +27,8 @@ from ray.autoscaler._private.resource_demand_scheduler import \
|
||||
get_bin_pack_residual, ResourceDemandScheduler, NodeType, NodeID, NodeIP, \
|
||||
ResourceDict
|
||||
from ray.autoscaler._private.util import ConcurrentCounter, validate_config, \
|
||||
with_head_node_ip, hash_launch_conf, hash_runtime_conf, add_prefix, \
|
||||
DEBUG_AUTOSCALING_STATUS, DEBUG_AUTOSCALING_ERROR
|
||||
with_head_node_ip, hash_launch_conf, hash_runtime_conf, \
|
||||
DEBUG_AUTOSCALING_ERROR, format_info_string
|
||||
from ray.autoscaler._private.constants import \
|
||||
AUTOSCALER_MAX_NUM_FAILURES, AUTOSCALER_MAX_LAUNCH_BATCH, \
|
||||
AUTOSCALER_MAX_CONCURRENT_LAUNCHES, AUTOSCALER_UPDATE_INTERVAL_S, \
|
||||
@@ -41,20 +43,23 @@ UpdateInstructions = namedtuple(
|
||||
"UpdateInstructions",
|
||||
["node_id", "init_commands", "start_ray_commands", "docker_config"])
|
||||
|
||||
AutoscalerSummary = namedtuple(
|
||||
"AutoscalerSummary",
|
||||
["active_nodes", "pending_nodes", "pending_launches", "failed_nodes"])
|
||||
|
||||
|
||||
class StandardAutoscaler:
|
||||
"""The autoscaling control loop for a Ray cluster.
|
||||
|
||||
There are two ways to start an autoscaling cluster: manually by running
|
||||
`ray start --head --autoscaling-config=/path/to/config.yaml` on a
|
||||
instance that has permission to launch other instances, or you can also use
|
||||
`ray up /path/to/config.yaml` from your laptop, which will
|
||||
configure the right AWS/Cloud roles automatically.
|
||||
|
||||
StandardAutoscaler's `update` method is periodically called by `monitor.py`
|
||||
to add and remove nodes as necessary. Currently, load-based autoscaling is
|
||||
not implemented, so all this class does is try to maintain a constant
|
||||
cluster size.
|
||||
`ray start --head --autoscaling-config=/path/to/config.yaml` on a instance
|
||||
that has permission to launch other instances, or you can also use `ray up
|
||||
/path/to/config.yaml` from your laptop, which will configure the right
|
||||
AWS/Cloud roles automatically. See the documentation for a full definition
|
||||
of autoscaling behavior:
|
||||
https://docs.ray.io/en/master/cluster/autoscaling.html
|
||||
StandardAutoscaler's `update` method is periodically called in
|
||||
`monitor.py`'s monitoring loop.
|
||||
|
||||
StandardAutoscaler is also used to bootstrap clusters (by adding workers
|
||||
until the cluster size that can handle the resource demand is met).
|
||||
@@ -120,9 +125,6 @@ class StandardAutoscaler:
|
||||
for local_path in self.config["file_mounts"].values():
|
||||
assert os.path.exists(local_path)
|
||||
|
||||
# List of resource bundles the user is requesting of the cluster.
|
||||
self.resource_demand_vector = []
|
||||
|
||||
logger.info("StandardAutoscaler: {}".format(self.config))
|
||||
|
||||
def update(self):
|
||||
@@ -149,7 +151,6 @@ class StandardAutoscaler:
|
||||
|
||||
def _update(self):
|
||||
now = time.time()
|
||||
|
||||
# Throttle autoscaling updates to this interval to avoid exceeding
|
||||
# rate limits on API calls.
|
||||
if now - self.last_update_time < self.update_interval_s:
|
||||
@@ -162,7 +163,6 @@ class StandardAutoscaler:
|
||||
self.provider.internal_ip(node_id)
|
||||
for node_id in self.all_workers()
|
||||
])
|
||||
self.log_info_string(nodes)
|
||||
|
||||
# Terminate any idle or out of date nodes
|
||||
last_used = self.load_metrics.last_used_time_by_ip
|
||||
@@ -176,7 +176,7 @@ class StandardAutoscaler:
|
||||
sorted_node_ids = self._sort_based_on_last_used(nodes, last_used)
|
||||
# Don't terminate nodes needed by request_resources()
|
||||
nodes_allowed_to_terminate: Dict[NodeID, bool] = {}
|
||||
if self.resource_demand_vector:
|
||||
if self.load_metrics.get_resource_requests():
|
||||
nodes_allowed_to_terminate = self._get_nodes_allowed_to_terminate(
|
||||
sorted_node_ids)
|
||||
|
||||
@@ -202,7 +202,6 @@ class StandardAutoscaler:
|
||||
if nodes_to_terminate:
|
||||
self.provider.terminate_nodes(nodes_to_terminate)
|
||||
nodes = self.workers()
|
||||
self.log_info_string(nodes)
|
||||
|
||||
# Terminate nodes if there are too many
|
||||
nodes_to_terminate = []
|
||||
@@ -217,8 +216,6 @@ class StandardAutoscaler:
|
||||
self.provider.terminate_nodes(nodes_to_terminate)
|
||||
nodes = self.workers()
|
||||
|
||||
self.log_info_string(nodes)
|
||||
|
||||
to_launch = self.resource_demand_scheduler.get_nodes_to_launch(
|
||||
self.provider.non_terminated_nodes(tag_filters={}),
|
||||
self.pending_launches.breakdown(),
|
||||
@@ -226,7 +223,7 @@ class StandardAutoscaler:
|
||||
self.load_metrics.get_resource_utilization(),
|
||||
self.load_metrics.get_pending_placement_groups(),
|
||||
self.load_metrics.get_static_node_resources_by_ip(),
|
||||
ensure_min_cluster_size=self.resource_demand_vector)
|
||||
ensure_min_cluster_size=self.load_metrics.get_resource_requests())
|
||||
for node_type, count in to_launch.items():
|
||||
self.launch_new_node(count, node_type=node_type)
|
||||
|
||||
@@ -256,7 +253,6 @@ class StandardAutoscaler:
|
||||
self.provider.terminate_nodes(nodes_to_terminate)
|
||||
|
||||
nodes = self.workers()
|
||||
self.log_info_string(nodes)
|
||||
|
||||
# Update nodes with out-of-date files.
|
||||
# TODO(edoakes): Spawning these threads directly seems to cause
|
||||
@@ -282,6 +278,9 @@ class StandardAutoscaler:
|
||||
for node_id in nodes:
|
||||
self.recover_if_needed(node_id, now)
|
||||
|
||||
logger.info(self.info_string())
|
||||
legacy_log_info_string(self, nodes)
|
||||
|
||||
def _sort_based_on_last_used(self, nodes: List[NodeID],
|
||||
last_used: Dict[str, float]) -> List[NodeID]:
|
||||
"""Sort the nodes based on the last time they were used.
|
||||
@@ -333,7 +332,7 @@ class StandardAutoscaler:
|
||||
NodeIP,
|
||||
ResourceDict] = \
|
||||
self.load_metrics.get_static_node_resources_by_ip()
|
||||
head_node_resources = static_nodes[head_ip]
|
||||
head_node_resources = static_nodes.get(head_ip, {})
|
||||
else:
|
||||
head_node_resources = {}
|
||||
|
||||
@@ -362,7 +361,7 @@ class StandardAutoscaler:
|
||||
used_resource_requests: List[ResourceDict]
|
||||
_, used_resource_requests = \
|
||||
get_bin_pack_residual(max_node_resources,
|
||||
self.resource_demand_vector)
|
||||
self.load_metrics.get_resource_requests())
|
||||
# Remove the first entry (the head node).
|
||||
max_node_resources.pop(0)
|
||||
# Remove the first entry (the head node).
|
||||
@@ -482,11 +481,13 @@ class StandardAutoscaler:
|
||||
# for legacy yamls.
|
||||
self.resource_demand_scheduler.reset_config(
|
||||
self.provider, self.available_node_types,
|
||||
self.config["max_workers"], upscaling_speed)
|
||||
self.config["max_workers"], self.config["head_node_type"],
|
||||
upscaling_speed)
|
||||
else:
|
||||
self.resource_demand_scheduler = ResourceDemandScheduler(
|
||||
self.provider, self.available_node_types,
|
||||
self.config["max_workers"], upscaling_speed)
|
||||
self.config["max_workers"], self.config["head_node_type"],
|
||||
upscaling_speed)
|
||||
|
||||
except Exception as e:
|
||||
if errors_fatal:
|
||||
@@ -532,15 +533,17 @@ class StandardAutoscaler:
|
||||
if not self.can_update(node_id):
|
||||
return
|
||||
key = self.provider.internal_ip(node_id)
|
||||
if key not in self.load_metrics.last_heartbeat_time_by_ip:
|
||||
self.load_metrics.last_heartbeat_time_by_ip[key] = now
|
||||
last_heartbeat_time = self.load_metrics.last_heartbeat_time_by_ip[key]
|
||||
delta = now - last_heartbeat_time
|
||||
if delta < AUTOSCALER_HEARTBEAT_TIMEOUT_S:
|
||||
return
|
||||
|
||||
if key in self.load_metrics.last_heartbeat_time_by_ip:
|
||||
last_heartbeat_time = self.load_metrics.last_heartbeat_time_by_ip[
|
||||
key]
|
||||
delta = now - last_heartbeat_time
|
||||
if delta < AUTOSCALER_HEARTBEAT_TIMEOUT_S:
|
||||
return
|
||||
|
||||
logger.warning("StandardAutoscaler: "
|
||||
"{}: No heartbeat in {}s, "
|
||||
"restarting Ray to recover...".format(node_id, delta))
|
||||
"{}: No recent heartbeat, "
|
||||
"restarting Ray to recover...".format(node_id))
|
||||
updater = NodeUpdaterThread(
|
||||
node_id=node_id,
|
||||
provider_config=self.config["provider"],
|
||||
@@ -677,43 +680,6 @@ class StandardAutoscaler:
|
||||
return self.provider.non_terminated_nodes(
|
||||
tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_UNMANAGED})
|
||||
|
||||
def log_info_string(self, nodes):
|
||||
tmp = "Cluster status: "
|
||||
tmp += self.info_string(nodes)
|
||||
tmp += "\n"
|
||||
tmp += self.load_metrics.info_string()
|
||||
tmp += "\n"
|
||||
tmp += self.resource_demand_scheduler.debug_string(
|
||||
nodes, self.pending_launches.breakdown(),
|
||||
self.load_metrics.get_resource_utilization())
|
||||
if _internal_kv_initialized():
|
||||
_internal_kv_put(DEBUG_AUTOSCALING_STATUS, tmp, overwrite=True)
|
||||
if self.prefix_cluster_info:
|
||||
tmp = add_prefix(tmp, self.config["cluster_name"])
|
||||
logger.debug(tmp)
|
||||
|
||||
def info_string(self, nodes):
|
||||
suffix = ""
|
||||
if self.updaters:
|
||||
suffix += " ({} updating)".format(len(self.updaters))
|
||||
if self.num_failed_updates:
|
||||
suffix += " ({} failed to update)".format(
|
||||
len(self.num_failed_updates))
|
||||
|
||||
return "{} nodes{}".format(len(nodes), suffix)
|
||||
|
||||
def request_resources(self, resources: List[dict]):
|
||||
"""Called by monitor to request resources.
|
||||
|
||||
Args:
|
||||
resources: A list of resource bundles.
|
||||
"""
|
||||
if resources:
|
||||
logger.info(
|
||||
"StandardAutoscaler: resource_requests={}".format(resources))
|
||||
assert isinstance(resources, list), resources
|
||||
self.resource_demand_vector = resources
|
||||
|
||||
def kill_workers(self):
|
||||
logger.error("StandardAutoscaler: kill_workers triggered")
|
||||
nodes = self.workers()
|
||||
@@ -721,3 +687,66 @@ class StandardAutoscaler:
|
||||
self.provider.terminate_nodes(nodes)
|
||||
logger.error("StandardAutoscaler: terminated {} node(s)".format(
|
||||
len(nodes)))
|
||||
|
||||
def summary(self):
|
||||
"""Summarizes the active, pending, and failed node launches.
|
||||
|
||||
An active node is a node whose raylet is actively reporting heartbeats.
|
||||
A pending node is non-active node whose node tag is uninitialized,
|
||||
waiting for ssh, syncing files, or setting up.
|
||||
If a node is not pending or active, it is failed.
|
||||
|
||||
Returns:
|
||||
AutoscalerSummary: The summary.
|
||||
"""
|
||||
all_node_ids = self.provider.non_terminated_nodes(tag_filters={})
|
||||
|
||||
active_nodes = Counter()
|
||||
pending_nodes = []
|
||||
failed_nodes = []
|
||||
|
||||
for node_id in all_node_ids:
|
||||
ip = self.provider.internal_ip(node_id)
|
||||
node_tags = self.provider.node_tags(node_id)
|
||||
if node_tags[TAG_RAY_NODE_KIND] == NODE_KIND_UNMANAGED:
|
||||
continue
|
||||
node_type = node_tags[TAG_RAY_USER_NODE_TYPE]
|
||||
|
||||
# TODO (Alex): If a node's raylet has died, it shouldn't be marked
|
||||
# as active.
|
||||
is_active = self.load_metrics.is_active(ip)
|
||||
if is_active:
|
||||
active_nodes[node_type] += 1
|
||||
else:
|
||||
status = node_tags[TAG_RAY_NODE_STATUS]
|
||||
pending_states = [
|
||||
STATUS_UNINITIALIZED, STATUS_WAITING_FOR_SSH,
|
||||
STATUS_SYNCING_FILES, STATUS_SETTING_UP
|
||||
]
|
||||
is_pending = status in pending_states
|
||||
if is_pending:
|
||||
pending_nodes.append((ip, node_type))
|
||||
else:
|
||||
# TODO (Alex): Failed nodes are now immediately killed, so
|
||||
# this list will almost always be empty. We should ideally
|
||||
# keep a cache of recently failed nodes and their startup
|
||||
# logs.
|
||||
failed_nodes.append((ip, node_type))
|
||||
|
||||
# The concurrent counter leaves some 0 counts in, so we need to
|
||||
# manually filter those out.
|
||||
pending_launches = {}
|
||||
for node_type, count in self.pending_launches.breakdown().items():
|
||||
if count:
|
||||
pending_launches[node_type] = count
|
||||
|
||||
return AutoscalerSummary(
|
||||
active_nodes=active_nodes,
|
||||
pending_nodes=pending_nodes,
|
||||
pending_launches=pending_launches,
|
||||
failed_nodes=failed_nodes)
|
||||
|
||||
def info_string(self):
|
||||
lm_summary = self.load_metrics.summary()
|
||||
autoscaler_summary = self.summary()
|
||||
return "\n" + format_info_string(lm_summary, autoscaler_summary)
|
||||
|
||||
@@ -35,6 +35,8 @@ logger = logging.getLogger(__name__)
|
||||
HASH_MAX_LENGTH = 10
|
||||
KUBECTL_RSYNC = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), "kubernetes/kubectl-rsync.sh")
|
||||
MAX_HOME_RETRIES = 3
|
||||
HOME_RETRY_DELAY_S = 5
|
||||
|
||||
_config = {"use_login_shells": True, "silent_rsync": True}
|
||||
|
||||
@@ -248,16 +250,31 @@ class KubernetesCommandRunner(CommandRunnerInterface):
|
||||
|
||||
@property
|
||||
def _home(self):
|
||||
if self._home_cached is not None:
|
||||
return self._home_cached
|
||||
for _ in range(MAX_HOME_RETRIES - 1):
|
||||
try:
|
||||
self._home_cached = self._try_to_get_home()
|
||||
return self._home_cached
|
||||
except Exception:
|
||||
# TODO (Dmitri): Identify the exception we're trying to avoid.
|
||||
logger.info("Error reading container's home directory. "
|
||||
f"Retrying in {HOME_RETRY_DELAY_S} seconds.")
|
||||
time.sleep(HOME_RETRY_DELAY_S)
|
||||
# Last try
|
||||
self._home_cached = self._try_to_get_home()
|
||||
return self._home_cached
|
||||
|
||||
def _try_to_get_home(self):
|
||||
# TODO (Dmitri): Think about how to use the node's HOME variable
|
||||
# without making an extra kubectl exec call.
|
||||
if self._home_cached is None:
|
||||
cmd = self.kubectl + [
|
||||
"exec", "-it", self.node_id, "--", "printenv", "HOME"
|
||||
]
|
||||
joined_cmd = " ".join(cmd)
|
||||
raw_out = self.process_runner.check_output(joined_cmd, shell=True)
|
||||
self._home_cached = raw_out.decode().strip("\n\r")
|
||||
return self._home_cached
|
||||
cmd = self.kubectl + [
|
||||
"exec", "-it", self.node_id, "--", "printenv", "HOME"
|
||||
]
|
||||
joined_cmd = " ".join(cmd)
|
||||
raw_out = self.process_runner.check_output(joined_cmd, shell=True)
|
||||
home = raw_out.decode().strip("\n\r")
|
||||
return home
|
||||
|
||||
|
||||
class SSHOptions:
|
||||
|
||||
@@ -43,6 +43,10 @@ from ray.worker import global_worker # type: ignore
|
||||
from ray.util.debug import log_once
|
||||
|
||||
import ray.autoscaler._private.subprocess_output_util as cmd_output_util
|
||||
from ray.autoscaler._private.load_metrics import LoadMetricsSummary
|
||||
from ray.autoscaler._private.autoscaler import AutoscalerSummary
|
||||
from ray.autoscaler._private.util import format_info_string, \
|
||||
format_info_string_no_node_types
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -94,6 +98,14 @@ def debug_status() -> str:
|
||||
status = "No cluster status."
|
||||
else:
|
||||
status = status.decode("utf-8")
|
||||
as_dict = json.loads(status)
|
||||
lm_summary = LoadMetricsSummary(**as_dict["load_metrics_report"])
|
||||
if "autoscaler_report" in as_dict:
|
||||
autoscaler_summary = AutoscalerSummary(
|
||||
**as_dict["autoscaler_report"])
|
||||
status = format_info_string(lm_summary, autoscaler_summary)
|
||||
else:
|
||||
status = format_info_string_no_node_types(lm_summary)
|
||||
if error:
|
||||
status += "\n"
|
||||
status += error.decode("utf-8")
|
||||
@@ -280,9 +292,10 @@ def _bootstrap_config(config: Dict[str, Any],
|
||||
f"Failed to autodetect node resources: {str(exc)}. "
|
||||
"You can see full stack trace with higher verbosity.")
|
||||
|
||||
# NOTE: if `resources` field is missing, validate_config for non-AWS will
|
||||
# fail (the schema error will ask the user to manually fill the resources)
|
||||
# as we currently support autofilling resources for AWS instances only.
|
||||
# NOTE: if `resources` field is missing, validate_config for providers
|
||||
# other than AWS and Kubernetes will fail (the schema error will ask the
|
||||
# user to manually fill the resources) as we currently support autofilling
|
||||
# resources for AWS and Kubernetes only.
|
||||
validate_config(config)
|
||||
resolved_config = provider_cls.bootstrap_config(config)
|
||||
|
||||
|
||||
@@ -60,6 +60,13 @@ def bootstrap_kubernetes(config):
|
||||
|
||||
|
||||
def fillout_resources_kubernetes(config):
|
||||
"""Fills CPU and GPU resources by reading pod spec of each available node
|
||||
type.
|
||||
|
||||
For each node type and each of CPU/GPU, looks at container's resources
|
||||
and limits, takes min of the two. The result is rounded up, as Ray does
|
||||
not currently support fractional CPU.
|
||||
"""
|
||||
if "available_node_types" not in config:
|
||||
return config["available_node_types"]
|
||||
node_types = copy.deepcopy(config["available_node_types"])
|
||||
@@ -96,20 +103,47 @@ def get_resource(container_resources, resource_name):
|
||||
limit = _get_resource(
|
||||
container_resources, resource_name, field_name="limits")
|
||||
resource = min(request, limit)
|
||||
# float("inf") value means the resource wasn't detected in either
|
||||
# requests or limits
|
||||
return 0 if resource == float("inf") else int(resource)
|
||||
|
||||
|
||||
def _get_resource(container_resources, resource_name, field_name):
|
||||
if (field_name in container_resources
|
||||
and resource_name in container_resources[field_name]):
|
||||
return _parse_resource(container_resources[field_name][resource_name])
|
||||
else:
|
||||
"""Returns the resource quantity.
|
||||
|
||||
The amount of resource is rounded up to nearest integer.
|
||||
Returns float("inf") if the resource is not present.
|
||||
|
||||
Args:
|
||||
container_resources (dict): Container's resource field.
|
||||
resource_name (str): One of 'cpu' or 'gpu'.
|
||||
field_name (str): One of 'requests' or 'limits'.
|
||||
|
||||
Returns:
|
||||
Union[int, float]: Detected resource quantity.
|
||||
"""
|
||||
if field_name not in container_resources:
|
||||
# No limit/resource field.
|
||||
return float("inf")
|
||||
resources = container_resources[field_name]
|
||||
# Look for keys containing the resource_name. For example,
|
||||
# the key 'nvidia.com/gpu' contains the key 'gpu'.
|
||||
matching_keys = [key for key in resources if resource_name in key.lower()]
|
||||
if len(matching_keys) == 0:
|
||||
return float("inf")
|
||||
if len(matching_keys) > 1:
|
||||
# Should have only one match -- mostly relevant for gpu.
|
||||
raise ValueError(f"Multiple {resource_name} types not supported.")
|
||||
# E.g. 'nvidia.com/gpu' or 'cpu'.
|
||||
resource_key = matching_keys.pop()
|
||||
resource_quantity = resources[resource_key]
|
||||
return _parse_resource(resource_quantity)
|
||||
|
||||
|
||||
def _parse_resource(resource):
|
||||
resource_str = str(resource)
|
||||
if resource_str[-1] == "m":
|
||||
# For example, '500m' rounds up to 1.
|
||||
return math.ceil(int(resource_str[:-1]) / 1000)
|
||||
else:
|
||||
return int(resource_str)
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
import logging
|
||||
from ray.autoscaler._private.util import DEBUG_AUTOSCALING_STATUS_LEGACY
|
||||
from ray.experimental.internal_kv import _internal_kv_put, \
|
||||
_internal_kv_initialized
|
||||
"""This file provides legacy support for the old info string in order to
|
||||
ensure the dashboard's `api/cluster_status` does not break backwards
|
||||
compatibilty.
|
||||
"""
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def legacy_log_info_string(autoscaler, nodes):
|
||||
tmp = "Cluster status: "
|
||||
tmp += info_string(autoscaler, nodes)
|
||||
tmp += "\n"
|
||||
tmp += autoscaler.load_metrics.info_string()
|
||||
tmp += "\n"
|
||||
tmp += autoscaler.resource_demand_scheduler.debug_string(
|
||||
nodes, autoscaler.pending_launches.breakdown(),
|
||||
autoscaler.load_metrics.get_resource_utilization())
|
||||
if _internal_kv_initialized():
|
||||
_internal_kv_put(DEBUG_AUTOSCALING_STATUS_LEGACY, tmp, overwrite=True)
|
||||
logger.debug(tmp)
|
||||
|
||||
|
||||
def info_string(autoscaler, nodes):
|
||||
suffix = ""
|
||||
if autoscaler.updaters:
|
||||
suffix += " ({} updating)".format(len(autoscaler.updaters))
|
||||
if autoscaler.num_failed_updates:
|
||||
suffix += " ({} failed to update)".format(
|
||||
len(autoscaler.num_failed_updates))
|
||||
|
||||
return "{} nodes{}".format(len(nodes), suffix)
|
||||
@@ -1,16 +1,26 @@
|
||||
from collections import namedtuple
|
||||
from functools import reduce
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
import ray._private.services as services
|
||||
from ray.autoscaler._private.constants import MEMORY_RESOURCE_UNIT_BYTES
|
||||
from ray.autoscaler._private.constants import MEMORY_RESOURCE_UNIT_BYTES,\
|
||||
AUTOSCALER_MAX_RESOURCE_DEMAND_VECTOR_SIZE
|
||||
from ray.autoscaler._private.util import add_resources, freq_of_dicts
|
||||
from ray.gcs_utils import PlacementGroupTableData
|
||||
from ray.autoscaler._private.resource_demand_scheduler import \
|
||||
NodeIP, ResourceDict
|
||||
from ray.core.generated.common_pb2 import PlacementStrategy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
LoadMetricsSummary = namedtuple("LoadMetricsSummary", [
|
||||
"head_ip", "usage", "resource_demand", "pg_demand", "request_demand",
|
||||
"node_types"
|
||||
])
|
||||
|
||||
|
||||
class LoadMetrics:
|
||||
"""Container for cluster load metrics.
|
||||
@@ -31,6 +41,7 @@ class LoadMetrics:
|
||||
self.waiting_bundles = []
|
||||
self.infeasible_bundles = []
|
||||
self.pending_placement_groups = []
|
||||
self.resource_requests = []
|
||||
|
||||
def update(self,
|
||||
ip: str,
|
||||
@@ -72,34 +83,37 @@ class LoadMetrics:
|
||||
|
||||
def mark_active(self, ip):
|
||||
assert ip is not None, "IP should be known at this time"
|
||||
logger.info("Node {} is newly setup, treating as active".format(ip))
|
||||
logger.debug("Node {} is newly setup, treating as active".format(ip))
|
||||
self.last_heartbeat_time_by_ip[ip] = time.time()
|
||||
|
||||
def is_active(self, ip):
|
||||
return ip in self.last_heartbeat_time_by_ip
|
||||
|
||||
def prune_active_ips(self, active_ips):
|
||||
active_ips = set(active_ips)
|
||||
active_ips.add(self.local_ip)
|
||||
|
||||
def prune(mapping):
|
||||
def prune(mapping, should_log):
|
||||
unwanted = set(mapping) - active_ips
|
||||
for unwanted_key in unwanted:
|
||||
# TODO (Alex): Change this back to info after #12138.
|
||||
logger.debug("LoadMetrics: "
|
||||
"Removed mapping: {} - {}".format(
|
||||
unwanted_key, mapping[unwanted_key]))
|
||||
if should_log:
|
||||
logger.info("LoadMetrics: "
|
||||
"Removed mapping: {} - {}".format(
|
||||
unwanted_key, mapping[unwanted_key]))
|
||||
del mapping[unwanted_key]
|
||||
if unwanted:
|
||||
if unwanted and should_log:
|
||||
# TODO (Alex): Change this back to info after #12138.
|
||||
logger.debug(
|
||||
logger.info(
|
||||
"LoadMetrics: "
|
||||
"Removed {} stale ip mappings: {} not in {}".format(
|
||||
len(unwanted), unwanted, active_ips))
|
||||
assert not (unwanted & set(mapping))
|
||||
|
||||
prune(self.last_used_time_by_ip)
|
||||
prune(self.static_resources_by_ip)
|
||||
prune(self.dynamic_resources_by_ip)
|
||||
prune(self.resource_load_by_ip)
|
||||
prune(self.last_heartbeat_time_by_ip)
|
||||
prune(self.last_used_time_by_ip, should_log=True)
|
||||
prune(self.static_resources_by_ip, should_log=False)
|
||||
prune(self.dynamic_resources_by_ip, should_log=False)
|
||||
prune(self.resource_load_by_ip, should_log=False)
|
||||
prune(self.last_heartbeat_time_by_ip, should_log=False)
|
||||
|
||||
def get_node_resources(self):
|
||||
"""Return a list of node resources (static resource sizes).
|
||||
@@ -155,12 +169,82 @@ class LoadMetrics:
|
||||
|
||||
return resources_used, resources_total
|
||||
|
||||
def get_resource_demand_vector(self):
|
||||
return self.waiting_bundles + self.infeasible_bundles
|
||||
def get_resource_demand_vector(self, clip=True):
|
||||
if clip:
|
||||
# Bound the total number of bundles to
|
||||
# 2xMAX_RESOURCE_DEMAND_VECTOR_SIZE. This guarantees the resource
|
||||
# demand scheduler bin packing algorithm takes a reasonable amount
|
||||
# of time to run.
|
||||
return (
|
||||
self.
|
||||
waiting_bundles[:AUTOSCALER_MAX_RESOURCE_DEMAND_VECTOR_SIZE] +
|
||||
self.
|
||||
infeasible_bundles[:AUTOSCALER_MAX_RESOURCE_DEMAND_VECTOR_SIZE]
|
||||
)
|
||||
else:
|
||||
return self.waiting_bundles + self.infeasible_bundles
|
||||
|
||||
def get_resource_requests(self):
|
||||
return self.resource_requests
|
||||
|
||||
def get_pending_placement_groups(self):
|
||||
return self.pending_placement_groups
|
||||
|
||||
def summary(self):
|
||||
available_resources = reduce(add_resources,
|
||||
self.dynamic_resources_by_ip.values()
|
||||
) if self.dynamic_resources_by_ip else {}
|
||||
total_resources = reduce(add_resources,
|
||||
self.static_resources_by_ip.values()
|
||||
) if self.static_resources_by_ip else {}
|
||||
usage_dict = {}
|
||||
for key in total_resources:
|
||||
total = total_resources[key]
|
||||
usage_dict[key] = (total - available_resources[key], total)
|
||||
|
||||
summarized_demand_vector = freq_of_dicts(
|
||||
self.get_resource_demand_vector(clip=False))
|
||||
summarized_resource_requests = freq_of_dicts(
|
||||
self.get_resource_requests())
|
||||
|
||||
def placement_group_serializer(pg):
|
||||
bundles = tuple(
|
||||
frozenset(bundle.unit_resources.items())
|
||||
for bundle in pg.bundles)
|
||||
return (bundles, pg.strategy)
|
||||
|
||||
def placement_group_deserializer(pg_tuple):
|
||||
# We marshal this as a dictionary so that we can easily json.dumps
|
||||
# it later.
|
||||
# TODO (Alex): Would there be a benefit to properly
|
||||
# marshalling this (into a protobuf)?
|
||||
bundles = list(map(dict, pg_tuple[0]))
|
||||
return {
|
||||
"bundles": freq_of_dicts(bundles),
|
||||
"strategy": PlacementStrategy.Name(pg_tuple[1])
|
||||
}
|
||||
|
||||
summarized_placement_groups = freq_of_dicts(
|
||||
self.get_pending_placement_groups(),
|
||||
serializer=placement_group_serializer,
|
||||
deserializer=placement_group_deserializer)
|
||||
nodes_summary = freq_of_dicts(self.static_resources_by_ip.values())
|
||||
|
||||
return LoadMetricsSummary(
|
||||
head_ip=self.local_ip,
|
||||
usage=usage_dict,
|
||||
resource_demand=summarized_demand_vector,
|
||||
pg_demand=summarized_placement_groups,
|
||||
request_demand=summarized_resource_requests,
|
||||
node_types=nodes_summary)
|
||||
|
||||
def set_resource_requests(self, requested_resources):
|
||||
if requested_resources is not None:
|
||||
assert isinstance(requested_resources, list), requested_resources
|
||||
self.resource_requests = [
|
||||
request for request in requested_resources if len(request) > 0
|
||||
]
|
||||
|
||||
def info_string(self):
|
||||
return " - " + "\n - ".join(
|
||||
["{}: {}".format(k, v) for k, v in sorted(self._info().items())])
|
||||
|
||||
@@ -47,16 +47,19 @@ class ResourceDemandScheduler:
|
||||
provider: NodeProvider,
|
||||
node_types: Dict[NodeType, NodeTypeConfigDict],
|
||||
max_workers: int,
|
||||
head_node_type: NodeType,
|
||||
upscaling_speed: float = 1) -> None:
|
||||
self.provider = provider
|
||||
self.node_types = copy.deepcopy(node_types)
|
||||
self.max_workers = max_workers
|
||||
self.head_node_type = head_node_type
|
||||
self.upscaling_speed = upscaling_speed
|
||||
|
||||
def reset_config(self,
|
||||
provider: NodeProvider,
|
||||
node_types: Dict[NodeType, NodeTypeConfigDict],
|
||||
max_workers: int,
|
||||
head_node_type: NodeType,
|
||||
upscaling_speed: float = 1) -> None:
|
||||
"""Updates the class state variables.
|
||||
|
||||
@@ -89,6 +92,7 @@ class ResourceDemandScheduler:
|
||||
self.provider = provider
|
||||
self.node_types = copy.deepcopy(final_node_types)
|
||||
self.max_workers = max_workers
|
||||
self.head_node_type = head_node_type
|
||||
self.upscaling_speed = upscaling_speed
|
||||
|
||||
def is_legacy_yaml(self,
|
||||
@@ -145,18 +149,18 @@ class ResourceDemandScheduler:
|
||||
node_resources, node_type_counts = self.calculate_node_resources(
|
||||
nodes, launching_nodes, unused_resources_by_ip)
|
||||
|
||||
logger.info("Cluster resources: {}".format(node_resources))
|
||||
logger.info("Node counts: {}".format(node_type_counts))
|
||||
logger.debug("Cluster resources: {}".format(node_resources))
|
||||
logger.debug("Node counts: {}".format(node_type_counts))
|
||||
# Step 2: add nodes to add to satisfy min_workers for each type
|
||||
(node_resources,
|
||||
node_type_counts,
|
||||
adjusted_min_workers) = \
|
||||
_add_min_workers_nodes(
|
||||
node_resources, node_type_counts, self.node_types,
|
||||
self.max_workers, ensure_min_cluster_size)
|
||||
self.max_workers, self.head_node_type, ensure_min_cluster_size)
|
||||
|
||||
# Step 3: add nodes for strict spread groups
|
||||
logger.info(f"Placement group demands: {pending_placement_groups}")
|
||||
logger.debug(f"Placement group demands: {pending_placement_groups}")
|
||||
placement_group_demand_vector, strict_spreads = \
|
||||
placement_groups_to_resource_demands(pending_placement_groups)
|
||||
resource_demands.extend(placement_group_demand_vector)
|
||||
@@ -183,12 +187,13 @@ class ResourceDemandScheduler:
|
||||
# groups
|
||||
unfulfilled, _ = get_bin_pack_residual(node_resources,
|
||||
resource_demands)
|
||||
logger.info("Resource demands: {}".format(resource_demands))
|
||||
logger.info("Unfulfilled demands: {}".format(unfulfilled))
|
||||
logger.debug("Resource demands: {}".format(resource_demands))
|
||||
logger.debug("Unfulfilled demands: {}".format(unfulfilled))
|
||||
# Add 1 to account for the head node.
|
||||
max_to_add = self.max_workers + 1 - sum(node_type_counts.values())
|
||||
nodes_to_add_based_on_demand = get_nodes_for(
|
||||
self.node_types, node_type_counts, max_to_add, unfulfilled)
|
||||
self.node_types, node_type_counts, self.head_node_type, max_to_add,
|
||||
unfulfilled)
|
||||
# Merge nodes to add based on demand and nodes to add based on
|
||||
# min_workers constraint. We add them because nodes to add based on
|
||||
# demand was calculated after the min_workers constraint was respected.
|
||||
@@ -206,7 +211,7 @@ class ResourceDemandScheduler:
|
||||
total_nodes_to_add, unused_resources_by_ip.keys(), nodes,
|
||||
launching_nodes, adjusted_min_workers)
|
||||
|
||||
logger.info("Node requests: {}".format(total_nodes_to_add))
|
||||
logger.debug("Node requests: {}".format(total_nodes_to_add))
|
||||
return total_nodes_to_add
|
||||
|
||||
def _legacy_worker_node_to_launch(
|
||||
@@ -443,6 +448,7 @@ class ResourceDemandScheduler:
|
||||
to_launch = get_nodes_for(
|
||||
self.node_types,
|
||||
node_type_counts,
|
||||
self.head_node_type,
|
||||
max_to_add,
|
||||
unfulfilled,
|
||||
strict_spread=True)
|
||||
@@ -490,7 +496,7 @@ def _add_min_workers_nodes(
|
||||
node_resources: List[ResourceDict],
|
||||
node_type_counts: Dict[NodeType, int],
|
||||
node_types: Dict[NodeType, NodeTypeConfigDict], max_workers: int,
|
||||
ensure_min_cluster_size: List[ResourceDict]
|
||||
head_node_type: NodeType, ensure_min_cluster_size: List[ResourceDict]
|
||||
) -> (List[ResourceDict], Dict[NodeType, int], Dict[NodeType, int]):
|
||||
"""Updates resource demands to respect the min_workers and
|
||||
request_resources() constraints.
|
||||
@@ -515,6 +521,9 @@ def _add_min_workers_nodes(
|
||||
existing = node_type_counts.get(node_type, 0)
|
||||
target = min(
|
||||
config.get("min_workers", 0), config.get("max_workers", 0))
|
||||
if node_type == head_node_type:
|
||||
# Add 1 to account for head node.
|
||||
target = target + 1
|
||||
if existing < target:
|
||||
total_nodes_to_add_dict[node_type] = target - existing
|
||||
node_type_counts[node_type] = target
|
||||
@@ -537,7 +546,7 @@ def _add_min_workers_nodes(
|
||||
max_node_resources, ensure_min_cluster_size)
|
||||
# Get the nodes to meet the unfulfilled.
|
||||
nodes_to_add_request_resources = get_nodes_for(
|
||||
node_types, node_type_counts, max_to_add,
|
||||
node_types, node_type_counts, head_node_type, max_to_add,
|
||||
resource_requests_unfulfilled)
|
||||
# Update the resources, counts and total nodes to add.
|
||||
for node_type in nodes_to_add_request_resources:
|
||||
@@ -558,6 +567,7 @@ def _add_min_workers_nodes(
|
||||
|
||||
def get_nodes_for(node_types: Dict[NodeType, NodeTypeConfigDict],
|
||||
existing_nodes: Dict[NodeType, int],
|
||||
head_node_type: NodeType,
|
||||
max_to_add: int,
|
||||
resources: List[ResourceDict],
|
||||
strict_spread: bool = False) -> Dict[NodeType, int]:
|
||||
@@ -581,9 +591,13 @@ def get_nodes_for(node_types: Dict[NodeType, NodeTypeConfigDict],
|
||||
while resources and sum(nodes_to_add.values()) < max_to_add:
|
||||
utilization_scores = []
|
||||
for node_type in node_types:
|
||||
max_workers_of_node_type = node_types[node_type].get(
|
||||
"max_workers", 0)
|
||||
if head_node_type == node_type:
|
||||
# Add 1 to account for head node.
|
||||
max_workers_of_node_type = max_workers_of_node_type + 1
|
||||
if (existing_nodes.get(node_type, 0) + nodes_to_add.get(
|
||||
node_type, 0) >= node_types[node_type].get(
|
||||
"max_workers", 0)):
|
||||
node_type, 0) >= max_workers_of_node_type):
|
||||
continue
|
||||
node_resources = node_types[node_type]["resources"]
|
||||
if strict_spread:
|
||||
@@ -601,8 +615,14 @@ def get_nodes_for(node_types: Dict[NodeType, NodeTypeConfigDict],
|
||||
# starts up because placement groups are scheduled via custom
|
||||
# resources. This will behave properly with the current utilization
|
||||
# score heuristic, but it's a little dangerous and misleading.
|
||||
logger.info(
|
||||
"No feasible node type to add for {}".format(resources))
|
||||
logger.warning(
|
||||
f"The autoscaler could not find a node type to satisfy the"
|
||||
f"request: {resources}. If this request is related to "
|
||||
f"placement groups the resource request will resolve itself, "
|
||||
f"otherwise please specify a node type with the necessary "
|
||||
f"resource "
|
||||
f"https://docs.ray.io/en/master/cluster/autoscaling.html#multiple-node-type-autoscaling." # noqa: E501
|
||||
)
|
||||
break
|
||||
|
||||
utilization_scores = sorted(utilization_scores, reverse=True)
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
import collections
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import hashlib
|
||||
import json
|
||||
import jsonschema
|
||||
import os
|
||||
import threading
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import ray
|
||||
import ray.ray_constants
|
||||
import ray._private.services as services
|
||||
from ray.autoscaler._private.providers import _get_default_config
|
||||
from ray.autoscaler._private.docker import validate_docker_config
|
||||
@@ -20,6 +22,7 @@ RAY_SCHEMA_PATH = os.path.join(
|
||||
# Internal kv keys for storing debug status.
|
||||
DEBUG_AUTOSCALING_ERROR = "__autoscaling_error"
|
||||
DEBUG_AUTOSCALING_STATUS = "__autoscaling_status"
|
||||
DEBUG_AUTOSCALING_STATUS_LEGACY = "__autoscaling_status_legacy"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -246,6 +249,47 @@ def hash_runtime_conf(file_mounts,
|
||||
return (_hash_cache[conf_str], file_mounts_contents_hash)
|
||||
|
||||
|
||||
def add_resources(dict1: Dict[str, float],
|
||||
dict2: Dict[str, float]) -> Dict[str, float]:
|
||||
"""Add the values in two dictionaries.
|
||||
|
||||
Returns:
|
||||
dict: A new dictionary (inputs remain unmodified).
|
||||
"""
|
||||
new_dict = dict1.copy()
|
||||
for k, v in dict2.items():
|
||||
new_dict[k] = v + new_dict.get(k, 0)
|
||||
return new_dict
|
||||
|
||||
|
||||
def freq_of_dicts(dicts: List[Dict],
|
||||
serializer=lambda d: frozenset(d.items()),
|
||||
deserializer=dict):
|
||||
"""Count a list of dictionaries (or unhashable types).
|
||||
|
||||
This is somewhat annoying because mutable data structures aren't hashable,
|
||||
and set/dict keys must be hashable.
|
||||
|
||||
Args:
|
||||
dicts (List[D]): A list of dictionaries to be counted.
|
||||
serializer (D -> S): A custom serailization function. The output type S
|
||||
must be hashable. The default serializer converts a dictionary into
|
||||
a frozenset of KV pairs.
|
||||
deserializer (S -> U): A custom deserialization function. See the
|
||||
serializer for information about type S. For dictionaries U := D.
|
||||
|
||||
Returns:
|
||||
List[Tuple[U, int]]: Returns a list of tuples. Each entry in the list
|
||||
is a tuple containing a unique entry from `dicts` and its
|
||||
corresponding frequency count.
|
||||
"""
|
||||
freqs = collections.Counter(map(lambda d: serializer(d), dicts))
|
||||
as_list = []
|
||||
for as_set, count in freqs.items():
|
||||
as_list.append((deserializer(as_set), count))
|
||||
return as_list
|
||||
|
||||
|
||||
def add_prefix(info_string, prefix):
|
||||
"""Prefixes each line of info_string, except the first, by prefix."""
|
||||
lines = info_string.split("\n")
|
||||
@@ -255,3 +299,132 @@ def add_prefix(info_string, prefix):
|
||||
prefixed_lines.append(prefixed_line)
|
||||
prefixed_info_string = "\n".join(prefixed_lines)
|
||||
return prefixed_info_string
|
||||
|
||||
|
||||
def format_pg(pg):
|
||||
strategy = pg["strategy"]
|
||||
bundles = pg["bundles"]
|
||||
shape_strs = []
|
||||
for bundle, count in bundles:
|
||||
shape_strs.append(f"{bundle} * {count}")
|
||||
bundles_str = ", ".join(shape_strs)
|
||||
return f"{bundles_str} ({strategy})"
|
||||
|
||||
|
||||
def get_usage_report(lm_summary):
|
||||
usage_lines = []
|
||||
for resource, (used, total) in lm_summary.usage.items():
|
||||
line = f" {used}/{total} {resource}"
|
||||
if resource in ["memory", "object_store_memory"]:
|
||||
to_GiB = ray.ray_constants.MEMORY_RESOURCE_UNIT_BYTES / 2**30
|
||||
used *= to_GiB
|
||||
total *= to_GiB
|
||||
line = f" {used:.2f}/{total:.3f} GiB {resource}"
|
||||
usage_lines.append(line)
|
||||
usage_report = "\n".join(usage_lines)
|
||||
return usage_report
|
||||
|
||||
|
||||
def get_demand_report(lm_summary):
|
||||
demand_lines = []
|
||||
for bundle, count in lm_summary.resource_demand:
|
||||
line = f" {bundle}: {count}+ pending tasks/actors"
|
||||
demand_lines.append(line)
|
||||
for entry in lm_summary.pg_demand:
|
||||
pg, count = entry
|
||||
pg_str = format_pg(pg)
|
||||
line = f" {pg_str}: {count}+ pending placement groups"
|
||||
demand_lines.append(line)
|
||||
for bundle, count in lm_summary.request_demand:
|
||||
line = f" {bundle}: {count}+ from request_resources()"
|
||||
demand_lines.append(line)
|
||||
if len(demand_lines) > 0:
|
||||
demand_report = "\n".join(demand_lines)
|
||||
else:
|
||||
demand_report = " (no resource demands)"
|
||||
return demand_report
|
||||
|
||||
|
||||
def format_info_string(lm_summary, autoscaler_summary, time=None):
|
||||
if time is None:
|
||||
time = datetime.now()
|
||||
header = "=" * 8 + f" Autoscaler status: {time} " + "=" * 8
|
||||
separator = "-" * len(header)
|
||||
available_node_report_lines = []
|
||||
for node_type, count in autoscaler_summary.active_nodes.items():
|
||||
line = f" {count} {node_type}"
|
||||
available_node_report_lines.append(line)
|
||||
available_node_report = "\n".join(available_node_report_lines)
|
||||
|
||||
pending_lines = []
|
||||
for node_type, count in autoscaler_summary.pending_launches.items():
|
||||
line = f" {node_type}, {count} launching"
|
||||
pending_lines.append(line)
|
||||
for ip, node_type in autoscaler_summary.pending_nodes:
|
||||
line = f" {ip}: {node_type}, setting up"
|
||||
pending_lines.append(line)
|
||||
if pending_lines:
|
||||
pending_report = "\n".join(pending_lines)
|
||||
else:
|
||||
pending_report = " (no pending nodes)"
|
||||
|
||||
failure_lines = []
|
||||
for ip, node_type in autoscaler_summary.failed_nodes:
|
||||
line = f" {ip}: {node_type}"
|
||||
failure_report = "Recent failures:\n"
|
||||
if failure_lines:
|
||||
failure_report += "\n".join(failure_lines)
|
||||
else:
|
||||
failure_report += " (no failures)"
|
||||
|
||||
usage_report = get_usage_report(lm_summary)
|
||||
demand_report = get_demand_report(lm_summary)
|
||||
|
||||
formatted_output = f"""{header}
|
||||
Node status
|
||||
{separator}
|
||||
Healthy:
|
||||
{available_node_report}
|
||||
Pending:
|
||||
{pending_report}
|
||||
{failure_report}
|
||||
|
||||
Resources
|
||||
{separator}
|
||||
|
||||
Usage:
|
||||
{usage_report}
|
||||
|
||||
Demands:
|
||||
{demand_report}"""
|
||||
return formatted_output
|
||||
|
||||
|
||||
def format_info_string_no_node_types(lm_summary, time=None):
|
||||
if time is None:
|
||||
time = datetime.now()
|
||||
header = "=" * 8 + f" Cluster status: {time} " + "=" * 8
|
||||
separator = "-" * len(header)
|
||||
|
||||
node_lines = []
|
||||
for node_type, count in lm_summary.node_types:
|
||||
line = f" {count} node(s) with resources: {node_type}"
|
||||
node_lines.append(line)
|
||||
node_report = "\n".join(node_lines)
|
||||
|
||||
usage_report = get_usage_report(lm_summary)
|
||||
demand_report = get_demand_report(lm_summary)
|
||||
|
||||
formatted_output = f"""{header}
|
||||
Node status
|
||||
{separator}
|
||||
{node_report}
|
||||
|
||||
Resources
|
||||
{separator}
|
||||
Usage:
|
||||
{usage_report}
|
||||
|
||||
Demands:
|
||||
{demand_report}"""
|
||||
return formatted_output
|
||||
|
||||
@@ -250,9 +250,11 @@ worker_nodes:
|
||||
# Files or directories to copy to the head and worker nodes. The format is a
|
||||
# dictionary from REMOTE_PATH: LOCAL_PATH, e.g.
|
||||
file_mounts: {
|
||||
# "/path1/on/remote/machine": "/path1/on/local/machine",
|
||||
# "/path2/on/remote/machine": "/path2/on/local/machine",
|
||||
# "~/path1/on/remote/machine": "/path1/on/local/machine",
|
||||
# "~/path2/on/remote/machine": "/path2/on/local/machine",
|
||||
}
|
||||
# Note that the container images in this example have a non-root user.
|
||||
# To avoid permissions issues, we recommend mounting into a subdirectory of home (~).
|
||||
|
||||
# Files or directories to copy from the head node to the worker nodes. The format is a
|
||||
# list of paths. The same path on the head node will be copied to the worker node.
|
||||
|
||||
@@ -250,9 +250,11 @@ worker_nodes:
|
||||
# Files or directories to copy to the head and worker nodes. The format is a
|
||||
# dictionary from REMOTE_PATH: LOCAL_PATH, e.g.
|
||||
file_mounts: {
|
||||
# "/path1/on/remote/machine": "/path1/on/local/machine",
|
||||
# "/path2/on/remote/machine": "/path2/on/local/machine",
|
||||
# "~/path1/on/remote/machine": "/path1/on/local/machine",
|
||||
# "~/path2/on/remote/machine": "/path2/on/local/machine",
|
||||
}
|
||||
# Note that the container images in this example have a non-root user.
|
||||
# To avoid permissions issues, we recommend mounting into a subdirectory of home (~).
|
||||
|
||||
# Files or directories to copy from the head node to the worker nodes. The format is a
|
||||
# list of paths. The same path on the head node will be copied to the worker node.
|
||||
|
||||
@@ -286,9 +286,11 @@ worker_nodes:
|
||||
# Files or directories to copy to the head and worker nodes. The format is a
|
||||
# dictionary from REMOTE_PATH: LOCAL_PATH, e.g.
|
||||
file_mounts: {
|
||||
# "/path1/on/remote/machine": "/path1/on/local/machine",
|
||||
# "/path2/on/remote/machine": "/path2/on/local/machine",
|
||||
# "~/path1/on/remote/machine": "/path1/on/local/machine",
|
||||
# "~/path2/on/remote/machine": "/path2/on/local/machine",
|
||||
}
|
||||
# Note that the container images in this example have a non-root user.
|
||||
# To avoid permissions issues, we recommend mounting into a subdirectory of home (~).
|
||||
|
||||
# List of commands that will be run before `setup_commands`. If docker is
|
||||
# enabled, these commands will run outside the container and before docker
|
||||
|
||||
@@ -142,7 +142,8 @@ class WorkerCrashedError(RayError):
|
||||
"""Indicates that the worker died unexpectedly while executing a task."""
|
||||
|
||||
def __str__(self):
|
||||
return "The worker died unexpectedly while executing this task."
|
||||
return ("The worker died unexpectedly while executing this task. "
|
||||
"Check python-core-worker-*.log files for more information.")
|
||||
|
||||
|
||||
class RayActorError(RayError):
|
||||
@@ -153,7 +154,8 @@ class RayActorError(RayError):
|
||||
"""
|
||||
|
||||
def __str__(self):
|
||||
return "The actor died unexpectedly before finishing this task."
|
||||
return ("The actor died unexpectedly before finishing this task. "
|
||||
"Check python-core-worker-*.log files for more information.")
|
||||
|
||||
|
||||
class RaySystemError(RayError):
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
from .dynamic_resources import set_resource
|
||||
from .object_spilling import force_spill_objects
|
||||
__all__ = [
|
||||
"set_resource",
|
||||
"force_spill_objects",
|
||||
]
|
||||
|
||||
@@ -1,117 +1,100 @@
|
||||
from ray.experimental.client.api import ClientAPI
|
||||
from ray.experimental.client.api import APIImpl
|
||||
from typing import Optional, List, Tuple
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Tuple
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# About these global variables: Ray 1.0 uses exported module functions to
|
||||
# provide its API, and we need to match that. However, we want different
|
||||
# behaviors depending on where, exactly, in the client stack this is running.
|
||||
#
|
||||
# The reason for these differences depends on what's being pickled and passed
|
||||
# to functions, or functions inside functions. So there are three cases to care
|
||||
# about
|
||||
#
|
||||
# (Python Client)-->(Python ClientServer)-->(Internal Raylet Process)
|
||||
#
|
||||
# * _client_api should be set if we're inside the client
|
||||
# * _server_api should be set if we're inside the clientserver
|
||||
# * Both will be set if we're running both (as in a test)
|
||||
# * Neither should be set if we're inside the raylet (but we still need to shim
|
||||
# from the client API surface to the Ray API)
|
||||
#
|
||||
# The job of RayAPIStub (below) delegates to the appropriate one of these
|
||||
# depending on what's set or not. Then, all users importing the ray object
|
||||
# from this package get the stub which routes them to the appropriate APIImpl.
|
||||
_client_api: Optional[APIImpl] = None
|
||||
_server_api: Optional[APIImpl] = None
|
||||
|
||||
# The reason for _is_server is a hack around the above comment while running
|
||||
# tests. If we have both a client and a server trying to control these static
|
||||
# variables then we need a way to decide which to use. In this case, both
|
||||
# _client_api and _server_api are set.
|
||||
# This boolean flips between the two
|
||||
_is_server: bool = False
|
||||
|
||||
|
||||
@contextmanager
|
||||
def stash_api_for_tests(in_test: bool):
|
||||
global _is_server
|
||||
is_server = _is_server
|
||||
if in_test:
|
||||
_is_server = True
|
||||
yield _server_api
|
||||
if in_test:
|
||||
_is_server = is_server
|
||||
|
||||
|
||||
def _set_client_api(val: Optional[APIImpl]):
|
||||
global _client_api
|
||||
global _is_server
|
||||
if _client_api is not None:
|
||||
raise Exception("Trying to set more than one client API")
|
||||
_client_api = val
|
||||
_is_server = False
|
||||
|
||||
|
||||
def _set_server_api(val: Optional[APIImpl]):
|
||||
global _server_api
|
||||
global _is_server
|
||||
if _server_api is not None:
|
||||
raise Exception("Trying to set more than one server API")
|
||||
_server_api = val
|
||||
_is_server = True
|
||||
|
||||
|
||||
def reset_api():
|
||||
global _client_api
|
||||
global _server_api
|
||||
global _is_server
|
||||
_client_api = None
|
||||
_server_api = None
|
||||
_is_server = False
|
||||
|
||||
|
||||
def _get_client_api() -> APIImpl:
|
||||
global _client_api
|
||||
global _server_api
|
||||
global _is_server
|
||||
api = None
|
||||
if _is_server:
|
||||
api = _server_api
|
||||
else:
|
||||
api = _client_api
|
||||
if api is None:
|
||||
# We're inside a raylet worker
|
||||
from ray.experimental.client.server.core_ray_api import CoreRayAPI
|
||||
return CoreRayAPI()
|
||||
return api
|
||||
|
||||
|
||||
class RayAPIStub:
|
||||
"""This class stands in as the replacement API for the `import ray` module.
|
||||
|
||||
Much like the ray module, this mostly delegates the work to the
|
||||
_client_worker. As parts of the ray API are covered, they are piped through
|
||||
here or on the client worker API.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
from ray.experimental.client.api import ClientAPI
|
||||
self.api = ClientAPI()
|
||||
self.client_worker = None
|
||||
self._server = None
|
||||
self._connected_with_init = False
|
||||
self._inside_client_test = False
|
||||
|
||||
def connect(self,
|
||||
conn_str: str,
|
||||
secure: bool = False,
|
||||
metadata: List[Tuple[str, str]] = None,
|
||||
stub=None):
|
||||
metadata: List[Tuple[str, str]] = None) -> None:
|
||||
"""Connect the Ray Client to a server.
|
||||
|
||||
Args:
|
||||
conn_str: Connection string, in the form "[host]:port"
|
||||
secure: Whether to use a TLS secured gRPC channel
|
||||
metadata: gRPC metadata to send on connect
|
||||
"""
|
||||
# Delay imports until connect to avoid circular imports.
|
||||
from ray.experimental.client.worker import Worker
|
||||
_client_worker = Worker(
|
||||
conn_str, secure=secure, metadata=metadata, stub=stub)
|
||||
_set_client_api(ClientAPI(_client_worker))
|
||||
import ray._private.client_mode_hook
|
||||
if self.client_worker is not None:
|
||||
if self._connected_with_init:
|
||||
return
|
||||
raise Exception(
|
||||
"ray.connect() called, but ray client is already connected")
|
||||
if not self._inside_client_test:
|
||||
# If we're calling a client connect specifically and we're not
|
||||
# currently in client mode, ensure we are.
|
||||
ray._private.client_mode_hook._explicitly_enable_client_mode()
|
||||
self.client_worker = Worker(conn_str, secure=secure, metadata=metadata)
|
||||
self.api.worker = self.client_worker
|
||||
|
||||
def disconnect(self):
|
||||
global _client_api
|
||||
if _client_api is not None:
|
||||
_client_api.close()
|
||||
_client_api = None
|
||||
"""Disconnect the Ray Client.
|
||||
"""
|
||||
if self.client_worker is not None:
|
||||
self.client_worker.close()
|
||||
self.client_worker = None
|
||||
|
||||
# remote can be called outside of a connection, which is why it
|
||||
# exists on the same API layer as connect() itself.
|
||||
def remote(self, *args, **kwargs):
|
||||
"""remote is the hook stub passed on to replace `ray.remote`.
|
||||
|
||||
This sets up remote functions or actors, as the decorator,
|
||||
but does not execute them.
|
||||
|
||||
Args:
|
||||
args: opaque arguments
|
||||
kwargs: opaque keyword arguments
|
||||
"""
|
||||
return self.api.remote(*args, **kwargs)
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
global _get_client_api
|
||||
api = _get_client_api()
|
||||
return getattr(api, key)
|
||||
if not self.is_connected():
|
||||
raise Exception("Ray Client is not connected. "
|
||||
"Please connect by calling `ray.connect`.")
|
||||
return getattr(self.api, key)
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
return self.api is not None
|
||||
|
||||
def init(self, *args, **kwargs):
|
||||
if self._server is not None:
|
||||
raise Exception("Trying to start two instances of ray via client")
|
||||
import ray.experimental.client.server.server as ray_client_server
|
||||
self._server, address_info = ray_client_server.init_and_serve(
|
||||
"localhost:50051", *args, **kwargs)
|
||||
self.connect("localhost:50051")
|
||||
self._connected_with_init = True
|
||||
return address_info
|
||||
|
||||
def shutdown(self, _exiting_interpreter=False):
|
||||
self.disconnect()
|
||||
import ray.experimental.client.server.server as ray_client_server
|
||||
if self._server is None:
|
||||
return
|
||||
ray_client_server.shutdown_with_server(self._server,
|
||||
_exiting_interpreter)
|
||||
self._server = None
|
||||
|
||||
|
||||
ray = RayAPIStub()
|
||||
|
||||
@@ -1,74 +1,51 @@
|
||||
# This file defines an interface and client-side API stub
|
||||
# for referring either to the core Ray API or the same interface
|
||||
# from the Ray client.
|
||||
#
|
||||
# In tandem with __init__.py, we want to expose an API that's
|
||||
# close to `python/ray/__init__.py` but with more than one implementation.
|
||||
# The stubs in __init__ should call into a well-defined interface.
|
||||
# Only the core Ray API implementation should actually `import ray`
|
||||
# (and thus import all the raylet worker C bindings and such).
|
||||
# But to make sure that we're matching these calls, we define this API.
|
||||
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Union, Optional
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
"""This file defines the interface between the ray client worker
|
||||
and the overall ray module API.
|
||||
"""
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from ray.experimental.client.common import ClientActorHandle
|
||||
from ray.experimental.client.common import ClientStub
|
||||
from ray.experimental.client.common import ClientActorHandle
|
||||
from ray.experimental.client.common import ClientObjectRef
|
||||
from ray._raylet import ObjectRef
|
||||
|
||||
# Use the imports for type checking. This is a python 3.6 limitation.
|
||||
# See https://www.python.org/dev/peps/pep-0563/
|
||||
PutType = Union[ClientObjectRef, ObjectRef]
|
||||
|
||||
|
||||
class APIImpl(ABC):
|
||||
"""
|
||||
APIImpl is the interface to implement for whichever version of the core
|
||||
Ray API that needs abstracting when run in client mode.
|
||||
class ClientAPI:
|
||||
"""The Client-side methods corresponding to the ray API. Delegates
|
||||
to the Client Worker that contains the connection to the ClientServer.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, vals, *, timeout: Optional[float] = None) -> Any:
|
||||
"""
|
||||
get is the hook stub passed on to replace `ray.get`
|
||||
def __init__(self, worker=None):
|
||||
self.worker = worker
|
||||
|
||||
def get(self, vals, *, timeout=None):
|
||||
"""get is the hook stub passed on to replace `ray.get`
|
||||
|
||||
Args:
|
||||
vals: [Client]ObjectRef or list of these refs to retrieve.
|
||||
timeout: Optional timeout in milliseconds
|
||||
"""
|
||||
pass
|
||||
return self.worker.get(vals, timeout=timeout)
|
||||
|
||||
@abstractmethod
|
||||
def put(self, vals: Any, *args,
|
||||
**kwargs) -> Union["ClientObjectRef", "ObjectRef"]:
|
||||
"""
|
||||
put is the hook stub passed on to replace `ray.put`
|
||||
def put(self, *args, **kwargs):
|
||||
"""put is the hook stub passed on to replace `ray.put`
|
||||
|
||||
Args:
|
||||
vals: The value or list of values to `put`.
|
||||
args: opaque arguments
|
||||
kwargs: opaque keyword arguments
|
||||
"""
|
||||
pass
|
||||
return self.worker.put(*args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def wait(self, *args, **kwargs):
|
||||
"""
|
||||
wait is the hook stub passed on to replace `ray.wait`
|
||||
"""wait is the hook stub passed on to replace `ray.wait`
|
||||
|
||||
Args:
|
||||
args: opaque arguments
|
||||
kwargs: opaque keyword arguments
|
||||
"""
|
||||
pass
|
||||
return self.worker.wait(*args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def remote(self, *args, **kwargs):
|
||||
"""
|
||||
remote is the hook stub passed on to replace `ray.remote`.
|
||||
"""remote is the hook stub passed on to replace `ray.remote`.
|
||||
|
||||
This sets up remote functions or actors, as the decorator,
|
||||
but does not execute them.
|
||||
@@ -77,12 +54,24 @@ class APIImpl(ABC):
|
||||
args: opaque arguments
|
||||
kwargs: opaque keyword arguments
|
||||
"""
|
||||
pass
|
||||
# Delayed import to avoid a cyclic import
|
||||
from ray.experimental.client.common import remote_decorator
|
||||
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
|
||||
# This is the case where the decorator is just @ray.remote.
|
||||
return remote_decorator(options=None)(args[0])
|
||||
error_string = ("The @ray.remote decorator must be applied either "
|
||||
"with no arguments and no parentheses, for example "
|
||||
"'@ray.remote', or it must be applied using some of "
|
||||
"the arguments 'num_returns', 'num_cpus', 'num_gpus', "
|
||||
"'memory', 'object_store_memory', 'resources', "
|
||||
"'max_calls', or 'max_restarts', like "
|
||||
"'@ray.remote(num_returns=2, "
|
||||
"resources={\"CustomResource\": 1})'.")
|
||||
assert len(args) == 0 and len(kwargs) > 0, error_string
|
||||
return remote_decorator(options=kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def call_remote(self, instance: "ClientStub", *args, **kwargs):
|
||||
"""
|
||||
call_remote is called by stub objects to execute them remotely.
|
||||
"""call_remote is called by stub objects to execute them remotely.
|
||||
|
||||
This is used by stub objects in situations where they're called
|
||||
with .remote, eg, `f.remote()` or `actor_cls.remote()`.
|
||||
@@ -95,31 +84,57 @@ class APIImpl(ABC):
|
||||
args: opaque arguments
|
||||
kwargs: opaque keyword arguments
|
||||
"""
|
||||
pass
|
||||
return self.worker.call_remote(instance, *args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None:
|
||||
def call_release(self, id: bytes) -> None:
|
||||
"""Attempts to release an object reference.
|
||||
|
||||
When client references are destructed, they release their reference,
|
||||
which can opportunistically send a notification through the datachannel
|
||||
to release the reference being held for that object on the server.
|
||||
|
||||
Args:
|
||||
id: The id of the reference to release on the server side.
|
||||
"""
|
||||
close cleans up an API connection by closing any channels or
|
||||
return self.worker.call_release(id)
|
||||
|
||||
def call_retain(self, id: bytes) -> None:
|
||||
"""Attempts to retain a client object reference.
|
||||
|
||||
Increments the reference count on the client side, to prevent
|
||||
the client worker from attempting to release the server reference.
|
||||
|
||||
Args:
|
||||
id: The id of the reference to retain on the client side.
|
||||
"""
|
||||
return self.worker.call_retain(id)
|
||||
|
||||
def close(self) -> None:
|
||||
"""close cleans up an API connection by closing any channels or
|
||||
shutting down any servers gracefully.
|
||||
"""
|
||||
pass
|
||||
return self.worker.close()
|
||||
|
||||
@abstractmethod
|
||||
def kill(self, actor, *, no_restart=True):
|
||||
def get_actor(self, name: str) -> "ClientActorHandle":
|
||||
"""Returns a handle to an actor by name.
|
||||
|
||||
Args:
|
||||
name: The name passed to this actor by
|
||||
Actor.options(name="name").remote()
|
||||
"""
|
||||
kill forcibly stops an actor running in the cluster
|
||||
return self.worker.get_actor(name)
|
||||
|
||||
def kill(self, actor: "ClientActorHandle", *, no_restart=True):
|
||||
"""kill forcibly stops an actor running in the cluster
|
||||
|
||||
Args:
|
||||
no_restart: Whether this actor should be restarted if it's a
|
||||
restartable actor.
|
||||
"""
|
||||
pass
|
||||
return self.worker.terminate_actor(actor, no_restart)
|
||||
|
||||
@abstractmethod
|
||||
def cancel(self, obj, *, force=False, recursive=True):
|
||||
"""
|
||||
Cancels a task on the cluster.
|
||||
def cancel(self, obj: "ClientObjectRef", *, force=False, recursive=True):
|
||||
"""Cancels a task on the cluster.
|
||||
|
||||
If the specified task is pending execution, it will not be executed. If
|
||||
the task is currently executing, the behavior depends on the ``force``
|
||||
@@ -136,46 +151,11 @@ class APIImpl(ABC):
|
||||
recursive (boolean): Whether to try to cancel tasks submitted by
|
||||
the task specified.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ClientAPI(APIImpl):
|
||||
"""
|
||||
The Client-side methods corresponding to the ray API. Delegates
|
||||
to the Client Worker that contains the connection to the ClientServer.
|
||||
"""
|
||||
|
||||
def __init__(self, worker):
|
||||
self.worker = worker
|
||||
|
||||
def get(self, vals, *, timeout=None):
|
||||
return self.worker.get(vals, timeout=timeout)
|
||||
|
||||
def put(self, *args, **kwargs):
|
||||
return self.worker.put(*args, **kwargs)
|
||||
|
||||
def wait(self, *args, **kwargs):
|
||||
return self.worker.wait(*args, **kwargs)
|
||||
|
||||
def remote(self, *args, **kwargs):
|
||||
return self.worker.remote(*args, **kwargs)
|
||||
|
||||
def call_remote(self, instance: "ClientStub", *args, **kwargs):
|
||||
return self.worker.call_remote(instance, *args, **kwargs)
|
||||
|
||||
def close(self) -> None:
|
||||
return self.worker.close()
|
||||
|
||||
def kill(self, actor: "ClientActorHandle", *, no_restart=True):
|
||||
return self.worker.terminate_actor(actor, no_restart)
|
||||
|
||||
def cancel(self, obj: "ClientObjectRef", *, force=False, recursive=True):
|
||||
return self.worker.terminate_task(obj, force, recursive)
|
||||
|
||||
# Various metadata methods for the client that are defined in the protocol.
|
||||
def is_initialized(self) -> bool:
|
||||
""" True if our client is connected, and if the server is initialized.
|
||||
|
||||
"""True if our client is connected, and if the server is initialized.
|
||||
Returns:
|
||||
A boolean determining if the client is connected and
|
||||
server initialized.
|
||||
@@ -188,6 +168,8 @@ class ClientAPI(APIImpl):
|
||||
Returns:
|
||||
Information about the Ray clients in the cluster.
|
||||
"""
|
||||
# This should be imported here, otherwise, it will error doc build.
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
return self.worker.get_cluster_info(
|
||||
ray_client_pb2.ClusterInfoType.NODES)
|
||||
|
||||
@@ -201,6 +183,8 @@ class ClientAPI(APIImpl):
|
||||
A dictionary mapping resource name to the total quantity of that
|
||||
resource in the cluster.
|
||||
"""
|
||||
# This should be imported here, otherwise, it will error doc build.
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
return self.worker.get_cluster_info(
|
||||
ray_client_pb2.ClusterInfoType.CLUSTER_RESOURCES)
|
||||
|
||||
@@ -216,6 +200,8 @@ class ClientAPI(APIImpl):
|
||||
A dictionary mapping resource name to the total quantity of that
|
||||
resource in the cluster.
|
||||
"""
|
||||
# This should be imported here, otherwise, it will error doc build.
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
return self.worker.get_cluster_info(
|
||||
ray_client_pb2.ClusterInfoType.AVAILABLE_RESOURCES)
|
||||
|
||||
|
||||
@@ -0,0 +1,182 @@
|
||||
"""Implements the client side of the client/server pickling protocol.
|
||||
|
||||
All ray client client/server data transfer happens through this pickling
|
||||
protocol. The model is as follows:
|
||||
|
||||
* All Client objects (eg ClientObjectRef) always live on the client and
|
||||
are never represented in the server
|
||||
* All Ray objects (eg, ray.ObjectRef) always live on the server and are
|
||||
never returned to the client
|
||||
* In order to translate between these two references, PickleStub tuples
|
||||
are generated as persistent ids in the data blobs during the pickling
|
||||
and unpickling of these objects.
|
||||
|
||||
The PickleStubs have just enough information to find or generate their
|
||||
associated partner object on either side.
|
||||
|
||||
This also has the advantage of avoiding predefined pickle behavior for ray
|
||||
objects, which may include ray internal reference counting.
|
||||
|
||||
ClientPickler dumps things from the client into the appropriate stubs
|
||||
ServerUnpickler loads stubs from the server into their client counterparts.
|
||||
"""
|
||||
|
||||
import cloudpickle
|
||||
import io
|
||||
import sys
|
||||
|
||||
from typing import NamedTuple
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
|
||||
from ray.experimental.client import RayAPIStub
|
||||
from ray.experimental.client.common import ClientObjectRef
|
||||
from ray.experimental.client.common import ClientActorHandle
|
||||
from ray.experimental.client.common import ClientActorRef
|
||||
from ray.experimental.client.common import ClientActorClass
|
||||
from ray.experimental.client.common import ClientRemoteFunc
|
||||
from ray.experimental.client.common import ClientRemoteMethod
|
||||
from ray.experimental.client.common import OptionWrapper
|
||||
from ray.experimental.client.common import SelfReferenceSentinel
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
from ray._private.client_mode_hook import disable_client_hook
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
try:
|
||||
import pickle5 as pickle # noqa: F401
|
||||
except ImportError:
|
||||
import pickle # noqa: F401
|
||||
else:
|
||||
import pickle # noqa: F401
|
||||
|
||||
# NOTE(barakmich): These PickleStubs are really close to
|
||||
# the data for an exectuion, with no arguments. Combine the two?
|
||||
PickleStub = NamedTuple("PickleStub",
|
||||
[("type", str), ("client_id", str), ("ref_id", bytes),
|
||||
("name", Optional[str]),
|
||||
("baseline_options", Optional[Dict])])
|
||||
|
||||
|
||||
class ClientPickler(cloudpickle.CloudPickler):
|
||||
def __init__(self, client_id, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.client_id = client_id
|
||||
|
||||
def persistent_id(self, obj):
|
||||
if isinstance(obj, RayAPIStub):
|
||||
return PickleStub(
|
||||
type="Ray",
|
||||
client_id=self.client_id,
|
||||
ref_id=b"",
|
||||
name=None,
|
||||
baseline_options=None,
|
||||
)
|
||||
elif isinstance(obj, ClientObjectRef):
|
||||
return PickleStub(
|
||||
type="Object",
|
||||
client_id=self.client_id,
|
||||
ref_id=obj.id,
|
||||
name=None,
|
||||
baseline_options=None,
|
||||
)
|
||||
elif isinstance(obj, ClientActorHandle):
|
||||
return PickleStub(
|
||||
type="Actor",
|
||||
client_id=self.client_id,
|
||||
ref_id=obj._actor_id,
|
||||
name=None,
|
||||
baseline_options=None,
|
||||
)
|
||||
elif isinstance(obj, ClientRemoteFunc):
|
||||
# TODO(barakmich): This is going to have trouble with mutually
|
||||
# recursive functions that haven't, as yet, been executed. It's
|
||||
# relatively doable (keep track of intermediate refs in progress
|
||||
# with ensure_ref and return appropriately) But punting for now.
|
||||
if obj._ref is None:
|
||||
obj._ensure_ref()
|
||||
if type(obj._ref) == SelfReferenceSentinel:
|
||||
return PickleStub(
|
||||
type="RemoteFuncSelfReference",
|
||||
client_id=self.client_id,
|
||||
ref_id=b"",
|
||||
name=None,
|
||||
baseline_options=None,
|
||||
)
|
||||
return PickleStub(
|
||||
type="RemoteFunc",
|
||||
client_id=self.client_id,
|
||||
ref_id=obj._ref.id,
|
||||
name=None,
|
||||
baseline_options=obj._options,
|
||||
)
|
||||
elif isinstance(obj, ClientActorClass):
|
||||
# TODO(barakmich): Mutual recursion, as above.
|
||||
if obj._ref is None:
|
||||
obj._ensure_ref()
|
||||
if type(obj._ref) == SelfReferenceSentinel:
|
||||
return PickleStub(
|
||||
type="RemoteActorSelfReference",
|
||||
client_id=self.client_id,
|
||||
ref_id=b"",
|
||||
name=None,
|
||||
baseline_options=None,
|
||||
)
|
||||
return PickleStub(
|
||||
type="RemoteActor",
|
||||
client_id=self.client_id,
|
||||
ref_id=obj._ref.id,
|
||||
name=None,
|
||||
baseline_options=obj._options,
|
||||
)
|
||||
elif isinstance(obj, ClientRemoteMethod):
|
||||
return PickleStub(
|
||||
type="RemoteMethod",
|
||||
client_id=self.client_id,
|
||||
ref_id=obj.actor_handle.actor_ref.id,
|
||||
name=obj.method_name,
|
||||
baseline_options=None,
|
||||
)
|
||||
elif isinstance(obj, OptionWrapper):
|
||||
raise NotImplementedError(
|
||||
"Sending a partial option is unimplemented")
|
||||
return None
|
||||
|
||||
|
||||
class ServerUnpickler(pickle.Unpickler):
|
||||
def persistent_load(self, pid):
|
||||
assert isinstance(pid, PickleStub)
|
||||
if pid.type == "Object":
|
||||
return ClientObjectRef(id=pid.ref_id)
|
||||
elif pid.type == "Actor":
|
||||
return ClientActorHandle(ClientActorRef(id=pid.ref_id))
|
||||
else:
|
||||
raise NotImplementedError("Being passed back an unknown stub")
|
||||
|
||||
|
||||
def dumps_from_client(obj: Any, client_id: str, protocol=None) -> bytes:
|
||||
with disable_client_hook():
|
||||
with io.BytesIO() as file:
|
||||
cp = ClientPickler(client_id, file, protocol=protocol)
|
||||
cp.dump(obj)
|
||||
return file.getvalue()
|
||||
|
||||
|
||||
def loads_from_server(data: bytes,
|
||||
*,
|
||||
fix_imports=True,
|
||||
encoding="ASCII",
|
||||
errors="strict") -> Any:
|
||||
if isinstance(data, str):
|
||||
raise TypeError("Can't load pickle from unicode string")
|
||||
file = io.BytesIO(data)
|
||||
return ServerUnpickler(
|
||||
file, fix_imports=fix_imports, encoding=encoding,
|
||||
errors=errors).load()
|
||||
|
||||
|
||||
def convert_to_arg(val: Any, client_id: str) -> ray_client_pb2.Arg:
|
||||
out = ray_client_pb2.Arg()
|
||||
out.local = ray_client_pb2.Arg.Locality.INTERNED
|
||||
out.data = dumps_from_client(val, client_id)
|
||||
return out
|
||||
@@ -1,16 +1,31 @@
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
from ray.experimental.client import ray
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from ray import cloudpickle
|
||||
from ray.experimental.client.options import validate_options
|
||||
|
||||
import base64
|
||||
import inspect
|
||||
from ray.util.inspect import is_cython
|
||||
import json
|
||||
import threading
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
|
||||
class ClientBaseRef:
|
||||
def __init__(self, id, handle=None):
|
||||
self.id = id
|
||||
self.handle = handle
|
||||
def __init__(self, id: bytes):
|
||||
self.id = None
|
||||
if not isinstance(id, bytes):
|
||||
raise TypeError("ClientRefs must be created with bytes IDs")
|
||||
self.id: bytes = id
|
||||
ray.call_retain(id)
|
||||
|
||||
def binary(self):
|
||||
return self.id
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.id == other.id
|
||||
|
||||
def __repr__(self):
|
||||
return "%s(%s)" % (
|
||||
@@ -18,20 +33,16 @@ class ClientBaseRef:
|
||||
self.id.hex(),
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.id == other.id
|
||||
def __hash__(self):
|
||||
return hash(self.id)
|
||||
|
||||
def binary(self):
|
||||
return self.id
|
||||
|
||||
@classmethod
|
||||
def from_remote_ref(cls, ref: ray_client_pb2.RemoteRef):
|
||||
return cls(id=ref.id, handle=ref.handle)
|
||||
def __del__(self):
|
||||
if ray.is_connected() and self.id is not None:
|
||||
ray.call_release(self.id)
|
||||
|
||||
|
||||
class ClientObjectRef(ClientBaseRef):
|
||||
def _unpack_ref(self):
|
||||
return cloudpickle.loads(self.handle)
|
||||
pass
|
||||
|
||||
|
||||
class ClientActorRef(ClientBaseRef):
|
||||
@@ -43,8 +54,7 @@ class ClientStub:
|
||||
|
||||
|
||||
class ClientRemoteFunc(ClientStub):
|
||||
"""
|
||||
A stub created on the Ray Client to represent a remote
|
||||
"""A stub created on the Ray Client to represent a remote
|
||||
function that can be exectued on the cluster.
|
||||
|
||||
This class is allowed to be passed around between remote functions.
|
||||
@@ -53,55 +63,57 @@ class ClientRemoteFunc(ClientStub):
|
||||
_func: The actual function to execute remotely
|
||||
_name: The original name of the function
|
||||
_ref: The ClientObjectRef of the pickled code of the function, _func
|
||||
_raylet_remote: The Raylet-side ray.remote_function.RemoteFunction
|
||||
for this object
|
||||
"""
|
||||
|
||||
def __init__(self, f):
|
||||
def __init__(self, f, options=None):
|
||||
self._lock = threading.Lock()
|
||||
self._func = f
|
||||
self._name = f.__name__
|
||||
self.id = None
|
||||
|
||||
# self._ref can be lazily instantiated. Rather than eagerly creating
|
||||
# function data objects in the server we can put them just before we
|
||||
# execute the function, especially in cases where many @ray.remote
|
||||
# functions exist in a library and only a handful are ever executed by
|
||||
# a user of the library.
|
||||
#
|
||||
# TODO(barakmich): This ref might actually be better as a serialized
|
||||
# ObjectRef. This requires being able to serialize the ref without
|
||||
# pinning it (as the lifetime of the ref is tied with the server, not
|
||||
# the client)
|
||||
self._ref = None
|
||||
self._raylet_remote = None
|
||||
self._options = validate_options(options)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise TypeError(f"Remote function cannot be called directly. "
|
||||
"Use {self._name}.remote method instead")
|
||||
|
||||
def remote(self, *args, **kwargs):
|
||||
return ray.call_remote(self, *args, **kwargs)
|
||||
return return_refs(ray.call_remote(self, *args, **kwargs))
|
||||
|
||||
def _get_ray_remote_impl(self):
|
||||
if self._raylet_remote is None:
|
||||
self._raylet_remote = ray.remote(self._func)
|
||||
return self._raylet_remote
|
||||
def options(self, **kwargs):
|
||||
return OptionWrapper(self, kwargs)
|
||||
|
||||
def _remote(self, args=[], kwargs={}, **option_args):
|
||||
return self.options(**option_args).remote(*args, **kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return "ClientRemoteFunc(%s, %s)" % (self._name, self._ref)
|
||||
|
||||
def _ensure_ref(self):
|
||||
with self._lock:
|
||||
if self._ref is None:
|
||||
# While calling ray.put() on our function, if
|
||||
# our function is recursive, it will attempt to
|
||||
# encode the ClientRemoteFunc -- itself -- and
|
||||
# infinitely recurse on _ensure_ref.
|
||||
#
|
||||
# So we set the state of the reference to be an
|
||||
# in-progress self reference value, which
|
||||
# the encoding can detect and handle correctly.
|
||||
self._ref = SelfReferenceSentinel()
|
||||
self._ref = ray.put(self._func)
|
||||
|
||||
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
|
||||
if self._ref is None:
|
||||
self._ref = ray.put(self._func)
|
||||
self._ensure_ref()
|
||||
task = ray_client_pb2.ClientTask()
|
||||
task.type = ray_client_pb2.ClientTask.FUNCTION
|
||||
task.name = self._name
|
||||
task.payload_id = self._ref.handle
|
||||
task.payload_id = self._ref.id
|
||||
set_task_options(task, self._options, "baseline_options")
|
||||
return task
|
||||
|
||||
|
||||
class ClientActorClass(ClientStub):
|
||||
""" A stub created on the Ray Client to represent an actor class.
|
||||
"""A stub created on the Ray Client to represent an actor class.
|
||||
|
||||
It is wrapped by ray.remote and can be executed on the cluster.
|
||||
|
||||
@@ -109,39 +121,40 @@ class ClientActorClass(ClientStub):
|
||||
actor_cls: The actual class to execute remotely
|
||||
_name: The original name of the class
|
||||
_ref: The ClientObjectRef of the pickled `actor_cls`
|
||||
_raylet_remote: The Raylet-side ray.ActorClass for this object
|
||||
"""
|
||||
|
||||
def __init__(self, actor_cls):
|
||||
def __init__(self, actor_cls, options=None):
|
||||
self.actor_cls = actor_cls
|
||||
self._name = actor_cls.__name__
|
||||
self._ref = None
|
||||
self._raylet_remote = None
|
||||
self._options = validate_options(options)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise TypeError(f"Remote actor cannot be instantiated directly. "
|
||||
"Use {self._name}.remote() instead")
|
||||
|
||||
def __getstate__(self) -> Dict:
|
||||
state = {
|
||||
"actor_cls": self.actor_cls,
|
||||
"_name": self._name,
|
||||
"_ref": self._ref,
|
||||
}
|
||||
return state
|
||||
def _ensure_ref(self):
|
||||
if self._ref is None:
|
||||
# As before, set the state of the reference to be an
|
||||
# in-progress self reference value, which
|
||||
# the encoding can detect and handle correctly.
|
||||
self._ref = SelfReferenceSentinel()
|
||||
self._ref = ray.put(self.actor_cls)
|
||||
|
||||
def __setstate__(self, state: Dict) -> None:
|
||||
self.actor_cls = state["actor_cls"]
|
||||
self._name = state["_name"]
|
||||
self._ref = state["_ref"]
|
||||
|
||||
def remote(self, *args, **kwargs):
|
||||
def remote(self, *args, **kwargs) -> "ClientActorHandle":
|
||||
# Actually instantiate the actor
|
||||
ref = ray.call_remote(self, *args, **kwargs)
|
||||
return ClientActorHandle(ClientActorRef(ref.id, ref.handle), self)
|
||||
ref_ids = ray.call_remote(self, *args, **kwargs)
|
||||
assert len(ref_ids) == 1
|
||||
return ClientActorHandle(ClientActorRef(ref_ids[0]))
|
||||
|
||||
def options(self, **kwargs):
|
||||
return ActorOptionWrapper(self, kwargs)
|
||||
|
||||
def _remote(self, args=[], kwargs={}, **option_args):
|
||||
return self.options(**option_args).remote(*args, **kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return "ClientRemoteActor(%s, %s)" % (self._name, self._ref)
|
||||
return "ClientActorClass(%s, %s)" % (self._name, self._ref)
|
||||
|
||||
def __getattr__(self, key):
|
||||
if key not in self.__dict__:
|
||||
@@ -149,12 +162,12 @@ class ClientActorClass(ClientStub):
|
||||
raise NotImplementedError("static methods")
|
||||
|
||||
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
|
||||
if self._ref is None:
|
||||
self._ref = ray.put(self.actor_cls)
|
||||
self._ensure_ref()
|
||||
task = ray_client_pb2.ClientTask()
|
||||
task.type = ray_client_pb2.ClientTask.ACTOR
|
||||
task.name = self._name
|
||||
task.payload_id = self._ref.handle
|
||||
task.payload_id = self._ref.id
|
||||
set_task_options(task, self._options, "baseline_options")
|
||||
return task
|
||||
|
||||
|
||||
@@ -174,29 +187,12 @@ class ClientActorHandle(ClientStub):
|
||||
ray.actor.ActorHandle contained in the actor_id ref.
|
||||
"""
|
||||
|
||||
def __init__(self, actor_ref: ClientActorRef,
|
||||
actor_class: ClientActorClass):
|
||||
def __init__(self, actor_ref: ClientActorRef):
|
||||
self.actor_ref = actor_ref
|
||||
self.actor_class = actor_class
|
||||
self._real_actor_handle = None
|
||||
|
||||
def _get_ray_remote_impl(self):
|
||||
if self._real_actor_handle is None:
|
||||
self._real_actor_handle = cloudpickle.loads(self.actor_ref.handle)
|
||||
return self._real_actor_handle
|
||||
|
||||
def __getstate__(self) -> Dict:
|
||||
state = {
|
||||
"actor_ref": self.actor_ref,
|
||||
"actor_class": self.actor_class,
|
||||
"_real_actor_handle": self._real_actor_handle,
|
||||
}
|
||||
return state
|
||||
|
||||
def __setstate__(self, state: Dict) -> None:
|
||||
self.actor_ref = state["actor_ref"]
|
||||
self.actor_class = state["actor_class"]
|
||||
self._real_actor_handle = state["_real_actor_handle"]
|
||||
def __del__(self) -> None:
|
||||
if ray.is_connected():
|
||||
ray.call_release(self.actor_ref.id)
|
||||
|
||||
@property
|
||||
def _actor_id(self):
|
||||
@@ -226,65 +222,90 @@ class ClientRemoteMethod(ClientStub):
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise TypeError(f"Remote method cannot be called directly. "
|
||||
"Use {self._name}.remote() instead")
|
||||
|
||||
def _get_ray_remote_impl(self):
|
||||
return getattr(self.actor_handle._get_ray_remote_impl(),
|
||||
self.method_name)
|
||||
|
||||
def __getstate__(self) -> Dict:
|
||||
state = {
|
||||
"actor_handle": self.actor_handle,
|
||||
"method_name": self.method_name,
|
||||
}
|
||||
return state
|
||||
|
||||
def __setstate__(self, state: Dict) -> None:
|
||||
self.actor_handle = state["actor_handle"]
|
||||
self.method_name = state["method_name"]
|
||||
f"Use {self._name}.remote() instead")
|
||||
|
||||
def remote(self, *args, **kwargs):
|
||||
return ray.call_remote(self, *args, **kwargs)
|
||||
return return_refs(ray.call_remote(self, *args, **kwargs))
|
||||
|
||||
def __repr__(self):
|
||||
name = "%s.%s" % (self.actor_handle.actor_class._name,
|
||||
self.method_name)
|
||||
return "ClientRemoteMethod(%s, %s)" % (name,
|
||||
self.actor_handle.actor_id)
|
||||
return "ClientRemoteMethod(%s, %s)" % (self.method_name,
|
||||
self.actor_handle)
|
||||
|
||||
def options(self, **kwargs):
|
||||
return OptionWrapper(self, kwargs)
|
||||
|
||||
def _remote(self, args=[], kwargs={}, **option_args):
|
||||
return self.options(**option_args).remote(*args, **kwargs)
|
||||
|
||||
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
|
||||
task = ray_client_pb2.ClientTask()
|
||||
task.type = ray_client_pb2.ClientTask.METHOD
|
||||
task.name = self.method_name
|
||||
task.payload_id = self.actor_handle.actor_ref.handle
|
||||
task.payload_id = self.actor_handle.actor_ref.id
|
||||
return task
|
||||
|
||||
|
||||
def convert_from_arg(pb) -> Any:
|
||||
if pb.local == ray_client_pb2.Arg.Locality.REFERENCE:
|
||||
return ClientObjectRef(pb.reference_id)
|
||||
elif pb.local == ray_client_pb2.Arg.Locality.INTERNED:
|
||||
return cloudpickle.loads(pb.data)
|
||||
class OptionWrapper:
|
||||
def __init__(self, stub: ClientStub, options: Optional[Dict[str, Any]]):
|
||||
self.remote_stub = stub
|
||||
self.options = validate_options(options)
|
||||
|
||||
raise Exception("convert_from_arg: Uncovered locality enum")
|
||||
def remote(self, *args, **kwargs):
|
||||
return return_refs(ray.call_remote(self, *args, **kwargs))
|
||||
|
||||
def __getattr__(self, key):
|
||||
return getattr(self.remote_stub, key)
|
||||
|
||||
def _prepare_client_task(self):
|
||||
task = self.remote_stub._prepare_client_task()
|
||||
set_task_options(task, self.options)
|
||||
return task
|
||||
|
||||
|
||||
def convert_to_arg(val):
|
||||
out = ray_client_pb2.Arg()
|
||||
if isinstance(val, ClientObjectRef):
|
||||
out.local = ray_client_pb2.Arg.Locality.REFERENCE
|
||||
out.reference_id = val.id
|
||||
else:
|
||||
out.local = ray_client_pb2.Arg.Locality.INTERNED
|
||||
out.data = cloudpickle.dumps(val)
|
||||
return out
|
||||
class ActorOptionWrapper(OptionWrapper):
|
||||
def remote(self, *args, **kwargs):
|
||||
ref_ids = ray.call_remote(self, *args, **kwargs)
|
||||
assert len(ref_ids) == 1
|
||||
return ClientActorHandle(ClientActorRef(ref_ids[0]))
|
||||
|
||||
|
||||
def encode_exception(exception) -> str:
|
||||
data = cloudpickle.dumps(exception)
|
||||
return base64.standard_b64encode(data).decode()
|
||||
def set_task_options(task: ray_client_pb2.ClientTask,
|
||||
options: Optional[Dict[str, Any]],
|
||||
field: str = "options") -> None:
|
||||
if options is None:
|
||||
task.ClearField(field)
|
||||
return
|
||||
options_str = json.dumps(options)
|
||||
getattr(task, field).json_options = options_str
|
||||
|
||||
|
||||
def decode_exception(data) -> Exception:
|
||||
data = base64.standard_b64decode(data)
|
||||
return cloudpickle.loads(data)
|
||||
def return_refs(ids: List[bytes]
|
||||
) -> Union[None, ClientObjectRef, List[ClientObjectRef]]:
|
||||
if len(ids) == 1:
|
||||
return ClientObjectRef(ids[0])
|
||||
if len(ids) == 0:
|
||||
return None
|
||||
return [ClientObjectRef(id) for id in ids]
|
||||
|
||||
|
||||
class DataEncodingSentinel:
|
||||
def __repr__(self) -> str:
|
||||
return self.__class__.__name__
|
||||
|
||||
|
||||
class SelfReferenceSentinel(DataEncodingSentinel):
|
||||
pass
|
||||
|
||||
|
||||
def remote_decorator(options: Optional[Dict[str, Any]]):
|
||||
def decorator(function_or_class) -> ClientStub:
|
||||
if (inspect.isfunction(function_or_class)
|
||||
or is_cython(function_or_class)):
|
||||
return ClientRemoteFunc(function_or_class, options=options)
|
||||
elif inspect.isclass(function_or_class):
|
||||
return ClientActorClass(function_or_class, options=options)
|
||||
else:
|
||||
raise TypeError("The @ray.remote decorator must be applied to "
|
||||
"either a function or to a class.")
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -0,0 +1,108 @@
|
||||
"""This file implements a threaded stream controller to abstract a data stream
|
||||
back to the ray clientserver.
|
||||
"""
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
import grpc
|
||||
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# The maximum field value for request_id -- which is also the maximum
|
||||
# number of simultaneous in-flight requests.
|
||||
INT32_MAX = (2**31) - 1
|
||||
|
||||
|
||||
class DataClient:
|
||||
def __init__(self, channel: "grpc._channel.Channel", client_id: str):
|
||||
"""Initializes a thread-safe datapath over a Ray Client gRPC channel.
|
||||
|
||||
Args:
|
||||
channel: connected gRPC channel
|
||||
client_id: the generated ID representing this client
|
||||
"""
|
||||
self.channel = channel
|
||||
self.request_queue = queue.Queue()
|
||||
self.data_thread = self._start_datathread()
|
||||
self.ready_data: Dict[int, Any] = {}
|
||||
self.cv = threading.Condition()
|
||||
self._req_id = 0
|
||||
self._client_id = client_id
|
||||
self.data_thread.start()
|
||||
|
||||
def _next_id(self) -> int:
|
||||
self._req_id += 1
|
||||
if self._req_id > INT32_MAX:
|
||||
self._req_id = 1
|
||||
# Responses that aren't tracked (like opportunistic releases)
|
||||
# have req_id=0, so make sure we never mint such an id.
|
||||
assert self._req_id != 0
|
||||
return self._req_id
|
||||
|
||||
def _start_datathread(self) -> threading.Thread:
|
||||
return threading.Thread(target=self._data_main, args=(), daemon=True)
|
||||
|
||||
def _data_main(self) -> None:
|
||||
stub = ray_client_pb2_grpc.RayletDataStreamerStub(self.channel)
|
||||
resp_stream = stub.Datapath(
|
||||
iter(self.request_queue.get, None),
|
||||
metadata=(("client_id", self._client_id), ))
|
||||
try:
|
||||
for response in resp_stream:
|
||||
if response.req_id == 0:
|
||||
# This is not being waited for.
|
||||
logger.debug(f"Got unawaited response {response}")
|
||||
continue
|
||||
with self.cv:
|
||||
self.ready_data[response.req_id] = response
|
||||
self.cv.notify_all()
|
||||
except grpc.RpcError as e:
|
||||
if grpc.StatusCode.CANCELLED == e.code():
|
||||
# Gracefully shutting down
|
||||
logger.info("Cancelling data channel")
|
||||
else:
|
||||
logger.error(
|
||||
f"Got Error from data channel -- shutting down: {e}")
|
||||
raise e
|
||||
|
||||
def close(self) -> None:
|
||||
if self.request_queue is not None:
|
||||
self.request_queue.put(None)
|
||||
if self.data_thread is not None:
|
||||
self.data_thread.join()
|
||||
|
||||
def _blocking_send(self, req: ray_client_pb2.DataRequest
|
||||
) -> ray_client_pb2.DataResponse:
|
||||
req_id = self._next_id()
|
||||
req.req_id = req_id
|
||||
self.request_queue.put(req)
|
||||
data = None
|
||||
with self.cv:
|
||||
self.cv.wait_for(lambda: req_id in self.ready_data)
|
||||
data = self.ready_data[req_id]
|
||||
del self.ready_data[req_id]
|
||||
return data
|
||||
|
||||
def GetObject(self, request: ray_client_pb2.GetRequest,
|
||||
context=None) -> ray_client_pb2.GetResponse:
|
||||
datareq = ray_client_pb2.DataRequest(get=request, )
|
||||
resp = self._blocking_send(datareq)
|
||||
return resp.get
|
||||
|
||||
def PutObject(self, request: ray_client_pb2.PutRequest,
|
||||
context=None) -> ray_client_pb2.PutResponse:
|
||||
datareq = ray_client_pb2.DataRequest(put=request, )
|
||||
resp = self._blocking_send(datareq)
|
||||
return resp.put
|
||||
|
||||
def ReleaseObject(self,
|
||||
request: ray_client_pb2.ReleaseRequest,
|
||||
context=None) -> None:
|
||||
datareq = ray_client_pb2.DataRequest(release=request, )
|
||||
self.request_queue.put(datareq)
|
||||
@@ -0,0 +1,7 @@
|
||||
from ray.experimental.client import ray
|
||||
|
||||
from ray.tune import tune
|
||||
|
||||
ray.connect("localhost:50051")
|
||||
|
||||
tune.run("PG", config={"env": "CartPole-v0"})
|
||||
@@ -0,0 +1,86 @@
|
||||
"""This file implements a threaded stream controller to return logs back from
|
||||
the ray clientserver.
|
||||
"""
|
||||
import sys
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
import grpc
|
||||
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# TODO(barakmich): Running a logger in a logger causes loopback.
|
||||
# The client logger need its own root -- possibly this one.
|
||||
# For the moment, let's just not propogate beyond this point.
|
||||
logger.propagate = False
|
||||
|
||||
|
||||
class LogstreamClient:
|
||||
def __init__(self, channel: "grpc._channel.Channel"):
|
||||
"""Initializes a thread-safe log stream over a Ray Client gRPC channel.
|
||||
|
||||
Args:
|
||||
channel: connected gRPC channel
|
||||
"""
|
||||
self.channel = channel
|
||||
self.request_queue = queue.Queue()
|
||||
self.log_thread = self._start_logthread()
|
||||
self.log_thread.start()
|
||||
|
||||
def _start_logthread(self) -> threading.Thread:
|
||||
return threading.Thread(target=self._log_main, args=(), daemon=True)
|
||||
|
||||
def _log_main(self) -> None:
|
||||
stub = ray_client_pb2_grpc.RayletLogStreamerStub(self.channel)
|
||||
log_stream = stub.Logstream(iter(self.request_queue.get, None))
|
||||
try:
|
||||
for record in log_stream:
|
||||
if record.level < 0:
|
||||
self.stdstream(level=record.level, msg=record.msg)
|
||||
self.log(level=record.level, msg=record.msg)
|
||||
except grpc.RpcError as e:
|
||||
if grpc.StatusCode.CANCELLED != e.code():
|
||||
# Not just shutting down normally
|
||||
logger.error(
|
||||
f"Got Error from logger channel -- shutting down: {e}")
|
||||
raise e
|
||||
|
||||
def log(self, level: int, msg: str):
|
||||
"""Log the message from the log stream.
|
||||
By default, calls logger.log but this can be overridden.
|
||||
|
||||
Args:
|
||||
level: The loglevel of the received log message
|
||||
msg: The content of the message
|
||||
"""
|
||||
logger.log(level=level, msg=msg)
|
||||
|
||||
def stdstream(self, level: int, msg: str):
|
||||
"""Log the stdout/stderr entry from the log stream.
|
||||
By default, calls print but this can be overridden.
|
||||
|
||||
Args:
|
||||
level: The loglevel of the received log message
|
||||
msg: The content of the message
|
||||
"""
|
||||
print_file = sys.stderr if level == -2 else sys.stdout
|
||||
print(msg, file=print_file)
|
||||
|
||||
def set_logstream_level(self, level: int):
|
||||
logger.setLevel(level)
|
||||
req = ray_client_pb2.LogSettingsRequest()
|
||||
req.enabled = True
|
||||
req.loglevel = level
|
||||
self.request_queue.put(req)
|
||||
|
||||
def close(self) -> None:
|
||||
self.request_queue.put(None)
|
||||
if self.log_thread is not None:
|
||||
self.log_thread.join()
|
||||
|
||||
def disable_logs(self) -> None:
|
||||
req = ray_client_pb2.LogSettingsRequest()
|
||||
req.enabled = False
|
||||
self.request_queue.put(req)
|
||||
@@ -0,0 +1,54 @@
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
|
||||
options = {
|
||||
"num_returns": (int, lambda x: x >= 0,
|
||||
"The keyword 'num_returns' only accepts 0 "
|
||||
"or a positive integer"),
|
||||
"num_cpus": (),
|
||||
"num_gpus": (),
|
||||
"resources": (),
|
||||
"accelerator_type": (),
|
||||
"max_calls": (int, lambda x: x >= 0,
|
||||
"The keyword 'max_calls' only accepts 0 "
|
||||
"or a positive integer"),
|
||||
"max_restarts": (int, lambda x: x >= -1,
|
||||
"The keyword 'max_restarts' only accepts -1, 0 "
|
||||
"or a positive integer"),
|
||||
"max_task_retries": (int, lambda x: x >= -1,
|
||||
"The keyword 'max_task_retries' only accepts -1, 0 "
|
||||
"or a positive integer"),
|
||||
"max_retries": (int, lambda x: x >= -1,
|
||||
"The keyword 'max_retries' only accepts 0, -1 "
|
||||
"or a positive integer"),
|
||||
"max_concurrency": (),
|
||||
"name": (),
|
||||
"lifetime": (),
|
||||
"memory": (),
|
||||
"object_store_memory": (),
|
||||
"placement_group": (),
|
||||
"placement_group_bundle_index": (),
|
||||
"placement_group_capture_child_tasks": (),
|
||||
"override_environment_variables": (),
|
||||
}
|
||||
|
||||
|
||||
def validate_options(
|
||||
kwargs_dict: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
||||
if kwargs_dict is None:
|
||||
return None
|
||||
if len(kwargs_dict) == 0:
|
||||
return None
|
||||
out = {}
|
||||
for k, v in kwargs_dict.items():
|
||||
if k not in options.keys():
|
||||
raise TypeError(f"Invalid option passed to remote(): {k}")
|
||||
validator = options[k]
|
||||
if len(validator) != 0:
|
||||
if not isinstance(v, validator[0]):
|
||||
raise ValueError(validator[2])
|
||||
if not validator[1](v):
|
||||
raise ValueError(validator[2])
|
||||
out[k] = v
|
||||
return out
|
||||
@@ -0,0 +1,17 @@
|
||||
from contextlib import contextmanager
|
||||
|
||||
import ray.experimental.client.server.server as ray_client_server
|
||||
from ray.experimental.client import ray
|
||||
|
||||
|
||||
@contextmanager
|
||||
def ray_start_client_server():
|
||||
ray._inside_client_test = True
|
||||
server = ray_client_server.serve("localhost:50051")
|
||||
ray.connect("localhost:50051")
|
||||
try:
|
||||
yield ray
|
||||
finally:
|
||||
ray._inside_client_test = False
|
||||
ray.disconnect()
|
||||
server.stop(0)
|
||||
@@ -1,101 +0,0 @@
|
||||
# Along with `api.py` this is the stub that interfaces with
|
||||
# the real (C-binding, raylet) ray core.
|
||||
#
|
||||
# Ideally, the first import line is the only time we actually
|
||||
# import ray in this library (excluding the main function for the server)
|
||||
#
|
||||
# While the stub is trivial, it allows us to check that the calls we're
|
||||
# making into the core-ray module are contained and well-defined.
|
||||
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
import ray
|
||||
|
||||
from ray.experimental.client.api import APIImpl
|
||||
from ray.experimental.client.common import ClientObjectRef
|
||||
from ray.experimental.client.common import ClientStub
|
||||
|
||||
|
||||
class CoreRayAPI(APIImpl):
|
||||
"""
|
||||
Implements the equivalent client-side Ray API by simply passing along to
|
||||
the Core Ray API. Primarily used inside of Ray Workers as a trampoline back
|
||||
to core ray when passed client stubs.
|
||||
"""
|
||||
|
||||
def get(self, vals, *, timeout: Optional[float] = None) -> Any:
|
||||
if isinstance(vals, list):
|
||||
if isinstance(vals[0], ClientObjectRef):
|
||||
return ray.get(
|
||||
[val._unpack_ref() for val in vals], timeout=timeout)
|
||||
elif isinstance(vals, ClientObjectRef):
|
||||
return ray.get(vals._unpack_ref(), timeout=timeout)
|
||||
return ray.get(vals, timeout=timeout)
|
||||
|
||||
def put(self, vals: Any, *args,
|
||||
**kwargs) -> Union[ClientObjectRef, ray._raylet.ObjectRef]:
|
||||
return ray.put(vals, *args, **kwargs)
|
||||
|
||||
def wait(self, *args, **kwargs):
|
||||
return ray.wait(*args, **kwargs)
|
||||
|
||||
def remote(self, *args, **kwargs):
|
||||
return ray.remote(*args, **kwargs)
|
||||
|
||||
def call_remote(self, instance: ClientStub, *args, **kwargs):
|
||||
return instance._get_ray_remote_impl().remote(*args, **kwargs)
|
||||
|
||||
def close(self) -> None:
|
||||
return None
|
||||
|
||||
def kill(self, actor, *, no_restart=True):
|
||||
return ray.kill(actor, no_restart=no_restart)
|
||||
|
||||
def cancel(self, obj, *, force=False, recursive=True):
|
||||
return ray.cancel(obj, force=force, recursive=recursive)
|
||||
|
||||
def is_initialized(self) -> bool:
|
||||
return ray.is_initialized()
|
||||
|
||||
# Allow for generic fallback to ray.* in remote methods. This allows calls
|
||||
# like ray.nodes() to be run in remote functions even though the client
|
||||
# doesn't currently support them.
|
||||
def __getattr__(self, key: str):
|
||||
return getattr(ray, key)
|
||||
|
||||
|
||||
class RayServerAPI(CoreRayAPI):
|
||||
"""
|
||||
Ray Client server-side API shim. By default, simply calls the default Core
|
||||
Ray API calls, but also accepts scheduling calls from functions running
|
||||
inside of other remote functions that need to create more work.
|
||||
"""
|
||||
|
||||
def __init__(self, server_instance):
|
||||
self.server = server_instance
|
||||
|
||||
# Wrap single item into list if needed before calling server put.
|
||||
def put(self, vals: Any, *args, **kwargs) -> ClientObjectRef:
|
||||
to_put = []
|
||||
single = False
|
||||
if isinstance(vals, list):
|
||||
to_put = vals
|
||||
else:
|
||||
single = True
|
||||
to_put.append(vals)
|
||||
|
||||
out = [self._put(x) for x in to_put]
|
||||
if single:
|
||||
out = out[0]
|
||||
return out
|
||||
|
||||
def _put(self, val: Any):
|
||||
resp = self.server._put_and_retain_obj(val)
|
||||
return ClientObjectRef(resp.id)
|
||||
|
||||
def call_remote(self, instance: ClientStub, *args, **kwargs):
|
||||
task = instance._prepare_client_task()
|
||||
ticket = self.server.Schedule(task, prepared_args=args)
|
||||
return ClientObjectRef(ticket.return_id)
|
||||
@@ -0,0 +1,54 @@
|
||||
import logging
|
||||
import grpc
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.experimental.client.server.server import RayletServicer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
|
||||
def __init__(self, basic_service: "RayletServicer"):
|
||||
self.basic_service = basic_service
|
||||
|
||||
def Datapath(self, request_iterator, context):
|
||||
metadata = {k: v for k, v in context.invocation_metadata()}
|
||||
client_id = metadata["client_id"]
|
||||
if client_id == "":
|
||||
logger.error("Client connecting with no client_id")
|
||||
return
|
||||
logger.info(f"New data connection from client {client_id}")
|
||||
try:
|
||||
for req in request_iterator:
|
||||
resp = None
|
||||
req_type = req.WhichOneof("type")
|
||||
if req_type == "get":
|
||||
get_resp = self.basic_service._get_object(
|
||||
req.get, client_id)
|
||||
resp = ray_client_pb2.DataResponse(get=get_resp)
|
||||
elif req_type == "put":
|
||||
put_resp = self.basic_service._put_object(
|
||||
req.put, client_id)
|
||||
resp = ray_client_pb2.DataResponse(put=put_resp)
|
||||
elif req_type == "release":
|
||||
released = []
|
||||
for rel_id in req.release.ids:
|
||||
rel = self.basic_service.release(client_id, rel_id)
|
||||
released.append(rel)
|
||||
resp = ray_client_pb2.DataResponse(
|
||||
release=ray_client_pb2.ReleaseResponse(ok=released))
|
||||
else:
|
||||
raise Exception(f"Unreachable code: Request type "
|
||||
f"{req_type} not handled in Datapath")
|
||||
resp.req_id = req.req_id
|
||||
yield resp
|
||||
except grpc.RpcError as e:
|
||||
logger.debug(f"Closing data channel: {e}")
|
||||
finally:
|
||||
logger.info(f"Lost data connection from client {client_id}")
|
||||
self.basic_service.release_all(client_id)
|
||||
@@ -0,0 +1,101 @@
|
||||
"""This file responds to log stream requests and forwards logs
|
||||
with its handler.
|
||||
"""
|
||||
import io
|
||||
import threading
|
||||
import queue
|
||||
import logging
|
||||
import grpc
|
||||
import uuid
|
||||
|
||||
from ray.worker import print_worker_logs
|
||||
from ray.ray_logging import global_worker_stdstream_dispatcher
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LogstreamHandler(logging.Handler):
|
||||
def __init__(self, queue, level):
|
||||
super().__init__()
|
||||
self.queue = queue
|
||||
self.level = level
|
||||
|
||||
def emit(self, record: logging.LogRecord):
|
||||
logdata = ray_client_pb2.LogData()
|
||||
logdata.msg = record.getMessage()
|
||||
logdata.level = record.levelno
|
||||
logdata.name = record.name
|
||||
self.queue.put(logdata)
|
||||
|
||||
|
||||
class StdStreamHandler:
|
||||
def __init__(self, queue):
|
||||
self.queue = queue
|
||||
self.id = str(uuid.uuid4())
|
||||
|
||||
def handle(self, data):
|
||||
logdata = ray_client_pb2.LogData()
|
||||
logdata.level = -2 if data["is_err"] else -1
|
||||
logdata.name = "stderr" if data["is_err"] else "stdout"
|
||||
with io.StringIO() as file:
|
||||
print_worker_logs(data, file)
|
||||
logdata.msg = file.getvalue()
|
||||
self.queue.put(logdata)
|
||||
|
||||
def register_global(self):
|
||||
global_worker_stdstream_dispatcher.add_handler(self.id, self.handle)
|
||||
|
||||
def unregister_global(self):
|
||||
global_worker_stdstream_dispatcher.remove_handler(self.id)
|
||||
|
||||
|
||||
def log_status_change_thread(log_queue, request_iterator):
|
||||
std_handler = StdStreamHandler(log_queue)
|
||||
current_handler = None
|
||||
root_logger = logging.getLogger("ray")
|
||||
default_level = root_logger.getEffectiveLevel()
|
||||
try:
|
||||
for req in request_iterator:
|
||||
if current_handler is not None:
|
||||
root_logger.setLevel(default_level)
|
||||
root_logger.removeHandler(current_handler)
|
||||
std_handler.unregister_global()
|
||||
if not req.enabled:
|
||||
current_handler = None
|
||||
continue
|
||||
current_handler = LogstreamHandler(log_queue, req.loglevel)
|
||||
std_handler.register_global()
|
||||
root_logger.addHandler(current_handler)
|
||||
root_logger.setLevel(req.loglevel)
|
||||
except grpc.RpcError as e:
|
||||
logger.debug(f"closing log thread "
|
||||
f"grpc error reading request_iterator: {e}")
|
||||
finally:
|
||||
if current_handler is not None:
|
||||
root_logger.setLevel(default_level)
|
||||
root_logger.removeHandler(current_handler)
|
||||
std_handler.unregister_global()
|
||||
log_queue.put(None)
|
||||
|
||||
|
||||
class LogstreamServicer(ray_client_pb2_grpc.RayletLogStreamerServicer):
|
||||
def Logstream(self, request_iterator, context):
|
||||
logger.info("New logs connection")
|
||||
log_queue = queue.Queue()
|
||||
thread = threading.Thread(
|
||||
target=log_status_change_thread,
|
||||
args=(log_queue, request_iterator),
|
||||
daemon=True)
|
||||
thread.start()
|
||||
try:
|
||||
queue_iter = iter(log_queue.get, None)
|
||||
for record in queue_iter:
|
||||
if record is None:
|
||||
break
|
||||
yield record
|
||||
except grpc.RpcError as e:
|
||||
logger.debug(f"Closing log channel: {e}")
|
||||
finally:
|
||||
thread.join()
|
||||
@@ -1,6 +1,14 @@
|
||||
import logging
|
||||
from concurrent import futures
|
||||
import grpc
|
||||
import base64
|
||||
from collections import defaultdict
|
||||
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Set
|
||||
from typing import Optional
|
||||
|
||||
from ray import cloudpickle
|
||||
import ray
|
||||
import ray.state
|
||||
@@ -9,29 +17,34 @@ import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
import time
|
||||
import inspect
|
||||
import json
|
||||
from ray.experimental.client import stash_api_for_tests, _set_server_api
|
||||
from ray.experimental.client.common import convert_from_arg
|
||||
from ray.experimental.client.common import encode_exception
|
||||
from ray.experimental.client.common import ClientObjectRef
|
||||
from ray.experimental.client.server.core_ray_api import RayServerAPI
|
||||
from ray.experimental.client.server.server_pickler import convert_from_arg
|
||||
from ray.experimental.client.server.server_pickler import dumps_from_server
|
||||
from ray.experimental.client.server.server_pickler import loads_from_client
|
||||
from ray.experimental.client.server.dataservicer import DataServicer
|
||||
from ray.experimental.client.server.logservicer import LogstreamServicer
|
||||
from ray.experimental.client.server.server_stubs import current_remote
|
||||
from ray._private.client_mode_hook import disable_client_hook
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
||||
def __init__(self, test_mode=False):
|
||||
self.object_refs = {}
|
||||
def __init__(self):
|
||||
self.object_refs: Dict[str, Dict[bytes, ray.ObjectRef]] = defaultdict(
|
||||
dict)
|
||||
self.function_refs = {}
|
||||
self.actor_refs = {}
|
||||
self.actor_refs: Dict[bytes, ray.ActorHandle] = {}
|
||||
self.actor_owners: Dict[str, Set[bytes]] = defaultdict(set)
|
||||
self.registered_actor_classes = {}
|
||||
self._test_mode = test_mode
|
||||
self._current_function_stub = None
|
||||
|
||||
def ClusterInfo(self, request,
|
||||
context=None) -> ray_client_pb2.ClusterInfoResponse:
|
||||
resp = ray_client_pb2.ClusterInfoResponse()
|
||||
resp.type = request.type
|
||||
if request.type == ray_client_pb2.ClusterInfoType.CLUSTER_RESOURCES:
|
||||
resources = ray.cluster_resources()
|
||||
with disable_client_hook():
|
||||
resources = ray.cluster_resources()
|
||||
# Normalize resources into floats
|
||||
# (the function may return values that are ints)
|
||||
float_resources = {k: float(v) for k, v in resources.items()}
|
||||
@@ -40,7 +53,8 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
||||
table=float_resources))
|
||||
elif request.type == \
|
||||
ray_client_pb2.ClusterInfoType.AVAILABLE_RESOURCES:
|
||||
resources = ray.available_resources()
|
||||
with disable_client_hook():
|
||||
resources = ray.available_resources()
|
||||
# Normalize resources into floats
|
||||
# (the function may return values that are ints)
|
||||
float_resources = {k: float(v) for k, v in resources.items()}
|
||||
@@ -48,7 +62,8 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
||||
ray_client_pb2.ClusterInfoResponse.ResourceTable(
|
||||
table=float_resources))
|
||||
else:
|
||||
resp.json = self._return_debug_cluster_info(request, context)
|
||||
with disable_client_hook():
|
||||
resp.json = self._return_debug_cluster_info(request, context)
|
||||
return resp
|
||||
|
||||
def _return_debug_cluster_info(self, request, context=None) -> str:
|
||||
@@ -61,20 +76,61 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
||||
raise TypeError("Unsupported cluster info type")
|
||||
return json.dumps(data)
|
||||
|
||||
def Terminate(self, request, context=None):
|
||||
if request.WhichOneof("terminate_type") == "task_object":
|
||||
def release(self, client_id: str, id: bytes) -> bool:
|
||||
if client_id in self.object_refs:
|
||||
if id in self.object_refs[client_id]:
|
||||
logger.debug(f"Releasing object {id.hex()} for {client_id}")
|
||||
del self.object_refs[client_id][id]
|
||||
return True
|
||||
|
||||
if client_id in self.actor_owners:
|
||||
if id in self.actor_owners[client_id]:
|
||||
logger.debug(f"Releasing actor {id.hex()} for {client_id}")
|
||||
del self.actor_refs[id]
|
||||
self.actor_owners[client_id].remove(id)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def release_all(self, client_id):
|
||||
self._release_objects(client_id)
|
||||
self._release_actors(client_id)
|
||||
|
||||
def _release_objects(self, client_id):
|
||||
if client_id not in self.object_refs:
|
||||
logger.debug(f"Releasing client with no references: {client_id}")
|
||||
return
|
||||
count = len(self.object_refs[client_id])
|
||||
del self.object_refs[client_id]
|
||||
logger.debug(f"Released all {count} objects for client {client_id}")
|
||||
|
||||
def _release_actors(self, client_id):
|
||||
if client_id not in self.actor_owners:
|
||||
logger.debug(f"Releasing client with no actors: {client_id}")
|
||||
count = 0
|
||||
for id_bytes in self.actor_owners[client_id]:
|
||||
count += 1
|
||||
del self.actor_refs[id_bytes]
|
||||
del self.actor_owners[client_id]
|
||||
logger.debug(f"Released all {count} actors for client: {client_id}")
|
||||
|
||||
def Terminate(self, req, context=None):
|
||||
if req.WhichOneof("terminate_type") == "task_object":
|
||||
try:
|
||||
object_ref = cloudpickle.loads(request.task_object.handle)
|
||||
ray.cancel(
|
||||
object_ref,
|
||||
force=request.task_object.force,
|
||||
recursive=request.task_object.recursive)
|
||||
object_ref = \
|
||||
self.object_refs[req.client_id][req.task_object.id]
|
||||
with disable_client_hook():
|
||||
ray.cancel(
|
||||
object_ref,
|
||||
force=req.task_object.force,
|
||||
recursive=req.task_object.recursive)
|
||||
except Exception as e:
|
||||
return_exception_in_context(e, context)
|
||||
elif request.WhichOneof("terminate_type") == "actor":
|
||||
elif req.WhichOneof("terminate_type") == "actor":
|
||||
try:
|
||||
actor_ref = cloudpickle.loads(request.actor.handle)
|
||||
ray.kill(actor_ref, no_restart=request.actor.no_restart)
|
||||
actor_ref = self.actor_refs[req.actor.id]
|
||||
with disable_client_hook():
|
||||
ray.kill(actor_ref, no_restart=req.actor.no_restart)
|
||||
except Exception as e:
|
||||
return_exception_in_context(e, context)
|
||||
else:
|
||||
@@ -84,166 +140,221 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
||||
return ray_client_pb2.TerminateResponse(ok=True)
|
||||
|
||||
def GetObject(self, request, context=None):
|
||||
request_ref = cloudpickle.loads(request.handle)
|
||||
if request_ref.binary() not in self.object_refs:
|
||||
return self._get_object(request, "", context)
|
||||
|
||||
def _get_object(self, request, client_id: str, context=None):
|
||||
if request.id not in self.object_refs[client_id]:
|
||||
return ray_client_pb2.GetResponse(valid=False)
|
||||
objectref = self.object_refs[request_ref.binary()]
|
||||
logger.info("get: %s" % objectref)
|
||||
objectref = self.object_refs[client_id][request.id]
|
||||
logger.debug("get: %s" % objectref)
|
||||
try:
|
||||
item = ray.get(objectref, timeout=request.timeout)
|
||||
with disable_client_hook():
|
||||
item = ray.get(objectref, timeout=request.timeout)
|
||||
except Exception as e:
|
||||
return_exception_in_context(e, context)
|
||||
item_ser = cloudpickle.dumps(item)
|
||||
return ray_client_pb2.GetResponse(
|
||||
valid=False, error=cloudpickle.dumps(e))
|
||||
item_ser = dumps_from_server(item, client_id, self)
|
||||
return ray_client_pb2.GetResponse(valid=True, data=item_ser)
|
||||
|
||||
def PutObject(self, request, context=None) -> ray_client_pb2.PutResponse:
|
||||
obj = cloudpickle.loads(request.data)
|
||||
objectref = self._put_and_retain_obj(obj)
|
||||
pickled_ref = cloudpickle.dumps(objectref)
|
||||
return ray_client_pb2.PutResponse(
|
||||
ref=make_remote_ref(objectref.binary(), pickled_ref))
|
||||
def PutObject(self, request: ray_client_pb2.PutRequest,
|
||||
context=None) -> ray_client_pb2.PutResponse:
|
||||
"""gRPC entrypoint for unary PutObject
|
||||
"""
|
||||
return self._put_object(request, "", context)
|
||||
|
||||
def _put_and_retain_obj(self, obj) -> ray.ObjectRef:
|
||||
objectref = ray.put(obj)
|
||||
self.object_refs[objectref.binary()] = objectref
|
||||
logger.info("put: %s" % objectref)
|
||||
return objectref
|
||||
def _put_object(self,
|
||||
request: ray_client_pb2.PutRequest,
|
||||
client_id: str,
|
||||
context=None):
|
||||
"""Put an object in the cluster with ray.put() via gRPC.
|
||||
|
||||
Args:
|
||||
request: PutRequest with pickled data.
|
||||
client_id: The client who owns this data, for tracking when to
|
||||
delete this reference.
|
||||
context: gRPC context.
|
||||
"""
|
||||
obj = loads_from_client(request.data, self)
|
||||
with disable_client_hook():
|
||||
objectref = ray.put(obj)
|
||||
self.object_refs[client_id][objectref.binary()] = objectref
|
||||
logger.debug("put: %s" % objectref)
|
||||
return ray_client_pb2.PutResponse(id=objectref.binary())
|
||||
|
||||
def WaitObject(self, request, context=None) -> ray_client_pb2.WaitResponse:
|
||||
object_refs = [cloudpickle.loads(o) for o in request.object_handles]
|
||||
object_refs = []
|
||||
for id in request.object_ids:
|
||||
if id not in self.object_refs[request.client_id]:
|
||||
raise Exception(
|
||||
"Asking for a ref not associated with this client: %s" %
|
||||
str(id))
|
||||
object_refs.append(self.object_refs[request.client_id][id])
|
||||
num_returns = request.num_returns
|
||||
timeout = request.timeout
|
||||
object_refs_ids = []
|
||||
for object_ref in object_refs:
|
||||
if object_ref.binary() not in self.object_refs:
|
||||
return ray_client_pb2.WaitResponse(valid=False)
|
||||
object_refs_ids.append(self.object_refs[object_ref.binary()])
|
||||
try:
|
||||
ready_object_refs, remaining_object_refs = ray.wait(
|
||||
object_refs_ids,
|
||||
num_returns=num_returns,
|
||||
timeout=timeout if timeout != -1 else None)
|
||||
except Exception:
|
||||
with disable_client_hook():
|
||||
ready_object_refs, remaining_object_refs = ray.wait(
|
||||
object_refs,
|
||||
num_returns=num_returns,
|
||||
timeout=timeout if timeout != -1 else None,
|
||||
)
|
||||
except Exception as e:
|
||||
# TODO(ameer): improve exception messages.
|
||||
logger.error(f"Exception {e}")
|
||||
return ray_client_pb2.WaitResponse(valid=False)
|
||||
logger.info("wait: %s %s" % (str(ready_object_refs),
|
||||
str(remaining_object_refs)))
|
||||
logger.debug("wait: %s %s" % (str(ready_object_refs),
|
||||
str(remaining_object_refs)))
|
||||
ready_object_ids = [
|
||||
make_remote_ref(
|
||||
id=ready_object_ref.binary(),
|
||||
handle=cloudpickle.dumps(ready_object_ref),
|
||||
) for ready_object_ref in ready_object_refs
|
||||
ready_object_ref.binary() for ready_object_ref in ready_object_refs
|
||||
]
|
||||
remaining_object_ids = [
|
||||
make_remote_ref(
|
||||
id=remaining_object_ref.binary(),
|
||||
handle=cloudpickle.dumps(remaining_object_ref),
|
||||
) for remaining_object_ref in remaining_object_refs
|
||||
remaining_object_ref.binary()
|
||||
for remaining_object_ref in remaining_object_refs
|
||||
]
|
||||
return ray_client_pb2.WaitResponse(
|
||||
valid=True,
|
||||
ready_object_ids=ready_object_ids,
|
||||
remaining_object_ids=remaining_object_ids)
|
||||
|
||||
def Schedule(self, task, context=None,
|
||||
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
logger.info("schedule: %s %s" %
|
||||
(task.name,
|
||||
ray_client_pb2.ClientTask.RemoteExecType.Name(task.type)))
|
||||
if task.type == ray_client_pb2.ClientTask.FUNCTION:
|
||||
return self._schedule_function(task, context, prepared_args)
|
||||
elif task.type == ray_client_pb2.ClientTask.ACTOR:
|
||||
return self._schedule_actor(task, context, prepared_args)
|
||||
elif task.type == ray_client_pb2.ClientTask.METHOD:
|
||||
return self._schedule_method(task, context, prepared_args)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Unimplemented Schedule task type: %s" %
|
||||
ray_client_pb2.ClientTask.RemoteExecType.Name(task.type))
|
||||
def Schedule(self, task, context=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
logger.debug(
|
||||
"schedule: %s %s" % (task.name,
|
||||
ray_client_pb2.ClientTask.RemoteExecType.Name(
|
||||
task.type)))
|
||||
try:
|
||||
with disable_client_hook():
|
||||
if task.type == ray_client_pb2.ClientTask.FUNCTION:
|
||||
result = self._schedule_function(task, context)
|
||||
elif task.type == ray_client_pb2.ClientTask.ACTOR:
|
||||
result = self._schedule_actor(task, context)
|
||||
elif task.type == ray_client_pb2.ClientTask.METHOD:
|
||||
result = self._schedule_method(task, context)
|
||||
elif task.type == ray_client_pb2.ClientTask.NAMED_ACTOR:
|
||||
result = self._schedule_named_actor(task, context)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Unimplemented Schedule task type: %s" %
|
||||
ray_client_pb2.ClientTask.RemoteExecType.Name(
|
||||
task.type))
|
||||
result.valid = True
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Caught schedule exception {e}")
|
||||
raise e
|
||||
return ray_client_pb2.ClientTaskTicket(
|
||||
valid=False, error=cloudpickle.dumps(e))
|
||||
|
||||
def _schedule_method(
|
||||
self,
|
||||
task: ray_client_pb2.ClientTask,
|
||||
context=None,
|
||||
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
def _schedule_method(self, task: ray_client_pb2.ClientTask,
|
||||
context=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
actor_handle = self.actor_refs.get(task.payload_id)
|
||||
if actor_handle is None:
|
||||
raise Exception(
|
||||
"Can't run an actor the server doesn't have a handle for")
|
||||
arglist = _convert_args(task.args, prepared_args)
|
||||
with stash_api_for_tests(self._test_mode):
|
||||
output = getattr(actor_handle, task.name).remote(*arglist)
|
||||
self.object_refs[output.binary()] = output
|
||||
pickled_ref = cloudpickle.dumps(output)
|
||||
return ray_client_pb2.ClientTaskTicket(
|
||||
return_ref=make_remote_ref(output.binary(), pickled_ref))
|
||||
arglist, kwargs = self._convert_args(task.args, task.kwargs)
|
||||
method = getattr(actor_handle, task.name)
|
||||
opts = decode_options(task.options)
|
||||
if opts is not None:
|
||||
method = method.options(**opts)
|
||||
output = method.remote(*arglist, **kwargs)
|
||||
ids = self.unify_and_track_outputs(output, task.client_id)
|
||||
return ray_client_pb2.ClientTaskTicket(return_ids=ids)
|
||||
|
||||
def _schedule_actor(self,
|
||||
task: ray_client_pb2.ClientTask,
|
||||
context=None,
|
||||
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
with stash_api_for_tests(self._test_mode):
|
||||
payload_ref = cloudpickle.loads(task.payload_id)
|
||||
if payload_ref.binary() not in self.registered_actor_classes:
|
||||
actor_class_ref = self.object_refs[payload_ref.binary()]
|
||||
def _schedule_actor(self, task: ray_client_pb2.ClientTask,
|
||||
context=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
remote_class = self.lookup_or_register_actor(
|
||||
task.payload_id, task.client_id,
|
||||
decode_options(task.baseline_options))
|
||||
|
||||
arglist, kwargs = self._convert_args(task.args, task.kwargs)
|
||||
opts = decode_options(task.options)
|
||||
if opts is not None:
|
||||
remote_class = remote_class.options(**opts)
|
||||
with current_remote(remote_class):
|
||||
actor = remote_class.remote(*arglist, **kwargs)
|
||||
self.actor_refs[actor._actor_id.binary()] = actor
|
||||
self.actor_owners[task.client_id].add(actor._actor_id.binary())
|
||||
return ray_client_pb2.ClientTaskTicket(
|
||||
return_ids=[actor._actor_id.binary()])
|
||||
|
||||
def _schedule_function(self, task: ray_client_pb2.ClientTask,
|
||||
context=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
remote_func = self.lookup_or_register_func(
|
||||
task.payload_id, task.client_id,
|
||||
decode_options(task.baseline_options))
|
||||
arglist, kwargs = self._convert_args(task.args, task.kwargs)
|
||||
opts = decode_options(task.options)
|
||||
if opts is not None:
|
||||
remote_func = remote_func.options(**opts)
|
||||
with current_remote(remote_func):
|
||||
output = remote_func.remote(*arglist, **kwargs)
|
||||
ids = self.unify_and_track_outputs(output, task.client_id)
|
||||
return ray_client_pb2.ClientTaskTicket(return_ids=ids)
|
||||
|
||||
def _schedule_named_actor(self,
|
||||
task: ray_client_pb2.ClientTask,
|
||||
context=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
assert len(task.payload_id) == 0
|
||||
actor = ray.get_actor(task.name)
|
||||
self.actor_refs[actor._actor_id.binary()] = actor
|
||||
self.actor_owners[task.client_id].add(actor._actor_id.binary())
|
||||
return ray_client_pb2.ClientTaskTicket(
|
||||
return_ids=[actor._actor_id.binary()])
|
||||
|
||||
def _convert_args(self, arg_list, kwarg_map):
|
||||
argout = []
|
||||
for arg in arg_list:
|
||||
t = convert_from_arg(arg, self)
|
||||
argout.append(t)
|
||||
kwargout = {}
|
||||
for k in kwarg_map:
|
||||
kwargout[k] = convert_from_arg(kwarg_map[k], self)
|
||||
return argout, kwargout
|
||||
|
||||
def lookup_or_register_func(
|
||||
self, id: bytes, client_id: str,
|
||||
options: Optional[Dict]) -> ray.remote_function.RemoteFunction:
|
||||
with disable_client_hook():
|
||||
if id not in self.function_refs:
|
||||
funcref = self.object_refs[client_id][id]
|
||||
func = ray.get(funcref)
|
||||
if not inspect.isfunction(func):
|
||||
raise Exception("Attempting to register function that "
|
||||
"isn't a function.")
|
||||
if options is None or len(options) == 0:
|
||||
self.function_refs[id] = ray.remote(func)
|
||||
else:
|
||||
self.function_refs[id] = ray.remote(**options)(func)
|
||||
return self.function_refs[id]
|
||||
|
||||
def lookup_or_register_actor(self, id: bytes, client_id: str,
|
||||
options: Optional[Dict]):
|
||||
with disable_client_hook():
|
||||
if id not in self.registered_actor_classes:
|
||||
actor_class_ref = self.object_refs[client_id][id]
|
||||
actor_class = ray.get(actor_class_ref)
|
||||
if not inspect.isclass(actor_class):
|
||||
raise Exception("Attempting to schedule actor that "
|
||||
"isn't a class.")
|
||||
reg_class = ray.remote(actor_class)
|
||||
self.registered_actor_classes[payload_ref.binary()] = reg_class
|
||||
remote_class = self.registered_actor_classes[payload_ref.binary()]
|
||||
arglist = _convert_args(task.args, prepared_args)
|
||||
actor = remote_class.remote(*arglist)
|
||||
actorhandle = cloudpickle.dumps(actor)
|
||||
self.actor_refs[actorhandle] = actor
|
||||
return ray_client_pb2.ClientTaskTicket(
|
||||
return_ref=make_remote_ref(actor._actor_id.binary(), actorhandle))
|
||||
if options is None or len(options) == 0:
|
||||
reg_class = ray.remote(actor_class)
|
||||
else:
|
||||
reg_class = ray.remote(**options)(actor_class)
|
||||
self.registered_actor_classes[id] = reg_class
|
||||
|
||||
def _schedule_function(
|
||||
self,
|
||||
task: ray_client_pb2.ClientTask,
|
||||
context=None,
|
||||
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
payload_ref = cloudpickle.loads(task.payload_id)
|
||||
if payload_ref.binary() not in self.function_refs:
|
||||
funcref = self.object_refs[payload_ref.binary()]
|
||||
func = ray.get(funcref)
|
||||
if not inspect.isfunction(func):
|
||||
raise Exception("Attempting to schedule function that "
|
||||
"isn't a function.")
|
||||
self.function_refs[payload_ref.binary()] = ray.remote(func)
|
||||
remote_func = self.function_refs[payload_ref.binary()]
|
||||
arglist = _convert_args(task.args, prepared_args)
|
||||
# Prepare call if we're in a test
|
||||
with stash_api_for_tests(self._test_mode):
|
||||
output = remote_func.remote(*arglist)
|
||||
if output.binary() in self.object_refs:
|
||||
raise Exception("already found it")
|
||||
self.object_refs[output.binary()] = output
|
||||
pickled_output = cloudpickle.dumps(output)
|
||||
return ray_client_pb2.ClientTaskTicket(
|
||||
return_ref=make_remote_ref(output.binary(), pickled_output))
|
||||
return self.registered_actor_classes[id]
|
||||
|
||||
|
||||
def _convert_args(arg_list, prepared_args=None):
|
||||
if prepared_args is not None:
|
||||
return prepared_args
|
||||
out = []
|
||||
for arg in arg_list:
|
||||
t = convert_from_arg(arg)
|
||||
if isinstance(t, ClientObjectRef):
|
||||
out.append(t._unpack_ref())
|
||||
def unify_and_track_outputs(self, output, client_id):
|
||||
if output is None:
|
||||
outputs = []
|
||||
elif isinstance(output, list):
|
||||
outputs = output
|
||||
else:
|
||||
out.append(t)
|
||||
return out
|
||||
|
||||
|
||||
def make_remote_ref(id: bytes, handle: bytes) -> ray_client_pb2.RemoteRef:
|
||||
return ray_client_pb2.RemoteRef(
|
||||
id=id,
|
||||
handle=handle,
|
||||
)
|
||||
outputs = [output]
|
||||
for out in outputs:
|
||||
if out.binary() in self.object_refs[client_id]:
|
||||
logger.warning(f"Already saw object_ref {out}")
|
||||
self.object_refs[client_id][out.binary()] = out
|
||||
return [out.binary() for out in outputs]
|
||||
|
||||
|
||||
def return_exception_in_context(err, context):
|
||||
@@ -252,17 +363,61 @@ def return_exception_in_context(err, context):
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
|
||||
|
||||
def serve(connection_str, test_mode=False):
|
||||
def encode_exception(exception) -> str:
|
||||
data = cloudpickle.dumps(exception)
|
||||
return base64.standard_b64encode(data).decode()
|
||||
|
||||
|
||||
def decode_options(
|
||||
options: ray_client_pb2.TaskOptions) -> Optional[Dict[str, Any]]:
|
||||
if options.json_options == "":
|
||||
return None
|
||||
opts = json.loads(options.json_options)
|
||||
assert isinstance(opts, dict)
|
||||
return opts
|
||||
|
||||
|
||||
_current_servicer: Optional[RayletServicer] = None
|
||||
|
||||
|
||||
# Used by tests to peek inside the servicer
|
||||
def _get_current_servicer():
|
||||
global _current_servicer
|
||||
return _current_servicer
|
||||
|
||||
|
||||
def serve(connection_str):
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
task_servicer = RayletServicer(test_mode=test_mode)
|
||||
_set_server_api(RayServerAPI(task_servicer))
|
||||
task_servicer = RayletServicer()
|
||||
data_servicer = DataServicer(task_servicer)
|
||||
logs_servicer = LogstreamServicer()
|
||||
global _current_servicer
|
||||
_current_servicer = task_servicer
|
||||
ray_client_pb2_grpc.add_RayletDriverServicer_to_server(
|
||||
task_servicer, server)
|
||||
ray_client_pb2_grpc.add_RayletDataStreamerServicer_to_server(
|
||||
data_servicer, server)
|
||||
ray_client_pb2_grpc.add_RayletLogStreamerServicer_to_server(
|
||||
logs_servicer, server)
|
||||
server.add_insecure_port(connection_str)
|
||||
server.start()
|
||||
return server
|
||||
|
||||
|
||||
def init_and_serve(connection_str, *args, **kwargs):
|
||||
with disable_client_hook():
|
||||
# Disable client mode inside the worker's environment
|
||||
info = ray.init(*args, **kwargs)
|
||||
server = serve(connection_str)
|
||||
return (server, info)
|
||||
|
||||
|
||||
def shutdown_with_server(server, _exiting_interpreter=False):
|
||||
server.stop(1)
|
||||
with disable_client_hook():
|
||||
ray.shutdown(_exiting_interpreter)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level="INFO")
|
||||
# TODO(barakmich): Perhaps wrap ray init
|
||||
|
||||
@@ -0,0 +1,135 @@
|
||||
"""Implements the client side of the client/server pickling protocol.
|
||||
|
||||
These picklers are aware of the server internals and can find the
|
||||
references held for the client within the server.
|
||||
|
||||
More discussion about the client/server pickling protocol can be found in:
|
||||
|
||||
ray/experimental/client/client_pickler.py
|
||||
|
||||
ServerPickler dumps ray objects from the server into the appropriate stubs.
|
||||
ClientUnpickler loads stubs from the client and finds their associated handle
|
||||
in the server instance.
|
||||
"""
|
||||
import cloudpickle
|
||||
import io
|
||||
import sys
|
||||
import ray
|
||||
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ray._private.client_mode_hook import disable_client_hook
|
||||
from ray.experimental.client.client_pickler import PickleStub
|
||||
from ray.experimental.client.server.server_stubs import (
|
||||
ServerSelfReferenceSentinel)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.experimental.client.server.server import RayletServicer
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
try:
|
||||
import pickle5 as pickle # noqa: F401
|
||||
except ImportError:
|
||||
import pickle # noqa: F401
|
||||
else:
|
||||
import pickle # noqa: F401
|
||||
|
||||
|
||||
class ServerPickler(cloudpickle.CloudPickler):
|
||||
def __init__(self, client_id: str, server: "RayletServicer", *args,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.client_id = client_id
|
||||
self.server = server
|
||||
|
||||
def persistent_id(self, obj):
|
||||
if isinstance(obj, ray.ObjectRef):
|
||||
obj_id = obj.binary()
|
||||
if obj_id not in self.server.object_refs[self.client_id]:
|
||||
# We're passing back a reference, probably inside a reference.
|
||||
# Let's hold onto it.
|
||||
self.server.object_refs[self.client_id][obj_id] = obj
|
||||
return PickleStub(
|
||||
type="Object",
|
||||
client_id=self.client_id,
|
||||
ref_id=obj_id,
|
||||
name=None,
|
||||
baseline_options=None,
|
||||
)
|
||||
elif isinstance(obj, ray.actor.ActorHandle):
|
||||
actor_id = obj._actor_id.binary()
|
||||
if actor_id not in self.server.actor_refs:
|
||||
# We're passing back a handle, probably inside a reference.
|
||||
self.actor_refs[actor_id] = obj
|
||||
if actor_id not in self.actor_owners[self.client_id]:
|
||||
self.actor_owners[self.client_id].add(actor_id)
|
||||
return PickleStub(
|
||||
type="Actor",
|
||||
client_id=self.client_id,
|
||||
ref_id=obj._actor_id.binary(),
|
||||
name=None,
|
||||
baseline_options=None,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
class ClientUnpickler(pickle.Unpickler):
|
||||
def __init__(self, server, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.server = server
|
||||
|
||||
def persistent_load(self, pid):
|
||||
assert isinstance(pid, PickleStub)
|
||||
if pid.type == "Ray":
|
||||
return ray
|
||||
elif pid.type == "Object":
|
||||
return self.server.object_refs[pid.client_id][pid.ref_id]
|
||||
elif pid.type == "Actor":
|
||||
return self.server.actor_refs[pid.ref_id]
|
||||
elif pid.type == "RemoteFuncSelfReference":
|
||||
return ServerSelfReferenceSentinel()
|
||||
elif pid.type == "RemoteFunc":
|
||||
return self.server.lookup_or_register_func(
|
||||
pid.ref_id, pid.client_id, pid.baseline_options)
|
||||
elif pid.type == "RemoteActorSelfReference":
|
||||
return ServerSelfReferenceSentinel()
|
||||
elif pid.type == "RemoteActor":
|
||||
return self.server.lookup_or_register_actor(
|
||||
pid.ref_id, pid.client_id, pid.baseline_options)
|
||||
elif pid.type == "RemoteMethod":
|
||||
actor = self.server.actor_refs[pid.ref_id]
|
||||
return getattr(actor, pid.name)
|
||||
else:
|
||||
raise NotImplementedError("Uncovered client data type")
|
||||
|
||||
|
||||
def dumps_from_server(obj: Any,
|
||||
client_id: str,
|
||||
server_instance: "RayletServicer",
|
||||
protocol=None) -> bytes:
|
||||
with io.BytesIO() as file:
|
||||
sp = ServerPickler(client_id, server_instance, file, protocol=protocol)
|
||||
sp.dump(obj)
|
||||
return file.getvalue()
|
||||
|
||||
|
||||
def loads_from_client(data: bytes,
|
||||
server_instance: "RayletServicer",
|
||||
*,
|
||||
fix_imports=True,
|
||||
encoding="ASCII",
|
||||
errors="strict") -> Any:
|
||||
with disable_client_hook():
|
||||
if isinstance(data, str):
|
||||
raise TypeError("Can't load pickle from unicode string")
|
||||
file = io.BytesIO(data)
|
||||
return ClientUnpickler(
|
||||
server_instance, file, fix_imports=fix_imports,
|
||||
encoding=encoding).load()
|
||||
|
||||
|
||||
def convert_from_arg(pb: "ray_client_pb2.Arg",
|
||||
server: "RayletServicer") -> Any:
|
||||
return loads_from_client(pb.data, server)
|
||||
@@ -0,0 +1,29 @@
|
||||
from contextlib import contextmanager
|
||||
|
||||
_current_remote_obj = None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def current_remote(r):
|
||||
global _current_remote_obj
|
||||
remote = _current_remote_obj
|
||||
_current_remote_obj = r
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_current_remote_obj = remote
|
||||
|
||||
|
||||
class ServerSelfReferenceSentinel:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __reduce__(self):
|
||||
global _current_remote_obj
|
||||
if _current_remote_obj is None:
|
||||
return (ServerSelfReferenceSentinel, tuple())
|
||||
return (identity, (_current_remote_obj, ))
|
||||
|
||||
|
||||
def identity(x):
|
||||
return x
|
||||
@@ -2,27 +2,30 @@
|
||||
It implements the Ray API functions that are forwarded through grpc calls
|
||||
to the server.
|
||||
"""
|
||||
import inspect
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
from typing import Optional
|
||||
|
||||
import ray.cloudpickle as cloudpickle
|
||||
from ray.util.inspect import is_cython
|
||||
import grpc
|
||||
|
||||
from ray.exceptions import TaskCancelledError
|
||||
import ray.cloudpickle as cloudpickle
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
from ray.experimental.client.common import convert_to_arg
|
||||
from ray.experimental.client.common import decode_exception
|
||||
from ray.experimental.client.common import ClientObjectRef
|
||||
from ray.experimental.client.common import ClientActorClass
|
||||
from ray.experimental.client.client_pickler import convert_to_arg
|
||||
from ray.experimental.client.client_pickler import dumps_from_client
|
||||
from ray.experimental.client.client_pickler import loads_from_server
|
||||
from ray.experimental.client.common import ClientActorHandle
|
||||
from ray.experimental.client.common import ClientRemoteFunc
|
||||
from ray.experimental.client.common import ClientActorRef
|
||||
from ray.experimental.client.common import ClientObjectRef
|
||||
from ray.experimental.client.dataclient import DataClient
|
||||
from ray.experimental.client.logsclient import LogstreamClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -31,34 +34,37 @@ class Worker:
|
||||
def __init__(self,
|
||||
conn_str: str = "",
|
||||
secure: bool = False,
|
||||
metadata: List[Tuple[str, str]] = None,
|
||||
stub=None):
|
||||
metadata: List[Tuple[str, str]] = None):
|
||||
"""Initializes the worker side grpc client.
|
||||
|
||||
Args:
|
||||
stub: custom grpc stub.
|
||||
secure: whether to use SSL secure channel or not.
|
||||
metadata: additional metadata passed in the grpc request headers.
|
||||
"""
|
||||
self.metadata = metadata
|
||||
self.channel = None
|
||||
if stub is None:
|
||||
if secure:
|
||||
credentials = grpc.ssl_channel_credentials()
|
||||
self.channel = grpc.secure_channel(conn_str, credentials)
|
||||
else:
|
||||
self.channel = grpc.insecure_channel(conn_str)
|
||||
self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel)
|
||||
self._client_id = make_client_id()
|
||||
if secure:
|
||||
credentials = grpc.ssl_channel_credentials()
|
||||
self.channel = grpc.secure_channel(conn_str, credentials)
|
||||
else:
|
||||
self.server = stub
|
||||
self.channel = grpc.insecure_channel(conn_str)
|
||||
self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel)
|
||||
|
||||
self.data_client = DataClient(self.channel, self._client_id)
|
||||
self.reference_count: Dict[bytes, int] = defaultdict(int)
|
||||
|
||||
self.log_client = LogstreamClient(self.channel)
|
||||
self.log_client.set_logstream_level(logging.INFO)
|
||||
self.closed = False
|
||||
|
||||
def get(self, vals, *, timeout: Optional[float] = None) -> Any:
|
||||
to_get = []
|
||||
single = False
|
||||
if isinstance(vals, list):
|
||||
to_get = [x.handle for x in vals]
|
||||
to_get = vals
|
||||
elif isinstance(vals, ClientObjectRef):
|
||||
to_get = [vals.handle]
|
||||
to_get = [vals]
|
||||
single = True
|
||||
else:
|
||||
raise Exception("Can't get something that's not a "
|
||||
@@ -70,15 +76,17 @@ class Worker:
|
||||
out = out[0]
|
||||
return out
|
||||
|
||||
def _get(self, handle: bytes, timeout: float):
|
||||
req = ray_client_pb2.GetRequest(handle=handle, timeout=timeout)
|
||||
def _get(self, ref: ClientObjectRef, timeout: float):
|
||||
req = ray_client_pb2.GetRequest(id=ref.id, timeout=timeout)
|
||||
try:
|
||||
data = self.server.GetObject(req, metadata=self.metadata)
|
||||
data = self.data_client.GetObject(req)
|
||||
except grpc.RpcError as e:
|
||||
raise decode_exception(e.details())
|
||||
raise e.details()
|
||||
if not data.valid:
|
||||
raise TaskCancelledError(handle)
|
||||
return cloudpickle.loads(data.data)
|
||||
err = cloudpickle.loads(data.error)
|
||||
logger.error(err)
|
||||
raise err
|
||||
return loads_from_server(data.data)
|
||||
|
||||
def put(self, vals):
|
||||
to_put = []
|
||||
@@ -95,26 +103,37 @@ class Worker:
|
||||
return out
|
||||
|
||||
def _put(self, val):
|
||||
data = cloudpickle.dumps(val)
|
||||
if isinstance(val, ClientObjectRef):
|
||||
raise TypeError(
|
||||
"Calling 'put' on an ObjectRef is not allowed "
|
||||
"(similarly, returning an ObjectRef from a remote "
|
||||
"function is not allowed). If you really want to "
|
||||
"do this, you can wrap the ObjectRef in a list and "
|
||||
"call 'put' on it (or return it).")
|
||||
data = dumps_from_client(val, self._client_id)
|
||||
req = ray_client_pb2.PutRequest(data=data)
|
||||
resp = self.server.PutObject(req, metadata=self.metadata)
|
||||
return ClientObjectRef.from_remote_ref(resp.ref)
|
||||
resp = self.data_client.PutObject(req)
|
||||
return ClientObjectRef(resp.id)
|
||||
|
||||
def wait(self,
|
||||
object_refs: List[ClientObjectRef],
|
||||
*,
|
||||
num_returns: int = 1,
|
||||
timeout: float = None
|
||||
timeout: float = None,
|
||||
fetch_local: bool = True
|
||||
) -> Tuple[List[ClientObjectRef], List[ClientObjectRef]]:
|
||||
assert isinstance(object_refs, list)
|
||||
if not isinstance(object_refs, list):
|
||||
raise TypeError("wait() expected a list of ClientObjectRef, "
|
||||
f"got {type(object_refs)}")
|
||||
for ref in object_refs:
|
||||
assert isinstance(ref, ClientObjectRef)
|
||||
if not isinstance(ref, ClientObjectRef):
|
||||
raise TypeError("wait() expected a list of ClientObjectRef, "
|
||||
f"got list containing {type(ref)}")
|
||||
data = {
|
||||
"object_handles": [
|
||||
object_ref.handle for object_ref in object_refs
|
||||
],
|
||||
"object_ids": [object_ref.id for object_ref in object_refs],
|
||||
"num_returns": num_returns,
|
||||
"timeout": timeout if timeout else -1
|
||||
"timeout": timeout if timeout else -1,
|
||||
"client_id": self._client_id,
|
||||
}
|
||||
req = ray_client_pb2.WaitRequest(**data)
|
||||
resp = self.server.WaitObject(req, metadata=self.metadata)
|
||||
@@ -122,41 +141,69 @@ class Worker:
|
||||
# TODO(ameer): improve error/exceptions messages.
|
||||
raise Exception("Client Wait request failed. Reference invalid?")
|
||||
client_ready_object_ids = [
|
||||
ClientObjectRef.from_remote_ref(ref)
|
||||
for ref in resp.ready_object_ids
|
||||
ClientObjectRef(ref) for ref in resp.ready_object_ids
|
||||
]
|
||||
client_remaining_object_ids = [
|
||||
ClientObjectRef.from_remote_ref(ref)
|
||||
for ref in resp.remaining_object_ids
|
||||
ClientObjectRef(ref) for ref in resp.remaining_object_ids
|
||||
]
|
||||
|
||||
return (client_ready_object_ids, client_remaining_object_ids)
|
||||
|
||||
def remote(self, function_or_class, *args, **kwargs):
|
||||
# TODO(barakmich): Arguments to ray.remote
|
||||
# get captured here.
|
||||
if (inspect.isfunction(function_or_class)
|
||||
or is_cython(function_or_class)):
|
||||
return ClientRemoteFunc(function_or_class)
|
||||
elif inspect.isclass(function_or_class):
|
||||
return ClientActorClass(function_or_class)
|
||||
else:
|
||||
raise TypeError("The @ray.remote decorator must be applied to "
|
||||
"either a function or to a class.")
|
||||
|
||||
def call_remote(self, instance, *args, **kwargs):
|
||||
def call_remote(self, instance, *args, **kwargs) -> List[bytes]:
|
||||
task = instance._prepare_client_task()
|
||||
for arg in args:
|
||||
pb_arg = convert_to_arg(arg)
|
||||
pb_arg = convert_to_arg(arg, self._client_id)
|
||||
task.args.append(pb_arg)
|
||||
logging.debug("Scheduling %s" % task)
|
||||
ticket = self.server.Schedule(task, metadata=self.metadata)
|
||||
return ClientObjectRef.from_remote_ref(ticket.return_ref)
|
||||
for k, v in kwargs.items():
|
||||
task.kwargs[k].CopyFrom(convert_to_arg(v, self._client_id))
|
||||
return self._call_schedule_for_task(task)
|
||||
|
||||
def _call_schedule_for_task(
|
||||
self, task: ray_client_pb2.ClientTask) -> List[bytes]:
|
||||
logger.debug("Scheduling %s" % task)
|
||||
task.client_id = self._client_id
|
||||
try:
|
||||
ticket = self.server.Schedule(task, metadata=self.metadata)
|
||||
except grpc.RpcError as e:
|
||||
raise decode_exception(e.details)
|
||||
if not ticket.valid:
|
||||
raise cloudpickle.loads(ticket.error)
|
||||
return ticket.return_ids
|
||||
|
||||
def call_release(self, id: bytes) -> None:
|
||||
if self.closed:
|
||||
return
|
||||
self.reference_count[id] -= 1
|
||||
if self.reference_count[id] == 0:
|
||||
self._release_server(id)
|
||||
del self.reference_count[id]
|
||||
|
||||
def _release_server(self, id: bytes) -> None:
|
||||
if self.data_client is not None:
|
||||
logger.debug(f"Releasing {id}")
|
||||
self.data_client.ReleaseObject(
|
||||
ray_client_pb2.ReleaseRequest(ids=[id]))
|
||||
|
||||
def call_retain(self, id: bytes) -> None:
|
||||
logger.debug(f"Retaining {id.hex()}")
|
||||
self.reference_count[id] += 1
|
||||
|
||||
def close(self):
|
||||
self.server = None
|
||||
self.log_client.close()
|
||||
self.data_client.close()
|
||||
if self.channel:
|
||||
self.channel.close()
|
||||
self.channel = None
|
||||
self.server = None
|
||||
self.closed = True
|
||||
|
||||
def get_actor(self, name: str) -> ClientActorHandle:
|
||||
task = ray_client_pb2.ClientTask()
|
||||
task.type = ray_client_pb2.ClientTask.NAMED_ACTOR
|
||||
task.name = name
|
||||
ids = self._call_schedule_for_task(task)
|
||||
assert len(ids) == 1
|
||||
return ClientActorHandle(ClientActorRef(ids[0]))
|
||||
|
||||
def terminate_actor(self, actor: ClientActorHandle,
|
||||
no_restart: bool) -> None:
|
||||
@@ -164,10 +211,11 @@ class Worker:
|
||||
raise ValueError("ray.kill() only supported for actors. "
|
||||
"Got: {}.".format(type(actor)))
|
||||
term_actor = ray_client_pb2.TerminateRequest.ActorTerminate()
|
||||
term_actor.handle = actor.actor_ref.handle
|
||||
term_actor.id = actor.actor_ref.id
|
||||
term_actor.no_restart = no_restart
|
||||
try:
|
||||
term = ray_client_pb2.TerminateRequest(actor=term_actor)
|
||||
term.client_id = self._client_id
|
||||
self.server.Terminate(term)
|
||||
except grpc.RpcError as e:
|
||||
raise decode_exception(e.details())
|
||||
@@ -179,11 +227,12 @@ class Worker:
|
||||
"ray.cancel() only supported for non-actor object refs. "
|
||||
f"Got: {type(obj)}.")
|
||||
term_object = ray_client_pb2.TerminateRequest.TaskObjectTerminate()
|
||||
term_object.handle = obj.handle
|
||||
term_object.id = obj.id
|
||||
term_object.force = force
|
||||
term_object.recursive = recursive
|
||||
try:
|
||||
term = ray_client_pb2.TerminateRequest(task_object=term_object)
|
||||
term.client_id = self._client_id
|
||||
self.server.Terminate(term)
|
||||
except grpc.RpcError as e:
|
||||
raise decode_exception(e.details())
|
||||
@@ -193,7 +242,9 @@ class Worker:
|
||||
req.type = type
|
||||
resp = self.server.ClusterInfo(req)
|
||||
if resp.WhichOneof("response_type") == "resource_table":
|
||||
return resp.resource_table.table
|
||||
# translate from a proto map to a python dict
|
||||
output_dict = {k: v for k, v in resp.resource_table.table.items()}
|
||||
return output_dict
|
||||
return json.loads(resp.json)
|
||||
|
||||
def is_initialized(self) -> bool:
|
||||
@@ -201,3 +252,13 @@ class Worker:
|
||||
return self.get_cluster_info(
|
||||
ray_client_pb2.ClusterInfoType.IS_INITIALIZED)
|
||||
return False
|
||||
|
||||
|
||||
def make_client_id() -> str:
|
||||
id = uuid.uuid4()
|
||||
return id.hex
|
||||
|
||||
|
||||
def decode_exception(data) -> Exception:
|
||||
data = base64.standard_b64decode(data)
|
||||
return loads_from_server(data)
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
import ray
|
||||
|
||||
|
||||
def force_spill_objects(object_refs):
|
||||
"""Force spilling objects to external storage.
|
||||
|
||||
Args:
|
||||
object_refs: Object refs of the objects to be
|
||||
spilled.
|
||||
"""
|
||||
core_worker = ray.worker.global_worker.core_worker
|
||||
# Make sure that the values are object refs.
|
||||
for object_ref in object_refs:
|
||||
if not isinstance(object_ref, ray.ObjectRef):
|
||||
raise TypeError(
|
||||
f"Attempting to call `force_spill_objects` on the "
|
||||
f"value {object_ref}, which is not an ray.ObjectRef.")
|
||||
return core_worker.force_spill_objects(object_refs)
|
||||
@@ -157,12 +157,15 @@ class ExternalStorage(metaclass=abc.ABCMeta):
|
||||
|
||||
@abc.abstractmethod
|
||||
def restore_spilled_objects(self, object_refs: List[ObjectRef],
|
||||
url_with_offset_list: List[str]):
|
||||
url_with_offset_list: List[str]) -> int:
|
||||
"""Restore objects from the external storage.
|
||||
|
||||
Args:
|
||||
object_refs: List of object IDs (note that it is not ref).
|
||||
url_with_offset_list: List of url_with_offset.
|
||||
|
||||
Returns:
|
||||
The total number of bytes restored.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -215,6 +218,7 @@ class FileSystemStorage(ExternalStorage):
|
||||
|
||||
def restore_spilled_objects(self, object_refs: List[ObjectRef],
|
||||
url_with_offset_list: List[str]):
|
||||
total = 0
|
||||
for i in range(len(object_refs)):
|
||||
object_ref = object_refs[i]
|
||||
url_with_offset = url_with_offset_list[i].decode()
|
||||
@@ -228,9 +232,11 @@ class FileSystemStorage(ExternalStorage):
|
||||
metadata_len = int.from_bytes(f.read(8), byteorder="little")
|
||||
buf_len = int.from_bytes(f.read(8), byteorder="little")
|
||||
self._size_check(metadata_len, buf_len, parsed_result.size)
|
||||
total += buf_len
|
||||
metadata = f.read(metadata_len)
|
||||
# read remaining data to our buffer
|
||||
self._put_object_to_store(metadata, buf_len, f, object_ref)
|
||||
return total
|
||||
|
||||
def delete_spilled_objects(self, urls: List[str]):
|
||||
for url in urls:
|
||||
@@ -297,6 +303,7 @@ class ExternalStorageSmartOpenImpl(ExternalStorage):
|
||||
def restore_spilled_objects(self, object_refs: List[ObjectRef],
|
||||
url_with_offset_list: List[str]):
|
||||
from smart_open import open
|
||||
total = 0
|
||||
for i in range(len(object_refs)):
|
||||
object_ref = object_refs[i]
|
||||
url_with_offset = url_with_offset_list[i].decode()
|
||||
@@ -315,9 +322,11 @@ class ExternalStorageSmartOpenImpl(ExternalStorage):
|
||||
metadata_len = int.from_bytes(f.read(8), byteorder="little")
|
||||
buf_len = int.from_bytes(f.read(8), byteorder="little")
|
||||
self._size_check(metadata_len, buf_len, parsed_result.size)
|
||||
total += buf_len
|
||||
metadata = f.read(metadata_len)
|
||||
# read remaining data to our buffer
|
||||
self._put_object_to_store(metadata, buf_len, f, object_ref)
|
||||
return total
|
||||
|
||||
def delete_spilled_objects(self, urls: List[str]):
|
||||
pass
|
||||
@@ -367,8 +376,8 @@ def restore_spilled_objects(object_refs: List[ObjectRef],
|
||||
object_refs: List of object IDs (note that it is not ref).
|
||||
url_with_offset_list: List of url_with_offset.
|
||||
"""
|
||||
_external_storage.restore_spilled_objects(object_refs,
|
||||
url_with_offset_list)
|
||||
return _external_storage.restore_spilled_objects(object_refs,
|
||||
url_with_offset_list)
|
||||
|
||||
|
||||
def delete_spilled_objects(urls: List[str]):
|
||||
|
||||
@@ -12,6 +12,7 @@ import hashlib
|
||||
import cython
|
||||
import inspect
|
||||
import uuid
|
||||
import ray.ray_constants as ray_constants
|
||||
|
||||
|
||||
ctypedef object (*FunctionDescriptor_from_cpp)(const CFunctionDescriptor &)
|
||||
@@ -188,7 +189,8 @@ cdef class PythonFunctionDescriptor(FunctionDescriptor):
|
||||
function_name = function.__name__
|
||||
class_name = ""
|
||||
|
||||
pickled_function_hash = hashlib.sha1(pickled_function).hexdigest()
|
||||
pickled_function_hash = hashlib.shake_128(pickled_function).hexdigest(
|
||||
ray_constants.ID_SIZE)
|
||||
|
||||
return cls(module_name, function_name, class_name,
|
||||
pickled_function_hash)
|
||||
@@ -208,7 +210,10 @@ cdef class PythonFunctionDescriptor(FunctionDescriptor):
|
||||
module_name = target_class.__module__
|
||||
class_name = target_class.__name__
|
||||
# Use a random uuid as function hash to solve actor name conflict.
|
||||
return cls(module_name, "__init__", class_name, str(uuid.uuid4()))
|
||||
return cls(
|
||||
module_name, "__init__", class_name,
|
||||
hashlib.shake_128(
|
||||
uuid.uuid4().bytes).hexdigest(ray_constants.ID_SIZE))
|
||||
|
||||
@property
|
||||
def module_name(self):
|
||||
@@ -268,14 +273,14 @@ cdef class PythonFunctionDescriptor(FunctionDescriptor):
|
||||
Returns:
|
||||
ray.ObjectRef to represent the function descriptor.
|
||||
"""
|
||||
function_id_hash = hashlib.sha1()
|
||||
function_id_hash = hashlib.shake_128()
|
||||
# Include the function module and name in the hash.
|
||||
function_id_hash.update(self.typed_descriptor.ModuleName())
|
||||
function_id_hash.update(self.typed_descriptor.FunctionName())
|
||||
function_id_hash.update(self.typed_descriptor.ClassName())
|
||||
function_id_hash.update(self.typed_descriptor.FunctionHash())
|
||||
# Compute the function ID.
|
||||
function_id = function_id_hash.digest()
|
||||
function_id = function_id_hash.digest(ray_constants.ID_SIZE)
|
||||
return ray.FunctionID(function_id)
|
||||
|
||||
def is_actor_method(self):
|
||||
|
||||
@@ -179,9 +179,10 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
|
||||
c_bool plasma_objects_only)
|
||||
CRayStatus Contains(const CObjectID &object_id, c_bool *has_object)
|
||||
CRayStatus Wait(const c_vector[CObjectID] &object_ids, int num_objects,
|
||||
int64_t timeout_ms, c_vector[c_bool] *results)
|
||||
int64_t timeout_ms, c_vector[c_bool] *results,
|
||||
c_bool fetch_local)
|
||||
CRayStatus Delete(const c_vector[CObjectID] &object_ids,
|
||||
c_bool local_only, c_bool delete_creating_tasks)
|
||||
c_bool local_only)
|
||||
CRayStatus TriggerGlobalGC()
|
||||
c_string MemoryUsageString()
|
||||
|
||||
@@ -232,7 +233,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
|
||||
(CRayStatus() nogil) check_signals
|
||||
(void() nogil) gc_collect
|
||||
(c_vector[c_string](const c_vector[CObjectID] &) nogil) spill_objects
|
||||
(void(
|
||||
(int64_t(
|
||||
const c_vector[CObjectID] &,
|
||||
const c_vector[c_string] &) nogil) restore_spilled_objects
|
||||
(void(
|
||||
|
||||
@@ -31,7 +31,7 @@ def check_id(b, size=kUniqueIDSize):
|
||||
raise TypeError("Unsupported type: " + str(type(b)))
|
||||
if len(b) != size:
|
||||
raise ValueError("ID string needs to have length " +
|
||||
str(size))
|
||||
str(size) + ", got " + str(len(b)))
|
||||
|
||||
|
||||
cdef extern from "ray/common/constants.h" nogil:
|
||||
|
||||
@@ -37,7 +37,7 @@ def memory_summary():
|
||||
return reply.memory_summary
|
||||
|
||||
|
||||
def free(object_refs, local_only=False, delete_creating_tasks=False):
|
||||
def free(object_refs, local_only=False):
|
||||
"""Free a list of IDs from the in-process and plasma object stores.
|
||||
|
||||
This function is a low-level API which should be used in restricted
|
||||
@@ -59,8 +59,6 @@ def free(object_refs, local_only=False, delete_creating_tasks=False):
|
||||
object_refs (List[ObjectRef]): List of object refs to delete.
|
||||
local_only (bool): Whether only deleting the list of objects in local
|
||||
object store or all object stores.
|
||||
delete_creating_tasks (bool): Whether also delete the object creating
|
||||
tasks.
|
||||
"""
|
||||
worker = ray.worker.global_worker
|
||||
|
||||
@@ -83,5 +81,4 @@ def free(object_refs, local_only=False, delete_creating_tasks=False):
|
||||
if len(object_refs) == 0:
|
||||
return
|
||||
|
||||
worker.core_worker.free_objects(object_refs, local_only,
|
||||
delete_creating_tasks)
|
||||
worker.core_worker.free_objects(object_refs, local_only)
|
||||
|
||||
@@ -22,7 +22,7 @@ from ray.ray_logging import setup_component_logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# The groups are worker id, job id, and pid.
|
||||
JOB_LOG_PATTERN = re.compile(".*worker-([0-9a-f]{40})-(\d+)-(\d+)")
|
||||
JOB_LOG_PATTERN = re.compile(".*worker-([0-9a-f]+)-(\d+)-(\d+)")
|
||||
|
||||
|
||||
class LogFileInfo:
|
||||
|
||||
+19
-15
@@ -15,11 +15,14 @@ from ray.autoscaler._private.constants import AUTOSCALER_UPDATE_INTERVAL_S
|
||||
from ray.autoscaler._private.load_metrics import LoadMetrics
|
||||
from ray.autoscaler._private.constants import \
|
||||
AUTOSCALER_MAX_RESOURCE_DEMAND_VECTOR_SIZE
|
||||
from ray.autoscaler._private.util import DEBUG_AUTOSCALING_STATUS
|
||||
import ray.gcs_utils
|
||||
import ray.utils
|
||||
import ray.ray_constants as ray_constants
|
||||
from ray.ray_logging import setup_component_logger
|
||||
from ray._raylet import GlobalStateAccessor
|
||||
from ray.experimental.internal_kv import _internal_kv_put, \
|
||||
_internal_kv_initialized
|
||||
|
||||
import redis
|
||||
|
||||
@@ -65,11 +68,7 @@ def parse_resource_demands(resource_load_by_shape):
|
||||
except Exception:
|
||||
logger.exception("Failed to parse resource demands.")
|
||||
|
||||
# Bound the total number of bundles to 2xMAX_RESOURCE_DEMAND_VECTOR_SIZE.
|
||||
# This guarantees the resource demand scheduler bin packing algorithm takes
|
||||
# a reasonable amount of time to run.
|
||||
return waiting_bundles[:AUTOSCALER_MAX_RESOURCE_DEMAND_VECTOR_SIZE], \
|
||||
infeasible_bundles[:AUTOSCALER_MAX_RESOURCE_DEMAND_VECTOR_SIZE]
|
||||
return waiting_bundles, infeasible_bundles
|
||||
|
||||
|
||||
class Monitor:
|
||||
@@ -184,14 +183,8 @@ class Monitor:
|
||||
data: a resource request as JSON, e.g. {"CPU": 1}
|
||||
"""
|
||||
|
||||
if not self.autoscaler:
|
||||
return
|
||||
|
||||
try:
|
||||
self.autoscaler.request_resources(json.loads(data))
|
||||
except Exception:
|
||||
# We don't want this to kill the monitor.
|
||||
traceback.print_exc()
|
||||
resource_request = json.loads(data)
|
||||
self.load_metrics.set_resource_requests(resource_request)
|
||||
|
||||
def process_messages(self, max_messages=10000):
|
||||
"""Process all messages ready in the subscription channels.
|
||||
@@ -257,12 +250,23 @@ class Monitor:
|
||||
|
||||
# Handle messages from the subscription channels.
|
||||
while True:
|
||||
self.update_raylet_map()
|
||||
self.update_load_metrics()
|
||||
status = {
|
||||
"load_metrics_report": self.load_metrics.summary()._asdict()
|
||||
}
|
||||
|
||||
# Process autoscaling actions
|
||||
if self.autoscaler:
|
||||
# Only used to update the load metrics for the autoscaler.
|
||||
self.update_raylet_map()
|
||||
self.update_load_metrics()
|
||||
self.autoscaler.update()
|
||||
status[
|
||||
"autoscaler_report"] = self.autoscaler.summary()._asdict()
|
||||
|
||||
as_json = json.dumps(status)
|
||||
if _internal_kv_initialized():
|
||||
_internal_kv_put(
|
||||
DEBUG_AUTOSCALING_STATUS, as_json, overwrite=True)
|
||||
|
||||
# Process a round of messages.
|
||||
self.process_messages()
|
||||
|
||||
@@ -19,7 +19,7 @@ def env_bool(key, default):
|
||||
return default
|
||||
|
||||
|
||||
ID_SIZE = 20
|
||||
ID_SIZE = 28
|
||||
|
||||
# The default maximum number of bytes to allocate to the object store unless
|
||||
# overridden by the user.
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import ray
|
||||
from ray.utils import binary_to_hex
|
||||
|
||||
@@ -258,3 +261,27 @@ def setup_and_get_worker_interceptor_logger(args,
|
||||
# logger to add a newline at the end of string.
|
||||
handler.terminator = ""
|
||||
return logger
|
||||
|
||||
|
||||
class WorkerStandardStreamDispatcher:
|
||||
def __init__(self):
|
||||
self.handlers = []
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def add_handler(self, name: str, handler: Callable) -> None:
|
||||
with self._lock:
|
||||
self.handlers.append((name, handler))
|
||||
|
||||
def remove_handler(self, name: str) -> None:
|
||||
with self._lock:
|
||||
new_handlers = [pair for pair in self.handlers if pair[0] != name]
|
||||
self.handlers = new_handlers
|
||||
|
||||
def emit(self, data):
|
||||
with self._lock:
|
||||
for pair in self.handlers:
|
||||
_, handle = pair
|
||||
handle(data)
|
||||
|
||||
|
||||
global_worker_stdstream_dispatcher = WorkerStandardStreamDispatcher()
|
||||
|
||||
+6
-26
@@ -2,17 +2,15 @@
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from ray._private.ray_microbenchmark_helpers import timeit
|
||||
from ray._private.ray_client_microbenchmark import (main as
|
||||
client_microbenchmark_main)
|
||||
import numpy as np
|
||||
import multiprocessing
|
||||
import ray
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Only run tests matching this filter pattern.
|
||||
filter_pattern = os.environ.get("TESTS_TO_RUN", "")
|
||||
|
||||
|
||||
@ray.remote(num_cpus=0)
|
||||
class Actor:
|
||||
@@ -71,27 +69,6 @@ def small_value_batch(n):
|
||||
return 0
|
||||
|
||||
|
||||
def timeit(name, fn, multiplier=1):
|
||||
if filter_pattern not in name:
|
||||
return
|
||||
# warmup
|
||||
start = time.time()
|
||||
while time.time() - start < 1:
|
||||
fn()
|
||||
# real run
|
||||
stats = []
|
||||
for _ in range(4):
|
||||
start = time.time()
|
||||
count = 0
|
||||
while time.time() - start < 2:
|
||||
fn()
|
||||
count += 1
|
||||
end = time.time()
|
||||
stats.append(multiplier * count / (end - start))
|
||||
print(name, "per second", round(np.mean(stats), 2), "+-",
|
||||
round(np.std(stats), 2))
|
||||
|
||||
|
||||
def check_optimized_build():
|
||||
if not ray._raylet.OPTIMIZED:
|
||||
msg = ("WARNING: Unoptimized build! "
|
||||
@@ -277,6 +254,9 @@ def main():
|
||||
ray.get([async_actor_work.remote(a) for _ in range(m)])
|
||||
|
||||
timeit("n:n async-actor calls async", async_actor_multi, m * n)
|
||||
ray.shutdown()
|
||||
|
||||
client_microbenchmark_main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -6,7 +6,6 @@ import logging
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from telnetlib import Telnet
|
||||
import time
|
||||
import urllib
|
||||
import urllib.parse
|
||||
@@ -172,8 +171,7 @@ def continue_debug_session():
|
||||
ray.experimental.internal_kv._internal_kv_del(key)
|
||||
return
|
||||
host, port = session["pdb_address"].split(":")
|
||||
with Telnet(host, int(port)) as tn:
|
||||
tn.interact()
|
||||
ray.util.rpdb.connect_pdb_client(host, int(port))
|
||||
ray.experimental.internal_kv._internal_kv_del(key)
|
||||
continue_debug_session()
|
||||
return
|
||||
@@ -215,8 +213,7 @@ def debug(address):
|
||||
ray.experimental.internal_kv._internal_kv_get(
|
||||
active_sessions[index]))
|
||||
host, port = session["pdb_address"].split(":")
|
||||
with Telnet(host, int(port)) as tn:
|
||||
tn.interact()
|
||||
ray.util.rpdb.connect_pdb_client(host, int(port))
|
||||
|
||||
|
||||
@cli.command()
|
||||
|
||||
@@ -74,7 +74,8 @@ def _try_to_compute_deterministic_class_id(cls, depth=5):
|
||||
new_class_id = pickle.dumps(pickle.loads(class_id))
|
||||
if new_class_id == class_id:
|
||||
# We appear to have reached a fix point, so use this as the ID.
|
||||
return hashlib.sha1(new_class_id).digest()
|
||||
return hashlib.shake_128(new_class_id).digest(
|
||||
ray_constants.ID_SIZE)
|
||||
class_id = new_class_id
|
||||
|
||||
# We have not reached a fixed point, so we may end up with a different
|
||||
@@ -82,7 +83,7 @@ def _try_to_compute_deterministic_class_id(cls, depth=5):
|
||||
# same class definition being exported many many times.
|
||||
logger.warning(
|
||||
f"WARNING: Could not produce a deterministic class ID for class {cls}")
|
||||
return hashlib.sha1(new_class_id).digest()
|
||||
return hashlib.shake_128(new_class_id).digest(ray_constants.ID_SIZE)
|
||||
|
||||
|
||||
def object_ref_deserializer(reduced_obj_ref, owner_address):
|
||||
|
||||
@@ -119,6 +119,14 @@ py_test(
|
||||
deps = [":serve_lib"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_imported_backend",
|
||||
size = "small",
|
||||
srcs = serve_tests_srcs,
|
||||
tags = ["exclusive"],
|
||||
deps = [":serve_lib"],
|
||||
)
|
||||
|
||||
|
||||
# Runs test_api and test_failure with injected failures in the controller.
|
||||
# TODO(simon): Tests are disabled until #11683 is fixed.
|
||||
|
||||
+93
-28
@@ -4,22 +4,41 @@ import time
|
||||
from functools import wraps
|
||||
import os
|
||||
from uuid import UUID
|
||||
import threading
|
||||
from typing import Any, Callable, Coroutine, Dict, List, Optional, Type, Union
|
||||
|
||||
from ray.serve.context import TaskContext
|
||||
|
||||
import ray
|
||||
from ray.serve.constants import (DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT,
|
||||
SERVE_CONTROLLER_NAME, HTTP_PROXY_TIMEOUT)
|
||||
from ray.serve.controller import ServeController
|
||||
from ray.serve.handle import RayServeHandle
|
||||
from ray.serve.handle import RayServeHandle, RayServeSyncHandle
|
||||
from ray.serve.utils import (block_until_http_ready, format_actor_name,
|
||||
get_random_letters, logger, get_conda_env_dir)
|
||||
from ray.serve.exceptions import RayServeException
|
||||
from ray.serve.config import BackendConfig, ReplicaConfig, BackendMetadata
|
||||
from ray.serve.config import (BackendConfig, ReplicaConfig, BackendMetadata,
|
||||
HTTPConfig)
|
||||
from ray.serve.env import CondaEnv
|
||||
from ray.serve.router import RequestMetadata, Router
|
||||
from ray.actor import ActorHandle
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, Union
|
||||
|
||||
_INTERNAL_CONTROLLER_NAME = None
|
||||
|
||||
global_async_loop = None
|
||||
|
||||
|
||||
def create_or_get_async_loop_in_thread():
|
||||
global global_async_loop
|
||||
if global_async_loop is None:
|
||||
global_async_loop = asyncio.new_event_loop()
|
||||
thread = threading.Thread(
|
||||
daemon=True,
|
||||
target=global_async_loop.run_forever,
|
||||
)
|
||||
thread.start()
|
||||
return global_async_loop
|
||||
|
||||
|
||||
def _set_internal_controller_name(name):
|
||||
global _INTERNAL_CONTROLLER_NAME
|
||||
@@ -36,6 +55,36 @@ def _ensure_connected(f: Callable) -> Callable:
|
||||
return check
|
||||
|
||||
|
||||
class ThreadProxiedRouter:
|
||||
def __init__(self, controller_handle, sync: bool):
|
||||
self.router = Router(controller_handle)
|
||||
|
||||
if sync:
|
||||
self.async_loop = create_or_get_async_loop_in_thread()
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.router.setup_in_async_loop(),
|
||||
self.async_loop,
|
||||
)
|
||||
else:
|
||||
self.async_loop = asyncio.get_event_loop()
|
||||
self.async_loop.create_task(self.router.setup_in_async_loop())
|
||||
|
||||
def _remote(self, endpoint_name, handle_options, request_data,
|
||||
kwargs) -> Coroutine:
|
||||
request_metadata = RequestMetadata(
|
||||
get_random_letters(10), # Used for debugging.
|
||||
endpoint_name,
|
||||
TaskContext.Python,
|
||||
call_method=handle_options.method_name,
|
||||
shard_key=handle_options.shard_key,
|
||||
http_method=handle_options.http_method,
|
||||
http_headers=handle_options.http_headers,
|
||||
)
|
||||
coro = self.router.assign_request(request_metadata, request_data,
|
||||
**kwargs)
|
||||
return coro
|
||||
|
||||
|
||||
class Client:
|
||||
def __init__(self,
|
||||
controller: ActorHandle,
|
||||
@@ -45,15 +94,10 @@ class Client:
|
||||
self._controller_name = controller_name
|
||||
self._detached = detached
|
||||
self._shutdown = False
|
||||
self._http_host, self._http_port = ray.get(
|
||||
controller.get_http_config.remote())
|
||||
self._http_config = ray.get(controller.get_http_config.remote())
|
||||
|
||||
# NOTE(simon): Used to cache client.get_handle(endpoint) call. It will
|
||||
# mostly grow in size, it will only shrink when user calls the
|
||||
# .remove_endpoint method. This is fine because we expect the number of
|
||||
# endpoints to be fairly small. However, in case this dictionary does
|
||||
# grow very big, we can replace it with a LRU cache instead.
|
||||
self._handle_cache: Dict[str, ActorHandle] = dict()
|
||||
self._sync_proxied_router = None
|
||||
self._async_proxied_router = None
|
||||
|
||||
# NOTE(edoakes): Need this because the shutdown order isn't guaranteed
|
||||
# when the interpreter is exiting so we can't rely on __del__ (it
|
||||
@@ -65,6 +109,18 @@ class Client:
|
||||
|
||||
atexit.register(shutdown_serve_client)
|
||||
|
||||
def _get_proxied_router(self, sync: bool):
|
||||
if sync:
|
||||
if self._sync_proxied_router is None:
|
||||
self._sync_proxied_router = ThreadProxiedRouter(
|
||||
self._controller, sync=True)
|
||||
return self._sync_proxied_router
|
||||
else:
|
||||
if self._async_proxied_router is None:
|
||||
self._async_proxied_router = ThreadProxiedRouter(
|
||||
self._controller, sync=False)
|
||||
return self._async_proxied_router
|
||||
|
||||
def __del__(self):
|
||||
if not self._detached:
|
||||
logger.debug("Shutting down Ray Serve because client went out of "
|
||||
@@ -181,8 +237,8 @@ class Client:
|
||||
num_cpus=0, resources={
|
||||
node_id: 0.01
|
||||
}).remote(
|
||||
"http://{}:{}/-/routes".format(self._http_host,
|
||||
self._http_port),
|
||||
"http://{}:{}/-/routes".format(self._http_config.host,
|
||||
self._http_config.port),
|
||||
check_ready=check_ready,
|
||||
timeout=HTTP_PROXY_TIMEOUT)
|
||||
futures.append(future)
|
||||
@@ -198,8 +254,6 @@ class Client:
|
||||
|
||||
Does not delete any associated backends.
|
||||
"""
|
||||
if endpoint in self._handle_cache:
|
||||
del self._handle_cache[endpoint]
|
||||
self._get_result(self._controller.delete_endpoint.remote(endpoint))
|
||||
|
||||
@_ensure_connected
|
||||
@@ -410,10 +464,11 @@ class Client:
|
||||
proportion))
|
||||
|
||||
@_ensure_connected
|
||||
def get_handle(self,
|
||||
endpoint_name: str,
|
||||
missing_ok: Optional[bool] = False,
|
||||
sync: bool = True) -> RayServeHandle:
|
||||
def get_handle(
|
||||
self,
|
||||
endpoint_name: str,
|
||||
missing_ok: Optional[bool] = False,
|
||||
sync: bool = True) -> Union[RayServeHandle, RayServeSyncHandle]:
|
||||
"""Retrieve RayServeHandle for service endpoint to invoke it from Python.
|
||||
|
||||
Args:
|
||||
@@ -433,14 +488,26 @@ class Client:
|
||||
|
||||
if asyncio.get_event_loop().is_running() and sync:
|
||||
logger.warning(
|
||||
"You are retrieving a ServeHandle inside an asyncio loop. "
|
||||
"You are retrieving a sync handle inside an asyncio loop. "
|
||||
"Try getting client.get_handle(.., sync=False) to get better "
|
||||
"performance.")
|
||||
"performance. Learn more at https://docs.ray.io/en/master/"
|
||||
"serve/advanced.html#sync-and-async-handles")
|
||||
|
||||
if endpoint_name not in self._handle_cache:
|
||||
handle = RayServeHandle(self._controller, endpoint_name, sync=sync)
|
||||
self._handle_cache[endpoint_name] = handle
|
||||
return self._handle_cache[endpoint_name]
|
||||
if not asyncio.get_event_loop().is_running() and not sync:
|
||||
logger.warning(
|
||||
"You are retrieving an async handle outside an asyncio loop. "
|
||||
"You should make sure client.get_handle is called inside a "
|
||||
"running event loop. Or call client.get_handle(.., sync=True) "
|
||||
"to create sync handle. Learn more at https://docs.ray.io/en/"
|
||||
"master/serve/advanced.html#sync-and-async-handles")
|
||||
|
||||
if sync:
|
||||
handle = RayServeSyncHandle(
|
||||
self._get_proxied_router(sync=sync), endpoint_name)
|
||||
else:
|
||||
handle = RayServeHandle(
|
||||
self._get_proxied_router(sync=sync), endpoint_name)
|
||||
return handle
|
||||
|
||||
|
||||
def start(detached: bool = False,
|
||||
@@ -492,9 +559,7 @@ def start(detached: bool = False,
|
||||
max_task_retries=-1,
|
||||
).remote(
|
||||
controller_name,
|
||||
http_host,
|
||||
http_port,
|
||||
http_middlewares,
|
||||
HTTPConfig(http_host, http_port, http_middlewares),
|
||||
detached=detached)
|
||||
|
||||
if http_host is not None:
|
||||
|
||||
@@ -186,10 +186,10 @@ class RayServeReplica:
|
||||
"backend_replica_starts",
|
||||
description=("The number of time this replica "
|
||||
"has been restarted due to failure."),
|
||||
tag_keys=("backend", "replica_tag"))
|
||||
tag_keys=("backend", "replica"))
|
||||
self.restart_counter.set_default_tags({
|
||||
"backend": self.backend_tag,
|
||||
"replica_tag": self.replica_tag
|
||||
"replica": self.replica_tag
|
||||
})
|
||||
|
||||
self.queuing_latency_tracker = metrics.Histogram(
|
||||
@@ -198,39 +198,39 @@ class RayServeReplica:
|
||||
"The latency for queries waiting in the replica's queue "
|
||||
"waiting to be processed or batched."),
|
||||
boundaries=DEFAULT_LATENCY_BUCKET_MS,
|
||||
tag_keys=("backend", "replica_tag"))
|
||||
tag_keys=("backend", "replica"))
|
||||
self.queuing_latency_tracker.set_default_tags({
|
||||
"backend": self.backend_tag,
|
||||
"replica_tag": self.replica_tag
|
||||
"replica": self.replica_tag
|
||||
})
|
||||
|
||||
self.processing_latency_tracker = metrics.Histogram(
|
||||
"backend_processing_latency_ms",
|
||||
description="The latency for queries to be processed",
|
||||
boundaries=DEFAULT_LATENCY_BUCKET_MS,
|
||||
tag_keys=("backend", "replica_tag", "batch_size"))
|
||||
tag_keys=("backend", "replica", "batch_size"))
|
||||
self.processing_latency_tracker.set_default_tags({
|
||||
"backend": self.backend_tag,
|
||||
"replica_tag": self.replica_tag
|
||||
"replica": self.replica_tag
|
||||
})
|
||||
|
||||
self.num_queued_items = metrics.Gauge(
|
||||
"replica_queued_queries",
|
||||
description=("Current number of queries queued in the "
|
||||
"the backend replicas"),
|
||||
tag_keys=("backend", "replica_tag"))
|
||||
tag_keys=("backend", "replica"))
|
||||
self.num_queued_items.set_default_tags({
|
||||
"backend": self.backend_tag,
|
||||
"replica_tag": self.replica_tag
|
||||
"replica": self.replica_tag
|
||||
})
|
||||
|
||||
self.num_processing_items = metrics.Gauge(
|
||||
"replica_processing_queries",
|
||||
description="Current number of queries being processed",
|
||||
tag_keys=("backend", "replica_tag"))
|
||||
tag_keys=("backend", "replica"))
|
||||
self.num_processing_items.set_default_tags({
|
||||
"backend": self.backend_tag,
|
||||
"replica_tag": self.replica_tag
|
||||
"replica": self.replica_tag
|
||||
})
|
||||
|
||||
self.restart_counter.record(1)
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
from ray.serve.utils import import_class
|
||||
|
||||
|
||||
class ImportedBackend:
|
||||
"""Factory for a class that will dynamically import a backend class.
|
||||
|
||||
This is intended to be used when the source code for a backend is
|
||||
installed in the worker environment but not the driver.
|
||||
|
||||
Intended usage:
|
||||
>>> client = serve.connect()
|
||||
>>> client.create_backend("b", ImportedBackend("module.Class"), *args)
|
||||
|
||||
This will import module.Class on the worker and proxy all relevant methods
|
||||
to it.
|
||||
"""
|
||||
|
||||
def __new__(cls, class_path):
|
||||
class ImportedBackend:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.wrapped = import_class(class_path)(*args, **kwargs)
|
||||
|
||||
def reconfigure(self, *args, **kwargs):
|
||||
# NOTE(edoakes): we check that the reconfigure method is
|
||||
# present if the user specifies a user_config, so we need to
|
||||
# proxy it manually.
|
||||
return self.wrapped.reconfigure(*args, **kwargs)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
"""Proxy all other methods to the wrapper class."""
|
||||
return getattr(self.wrapped, attr)
|
||||
|
||||
return ImportedBackend
|
||||
@@ -2,8 +2,8 @@ import inspect
|
||||
|
||||
from pydantic import BaseModel, PositiveInt, validator
|
||||
from ray.serve.constants import ASYNC_CONCURRENCY
|
||||
from typing import Optional, Dict, Any
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Dict, Any, List
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
def _callable_accepts_batch(func_or_class):
|
||||
@@ -191,3 +191,10 @@ class ReplicaConfig:
|
||||
raise TypeError(
|
||||
"resources in ray_actor_options must be a dictionary.")
|
||||
self.resource_dict.update(custom_resources)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HTTPConfig:
|
||||
host: str = field(init=True)
|
||||
port: int = field(init=True)
|
||||
middlewares: List[Any] = field(init=True)
|
||||
|
||||
+283
-224
@@ -20,7 +20,7 @@ from ray.serve.kv_store import RayInternalKVStore
|
||||
from ray.serve.exceptions import RayServeException
|
||||
from ray.serve.utils import (format_actor_name, get_random_letters, logger,
|
||||
try_schedule_resources_on_nodes, get_all_node_ids)
|
||||
from ray.serve.config import BackendConfig, ReplicaConfig
|
||||
from ray.serve.config import BackendConfig, ReplicaConfig, HTTPConfig
|
||||
from ray.serve.long_poll import LongPollHost
|
||||
from ray.actor import ActorHandle
|
||||
|
||||
@@ -80,6 +80,77 @@ class TrafficPolicy:
|
||||
return f"<Traffic {self.traffic_dict}; Shadow {self.shadow_dict}>"
|
||||
|
||||
|
||||
class HTTPState:
|
||||
def __init__(self, controller_name: str, detached: bool,
|
||||
config: HTTPConfig):
|
||||
self._controller_name = controller_name
|
||||
self._detached = detached
|
||||
self._config = config
|
||||
self._proxy_actors: Dict[NodeId, ActorHandle] = dict()
|
||||
|
||||
# Will populate self.proxy_actors with existing actors.
|
||||
self._start_proxies_if_needed()
|
||||
|
||||
def get_config(self):
|
||||
return self._config
|
||||
|
||||
def get_http_proxy_handles(self) -> Dict[NodeId, ActorHandle]:
|
||||
return self._proxy_actors
|
||||
|
||||
def update(self):
|
||||
self._start_proxies_if_needed()
|
||||
self._stop_proxies_if_needed()
|
||||
|
||||
def _start_proxies_if_needed(self) -> None:
|
||||
"""Start a proxy on every node if it doesn't already exist."""
|
||||
if self._config.host is None:
|
||||
return
|
||||
|
||||
for node_id, node_resource in get_all_node_ids():
|
||||
if node_id in self._proxy_actors:
|
||||
continue
|
||||
|
||||
name = format_actor_name(SERVE_PROXY_NAME, self._controller_name,
|
||||
node_id)
|
||||
try:
|
||||
proxy = ray.get_actor(name)
|
||||
except ValueError:
|
||||
logger.info("Starting HTTP proxy with name '{}' on node '{}' "
|
||||
"listening on '{}:{}'".format(
|
||||
name, node_id, self._config.host,
|
||||
self._config.port))
|
||||
proxy = HTTPProxyActor.options(
|
||||
name=name,
|
||||
lifetime="detached" if self._detached else None,
|
||||
max_concurrency=ASYNC_CONCURRENCY,
|
||||
max_restarts=-1,
|
||||
max_task_retries=-1,
|
||||
resources={
|
||||
node_resource: 0.01
|
||||
},
|
||||
).remote(
|
||||
self._config.host,
|
||||
self._config.port,
|
||||
controller_name=self._controller_name,
|
||||
http_middlewares=self._config.middlewares)
|
||||
|
||||
self._proxy_actors[node_id] = proxy
|
||||
|
||||
def _stop_proxies_if_needed(self) -> bool:
|
||||
"""Removes proxy actors from any nodes that no longer exist."""
|
||||
all_node_ids = {node_id for node_id, _ in get_all_node_ids()}
|
||||
to_stop = []
|
||||
for node_id in self._proxy_actors:
|
||||
if node_id not in all_node_ids:
|
||||
logger.info("Removing HTTP proxy on removed node '{}'.".format(
|
||||
node_id))
|
||||
to_stop.append(node_id)
|
||||
|
||||
for node_id in to_stop:
|
||||
proxy = self._proxy_actors.pop(node_id)
|
||||
ray.kill(proxy, no_restart=True)
|
||||
|
||||
|
||||
class BackendInfo(BaseModel):
|
||||
# TODO(architkulkarni): Add type hint for worker_class after upgrading
|
||||
# cloudpickle and adding types to RayServeWrappedReplica
|
||||
@@ -93,17 +164,15 @@ class BackendInfo(BaseModel):
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class SystemState:
|
||||
backends: Dict[BackendTag, BackendInfo] = field(default_factory=dict)
|
||||
traffic_policies: Dict[EndpointTag, TrafficPolicy] = field(
|
||||
default_factory=dict)
|
||||
routes: Dict[BackendTag, Tuple[EndpointTag, Any]] = field(
|
||||
default_factory=dict)
|
||||
class BackendState:
|
||||
def __init__(self, checkpoint: bytes = None):
|
||||
self.backends: Dict[BackendTag, BackendInfo] = dict()
|
||||
|
||||
backend_goal_ids: Dict[BackendTag, GoalId] = field(default_factory=dict)
|
||||
traffic_goal_ids: Dict[EndpointTag, GoalId] = field(default_factory=dict)
|
||||
route_goal_ids: Dict[BackendTag, GoalId] = field(default_factory=dict)
|
||||
if checkpoint is not None:
|
||||
self.backends = pickle.loads(checkpoint)
|
||||
|
||||
def checkpoint(self):
|
||||
return pickle.dumps(self.backends)
|
||||
|
||||
def get_backend_configs(self) -> Dict[BackendTag, BackendConfig]:
|
||||
return {
|
||||
@@ -119,7 +188,18 @@ class SystemState:
|
||||
backend_info: BackendInfo,
|
||||
goal_id: GoalId = 0) -> None:
|
||||
self.backends[backend_tag] = backend_info
|
||||
self.backend_goal_ids = goal_id
|
||||
|
||||
|
||||
class EndpointState:
|
||||
def __init__(self, checkpoint: bytes = None):
|
||||
self.routes: Dict[BackendTag, Tuple[EndpointTag, Any]] = dict()
|
||||
self.traffic_policies: Dict[EndpointTag, TrafficPolicy] = dict()
|
||||
|
||||
if checkpoint is not None:
|
||||
self.routes, self.traffic_policies = pickle.loads(checkpoint)
|
||||
|
||||
def checkpoint(self):
|
||||
return pickle.dumps((self.routes, self.traffic_policies))
|
||||
|
||||
def get_endpoints(self) -> Dict[EndpointTag, Dict[str, Any]]:
|
||||
endpoints = {}
|
||||
@@ -146,7 +226,6 @@ class ActorStateReconciler:
|
||||
controller_name: str = field(init=True)
|
||||
detached: bool = field(init=True)
|
||||
|
||||
http_proxy_cache: Dict[NodeId, ActorHandle] = field(default_factory=dict)
|
||||
backend_replicas: Dict[BackendTag, Dict[ReplicaTag, ActorHandle]] = field(
|
||||
default_factory=lambda: defaultdict(dict))
|
||||
backend_replicas_to_start: Dict[BackendTag, List[ReplicaTag]] = field(
|
||||
@@ -154,13 +233,27 @@ class ActorStateReconciler:
|
||||
backend_replicas_to_stop: Dict[BackendTag, List[ReplicaTag]] = field(
|
||||
default_factory=lambda: defaultdict(list))
|
||||
backends_to_remove: List[BackendTag] = field(default_factory=list)
|
||||
endpoints_to_remove: List[EndpointTag] = field(default_factory=list)
|
||||
|
||||
# NOTE(ilr): These are not checkpointed, but will be recreated by
|
||||
# `_enqueue_pending_scale_changes_loop`.
|
||||
currently_starting_replicas: Dict[asyncio.Future, Tuple[
|
||||
BackendTag, ReplicaTag, ActorHandle]] = field(default_factory=dict)
|
||||
currently_stopping_replicas: Dict[asyncio.Future, Tuple[
|
||||
BackendTag, ReplicaTag]] = field(default_factory=dict)
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
del state["currently_stopping_replicas"]
|
||||
del state["currently_starting_replicas"]
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__.update(state)
|
||||
self.currently_stopping_replicas = {}
|
||||
self.currently_starting_replicas = {}
|
||||
|
||||
# TODO(edoakes): consider removing this and just using the names.
|
||||
|
||||
def http_proxy_handles(self) -> List[ActorHandle]:
|
||||
return list(self.http_proxy_cache.values())
|
||||
|
||||
def get_replica_handles(self) -> List[ActorHandle]:
|
||||
return list(
|
||||
chain.from_iterable([
|
||||
@@ -175,43 +268,7 @@ class ActorStateReconciler:
|
||||
for replica_dict in self.backend_replicas.values()
|
||||
]))
|
||||
|
||||
async def _start_pending_backend_replicas(
|
||||
self, current_state: SystemState) -> None:
|
||||
"""Starts the pending backend replicas in self.backend_replicas_to_start.
|
||||
|
||||
Waits for replicas to start up, then removes them from
|
||||
self.backend_replicas_to_start.
|
||||
"""
|
||||
fut_to_replica_info = {}
|
||||
for backend_tag, replicas_to_create in self.backend_replicas_to_start.\
|
||||
items():
|
||||
for replica_tag in replicas_to_create:
|
||||
replica_handle = await self._start_backend_replica(
|
||||
current_state, backend_tag, replica_tag)
|
||||
ready_future = replica_handle.ready.remote().as_future()
|
||||
fut_to_replica_info[ready_future] = (backend_tag, replica_tag,
|
||||
replica_handle)
|
||||
|
||||
start = time.time()
|
||||
prev_warning = start
|
||||
while fut_to_replica_info:
|
||||
if time.time() - prev_warning > REPLICA_STARTUP_TIME_WARNING_S:
|
||||
prev_warning = time.time()
|
||||
logger.warning("Waited {:.2f}s for replicas to start up. Make "
|
||||
"sure there are enough resources to create the "
|
||||
"replicas.".format(time.time() - start))
|
||||
|
||||
done, pending = await asyncio.wait(
|
||||
list(fut_to_replica_info.keys()), timeout=1)
|
||||
for fut in done:
|
||||
(backend_tag, replica_tag,
|
||||
replica_handle) = fut_to_replica_info.pop(fut)
|
||||
self.backend_replicas[backend_tag][
|
||||
replica_tag] = replica_handle
|
||||
|
||||
self.backend_replicas_to_start.clear()
|
||||
|
||||
async def _start_backend_replica(self, current_state: SystemState,
|
||||
async def _start_backend_replica(self, backend_state: BackendState,
|
||||
backend_tag: BackendTag,
|
||||
replica_tag: ReplicaTag) -> ActorHandle:
|
||||
"""Start a replica and return its actor handle.
|
||||
@@ -229,7 +286,7 @@ class ActorStateReconciler:
|
||||
except ValueError:
|
||||
logger.debug("Starting replica '{}' for backend '{}'.".format(
|
||||
replica_tag, backend_tag))
|
||||
backend_info = current_state.get_backend(backend_tag)
|
||||
backend_info = backend_state.get_backend(backend_tag)
|
||||
|
||||
replica_handle = ray.remote(backend_info.worker_class).options(
|
||||
name=replica_name,
|
||||
@@ -255,6 +312,7 @@ class ActorStateReconciler:
|
||||
intended replicas. This avoids inconsistencies with starting/stopping a
|
||||
replica and then crashing before writing a checkpoint.
|
||||
"""
|
||||
|
||||
logger.debug("Scaling backend '{}' to {} replicas".format(
|
||||
backend_tag, num_replicas))
|
||||
assert (backend_tag in backends
|
||||
@@ -301,97 +359,104 @@ class ActorStateReconciler:
|
||||
|
||||
self.backend_replicas_to_stop[backend_tag].append(replica_tag)
|
||||
|
||||
async def _stop_pending_backend_replicas(self) -> None:
|
||||
"""Stops the pending backend replicas in self.backend_replicas_to_stop.
|
||||
async def _enqueue_pending_scale_changes_loop(self,
|
||||
backend_state: BackendState):
|
||||
for backend_tag, replicas_to_create in self.backend_replicas_to_start.\
|
||||
items():
|
||||
for replica_tag in replicas_to_create:
|
||||
replica_handle = await self._start_backend_replica(
|
||||
backend_state, backend_tag, replica_tag)
|
||||
ready_future = replica_handle.ready.remote().as_future()
|
||||
self.currently_starting_replicas[ready_future] = (
|
||||
backend_tag, replica_tag, replica_handle)
|
||||
|
||||
Removes backend_replicas from the http_proxy, kills them, and clears
|
||||
self.backend_replicas_to_stop.
|
||||
"""
|
||||
for backend_tag, replicas_list in self.backend_replicas_to_stop.items(
|
||||
):
|
||||
for replica_tag in replicas_list:
|
||||
# NOTE(edoakes): the replicas may already be stopped if we
|
||||
# failed after stopping them but before writing a checkpoint.
|
||||
for backend_tag, replicas_to_stop in self.backend_replicas_to_stop.\
|
||||
items():
|
||||
for replica_tag in replicas_to_stop:
|
||||
replica_name = format_actor_name(replica_tag,
|
||||
self.controller_name)
|
||||
try:
|
||||
replica = ray.get_actor(replica_name)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# TODO(edoakes): this logic isn't ideal because there may be
|
||||
# pending tasks still executing on the replica. However, if we
|
||||
# use replica.__ray_terminate__, we may send it while the
|
||||
# replica is being restarted and there's no way to tell if it
|
||||
# successfully killed the worker or not.
|
||||
ray.kill(replica, no_restart=True)
|
||||
async def kill_actor(replica_name_to_use):
|
||||
# NOTE: the replicas may already be stopped if we failed
|
||||
# after stopping them but before writing a checkpoint.
|
||||
try:
|
||||
replica = ray.get_actor(replica_name_to_use)
|
||||
except ValueError:
|
||||
return
|
||||
|
||||
self.backend_replicas_to_stop.clear()
|
||||
# TODO(edoakes): this logic isn't ideal because there may
|
||||
# be pending tasks still executing on the replica. However,
|
||||
# if we use replica.__ray_terminate__, we may send it while
|
||||
# the replica is being restarted and there's no way to tell
|
||||
# if it successfully killed the worker or not.
|
||||
ray.kill(replica, no_restart=True)
|
||||
|
||||
def _start_http_proxies_if_needed(self, http_host: str, http_port: str,
|
||||
http_middlewares: List[Any]) -> None:
|
||||
"""Start an HTTP proxy on every node if it doesn't already exist."""
|
||||
if http_host is None:
|
||||
return
|
||||
self.currently_stopping_replicas[asyncio.ensure_future(
|
||||
kill_actor(replica_name))] = (backend_tag, replica_tag)
|
||||
|
||||
for node_id, node_resource in get_all_node_ids():
|
||||
if node_id in self.http_proxy_cache:
|
||||
continue
|
||||
async def _check_currently_starting_replicas(self) -> bool:
|
||||
"""Returns a boolean specifying if there are more replicas to start"""
|
||||
in_flight = list()
|
||||
|
||||
name = format_actor_name(SERVE_PROXY_NAME, self.controller_name,
|
||||
node_id)
|
||||
try:
|
||||
proxy = ray.get_actor(name)
|
||||
except ValueError:
|
||||
logger.info("Starting HTTP proxy with name '{}' on node '{}' "
|
||||
"listening on '{}:{}'".format(
|
||||
name, node_id, http_host, http_port))
|
||||
proxy = HTTPProxyActor.options(
|
||||
name=name,
|
||||
lifetime="detached" if self.detached else None,
|
||||
max_concurrency=ASYNC_CONCURRENCY,
|
||||
max_restarts=-1,
|
||||
max_task_retries=-1,
|
||||
resources={
|
||||
node_resource: 0.01
|
||||
},
|
||||
).remote(
|
||||
http_host,
|
||||
http_port,
|
||||
controller_name=self.controller_name,
|
||||
http_middlewares=http_middlewares)
|
||||
if self.currently_starting_replicas:
|
||||
done, in_flight = await asyncio.wait(
|
||||
list(self.currently_starting_replicas.keys()), timeout=0)
|
||||
for fut in done:
|
||||
(backend_tag, replica_tag,
|
||||
replica_handle) = self.currently_starting_replicas.pop(fut)
|
||||
self.backend_replicas[backend_tag][
|
||||
replica_tag] = replica_handle
|
||||
|
||||
self.http_proxy_cache[node_id] = proxy
|
||||
backend = self.backend_replicas_to_start.get(backend_tag)
|
||||
if backend:
|
||||
try:
|
||||
backend.remove(replica_tag)
|
||||
except ValueError:
|
||||
pass
|
||||
if len(backend) == 0:
|
||||
del self.backend_replicas_to_start[backend_tag]
|
||||
return len(in_flight) > 0
|
||||
|
||||
def _stop_http_proxies_if_needed(self) -> bool:
|
||||
"""Removes HTTP proxy actors from any nodes that no longer exist.
|
||||
async def _check_currently_stopping_replicas(self) -> bool:
|
||||
"""Returns a boolean specifying if there are more replicas to stop"""
|
||||
in_flight = list()
|
||||
if self.currently_stopping_replicas:
|
||||
done_stoppping, in_flight = await asyncio.wait(
|
||||
list(self.currently_stopping_replicas.keys()), timeout=0)
|
||||
for fut in done_stoppping:
|
||||
(backend_tag,
|
||||
replica_tag) = self.currently_stopping_replicas.pop(fut)
|
||||
|
||||
Returns whether or not any actors were removed (a checkpoint should
|
||||
be taken).
|
||||
"""
|
||||
actor_stopped = False
|
||||
all_node_ids = {node_id for node_id, _ in get_all_node_ids()}
|
||||
to_stop = []
|
||||
for node_id in self.http_proxy_cache:
|
||||
if node_id not in all_node_ids:
|
||||
logger.info("Removing HTTP proxy on removed node '{}'.".format(
|
||||
node_id))
|
||||
to_stop.append(node_id)
|
||||
backend = self.backend_replicas_to_stop.get(backend_tag)
|
||||
|
||||
for node_id in to_stop:
|
||||
proxy = self.http_proxy_cache.pop(node_id)
|
||||
ray.kill(proxy, no_restart=True)
|
||||
actor_stopped = True
|
||||
if backend:
|
||||
try:
|
||||
backend.remove(replica_tag)
|
||||
except ValueError:
|
||||
pass
|
||||
if len(backend) == 0:
|
||||
del self.backend_replicas_to_stop[backend_tag]
|
||||
|
||||
return actor_stopped
|
||||
return len(in_flight) > 0
|
||||
|
||||
async def backend_control_loop(self):
|
||||
start = time.time()
|
||||
prev_warning = start
|
||||
need_to_continue = True
|
||||
while need_to_continue:
|
||||
if time.time() - prev_warning > REPLICA_STARTUP_TIME_WARNING_S:
|
||||
prev_warning = time.time()
|
||||
logger.warning("Waited {:.2f}s for replicas to start up. Make "
|
||||
"sure there are enough resources to create the "
|
||||
"replicas.".format(time.time() - start))
|
||||
|
||||
need_to_continue = (
|
||||
await self._check_currently_starting_replicas()
|
||||
or await self._check_currently_stopping_replicas())
|
||||
|
||||
asyncio.sleep(1)
|
||||
|
||||
def _recover_actor_handles(self) -> None:
|
||||
# Refresh the RouterCache
|
||||
for node_id in self.http_proxy_cache.keys():
|
||||
name = format_actor_name(SERVE_PROXY_NAME, self.controller_name,
|
||||
node_id)
|
||||
self.http_proxy_cache[node_id] = ray.get_actor(name)
|
||||
|
||||
# Fetch actor handles for all of the backend replicas in the system.
|
||||
# All of these backend_replicas are guaranteed to already exist because
|
||||
# they would not be written to a checkpoint in self.backend_replicas
|
||||
@@ -404,20 +469,20 @@ class ActorStateReconciler:
|
||||
replica_tag] = ray.get_actor(replica_name)
|
||||
|
||||
async def _recover_from_checkpoint(
|
||||
self, current_state: SystemState, controller: "ServeController"
|
||||
self, backend_state: BackendState, controller: "ServeController"
|
||||
) -> Dict[BackendTag, BasicAutoscalingPolicy]:
|
||||
self._recover_actor_handles()
|
||||
autoscaling_policies = dict()
|
||||
|
||||
for backend, info in current_state.backends.items():
|
||||
for backend, info in backend_state.backends.items():
|
||||
metadata = info.backend_config.internal_metadata
|
||||
if metadata.autoscaling_config is not None:
|
||||
autoscaling_policies[backend] = BasicAutoscalingPolicy(
|
||||
backend, metadata.autoscaling_config)
|
||||
|
||||
# Start/stop any pending backend replicas.
|
||||
await self._start_pending_backend_replicas(current_state)
|
||||
await self._stop_pending_backend_replicas()
|
||||
await self._enqueue_pending_scale_changes_loop(backend_state)
|
||||
await self.backend_control_loop()
|
||||
|
||||
return autoscaling_policies
|
||||
|
||||
@@ -430,8 +495,8 @@ class FutureResult:
|
||||
|
||||
@dataclass
|
||||
class Checkpoint:
|
||||
goal_state: SystemState
|
||||
current_state: SystemState
|
||||
endpoint_state_checkpoint: bytes
|
||||
backend_state_checkpoint: bytes
|
||||
reconciler: ActorStateReconciler
|
||||
# TODO(ilr) Rename reconciler to PendingState
|
||||
inflight_reqs: Dict[uuid4, FutureResult]
|
||||
@@ -465,19 +530,10 @@ class ServeController:
|
||||
|
||||
async def __init__(self,
|
||||
controller_name: str,
|
||||
http_host: str,
|
||||
http_port: str,
|
||||
http_middlewares: List[Any],
|
||||
http_config: HTTPConfig,
|
||||
detached: bool = False):
|
||||
# Used to read/write checkpoints.
|
||||
self.kv_store = RayInternalKVStore(namespace=controller_name)
|
||||
# Current State
|
||||
self.current_state = SystemState()
|
||||
# Goal State
|
||||
# TODO(ilr) This is currently *unused* until the refactor of the serve
|
||||
# controller.
|
||||
self.goal_state = SystemState()
|
||||
# ActorStateReconciler
|
||||
self.actor_reconciler = ActorStateReconciler(controller_name, detached)
|
||||
|
||||
# backend -> AutoscalingPolicy
|
||||
@@ -490,24 +546,25 @@ class ServeController:
|
||||
# at any given time.
|
||||
self.write_lock = asyncio.Lock()
|
||||
|
||||
self.http_host = http_host
|
||||
self.http_port = http_port
|
||||
self.http_middlewares = http_middlewares
|
||||
|
||||
# If starting the actor for the first time, starts up the other system
|
||||
# components. If recovering, fetches their actor handles.
|
||||
self.actor_reconciler._start_http_proxies_if_needed(
|
||||
self.http_host, self.http_port, self.http_middlewares)
|
||||
|
||||
# Map of awaiting results
|
||||
# TODO(ilr): Checkpoint this once this becomes asynchronous
|
||||
self.inflight_results: Dict[UUID, asyncio.Event] = dict()
|
||||
self._serializable_inflight_results: Dict[UUID, FutureResult] = dict()
|
||||
|
||||
checkpoint = self.kv_store.get(CHECKPOINT_KEY)
|
||||
if checkpoint is None:
|
||||
# HTTP state doesn't currently require a checkpoint.
|
||||
self.http_state = HTTPState(controller_name, detached, http_config)
|
||||
|
||||
checkpoint_bytes = self.kv_store.get(CHECKPOINT_KEY)
|
||||
if checkpoint_bytes is None:
|
||||
logger.debug("No checkpoint found")
|
||||
self.backend_state = BackendState()
|
||||
self.endpoint_state = EndpointState()
|
||||
else:
|
||||
checkpoint: Checkpoint = pickle.loads(checkpoint_bytes)
|
||||
self.backend_state = BackendState(
|
||||
checkpoint=checkpoint.backend_state_checkpoint)
|
||||
self.endpoint_state = EndpointState(
|
||||
checkpoint=checkpoint.endpoint_state_checkpoint)
|
||||
await self._recover_from_checkpoint(checkpoint)
|
||||
|
||||
# NOTE(simon): Currently we do all-to-all broadcast. This means
|
||||
@@ -566,17 +623,17 @@ class ServeController:
|
||||
def notify_traffic_policies_changed(self):
|
||||
self.long_poll_host.notify_changed(
|
||||
LongPollKey.TRAFFIC_POLICIES,
|
||||
self.current_state.traffic_policies,
|
||||
self.endpoint_state.traffic_policies,
|
||||
)
|
||||
|
||||
def notify_backend_configs_changed(self):
|
||||
self.long_poll_host.notify_changed(
|
||||
LongPollKey.BACKEND_CONFIGS,
|
||||
self.current_state.get_backend_configs())
|
||||
self.backend_state.get_backend_configs())
|
||||
|
||||
def notify_route_table_changed(self):
|
||||
self.long_poll_host.notify_changed(LongPollKey.ROUTE_TABLE,
|
||||
self.current_state.routes)
|
||||
self.endpoint_state.routes)
|
||||
|
||||
async def listen_for_change(self, keys_to_snapshot_ids: Dict[str, int]):
|
||||
"""Proxy long pull client's listen request.
|
||||
@@ -589,9 +646,9 @@ class ServeController:
|
||||
return await (
|
||||
self.long_poll_host.listen_for_change(keys_to_snapshot_ids))
|
||||
|
||||
def get_http_proxies(self) -> Dict[str, ActorHandle]:
|
||||
def get_http_proxies(self) -> Dict[NodeId, ActorHandle]:
|
||||
"""Returns a dictionary of node ID to http_proxy actor handles."""
|
||||
return self.actor_reconciler.http_proxy_cache
|
||||
return self.http_state.get_http_proxy_handles()
|
||||
|
||||
def _checkpoint(self) -> None:
|
||||
"""Checkpoint internal state and write it to the KV store."""
|
||||
@@ -600,19 +657,19 @@ class ServeController:
|
||||
start = time.time()
|
||||
|
||||
checkpoint = pickle.dumps(
|
||||
Checkpoint(self.goal_state, self.current_state,
|
||||
self.actor_reconciler,
|
||||
Checkpoint(self.endpoint_state.checkpoint(),
|
||||
self.backend_state.checkpoint(), self.actor_reconciler,
|
||||
self._serializable_inflight_results))
|
||||
|
||||
self.kv_store.put(CHECKPOINT_KEY, checkpoint)
|
||||
logger.debug("Wrote checkpoint in {:.2f}".format(time.time() - start))
|
||||
logger.debug("Wrote checkpoint in {:.3f}s".format(time.time() - start))
|
||||
|
||||
if random.random(
|
||||
) < _CRASH_AFTER_CHECKPOINT_PROBABILITY and self.detached:
|
||||
logger.warning("Intentionally crashing after checkpoint")
|
||||
os._exit(0)
|
||||
|
||||
async def _recover_from_checkpoint(self, checkpoint_bytes: bytes) -> None:
|
||||
async def _recover_from_checkpoint(self, checkpoint: Checkpoint) -> None:
|
||||
"""Recover the instance state from the provided checkpoint.
|
||||
|
||||
This should be called in the constructor to ensure that the internal
|
||||
@@ -627,12 +684,9 @@ class ServeController:
|
||||
start = time.time()
|
||||
logger.info("Recovering from checkpoint")
|
||||
|
||||
restored_checkpoint: Checkpoint = pickle.loads(checkpoint_bytes)
|
||||
self.current_state = restored_checkpoint.current_state
|
||||
self.actor_reconciler = checkpoint.reconciler
|
||||
|
||||
self.actor_reconciler = restored_checkpoint.reconciler
|
||||
|
||||
self._serializable_inflight_results = restored_checkpoint.inflight_reqs
|
||||
self._serializable_inflight_results = checkpoint.inflight_reqs
|
||||
for uuid, fut_result in self._serializable_inflight_results.items():
|
||||
self._create_event_with_result(fut_result.requested_goal, uuid)
|
||||
|
||||
@@ -652,7 +706,7 @@ class ServeController:
|
||||
async def finish_recover_from_checkpoint():
|
||||
assert self.write_lock.locked()
|
||||
self.autoscaling_policies = await self.actor_reconciler.\
|
||||
_recover_from_checkpoint(self.current_state, self)
|
||||
_recover_from_checkpoint(self.backend_state, self)
|
||||
self.write_lock.release()
|
||||
logger.info(
|
||||
"Recovered from checkpoint in {:.3f}s".format(time.time() -
|
||||
@@ -662,7 +716,7 @@ class ServeController:
|
||||
asyncio.get_event_loop().create_task(finish_recover_from_checkpoint())
|
||||
|
||||
async def do_autoscale(self) -> None:
|
||||
for backend, info in self.current_state.backends.items():
|
||||
for backend, info in self.backend_state.backends.items():
|
||||
if backend not in self.autoscaling_policies:
|
||||
continue
|
||||
|
||||
@@ -672,16 +726,14 @@ class ServeController:
|
||||
await self.update_backend_config(
|
||||
backend, BackendConfig(num_replicas=new_num_replicas))
|
||||
|
||||
async def reconcile_current_and_goal_backends(self):
|
||||
pass
|
||||
|
||||
async def run_control_loop(self) -> None:
|
||||
while True:
|
||||
await self.do_autoscale()
|
||||
async with self.write_lock:
|
||||
self.actor_reconciler._start_http_proxies_if_needed(
|
||||
self.http_host, self.http_port, self.http_middlewares)
|
||||
checkpoint_required = self.actor_reconciler.\
|
||||
_stop_http_proxies_if_needed()
|
||||
if checkpoint_required:
|
||||
self._checkpoint()
|
||||
self.http_state.update()
|
||||
|
||||
await asyncio.sleep(CONTROL_LOOP_PERIOD_S)
|
||||
|
||||
@@ -692,15 +744,15 @@ class ServeController:
|
||||
|
||||
def get_all_backends(self) -> Dict[BackendTag, BackendConfig]:
|
||||
"""Returns a dictionary of backend tag to backend config."""
|
||||
return self.current_state.get_backend_configs()
|
||||
return self.backend_state.get_backend_configs()
|
||||
|
||||
def get_all_endpoints(self) -> Dict[EndpointTag, Dict[BackendTag, Any]]:
|
||||
"""Returns a dictionary of backend tag to backend config."""
|
||||
return self.current_state.get_endpoints()
|
||||
return self.endpoint_state.get_endpoints()
|
||||
|
||||
async def _set_traffic(self, endpoint_name: str,
|
||||
traffic_dict: Dict[str, float]) -> UUID:
|
||||
if endpoint_name not in self.current_state.get_endpoints():
|
||||
if endpoint_name not in self.endpoint_state.get_endpoints():
|
||||
raise ValueError("Attempted to assign traffic for an endpoint '{}'"
|
||||
" that is not registered.".format(endpoint_name))
|
||||
|
||||
@@ -708,13 +760,13 @@ class ServeController:
|
||||
dict), "Traffic policy must be a dictionary."
|
||||
|
||||
for backend in traffic_dict:
|
||||
if self.current_state.get_backend(backend) is None:
|
||||
if self.backend_state.get_backend(backend) is None:
|
||||
raise ValueError(
|
||||
"Attempted to assign traffic to a backend '{}' that "
|
||||
"is not registered.".format(backend))
|
||||
|
||||
traffic_policy = TrafficPolicy(traffic_dict)
|
||||
self.current_state.traffic_policies[endpoint_name] = traffic_policy
|
||||
self.endpoint_state.traffic_policies[endpoint_name] = traffic_policy
|
||||
|
||||
return_uuid = self._create_event_with_result({
|
||||
endpoint_name: traffic_policy
|
||||
@@ -737,20 +789,21 @@ class ServeController:
|
||||
proportion: float) -> UUID:
|
||||
"""Shadow traffic from the endpoint to the backend."""
|
||||
async with self.write_lock:
|
||||
if endpoint_name not in self.current_state.get_endpoints():
|
||||
if endpoint_name not in self.endpoint_state.get_endpoints():
|
||||
raise ValueError("Attempted to shadow traffic from an "
|
||||
"endpoint '{}' that is not registered."
|
||||
.format(endpoint_name))
|
||||
|
||||
if self.current_state.get_backend(backend_tag) is None:
|
||||
if self.backend_state.get_backend(backend_tag) is None:
|
||||
raise ValueError(
|
||||
"Attempted to shadow traffic to a backend '{}' that "
|
||||
"is not registered.".format(backend_tag))
|
||||
|
||||
self.current_state.traffic_policies[endpoint_name].set_shadow(
|
||||
self.endpoint_state.traffic_policies[endpoint_name].set_shadow(
|
||||
backend_tag, proportion)
|
||||
|
||||
traffic_policy = self.current_state.traffic_policies[endpoint_name]
|
||||
traffic_policy = self.endpoint_state.traffic_policies[
|
||||
endpoint_name]
|
||||
|
||||
return_uuid = self._create_event_with_result({
|
||||
endpoint_name: traffic_policy
|
||||
@@ -781,10 +834,10 @@ class ServeController:
|
||||
|
||||
# TODO(edoakes): move this to client side.
|
||||
err_prefix = "Cannot create endpoint."
|
||||
if route in self.current_state.routes:
|
||||
if route in self.endpoint_state.routes:
|
||||
|
||||
# Ensures this method is idempotent
|
||||
if self.current_state.routes[route] == (endpoint, methods):
|
||||
if self.endpoint_state.routes[route] == (endpoint, methods):
|
||||
return
|
||||
|
||||
else:
|
||||
@@ -792,7 +845,7 @@ class ServeController:
|
||||
"{} Route '{}' is already registered.".format(
|
||||
err_prefix, route))
|
||||
|
||||
if endpoint in self.current_state.get_endpoints():
|
||||
if endpoint in self.endpoint_state.get_endpoints():
|
||||
raise ValueError(
|
||||
"{} Endpoint '{}' is already registered.".format(
|
||||
err_prefix, endpoint))
|
||||
@@ -801,7 +854,7 @@ class ServeController:
|
||||
"Registering route '{}' to endpoint '{}' with methods '{}'.".
|
||||
format(route, endpoint, methods))
|
||||
|
||||
self.current_state.routes[route] = (endpoint, methods)
|
||||
self.endpoint_state.routes[route] = (endpoint, methods)
|
||||
|
||||
# NOTE(edoakes): checkpoint is written in self._set_traffic.
|
||||
return_uuid = await self._set_traffic(endpoint, traffic_dict)
|
||||
@@ -818,7 +871,7 @@ class ServeController:
|
||||
# This method must be idempotent. We should validate that the
|
||||
# specified endpoint exists on the client.
|
||||
for route, (route_endpoint,
|
||||
_) in self.current_state.routes.items():
|
||||
_) in self.endpoint_state.routes.items():
|
||||
if route_endpoint == endpoint:
|
||||
route_to_delete = route
|
||||
break
|
||||
@@ -827,13 +880,11 @@ class ServeController:
|
||||
return
|
||||
|
||||
# Remove the routing entry.
|
||||
del self.current_state.routes[route_to_delete]
|
||||
del self.endpoint_state.routes[route_to_delete]
|
||||
|
||||
# Remove the traffic policy entry if it exists.
|
||||
if endpoint in self.current_state.traffic_policies:
|
||||
del self.current_state.traffic_policies[endpoint]
|
||||
|
||||
self.actor_reconciler.endpoints_to_remove.append(endpoint)
|
||||
if endpoint in self.endpoint_state.traffic_policies:
|
||||
del self.endpoint_state.traffic_policies[endpoint]
|
||||
|
||||
return_uuid = self._create_event_with_result({
|
||||
route_to_delete: None,
|
||||
@@ -852,7 +903,7 @@ class ServeController:
|
||||
"""Register a new backend under the specified tag."""
|
||||
async with self.write_lock:
|
||||
# Ensures this method is idempotent.
|
||||
backend_info = self.current_state.get_backend(backend_tag)
|
||||
backend_info = self.backend_state.get_backend(backend_tag)
|
||||
if backend_info is not None:
|
||||
if (backend_info.backend_config == backend_config
|
||||
and backend_info.replica_config == replica_config):
|
||||
@@ -867,7 +918,7 @@ class ServeController:
|
||||
worker_class=backend_replica,
|
||||
backend_config=backend_config,
|
||||
replica_config=replica_config)
|
||||
self.current_state.add_backend(backend_tag, backend_info)
|
||||
self.backend_state.add_backend(backend_tag, backend_info)
|
||||
metadata = backend_config.internal_metadata
|
||||
if metadata.autoscaling_config is not None:
|
||||
self.autoscaling_policies[
|
||||
@@ -875,11 +926,12 @@ class ServeController:
|
||||
backend_tag, metadata.autoscaling_config)
|
||||
|
||||
try:
|
||||
# This call should be to run control loop
|
||||
self.actor_reconciler._scale_backend_replicas(
|
||||
self.current_state.backends, backend_tag,
|
||||
self.backend_state.backends, backend_tag,
|
||||
backend_config.num_replicas)
|
||||
except RayServeException as e:
|
||||
del self.current_state.backends[backend_tag]
|
||||
del self.backend_state.backends[backend_tag]
|
||||
raise e
|
||||
|
||||
return_uuid = self._create_event_with_result({
|
||||
@@ -889,8 +941,9 @@ class ServeController:
|
||||
# or pushing the updated config to avoid inconsistent state if we
|
||||
# crash while making the change.
|
||||
self._checkpoint()
|
||||
await self.actor_reconciler._start_pending_backend_replicas(
|
||||
self.current_state)
|
||||
await self.actor_reconciler._enqueue_pending_scale_changes_loop(
|
||||
self.backend_state)
|
||||
await self.actor_reconciler.backend_control_loop()
|
||||
|
||||
self.notify_replica_handles_changed()
|
||||
|
||||
@@ -903,11 +956,11 @@ class ServeController:
|
||||
async with self.write_lock:
|
||||
# This method must be idempotent. We should validate that the
|
||||
# specified backend exists on the client.
|
||||
if self.current_state.get_backend(backend_tag) is None:
|
||||
if self.backend_state.get_backend(backend_tag) is None:
|
||||
return
|
||||
|
||||
# Check that the specified backend isn't used by any endpoints.
|
||||
for endpoint, traffic_policy in self.current_state.\
|
||||
for endpoint, traffic_policy in self.endpoint_state.\
|
||||
traffic_policies.items():
|
||||
if (backend_tag in traffic_policy.traffic_dict
|
||||
or backend_tag in traffic_policy.shadow_dict):
|
||||
@@ -917,13 +970,15 @@ class ServeController:
|
||||
"again.".format(backend_tag, endpoint))
|
||||
|
||||
# Scale its replicas down to 0. This will also remove the backend
|
||||
# from self.current_state.backends and
|
||||
# from self.backend_state.backends and
|
||||
# self.actor_reconciler.backend_replicas.
|
||||
|
||||
# This should be a call to the control loop
|
||||
self.actor_reconciler._scale_backend_replicas(
|
||||
self.current_state.backends, backend_tag, 0)
|
||||
self.backend_state.backends, backend_tag, 0)
|
||||
|
||||
# Remove the backend's metadata.
|
||||
del self.current_state.backends[backend_tag]
|
||||
del self.backend_state.backends[backend_tag]
|
||||
if backend_tag in self.autoscaling_policies:
|
||||
del self.autoscaling_policies[backend_tag]
|
||||
|
||||
@@ -935,7 +990,9 @@ class ServeController:
|
||||
# backend from the routers to avoid inconsistent state if we crash
|
||||
# after pushing the update.
|
||||
self._checkpoint()
|
||||
await self.actor_reconciler._stop_pending_backend_replicas()
|
||||
await self.actor_reconciler._enqueue_pending_scale_changes_loop(
|
||||
self.backend_state)
|
||||
await self.actor_reconciler.backend_control_loop()
|
||||
|
||||
self.notify_replica_handles_changed()
|
||||
return return_uuid
|
||||
@@ -944,22 +1001,24 @@ class ServeController:
|
||||
config_options: BackendConfig) -> UUID:
|
||||
"""Set the config for the specified backend."""
|
||||
async with self.write_lock:
|
||||
assert (self.current_state.get_backend(backend_tag)
|
||||
assert (self.backend_state.get_backend(backend_tag)
|
||||
), "Backend {} is not registered.".format(backend_tag)
|
||||
assert isinstance(config_options, BackendConfig)
|
||||
|
||||
stored_backend_config = self.current_state.get_backend(
|
||||
stored_backend_config = self.backend_state.get_backend(
|
||||
backend_tag).backend_config
|
||||
backend_config = stored_backend_config.copy(
|
||||
update=config_options.dict(exclude_unset=True))
|
||||
backend_config._validate_complete()
|
||||
self.current_state.get_backend(
|
||||
self.backend_state.get_backend(
|
||||
backend_tag).backend_config = backend_config
|
||||
backend_info = self.current_state.get_backend(backend_tag)
|
||||
backend_info = self.backend_state.get_backend(backend_tag)
|
||||
|
||||
# Scale the replicas with the new configuration.
|
||||
|
||||
# This should be to run the control loop
|
||||
self.actor_reconciler._scale_backend_replicas(
|
||||
self.current_state.backends, backend_tag,
|
||||
self.backend_state.backends, backend_tag,
|
||||
backend_config.num_replicas)
|
||||
|
||||
return_uuid = self._create_event_with_result({
|
||||
@@ -973,9 +1032,9 @@ class ServeController:
|
||||
# Inform the routers about change in configuration
|
||||
# (particularly for setting max_batch_size).
|
||||
|
||||
await self.actor_reconciler._start_pending_backend_replicas(
|
||||
self.current_state)
|
||||
await self.actor_reconciler._stop_pending_backend_replicas()
|
||||
await self.actor_reconciler._enqueue_pending_scale_changes_loop(
|
||||
self.backend_state)
|
||||
await self.actor_reconciler.backend_control_loop()
|
||||
|
||||
self.notify_replica_handles_changed()
|
||||
self.notify_backend_configs_changed()
|
||||
@@ -983,19 +1042,19 @@ class ServeController:
|
||||
|
||||
def get_backend_config(self, backend_tag: BackendTag) -> BackendConfig:
|
||||
"""Get the current config for the specified backend."""
|
||||
assert (self.current_state.get_backend(backend_tag)
|
||||
assert (self.backend_state.get_backend(backend_tag)
|
||||
), "Backend {} is not registered.".format(backend_tag)
|
||||
return self.current_state.get_backend(backend_tag).backend_config
|
||||
return self.backend_state.get_backend(backend_tag).backend_config
|
||||
|
||||
def get_http_config(self):
|
||||
"""Return the HTTP proxy configuration."""
|
||||
return self.http_host, self.http_port
|
||||
return self.http_state.get_config()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Shuts down the serve instance completely."""
|
||||
async with self.write_lock:
|
||||
for http_proxy in self.actor_reconciler.http_proxy_handles():
|
||||
ray.kill(http_proxy, no_restart=True)
|
||||
for proxy in self.http_state.get_http_proxy_handles().values():
|
||||
ray.kill(proxy, no_restart=True)
|
||||
for replica in self.actor_reconciler.get_replica_handles():
|
||||
ray.kill(replica, no_restart=True)
|
||||
self.kv_store.delete(CHECKPOINT_KEY)
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
import requests
|
||||
import ray
|
||||
from ray import serve
|
||||
from ray.serve import CondaEnv
|
||||
import tensorflow as tf
|
||||
|
||||
ray.init()
|
||||
client = serve.start()
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
import requests
|
||||
|
||||
from ray import serve
|
||||
from ray.serve.backends import ImportedBackend
|
||||
|
||||
client = serve.start()
|
||||
|
||||
backend_class = ImportedBackend("ray.serve.utils.MockImportedBackend")
|
||||
client.create_backend("imported", backend_class, "input_arg")
|
||||
client.create_endpoint("imported", backend="imported", route="/imported")
|
||||
|
||||
print(requests.get("http://127.0.0.1:8000/imported").text)
|
||||
@@ -10,7 +10,7 @@ class Counter:
|
||||
def __init__(self):
|
||||
self.count = 0
|
||||
|
||||
def __call__(self, flask_request):
|
||||
def __call__(self, starlette_request):
|
||||
self.count += 1
|
||||
return {"current_counter": self.count}
|
||||
|
||||
|
||||
@@ -6,8 +6,8 @@ ray.init(num_cpus=8)
|
||||
client = serve.start()
|
||||
|
||||
|
||||
def echo(flask_request):
|
||||
return "hello " + flask_request.args.get("name", "serve!")
|
||||
def echo(starlette_request):
|
||||
return "hello " + starlette_request.query_params.get("name", "serve!")
|
||||
|
||||
|
||||
client.create_backend("hello", echo)
|
||||
|
||||
@@ -16,13 +16,13 @@ client = serve.start()
|
||||
|
||||
|
||||
def model_one(request):
|
||||
print("Model 1 called with data ", request.args.get("data"))
|
||||
print("Model 1 called with data ", request.query_params.get("data"))
|
||||
return random()
|
||||
|
||||
|
||||
def model_two(request):
|
||||
print("Model 2 called with data ", request.args.get("data"))
|
||||
return request.args.get("data")
|
||||
print("Model 2 called with data ", request.query_params.get("data"))
|
||||
return request.query_params.get("data")
|
||||
|
||||
|
||||
class ComposedModel:
|
||||
@@ -32,8 +32,8 @@ class ComposedModel:
|
||||
self.model_two = client.get_handle("model_two")
|
||||
|
||||
# This method can be called concurrently!
|
||||
async def __call__(self, flask_request):
|
||||
data = flask_request.data
|
||||
async def __call__(self, starlette_request):
|
||||
data = await starlette_request.body()
|
||||
|
||||
score = await self.model_one.remote(data=data)
|
||||
if score > 0.5:
|
||||
|
||||
@@ -14,8 +14,10 @@ import requests
|
||||
|
||||
# __doc_define_servable_v0_begin__
|
||||
@serve.accept_batch
|
||||
def batch_adder_v0(flask_requests: List):
|
||||
numbers = [int(request.args["number"]) for request in flask_requests]
|
||||
def batch_adder_v0(starlette_requests: List):
|
||||
numbers = [
|
||||
int(request.query_params["number"]) for request in starlette_requests
|
||||
]
|
||||
|
||||
input_array = np.array(numbers)
|
||||
print("Our input array has shape:", input_array.shape)
|
||||
@@ -58,7 +60,7 @@ print("Result returned:", results)
|
||||
# __doc_define_servable_v1_begin__
|
||||
@serve.accept_batch
|
||||
def batch_adder_v1(requests: List):
|
||||
numbers = [int(request.args["number"]) for request in requests]
|
||||
numbers = [int(request.query_params["number"]) for request in requests]
|
||||
input_array = np.array(numbers)
|
||||
print("Our input array has shape:", input_array.shape)
|
||||
# Sleep for 200ms, this could be performing CPU intensive computation
|
||||
|
||||
@@ -48,9 +48,9 @@ class BoostingModel:
|
||||
with open("/tmp/iris_labels.json") as f:
|
||||
self.label_list = json.load(f)
|
||||
|
||||
def __call__(self, flask_request):
|
||||
payload = flask_request.json
|
||||
print("Worker: received flask request with data", payload)
|
||||
async def __call__(self, starlette_request):
|
||||
payload = await starlette_request.json()
|
||||
print("Worker: received starlette request with data", payload)
|
||||
|
||||
input_vector = [
|
||||
payload["sepal length"],
|
||||
@@ -143,9 +143,9 @@ class BoostingModelv2:
|
||||
with open("/tmp/iris_labels_2.json") as f:
|
||||
self.label_list = json.load(f)
|
||||
|
||||
def __call__(self, flask_request):
|
||||
payload = flask_request.json
|
||||
print("Worker: received flask request with data", payload)
|
||||
async def __call__(self, starlette_request):
|
||||
payload = await starlette_request.json()
|
||||
print("Worker: received starlette request with data", payload)
|
||||
|
||||
input_vector = [
|
||||
payload["sepal length"],
|
||||
|
||||
@@ -27,8 +27,8 @@ class ImageModel:
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
])
|
||||
|
||||
def __call__(self, flask_request):
|
||||
image_payload_bytes = flask_request.data
|
||||
async def __call__(self, starlette_request):
|
||||
image_payload_bytes = await starlette_request.body()
|
||||
pil_image = Image.open(BytesIO(image_payload_bytes))
|
||||
print("[1/3] Parsed image data: {}".format(pil_image))
|
||||
|
||||
|
||||
@@ -54,9 +54,9 @@ class BoostingModel:
|
||||
with open(LABEL_PATH) as f:
|
||||
self.label_list = json.load(f)
|
||||
|
||||
def __call__(self, flask_request):
|
||||
payload = flask_request.json
|
||||
print("Worker: received flask request with data", payload)
|
||||
async def __call__(self, starlette_request):
|
||||
payload = await starlette_request.json()
|
||||
print("Worker: received starlette request with data", payload)
|
||||
|
||||
input_vector = [
|
||||
payload["sepal length"],
|
||||
|
||||
@@ -51,10 +51,10 @@ class TFMnistModel:
|
||||
self.model_path = model_path
|
||||
self.model = tf.keras.models.load_model(model_path)
|
||||
|
||||
def __call__(self, flask_request):
|
||||
async def __call__(self, starlette_request):
|
||||
# Step 1: transform HTTP request -> tensorflow input
|
||||
# Here we define the request schema to be a json array.
|
||||
input_array = np.array(flask_request.json["array"])
|
||||
input_array = np.array((await starlette_request.json())["array"])
|
||||
reshaped_array = input_array.reshape((1, 28, 28))
|
||||
|
||||
# Step 2: tensorflow input -> tensorflow output
|
||||
|
||||
@@ -9,8 +9,8 @@ import requests
|
||||
from ray import serve
|
||||
|
||||
|
||||
def echo(flask_request):
|
||||
return ["hello " + flask_request.args.get("name", "serve!")]
|
||||
def echo(starlette_request):
|
||||
return ["hello " + starlette_request.query_params.get("name", "serve!")]
|
||||
|
||||
|
||||
client = serve.start()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""
|
||||
Example actor that adds an increment to a number. This number can
|
||||
come from either web (parsing Flask request) or python call.
|
||||
come from either web (parsing Starlette request) or python call.
|
||||
|
||||
This actor can be called from HTTP as well as from Python.
|
||||
"""
|
||||
@@ -30,9 +30,10 @@ class MagicCounter:
|
||||
def __init__(self, increment):
|
||||
self.increment = increment
|
||||
|
||||
def __call__(self, flask_request, base_number=None):
|
||||
def __call__(self, starlette_request, base_number=None):
|
||||
if serve.context.web:
|
||||
base_number = int(flask_request.args.get("base_number", "0"))
|
||||
base_number = int(
|
||||
starlette_request.query_params.get("base_number", "0"))
|
||||
return base_number + self.increment
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""
|
||||
Example actor that adds an increment to a number. This number can
|
||||
come from either web (parsing Flask request) or python call.
|
||||
come from either web (parsing Starlette request) or python call.
|
||||
The queries incoming to this actor are batched.
|
||||
This actor can be called from HTTP as well as from Python.
|
||||
"""
|
||||
@@ -31,12 +31,13 @@ class MagicCounter:
|
||||
self.increment = increment
|
||||
|
||||
@serve.accept_batch
|
||||
def __call__(self, flask_request_list, base_number=None):
|
||||
def __call__(self, starlette_request_list, base_number=None):
|
||||
# batch_size = serve.context.batch_size
|
||||
if serve.context.web:
|
||||
result = []
|
||||
for flask_request in flask_request_list:
|
||||
base_number = int(flask_request.args.get("base_number", "0"))
|
||||
for starlette_request in starlette_request_list:
|
||||
base_number = int(
|
||||
starlette_request.query_params.get("base_number", "0"))
|
||||
result.append(base_number)
|
||||
return list(map(lambda x: x + self.increment, result))
|
||||
else:
|
||||
|
||||
@@ -11,7 +11,7 @@ class MagicCounter:
|
||||
self.increment = increment
|
||||
|
||||
@serve.accept_batch
|
||||
def __call__(self, flask_request, base_number=None):
|
||||
def __call__(self, starlette_request, base_number=None):
|
||||
# __call__ fn should preserve the batch size
|
||||
# base_number is a python list
|
||||
|
||||
|
||||
@@ -12,8 +12,8 @@ client = serve.start()
|
||||
|
||||
# a backend can be a function or class.
|
||||
# it can be made to be invoked from web as well as python.
|
||||
def echo_v1(flask_request):
|
||||
response = flask_request.args.get("response", "web")
|
||||
def echo_v1(starlette_request):
|
||||
response = starlette_request.query_params.get("response", "web")
|
||||
return response
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ print(ray.get(client.get_handle("my_endpoint").remote(response="hello")))
|
||||
|
||||
|
||||
# We can also add a new backend and split the traffic.
|
||||
def echo_v2(flask_request):
|
||||
def echo_v2(starlette_request):
|
||||
# magic, only from web.
|
||||
return "something new"
|
||||
|
||||
|
||||
+80
-115
@@ -1,27 +1,23 @@
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import threading
|
||||
from typing import Any, Coroutine, Dict, Optional, Union
|
||||
|
||||
import ray
|
||||
from ray.serve.context import TaskContext
|
||||
from ray.serve.router import RequestMetadata, Router
|
||||
from ray.serve.utils import get_random_letters
|
||||
from ray.serve.exceptions import RayServeException
|
||||
|
||||
global_async_loop = None
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from enum import Enum
|
||||
|
||||
|
||||
def create_or_get_async_loop_in_thread():
|
||||
global global_async_loop
|
||||
if global_async_loop is None:
|
||||
global_async_loop = asyncio.new_event_loop()
|
||||
thread = threading.Thread(
|
||||
daemon=True,
|
||||
target=global_async_loop.run_forever,
|
||||
)
|
||||
thread.start()
|
||||
return global_async_loop
|
||||
@dataclass(frozen=True)
|
||||
class HandleOptions:
|
||||
"""Options for each ServeHandle instances. These fields are immutable."""
|
||||
method_name: str = "__call__"
|
||||
shard_key: Optional[str] = None
|
||||
http_method: str = "GET"
|
||||
http_headers: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
# Use a global singleton enum to emulate default options. We cannot use None
|
||||
# for those option because None is a valid new value.
|
||||
class DEFAULT(Enum):
|
||||
VALUE = 1
|
||||
|
||||
|
||||
class RayServeHandle:
|
||||
@@ -31,75 +27,83 @@ class RayServeHandle:
|
||||
an HTTP endpoint.
|
||||
|
||||
Example:
|
||||
>>> handle = serve.get_handle("my_endpoint")
|
||||
>>> handle = serve_client.get_handle("my_endpoint")
|
||||
>>> handle
|
||||
RayServeHandle(
|
||||
Endpoint="my_endpoint",
|
||||
Traffic=...
|
||||
)
|
||||
>>> handle.remote(my_request_content)
|
||||
RayServeHandle(endpoint="my_endpoint")
|
||||
>>> await handle.remote(my_request_content)
|
||||
ObjectRef(...)
|
||||
>>> ray.get(handle.remote(...))
|
||||
>>> ray.get(await handle.remote(...))
|
||||
# result
|
||||
>>> ray.get(handle.remote(let_it_crash_request))
|
||||
>>> ray.get(await handle.remote(let_it_crash_request))
|
||||
# raises RayTaskError Exception
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
controller_handle,
|
||||
endpoint_name,
|
||||
sync: bool,
|
||||
*,
|
||||
method_name=None,
|
||||
shard_key=None,
|
||||
http_method=None,
|
||||
http_headers=None,
|
||||
):
|
||||
self.controller_handle = controller_handle
|
||||
def __init__(self,
|
||||
router,
|
||||
endpoint_name,
|
||||
handle_options: Optional[HandleOptions] = None):
|
||||
self.router = router
|
||||
self.endpoint_name = endpoint_name
|
||||
self.handle_options = handle_options or HandleOptions()
|
||||
|
||||
self.method_name = method_name
|
||||
self.shard_key = shard_key
|
||||
self.http_method = http_method
|
||||
self.http_headers = http_headers
|
||||
def options(self,
|
||||
*,
|
||||
method_name: Union[str, DEFAULT] = DEFAULT.VALUE,
|
||||
shard_key: Union[str, DEFAULT] = DEFAULT.VALUE,
|
||||
http_method: Union[str, DEFAULT] = DEFAULT.VALUE,
|
||||
http_headers: Union[Dict[str, str], DEFAULT] = DEFAULT.VALUE):
|
||||
"""Set options for this handle.
|
||||
|
||||
self.router = Router(self.controller_handle)
|
||||
self.sync = sync
|
||||
# In the synchrounous mode, we create a new event loop in a separate
|
||||
# thread and run the Router.setup in that loop. In the async mode, we
|
||||
# can just use the current loop we are in right now.
|
||||
if self.sync:
|
||||
self.async_loop = create_or_get_async_loop_in_thread()
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.router.setup_in_async_loop(),
|
||||
self.async_loop,
|
||||
)
|
||||
else: # async
|
||||
self.async_loop = asyncio.get_event_loop()
|
||||
# create_task is not threadsafe.
|
||||
self.async_loop.create_task(self.router.setup_in_async_loop())
|
||||
Args:
|
||||
method_name(str): The method to invoke on the backend.
|
||||
http_method(str): The HTTP method to use for the request.
|
||||
shard_key(str): A string to use to deterministically map this
|
||||
request to a backend if there are multiple for this endpoint.
|
||||
"""
|
||||
new_options_dict = self.handle_options.__dict__.copy()
|
||||
user_modified_options_dict = {
|
||||
key: value
|
||||
for key, value in
|
||||
zip(["method_name", "shard_key", "http_method", "http_headers"],
|
||||
[method_name, shard_key, http_method, http_headers])
|
||||
if value != DEFAULT.VALUE
|
||||
}
|
||||
new_options_dict.update(user_modified_options_dict)
|
||||
new_options = HandleOptions(**new_options_dict)
|
||||
|
||||
def _remote(self, request_data, kwargs) -> Coroutine:
|
||||
request_metadata = RequestMetadata(
|
||||
get_random_letters(10), # Used for debugging.
|
||||
self.endpoint_name,
|
||||
TaskContext.Python,
|
||||
call_method=self.method_name or "__call__",
|
||||
shard_key=self.shard_key,
|
||||
http_method=self.http_method or "GET",
|
||||
http_headers=self.http_headers or dict(),
|
||||
)
|
||||
coro = self.router.assign_request(request_metadata, request_data,
|
||||
**kwargs)
|
||||
return coro
|
||||
return self.__class__(self.router, self.endpoint_name, new_options)
|
||||
|
||||
async def remote(self,
|
||||
request_data: Optional[Union[Dict, Any]] = None,
|
||||
**kwargs):
|
||||
"""Issue an asynchrounous request to the endpoint.
|
||||
|
||||
Returns a Ray ObjectRef whose results can be waited for or retrieved
|
||||
using ray.wait or ray.get (or ``await object_ref``), respectively.
|
||||
|
||||
Returns:
|
||||
ray.ObjectRef
|
||||
Args:
|
||||
request_data(dict, Any): If it's a dictionary, the data will be
|
||||
available in ``request.json()`` or ``request.form()``.
|
||||
Otherwise, it will be available in ``request.body()``.
|
||||
``**kwargs``: All keyword arguments will be available in
|
||||
``request.query_params``.
|
||||
"""
|
||||
return await self.router._remote(
|
||||
self.endpoint_name, self.handle_options, request_data, kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}(endpoint='{self.endpoint_name}')"
|
||||
|
||||
|
||||
class RayServeSyncHandle(RayServeHandle):
|
||||
def remote(self, request_data: Optional[Union[Dict, Any]] = None,
|
||||
**kwargs):
|
||||
"""Issue an asynchrounous request to the endpoint.
|
||||
|
||||
Returns a Ray ObjectRef whose results can be waited for or retrieved
|
||||
using ray.wait or ray.get, respectively.
|
||||
using ray.wait or ray.get (or ``await object_ref``), respectively.
|
||||
|
||||
Returns:
|
||||
ray.ObjectRef
|
||||
@@ -110,47 +114,8 @@ class RayServeHandle:
|
||||
``**kwargs``: All keyword arguments will be available in
|
||||
``request.args``.
|
||||
"""
|
||||
if not self.sync:
|
||||
raise RayServeException(
|
||||
"You are trying to call handle.remote() with async handle. "
|
||||
"Please use `await handle.remote_async()` instead.")
|
||||
|
||||
coro = self._remote(request_data, kwargs)
|
||||
coro = self.router._remote(self.endpoint_name, self.handle_options,
|
||||
request_data, kwargs)
|
||||
future: concurrent.futures.Future = asyncio.run_coroutine_threadsafe(
|
||||
coro, self.async_loop)
|
||||
|
||||
# Block until the result is ready.
|
||||
coro, self.router.async_loop)
|
||||
return future.result()
|
||||
|
||||
async def remote_async(self,
|
||||
request_data: Optional[Union[Dict, Any]] = None,
|
||||
**kwargs) -> ray.ObjectRef:
|
||||
"""Experimental API for enqueue a request in async context."""
|
||||
if not asyncio.get_event_loop().is_running():
|
||||
raise RayServeException(
|
||||
"remote_async must be called from a running event loop.")
|
||||
return await self._remote(request_data, kwargs)
|
||||
|
||||
def options(self,
|
||||
method_name: Optional[str] = None,
|
||||
*,
|
||||
shard_key: Optional[str] = None,
|
||||
http_method: Optional[str] = None,
|
||||
http_headers: Optional[Dict[str, str]] = None):
|
||||
"""Set options for this handle.
|
||||
|
||||
Args:
|
||||
method_name(str): The method to invoke on the backend.
|
||||
http_method(str): The HTTP method to use for the request.
|
||||
shard_key(str): A string to use to deterministically map this
|
||||
request to a backend if there are multiple for this endpoint.
|
||||
"""
|
||||
# Don't override default non-null values.
|
||||
self.method_name = self.method_name or method_name
|
||||
self.shard_key = self.shard_key or shard_key
|
||||
self.http_method = self.http_method or http_method
|
||||
self.http_headers = self.http_headers or http_headers
|
||||
return self
|
||||
|
||||
def __repr__(self):
|
||||
return f"RayServeHandle(endpoint='{self.endpoint_name}')"
|
||||
|
||||
@@ -1,76 +1,25 @@
|
||||
import io
|
||||
import json
|
||||
|
||||
import flask
|
||||
import starlette.requests
|
||||
|
||||
|
||||
def build_flask_request(asgi_scope_dict, request_body):
|
||||
"""Build and return a flask request from ASGI payload
|
||||
def build_starlette_request(scope, serialized_body: bytes):
|
||||
"""Build and return a Starlette Request from ASGI payload.
|
||||
|
||||
This function is indented to be used immediately before task invocation
|
||||
happen.
|
||||
This function is intended to be used immediately before task invocation
|
||||
happens.
|
||||
"""
|
||||
wsgi_environ = build_wsgi_environ(asgi_scope_dict, request_body)
|
||||
# We set populate_request=False to prevent self reference, which can lead
|
||||
# to objects tracked by python garbage collector and memory growth. See
|
||||
# https://github.com/ray-project/ray/issues/12395.
|
||||
return flask.Request(wsgi_environ, populate_request=False)
|
||||
|
||||
# Simulates receiving HTTP body from TCP socket. In reality, the body has
|
||||
# already been streamed in chunks and stored in serialized_body.
|
||||
async def mock_receive():
|
||||
return {
|
||||
"body": serialized_body,
|
||||
"type": "http.request",
|
||||
"more_body": False
|
||||
}
|
||||
|
||||
def build_wsgi_environ(scope, body):
|
||||
"""
|
||||
Builds a scope and request body into a WSGI environ object.
|
||||
|
||||
This code snippet is taken from https://github.com/django/asgiref/blob
|
||||
/36c3e8dc70bf38fe2db87ac20b514f21aaf5ea9d/asgiref/wsgi.py#L52
|
||||
|
||||
WSGI specification can be found at
|
||||
https://www.python.org/dev/peps/pep-0333/
|
||||
|
||||
This function helps translate ASGI scope and body into a flask request.
|
||||
"""
|
||||
environ = {
|
||||
"REQUEST_METHOD": scope["method"],
|
||||
"SCRIPT_NAME": scope.get("root_path", ""),
|
||||
"PATH_INFO": scope["path"],
|
||||
"QUERY_STRING": scope["query_string"].decode("ascii"),
|
||||
"SERVER_PROTOCOL": "HTTP/{}".format(scope["http_version"]),
|
||||
"wsgi.version": (1, 0),
|
||||
"wsgi.url_scheme": scope.get("scheme", "http"),
|
||||
"wsgi.input": body,
|
||||
"wsgi.errors": io.BytesIO(),
|
||||
"wsgi.multithread": True,
|
||||
"wsgi.multiprocess": True,
|
||||
"wsgi.run_once": False,
|
||||
}
|
||||
|
||||
# Get server name and port - required in WSGI, not in ASGI
|
||||
environ["SERVER_NAME"] = scope["server"][0]
|
||||
environ["SERVER_PORT"] = str(scope["server"][1])
|
||||
environ["REMOTE_ADDR"] = scope["client"][0]
|
||||
|
||||
# Transforms headers into environ entries.
|
||||
for name, value in scope.get("headers", []):
|
||||
# name, values are both bytes, we need to decode them to string
|
||||
name = name.decode("latin1")
|
||||
value = value.decode("latin1")
|
||||
|
||||
# Handle name correction to conform to WSGI spec
|
||||
# https://www.python.org/dev/peps/pep-0333/#environ-variables
|
||||
if name == "content-length":
|
||||
corrected_name = "CONTENT_LENGTH"
|
||||
elif name == "content-type":
|
||||
corrected_name = "CONTENT_TYPE"
|
||||
else:
|
||||
corrected_name = "HTTP_%s" % name.upper().replace("-", "_")
|
||||
|
||||
# If the header value repeated,
|
||||
# we will just concatenate it to the field.
|
||||
if corrected_name in environ:
|
||||
value = environ[corrected_name] + "," + value
|
||||
|
||||
environ[corrected_name] = value
|
||||
return environ
|
||||
return starlette.requests.Request(scope, mock_receive)
|
||||
|
||||
|
||||
class Response:
|
||||
|
||||
@@ -19,8 +19,8 @@ from ray.serve.utils import (block_until_http_ready, format_actor_name,
|
||||
def test_e2e(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
def function(flask_request):
|
||||
return {"method": flask_request.method}
|
||||
def function(starlette_request):
|
||||
return {"method": starlette_request.method}
|
||||
|
||||
client.create_backend("echo:v1", function)
|
||||
client.create_endpoint(
|
||||
@@ -97,7 +97,7 @@ def test_backend_user_config(serve_instance):
|
||||
def __init__(self):
|
||||
self.count = 10
|
||||
|
||||
def __call__(self, flask_request):
|
||||
def __call__(self, starlette_request):
|
||||
return self.count, os.getpid()
|
||||
|
||||
def reconfigure(self, config):
|
||||
@@ -146,7 +146,7 @@ def test_call_method(serve_instance):
|
||||
|
||||
# Test serve handle path.
|
||||
handle = client.get_handle("endpoint")
|
||||
assert ray.get(handle.options("method").remote()) == "hello"
|
||||
assert ray.get(handle.options(method_name="method").remote()) == "hello"
|
||||
|
||||
|
||||
def test_no_route(serve_instance):
|
||||
@@ -820,8 +820,8 @@ def test_serve_metrics(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
@serve.accept_batch
|
||||
def batcher(flask_requests):
|
||||
return ["hello"] * len(flask_requests)
|
||||
def batcher(starlette_requests):
|
||||
return ["hello"] * len(starlette_requests)
|
||||
|
||||
client.create_backend("metrics", batcher)
|
||||
client.create_endpoint("metrics", backend="metrics", route="/metrics")
|
||||
@@ -871,6 +871,26 @@ def test_serve_metrics(serve_instance):
|
||||
verify_metrics()
|
||||
|
||||
|
||||
def test_starlette_request(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
async def echo_body(starlette_request):
|
||||
data = await starlette_request.body()
|
||||
return data
|
||||
|
||||
UVICORN_HIGH_WATER_MARK = 65536 # max bytes in one message
|
||||
|
||||
# Long string to test serialization of multiple messages.
|
||||
long_string = "x" * 10 * UVICORN_HIGH_WATER_MARK
|
||||
|
||||
client.create_backend("echo:v1", echo_body)
|
||||
client.create_endpoint(
|
||||
"endpoint", backend="echo:v1", route="/api", methods=["GET", "POST"])
|
||||
|
||||
resp = requests.post("http://127.0.0.1:8000/api", data=long_string).text
|
||||
assert resp == long_string
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", "-s", __file__]))
|
||||
|
||||
@@ -85,7 +85,7 @@ async def test_runner_wraps_error():
|
||||
async def test_servable_function(serve_instance, router,
|
||||
mock_controller_with_name):
|
||||
def echo(request):
|
||||
return request.args["i"]
|
||||
return request.query_params["i"]
|
||||
|
||||
await add_servable_to_router(echo, router, mock_controller_with_name[0])
|
||||
|
||||
@@ -103,7 +103,7 @@ async def test_servable_class(serve_instance, router,
|
||||
self.increment = inc
|
||||
|
||||
def __call__(self, request):
|
||||
return request.args["i"] + self.increment
|
||||
return request.query_params["i"] + self.increment
|
||||
|
||||
await add_servable_to_router(
|
||||
MyAdder, router, mock_controller_with_name[0], init_args=(3, ))
|
||||
@@ -277,7 +277,7 @@ async def test_user_config_update(serve_instance, router,
|
||||
def __init__(self):
|
||||
self.reval = ""
|
||||
|
||||
def __call__(self, flask_request):
|
||||
def __call__(self, starlette_request):
|
||||
return self.retval
|
||||
|
||||
def reconfigure(self, config):
|
||||
|
||||
@@ -8,7 +8,7 @@ def test_handle_in_endpoint(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
class Endpoint1:
|
||||
def __call__(self, flask_request):
|
||||
def __call__(self, starlette_request):
|
||||
return "hello"
|
||||
|
||||
class Endpoint2:
|
||||
@@ -40,12 +40,12 @@ def test_handle_http_args(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
class Endpoint:
|
||||
def __call__(self, request):
|
||||
async def __call__(self, request):
|
||||
return {
|
||||
"args": dict(request.args),
|
||||
"args": dict(request.query_params),
|
||||
"headers": dict(request.headers),
|
||||
"method": request.method,
|
||||
"json": request.json
|
||||
"json": await request.json()
|
||||
}
|
||||
|
||||
client.create_backend("backend", Endpoint)
|
||||
@@ -58,7 +58,7 @@ def test_handle_http_args(serve_instance):
|
||||
"arg2": "2"
|
||||
},
|
||||
"headers": {
|
||||
"X-Custom-Header": "value"
|
||||
"x-custom-header": "value"
|
||||
},
|
||||
"method": "POST",
|
||||
"json": {
|
||||
@@ -81,10 +81,10 @@ def test_handle_http_args(serve_instance):
|
||||
for resp in [resp_web, resp_handle]:
|
||||
for field in ["args", "method", "json"]:
|
||||
assert resp[field] == ground_truth[field]
|
||||
resp["headers"]["X-Custom-Header"] == "value"
|
||||
resp["headers"]["x-custom-header"] == "value"
|
||||
|
||||
|
||||
def test_handle_inject_flask_request(serve_instance):
|
||||
def test_handle_inject_starlette_request(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
def echo_request_type(request):
|
||||
@@ -103,7 +103,37 @@ def test_handle_inject_flask_request(serve_instance):
|
||||
for route in ["/echo", "/wrapper"]:
|
||||
resp = requests.get(f"http://127.0.0.1:8000{route}")
|
||||
request_type = resp.text
|
||||
assert request_type == "<class 'flask.wrappers.Request'>"
|
||||
assert request_type == "<class 'starlette.requests.Request'>"
|
||||
|
||||
|
||||
def test_handle_option_chaining(serve_instance):
|
||||
# https://github.com/ray-project/ray/issues/12802
|
||||
# https://github.com/ray-project/ray/issues/12798
|
||||
|
||||
client = serve_instance
|
||||
|
||||
class MultiMethod:
|
||||
def method_a(self, _):
|
||||
return "method_a"
|
||||
|
||||
def method_b(self, _):
|
||||
return "method_b"
|
||||
|
||||
def __call__(self, _):
|
||||
return "__call__"
|
||||
|
||||
client.create_backend("m", MultiMethod)
|
||||
client.create_endpoint("m", backend="m")
|
||||
|
||||
# get_handle should give you a clean handle
|
||||
handle1 = client.get_handle("m").options(method_name="method_a")
|
||||
handle2 = client.get_handle("m")
|
||||
# options().options() override should work
|
||||
handle3 = handle1.options(method_name="method_b")
|
||||
|
||||
assert ray.get(handle1.remote()) == "method_a"
|
||||
assert ray.get(handle2.remote()) == "__call__"
|
||||
assert ray.get(handle3.remote()) == "method_b"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
import ray
|
||||
from ray.serve.backends import ImportedBackend
|
||||
from ray.serve.config import BackendConfig
|
||||
|
||||
|
||||
def test_imported_backend(serve_instance):
|
||||
client = serve_instance
|
||||
|
||||
backend_class = ImportedBackend("ray.serve.utils.MockImportedBackend")
|
||||
config = BackendConfig(user_config="config")
|
||||
client.create_backend(
|
||||
"imported", backend_class, "input_arg", config=config)
|
||||
client.create_endpoint("imported", backend="imported")
|
||||
|
||||
# Basic sanity check.
|
||||
handle = client.get_handle("imported")
|
||||
assert ray.get(handle.remote()) == {"arg": "input_arg", "config": "config"}
|
||||
|
||||
# Check that updating backend config works.
|
||||
client.update_backend_config(
|
||||
"imported", BackendConfig(user_config="new_config"))
|
||||
assert ray.get(handle.remote()) == {
|
||||
"arg": "input_arg",
|
||||
"config": "new_config"
|
||||
}
|
||||
|
||||
# Check that other call methods work.
|
||||
handle = handle.options(method_name="other_method")
|
||||
assert ray.get(handle.remote("hello")) == "hello"
|
||||
@@ -12,7 +12,7 @@ ray.init(address="{}")
|
||||
from ray import serve
|
||||
client = serve.connect()
|
||||
|
||||
def driver(flask_request):
|
||||
def driver(starlette_request):
|
||||
return "OK!"
|
||||
|
||||
client.create_backend("driver", driver)
|
||||
|
||||
@@ -15,7 +15,7 @@ def test_np_in_composed_model(serve_instance):
|
||||
# in cloudpickle _from_numpy_buffer
|
||||
|
||||
def sum_model(request):
|
||||
return np.sum(request.args["data"])
|
||||
return np.sum(request.query_params["data"])
|
||||
|
||||
class ComposedModel:
|
||||
def __init__(self):
|
||||
@@ -42,7 +42,7 @@ def test_backend_worker_memory_growth(serve_instance):
|
||||
# https://github.com/ray-project/ray/issues/12395
|
||||
client = serve_instance
|
||||
|
||||
def gc_unreachable_objects(flask_request):
|
||||
def gc_unreachable_objects(starlette_request):
|
||||
gc.set_debug(gc.DEBUG_SAVEALL)
|
||||
gc.collect()
|
||||
return len(gc.garbage)
|
||||
|
||||
@@ -6,9 +6,10 @@ from copy import deepcopy
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import ray
|
||||
from ray.serve.utils import (ServeEncoder, chain_future, unpack_future,
|
||||
try_schedule_resources_on_nodes,
|
||||
get_conda_env_dir)
|
||||
get_conda_env_dir, import_class)
|
||||
|
||||
|
||||
def test_bytes_encoder():
|
||||
@@ -125,6 +126,21 @@ def test_get_conda_env_dir(tmp_path):
|
||||
os.environ["CONDA_PREFIX"] = ""
|
||||
|
||||
|
||||
def test_import_class():
|
||||
assert import_class("ray.serve.Client") == ray.serve.api.Client
|
||||
assert import_class("ray.serve.api.Client") == ray.serve.api.Client
|
||||
|
||||
policy_cls = import_class("ray.serve.controller.TrafficPolicy")
|
||||
assert policy_cls == ray.serve.controller.TrafficPolicy
|
||||
|
||||
policy = policy_cls({"endpoint1": 0.5, "endpoint2": 0.5})
|
||||
with pytest.raises(ValueError):
|
||||
policy.set_traffic_dict({"endpoint1": 0.5, "endpoint2": 0.6})
|
||||
policy.set_traffic_dict({"endpoint1": 0.4, "endpoint2": 0.6})
|
||||
|
||||
print(repr(policy))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", "-s", __file__]))
|
||||
|
||||
+57
-17
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
from functools import singledispatch
|
||||
import importlib
|
||||
from itertools import groupby
|
||||
import json
|
||||
import logging
|
||||
@@ -7,7 +8,6 @@ import random
|
||||
import string
|
||||
import time
|
||||
from typing import List, Dict
|
||||
import io
|
||||
import os
|
||||
from ray.serve.exceptions import RayServeException
|
||||
from collections import UserDict
|
||||
@@ -15,18 +15,18 @@ from collections import UserDict
|
||||
import requests
|
||||
import numpy as np
|
||||
import pydantic
|
||||
import flask
|
||||
import starlette.requests
|
||||
|
||||
import ray
|
||||
from ray.serve.constants import HTTP_PROXY_TIMEOUT
|
||||
from ray.serve.context import TaskContext
|
||||
from ray.serve.http_util import build_flask_request
|
||||
from ray.serve.http_util import build_starlette_request
|
||||
|
||||
ACTOR_FAILURE_RETRY_TIMEOUT_S = 60
|
||||
|
||||
|
||||
class ServeMultiDict(UserDict):
|
||||
"""Compatible data structure to simulate Flask.Request.args API."""
|
||||
"""Compatible data structure to simulate Starlette Request query_args."""
|
||||
|
||||
def getlist(self, key):
|
||||
"""Return the list of items for a given key."""
|
||||
@@ -34,11 +34,14 @@ class ServeMultiDict(UserDict):
|
||||
|
||||
|
||||
class ServeRequest:
|
||||
"""The request object used in Python context.
|
||||
"""The request object used when passing arguments via ServeHandle.
|
||||
|
||||
ServeRequest is built to have similar API as Flask.Request. You only need
|
||||
to write your model serving code once; it can be queried by both HTTP and
|
||||
Python.
|
||||
ServeRequest partially implements the API of Starlette Request. You only
|
||||
need to write your model serving code once; it can be queried by both HTTP
|
||||
and Python.
|
||||
|
||||
To use the full Starlette Request interface with ServeHandle, you may
|
||||
instead directly pass in a Starlette Request object to the ServeHandle.
|
||||
"""
|
||||
|
||||
def __init__(self, data, kwargs, headers, method):
|
||||
@@ -58,28 +61,25 @@ class ServeRequest:
|
||||
return self._method
|
||||
|
||||
@property
|
||||
def args(self):
|
||||
def query_params(self):
|
||||
"""The keyword arguments from ``handle.remote(**kwargs)``."""
|
||||
return self._kwargs
|
||||
|
||||
@property
|
||||
def json(self):
|
||||
async def json(self):
|
||||
"""The request dictionary, from ``handle.remote(dict)``."""
|
||||
if not isinstance(self._data, dict):
|
||||
raise RayServeException("Request data is not a dictionary. "
|
||||
f"It is {type(self._data)}.")
|
||||
return self._data
|
||||
|
||||
@property
|
||||
def form(self):
|
||||
async def form(self):
|
||||
"""The request dictionary, from ``handle.remote(dict)``."""
|
||||
if not isinstance(self._data, dict):
|
||||
raise RayServeException("Request data is not a dictionary. "
|
||||
f"It is {type(self._data)}.")
|
||||
return self._data
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
async def body(self):
|
||||
"""The request data from ``handle.remote(obj)``."""
|
||||
return self._data
|
||||
|
||||
@@ -87,13 +87,13 @@ class ServeRequest:
|
||||
def parse_request_item(request_item):
|
||||
if request_item.metadata.request_context == TaskContext.Web:
|
||||
asgi_scope, body_bytes = request_item.args
|
||||
return build_flask_request(asgi_scope, io.BytesIO(body_bytes))
|
||||
return build_starlette_request(asgi_scope, body_bytes)
|
||||
else:
|
||||
arg = request_item.args[0] if len(request_item.args) == 1 else None
|
||||
|
||||
# If the input data from handle is web request, we don't need to wrap
|
||||
# it in ServeRequest.
|
||||
if isinstance(arg, flask.Request):
|
||||
if isinstance(arg, starlette.requests.Request):
|
||||
return arg
|
||||
|
||||
return ServeRequest(
|
||||
@@ -342,3 +342,43 @@ def get_node_id_for_actor(actor_handle):
|
||||
"""Given an actor handle, return the node id it's placed on."""
|
||||
|
||||
return ray.actors()[actor_handle._actor_id.hex()]["Address"]["NodeID"]
|
||||
|
||||
|
||||
def import_class(full_path: str):
|
||||
"""Given a full import path to a class name, return the imported class.
|
||||
|
||||
For example, the following are equivalent:
|
||||
MyClass = import_class("module.submodule.MyClass")
|
||||
from module.submodule import MyClass
|
||||
|
||||
Returns:
|
||||
Imported class
|
||||
"""
|
||||
|
||||
last_period_idx = full_path.rfind(".")
|
||||
class_name = full_path[last_period_idx + 1:]
|
||||
module_name = full_path[:last_period_idx]
|
||||
module = importlib.import_module(module_name)
|
||||
return getattr(module, class_name)
|
||||
|
||||
|
||||
class MockImportedBackend:
|
||||
"""Used for testing backends.ImportedBackend.
|
||||
|
||||
This is necessary because we need the class to be installed in the worker
|
||||
processes. We could instead mock out importlib but doing so is messier and
|
||||
reduces confidence in the test (it isn't truly end-to-end).
|
||||
"""
|
||||
|
||||
def __init__(self, arg):
|
||||
self.arg = arg
|
||||
self.config = None
|
||||
|
||||
def reconfigure(self, config):
|
||||
self.config = config
|
||||
|
||||
def __call__(self, *args):
|
||||
return {"arg": self.arg, "config": self.config}
|
||||
|
||||
def other_method(self, request):
|
||||
return request.data
|
||||
|
||||
@@ -9,6 +9,7 @@ import ray
|
||||
from ray import gcs_utils
|
||||
from google.protobuf.json_format import MessageToDict
|
||||
from ray._private import services
|
||||
from ray._private.client_mode_hook import client_mode_hook
|
||||
from ray.utils import (decode, binary_to_hex, hex_to_binary)
|
||||
|
||||
from ray._raylet import GlobalStateAccessor
|
||||
@@ -851,6 +852,7 @@ def jobs():
|
||||
return state.job_table()
|
||||
|
||||
|
||||
@client_mode_hook
|
||||
def nodes():
|
||||
"""Get a list of the nodes in the cluster (for debugging only).
|
||||
|
||||
@@ -964,6 +966,7 @@ def object_transfer_timeline(filename=None):
|
||||
return state.chrome_tracing_object_transfer_dump(filename=filename)
|
||||
|
||||
|
||||
@client_mode_hook
|
||||
def cluster_resources():
|
||||
"""Get the current total cluster resources.
|
||||
|
||||
@@ -977,6 +980,7 @@ def cluster_resources():
|
||||
return state.cluster_resources()
|
||||
|
||||
|
||||
@client_mode_hook
|
||||
def available_resources():
|
||||
"""Get the current available cluster resources.
|
||||
|
||||
|
||||
@@ -414,7 +414,7 @@ def init_error_pubsub():
|
||||
return p
|
||||
|
||||
|
||||
def get_error_message(pub_sub, num, error_type=None, timeout=5):
|
||||
def get_error_message(pub_sub, num, error_type=None, timeout=20):
|
||||
"""Get errors through pub/sub."""
|
||||
start_time = time.time()
|
||||
msgs = []
|
||||
@@ -442,4 +442,8 @@ def format_web_url(url):
|
||||
|
||||
|
||||
def new_scheduler_enabled():
|
||||
return os.environ.get("RAY_ENABLE_NEW_SCHEDULER") == "1"
|
||||
return os.environ.get("RAY_ENABLE_NEW_SCHEDULER", "1") == "1"
|
||||
|
||||
|
||||
def client_test_enabled() -> bool:
|
||||
return os.environ.get("RAY_CLIENT_MODE") == "1"
|
||||
|
||||
+21
-12
@@ -10,6 +10,7 @@ SRCS = [] + select({
|
||||
|
||||
py_test_module_list(
|
||||
files = [
|
||||
# "test_dynres.py", # dyn res not implemented
|
||||
"test_async.py",
|
||||
"test_actor.py",
|
||||
"test_actor_advanced.py",
|
||||
@@ -40,16 +41,6 @@ py_test_module_list(
|
||||
deps = ["//:ray_lib"],
|
||||
)
|
||||
|
||||
py_test_module_list(
|
||||
files = [
|
||||
"test_dynres.py", # dyn res not implemented
|
||||
],
|
||||
size = "medium",
|
||||
extra_srcs = SRCS,
|
||||
tags = ["exclusive", "medium_size_python_tests_a_to_j", "new_scheduler_broken"],
|
||||
deps = ["//:ray_lib"],
|
||||
)
|
||||
|
||||
py_test_module_list(
|
||||
files = [
|
||||
"test_memory_limits.py",
|
||||
@@ -96,6 +87,7 @@ py_test_module_list(
|
||||
"test_debug_tools.py",
|
||||
"test_experimental_client.py",
|
||||
"test_experimental_client_metadata.py",
|
||||
"test_experimental_client_references.py",
|
||||
"test_experimental_client_terminate.py",
|
||||
"test_job.py",
|
||||
"test_memstat.py",
|
||||
@@ -129,11 +121,10 @@ py_test_module_list(
|
||||
|
||||
py_test_module_list(
|
||||
files = [
|
||||
"test_placement_group.py", # placement groups not implemented
|
||||
"test_placement_group.py",
|
||||
],
|
||||
size = "large",
|
||||
extra_srcs = SRCS,
|
||||
tags = ["exclusive", "new_scheduler_broken"],
|
||||
deps = ["//:ray_lib"],
|
||||
)
|
||||
|
||||
@@ -162,3 +153,21 @@ py_test(
|
||||
tags = ["exclusive"],
|
||||
deps = ["//:ray_lib"],
|
||||
)
|
||||
|
||||
|
||||
py_test_module_list(
|
||||
files = [
|
||||
"test_actor.py",
|
||||
"test_advanced.py",
|
||||
"test_basic.py",
|
||||
"test_basic_2.py",
|
||||
],
|
||||
size = "medium",
|
||||
extra_srcs = SRCS,
|
||||
name_suffix = "_client_mode",
|
||||
# TODO(barakmich): py_test will support env in Bazel 4.0.0...
|
||||
# Until then, we can use tags.
|
||||
#env = {"RAY_CLIENT_MODE": "1"},
|
||||
tags = ["exclusive", "client_tests"],
|
||||
deps = ["//:ray_lib"],
|
||||
)
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
import asyncio
|
||||
|
||||
|
||||
def create_remote_signal_actor(ray):
|
||||
# TODO(barakmich): num_cpus=0
|
||||
@ray.remote
|
||||
class SignalActor:
|
||||
def __init__(self):
|
||||
self.ready_event = asyncio.Event()
|
||||
|
||||
def send(self, clear=False):
|
||||
self.ready_event.set()
|
||||
if clear:
|
||||
self.ready_event.clear()
|
||||
|
||||
async def wait(self, should_wait=True):
|
||||
if should_wait:
|
||||
await self.ready_event.wait()
|
||||
|
||||
return SignalActor
|
||||
@@ -23,7 +23,7 @@ def get_default_fixure_system_config():
|
||||
"object_timeout_milliseconds": 200,
|
||||
"num_heartbeats_timeout": 10,
|
||||
"object_store_full_max_retries": 3,
|
||||
"object_store_full_initial_delay_ms": 100,
|
||||
"object_store_full_delay_ms": 100,
|
||||
}
|
||||
return system_config
|
||||
|
||||
@@ -44,6 +44,7 @@ def _ray_start(**kwargs):
|
||||
init_kwargs.update(kwargs)
|
||||
# Start the Ray processes.
|
||||
address_info = ray.init(**init_kwargs)
|
||||
|
||||
yield address_info
|
||||
# The code after the yield will run as teardown code.
|
||||
ray.shutdown()
|
||||
|
||||
Binary file not shown.
@@ -12,9 +12,10 @@ import tempfile
|
||||
import datetime
|
||||
import setproctitle
|
||||
|
||||
import ray
|
||||
import ray.test_utils
|
||||
import ray.cluster_utils
|
||||
from ray.test_utils import client_test_enabled
|
||||
from ray.test_utils import wait_for_condition
|
||||
from ray.test_utils import wait_for_pid_to_exit
|
||||
from ray.tests.client_test_utils import create_remote_signal_actor
|
||||
|
||||
|
||||
def test_caching_actors(shutdown_only):
|
||||
@@ -235,6 +236,7 @@ def test_actor_import_counter(ray_start_10_cpus):
|
||||
assert ray.get(g.remote()) == num_remote_functions - 1
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
|
||||
def test_actor_method_metadata_cache(ray_start_regular):
|
||||
class Actor(object):
|
||||
pass
|
||||
@@ -254,6 +256,7 @@ def test_actor_method_metadata_cache(ray_start_regular):
|
||||
assert [id(x) for x in list(cache.items())[0]] == cached_data_id
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
|
||||
def test_actor_class_name(ray_start_regular):
|
||||
@ray.remote
|
||||
class Foo:
|
||||
@@ -615,6 +618,8 @@ def test_random_id_generation(ray_start_regular_shared):
|
||||
assert f1._actor_id != f2._actor_id
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
client_test_enabled(), reason="differing inheritence structure")
|
||||
def test_actor_inheritance(ray_start_regular_shared):
|
||||
class NonActorBase:
|
||||
def __init__(self):
|
||||
@@ -627,8 +632,7 @@ def test_actor_inheritance(ray_start_regular_shared):
|
||||
pass
|
||||
|
||||
# Test that you can't instantiate an actor class directly.
|
||||
with pytest.raises(
|
||||
Exception, match="Actors cannot be instantiated directly."):
|
||||
with pytest.raises(Exception, match="cannot be instantiated directly"):
|
||||
ActorBase()
|
||||
|
||||
# Test that you can't inherit from an actor class.
|
||||
@@ -642,6 +646,7 @@ def test_actor_inheritance(ray_start_regular_shared):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="ray.method unimplemented")
|
||||
def test_multiple_return_values(ray_start_regular_shared):
|
||||
@ray.remote
|
||||
class Foo:
|
||||
@@ -731,13 +736,13 @@ def test_actor_deletion(ray_start_regular_shared):
|
||||
a = Actor.remote()
|
||||
pid = ray.get(a.getpid.remote())
|
||||
a = None
|
||||
ray.test_utils.wait_for_pid_to_exit(pid)
|
||||
wait_for_pid_to_exit(pid)
|
||||
|
||||
actors = [Actor.remote() for _ in range(10)]
|
||||
pids = ray.get([a.getpid.remote() for a in actors])
|
||||
a = None
|
||||
actors = None
|
||||
[ray.test_utils.wait_for_pid_to_exit(pid) for pid in pids]
|
||||
[wait_for_pid_to_exit(pid) for pid in pids]
|
||||
|
||||
|
||||
def test_actor_method_deletion(ray_start_regular_shared):
|
||||
@@ -766,7 +771,8 @@ def test_distributed_actor_handle_deletion(ray_start_regular_shared):
|
||||
ray.get(signal.wait.remote())
|
||||
return ray.get(actor.method.remote())
|
||||
|
||||
signal = ray.test_utils.SignalActor.remote()
|
||||
SignalActor = create_remote_signal_actor(ray)
|
||||
signal = SignalActor.remote()
|
||||
a = Actor.remote()
|
||||
pid = ray.get(a.getpid.remote())
|
||||
# Pass the handle to another task that cannot run yet.
|
||||
@@ -777,7 +783,7 @@ def test_distributed_actor_handle_deletion(ray_start_regular_shared):
|
||||
# Once the task finishes, the actor process should get killed.
|
||||
ray.get(signal.send.remote())
|
||||
assert ray.get(x_id) == 1
|
||||
ray.test_utils.wait_for_pid_to_exit(pid)
|
||||
wait_for_pid_to_exit(pid)
|
||||
|
||||
|
||||
def test_multiple_actors(ray_start_regular_shared):
|
||||
@@ -918,7 +924,7 @@ def test_atexit_handler(ray_start_regular_shared, exit_condition):
|
||||
if exit_condition == "ray.kill":
|
||||
assert not check_file_written()
|
||||
else:
|
||||
ray.test_utils.wait_for_condition(check_file_written)
|
||||
wait_for_condition(check_file_written)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -10,16 +10,22 @@ import time
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import ray
|
||||
import ray.cluster_utils
|
||||
import ray.test_utils
|
||||
|
||||
from ray.test_utils import client_test_enabled
|
||||
from ray.test_utils import RayTestTimeoutException
|
||||
|
||||
if client_test_enabled():
|
||||
from ray.experimental.client import ray
|
||||
else:
|
||||
import ray
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# issue https://github.com/ray-project/ray/issues/7105
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="message size")
|
||||
def test_internal_free(shutdown_only):
|
||||
ray.init(num_cpus=1)
|
||||
|
||||
@@ -60,14 +66,14 @@ def test_multiple_waits_and_gets(shutdown_only):
|
||||
return 1
|
||||
|
||||
@ray.remote
|
||||
def g(l):
|
||||
# The argument l should be a list containing one object ref.
|
||||
ray.wait([l[0]])
|
||||
def g(input_list):
|
||||
# The argument input_list should be a list containing one object ref.
|
||||
ray.wait([input_list[0]])
|
||||
|
||||
@ray.remote
|
||||
def h(l):
|
||||
# The argument l should be a list containing one object ref.
|
||||
ray.get(l[0])
|
||||
def h(input_list):
|
||||
# The argument input_list should be a list containing one object ref.
|
||||
ray.get(input_list[0])
|
||||
|
||||
# Make sure that multiple wait requests involving the same object ref
|
||||
# all return.
|
||||
@@ -80,6 +86,7 @@ def test_multiple_waits_and_gets(shutdown_only):
|
||||
ray.get([h.remote([x]), h.remote([x])])
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
|
||||
def test_caching_functions_to_run(shutdown_only):
|
||||
# Test that we export functions to run on all workers before the driver
|
||||
# is connected.
|
||||
@@ -125,6 +132,7 @@ def test_caching_functions_to_run(shutdown_only):
|
||||
ray.worker.global_worker.run_function_on_all_workers(f)
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
|
||||
def test_running_function_on_all_workers(ray_start_regular):
|
||||
def f(worker_info):
|
||||
sys.path.append("fake_directory")
|
||||
@@ -152,6 +160,7 @@ def test_running_function_on_all_workers(ray_start_regular):
|
||||
assert "fake_directory" not in ray.get(get_path2.remote())
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="ray.timeline")
|
||||
def test_profiling_api(ray_start_2_cpus):
|
||||
@ray.remote
|
||||
def f():
|
||||
@@ -345,6 +354,8 @@ def test_illegal_api_calls(ray_start_regular):
|
||||
ray.get(3)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
client_test_enabled(), reason="grpc interaction with releasing resources")
|
||||
def test_multithreading(ray_start_2_cpus):
|
||||
# This test requires at least 2 CPUs to finish since the worker does not
|
||||
# release resources when joining the threads.
|
||||
@@ -482,6 +493,7 @@ def test_multithreading(ray_start_2_cpus):
|
||||
ray.get(actor.join.remote()) == "ok"
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="message size")
|
||||
def test_wait_makes_object_local(ray_start_cluster):
|
||||
cluster = ray_start_cluster
|
||||
cluster.add_node(num_cpus=0)
|
||||
|
||||
@@ -284,14 +284,14 @@ def test_workers(shutdown_only):
|
||||
|
||||
|
||||
def test_object_ref_properties():
|
||||
id_bytes = b"00112233445566778899"
|
||||
id_bytes = b"0011223344556677889900001111"
|
||||
object_ref = ray.ObjectRef(id_bytes)
|
||||
assert object_ref.binary() == id_bytes
|
||||
object_ref = ray.ObjectRef.nil()
|
||||
assert object_ref.is_nil()
|
||||
with pytest.raises(ValueError, match=r".*needs to have length 20.*"):
|
||||
with pytest.raises(ValueError, match=r".*needs to have length.*"):
|
||||
ray.ObjectRef(id_bytes + b"1234")
|
||||
with pytest.raises(ValueError, match=r".*needs to have length 20.*"):
|
||||
with pytest.raises(ValueError, match=r".*needs to have length.*"):
|
||||
ray.ObjectRef(b"0123456789")
|
||||
object_ref = ray.ObjectRef.from_random()
|
||||
assert not object_ref.is_nil()
|
||||
|
||||
@@ -12,7 +12,6 @@ import sys
|
||||
from jsonschema.exceptions import ValidationError
|
||||
|
||||
import ray
|
||||
import ray._private.services as services
|
||||
from ray.autoscaler._private.util import prepare_config, validate_config
|
||||
from ray.autoscaler._private import commands
|
||||
from ray.autoscaler.sdk import get_docker_host_mount_location
|
||||
@@ -55,8 +54,11 @@ class MockProcessRunner:
|
||||
self.calls = []
|
||||
self.fail_cmds = fail_cmds or []
|
||||
self.call_response = {}
|
||||
self.ready_to_run = threading.Event()
|
||||
self.ready_to_run.set()
|
||||
|
||||
def check_call(self, cmd, *args, **kwargs):
|
||||
self.ready_to_run.wait()
|
||||
for token in self.fail_cmds:
|
||||
if token in str(cmd):
|
||||
raise CalledProcessError(1, token,
|
||||
@@ -166,22 +168,28 @@ class MockProvider(NodeProvider):
|
||||
]
|
||||
|
||||
def is_running(self, node_id):
|
||||
return self.mock_nodes[node_id].state == "running"
|
||||
with self.lock:
|
||||
return self.mock_nodes[node_id].state == "running"
|
||||
|
||||
def is_terminated(self, node_id):
|
||||
return self.mock_nodes[node_id].state in ["stopped", "terminated"]
|
||||
with self.lock:
|
||||
return self.mock_nodes[node_id].state in ["stopped", "terminated"]
|
||||
|
||||
def node_tags(self, node_id):
|
||||
return self.mock_nodes[node_id].tags
|
||||
with self.lock:
|
||||
return self.mock_nodes[node_id].tags
|
||||
|
||||
def internal_ip(self, node_id):
|
||||
return self.mock_nodes[node_id].internal_ip
|
||||
with self.lock:
|
||||
return self.mock_nodes[node_id].internal_ip
|
||||
|
||||
def external_ip(self, node_id):
|
||||
return self.mock_nodes[node_id].external_ip
|
||||
with self.lock:
|
||||
return self.mock_nodes[node_id].external_ip
|
||||
|
||||
def create_node(self, node_config, tags, count):
|
||||
self.ready_to_create.wait()
|
||||
def create_node(self, node_config, tags, count, _skip_wait=False):
|
||||
if not _skip_wait:
|
||||
self.ready_to_create.wait()
|
||||
if self.fail_creates:
|
||||
return
|
||||
with self.lock:
|
||||
@@ -201,7 +209,8 @@ class MockProvider(NodeProvider):
|
||||
self.next_id += 1
|
||||
|
||||
def set_node_tags(self, node_id, tags):
|
||||
self.mock_nodes[node_id].tags.update(tags)
|
||||
with self.lock:
|
||||
self.mock_nodes[node_id].tags.update(tags)
|
||||
|
||||
def terminate_node(self, node_id):
|
||||
with self.lock:
|
||||
@@ -535,7 +544,11 @@ class AutoscalingTest(unittest.TestCase):
|
||||
config["max_workers"] = 5
|
||||
config_path = self.write_config(config)
|
||||
self.provider = MockProvider()
|
||||
self.provider.create_node({}, {TAG_RAY_NODE_KIND: "worker"}, 10)
|
||||
self.provider.create_node({}, {
|
||||
TAG_RAY_NODE_KIND: "worker",
|
||||
TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE,
|
||||
TAG_RAY_USER_NODE_TYPE: NODE_TYPE_LEGACY_WORKER
|
||||
}, 10)
|
||||
runner = MockProcessRunner()
|
||||
runner.respond_to_call("json .Config.Env", ["[]" for i in range(10)])
|
||||
autoscaler = StandardAutoscaler(
|
||||
@@ -559,8 +572,14 @@ class AutoscalingTest(unittest.TestCase):
|
||||
config_path = self.write_config(SMALL_CLUSTER)
|
||||
self.provider = MockProvider()
|
||||
runner = MockProcessRunner()
|
||||
runner.respond_to_call("json .Config.Env", ["[]" for i in range(11)])
|
||||
runner.respond_to_call("json .Config.Env", ["[]" for i in range(12)])
|
||||
lm = LoadMetrics()
|
||||
self.provider.create_node({}, {
|
||||
TAG_RAY_NODE_KIND: NODE_KIND_HEAD,
|
||||
TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE,
|
||||
TAG_RAY_USER_NODE_TYPE: NODE_TYPE_LEGACY_HEAD
|
||||
}, 1)
|
||||
lm.update("172.0.0.0", {"CPU": 1}, {"CPU": 0}, {})
|
||||
autoscaler = StandardAutoscaler(
|
||||
config_path,
|
||||
lm,
|
||||
@@ -569,16 +588,16 @@ class AutoscalingTest(unittest.TestCase):
|
||||
max_failures=0,
|
||||
process_runner=runner,
|
||||
update_interval_s=0)
|
||||
self.waitForNodes(0)
|
||||
self.waitForNodes(0, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER})
|
||||
autoscaler.update()
|
||||
self.waitForNodes(2)
|
||||
self.waitForNodes(2, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER})
|
||||
|
||||
# Update the config to reduce the cluster size
|
||||
new_config = SMALL_CLUSTER.copy()
|
||||
new_config["max_workers"] = 1
|
||||
self.write_config(new_config)
|
||||
autoscaler.update()
|
||||
self.waitForNodes(1)
|
||||
self.waitForNodes(1, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER})
|
||||
|
||||
# Update the config to reduce the cluster size
|
||||
new_config["min_workers"] = 10
|
||||
@@ -587,12 +606,13 @@ class AutoscalingTest(unittest.TestCase):
|
||||
autoscaler.update()
|
||||
# Because one worker already started, the scheduler waits for its
|
||||
# resources to be updated before it launches the remaining min_workers.
|
||||
self.waitForNodes(1)
|
||||
self.waitForNodes(1, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER})
|
||||
worker_ip = self.provider.non_terminated_node_ips(
|
||||
tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER}, )[0]
|
||||
lm.update(worker_ip, {"CPU": 1}, {"CPU": 1}, {})
|
||||
autoscaler.update()
|
||||
self.waitForNodes(10)
|
||||
self.waitForNodes(
|
||||
10, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER})
|
||||
|
||||
def testInitialWorkers(self):
|
||||
"""initial_workers is deprecated, this tests that it is ignored."""
|
||||
@@ -653,11 +673,15 @@ class AutoscalingTest(unittest.TestCase):
|
||||
autoscaler.update()
|
||||
# 1 head node.
|
||||
self.waitForNodes(1)
|
||||
autoscaler.request_resources([{"CPU": 1}])
|
||||
autoscaler.load_metrics.set_resource_requests([{"CPU": 1}])
|
||||
autoscaler.update()
|
||||
# still 1 head node because request_resources fits in the headnode.
|
||||
self.waitForNodes(1)
|
||||
autoscaler.request_resources([{"CPU": 1}] + [{"CPU": 2}] * 9)
|
||||
autoscaler.load_metrics.set_resource_requests([{
|
||||
"CPU": 1
|
||||
}] + [{
|
||||
"CPU": 2
|
||||
}] * 9)
|
||||
autoscaler.update()
|
||||
self.waitForNodes(2) # Adds a single worker to get its resources.
|
||||
autoscaler.update()
|
||||
@@ -760,7 +784,11 @@ class AutoscalingTest(unittest.TestCase):
|
||||
config_path = self.write_config(config)
|
||||
|
||||
self.provider = MockProvider()
|
||||
self.provider.create_node({}, {TAG_RAY_NODE_KIND: "head"}, 1)
|
||||
self.provider.create_node({}, {
|
||||
TAG_RAY_NODE_KIND: "head",
|
||||
TAG_RAY_USER_NODE_TYPE: NODE_TYPE_LEGACY_HEAD,
|
||||
TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE
|
||||
}, 1)
|
||||
head_ip = self.provider.non_terminated_node_ips(
|
||||
tag_filters={TAG_RAY_NODE_KIND: "head"}, )[0]
|
||||
|
||||
@@ -809,7 +837,11 @@ class AutoscalingTest(unittest.TestCase):
|
||||
config_path = self.write_config(config)
|
||||
|
||||
self.provider = MockProvider()
|
||||
self.provider.create_node({}, {TAG_RAY_NODE_KIND: "head"}, 1)
|
||||
self.provider.create_node({}, {
|
||||
TAG_RAY_NODE_KIND: "head",
|
||||
TAG_RAY_USER_NODE_TYPE: NODE_TYPE_LEGACY_HEAD,
|
||||
TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE
|
||||
}, 1)
|
||||
head_ip = self.provider.non_terminated_node_ips(
|
||||
tag_filters={TAG_RAY_NODE_KIND: "head"}, )[0]
|
||||
|
||||
@@ -964,8 +996,14 @@ class AutoscalingTest(unittest.TestCase):
|
||||
config_path = self.write_config(SMALL_CLUSTER)
|
||||
self.provider = MockProvider()
|
||||
runner = MockProcessRunner()
|
||||
runner.respond_to_call("json .Config.Env", ["[]" for i in range(10)])
|
||||
runner.respond_to_call("json .Config.Env", ["[]" for i in range(11)])
|
||||
self.provider.create_node({}, {
|
||||
TAG_RAY_NODE_KIND: NODE_KIND_HEAD,
|
||||
TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE,
|
||||
TAG_RAY_USER_NODE_TYPE: NODE_TYPE_LEGACY_HEAD
|
||||
}, 1)
|
||||
lm = LoadMetrics()
|
||||
lm.update("172.0.0.0", {"CPU": 1}, {"CPU": 0}, {})
|
||||
autoscaler = StandardAutoscaler(
|
||||
config_path,
|
||||
lm,
|
||||
@@ -975,7 +1013,7 @@ class AutoscalingTest(unittest.TestCase):
|
||||
max_failures=0,
|
||||
update_interval_s=0)
|
||||
autoscaler.update()
|
||||
self.waitForNodes(2)
|
||||
self.waitForNodes(2, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER})
|
||||
|
||||
# Write a corrupted config
|
||||
self.write_config("asdf", call_prepare_config=False)
|
||||
@@ -983,7 +1021,10 @@ class AutoscalingTest(unittest.TestCase):
|
||||
autoscaler.update()
|
||||
time.sleep(0.1)
|
||||
assert autoscaler.pending_launches.value == 0
|
||||
assert len(self.provider.non_terminated_nodes({})) == 2
|
||||
assert len(
|
||||
self.provider.non_terminated_nodes({
|
||||
TAG_RAY_NODE_KIND: NODE_KIND_WORKER
|
||||
})) == 2
|
||||
|
||||
# New a good config again
|
||||
new_config = SMALL_CLUSTER.copy()
|
||||
@@ -996,7 +1037,8 @@ class AutoscalingTest(unittest.TestCase):
|
||||
# resources to be updated before it launches the remaining min_workers.
|
||||
lm.update(worker_ip, {"CPU": 1}, {"CPU": 1}, {})
|
||||
autoscaler.update()
|
||||
self.waitForNodes(10)
|
||||
self.waitForNodes(
|
||||
10, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER})
|
||||
|
||||
def testMaxFailures(self):
|
||||
config_path = self.write_config(SMALL_CLUSTER)
|
||||
@@ -1076,9 +1118,17 @@ class AutoscalingTest(unittest.TestCase):
|
||||
2, tag_filters={TAG_RAY_NODE_STATUS: STATUS_UPDATE_FAILED})
|
||||
except AssertionError:
|
||||
# The failed nodes might have been already terminated by autoscaler
|
||||
assert len(self.provider.non_terminated_nodes({})) == 0
|
||||
assert len(self.provider.non_terminated_nodes({})) < 2
|
||||
|
||||
def testConfiguresOutdatedNodes(self):
|
||||
from ray.autoscaler._private.cli_logger import cli_logger
|
||||
|
||||
def do_nothing(*args, **kwargs):
|
||||
pass
|
||||
|
||||
cli_logger._print = type(cli_logger._print)(do_nothing,
|
||||
type(cli_logger))
|
||||
|
||||
config_path = self.write_config(SMALL_CLUSTER)
|
||||
self.provider = MockProvider()
|
||||
runner = MockProcessRunner()
|
||||
@@ -1113,53 +1163,61 @@ class AutoscalingTest(unittest.TestCase):
|
||||
self.provider = MockProvider()
|
||||
lm = LoadMetrics()
|
||||
runner = MockProcessRunner()
|
||||
runner.respond_to_call("json .Config.Env", ["[]" for i in range(5)])
|
||||
runner.respond_to_call("json .Config.Env", ["[]" for i in range(6)])
|
||||
self.provider.create_node({}, {
|
||||
TAG_RAY_NODE_KIND: NODE_KIND_HEAD,
|
||||
TAG_RAY_NODE_STATUS: STATUS_UP_TO_DATE,
|
||||
TAG_RAY_USER_NODE_TYPE: NODE_TYPE_LEGACY_HEAD
|
||||
}, 1)
|
||||
lm.update("172.0.0.0", {"CPU": 1}, {"CPU": 0}, {})
|
||||
autoscaler = StandardAutoscaler(
|
||||
config_path,
|
||||
lm,
|
||||
max_failures=0,
|
||||
process_runner=runner,
|
||||
update_interval_s=0)
|
||||
assert len(self.provider.non_terminated_nodes({})) == 0
|
||||
assert len(
|
||||
self.provider.non_terminated_nodes({
|
||||
TAG_RAY_NODE_KIND: NODE_KIND_WORKER
|
||||
})) == 0
|
||||
autoscaler.update()
|
||||
self.waitForNodes(1)
|
||||
self.waitForNodes(1, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER})
|
||||
autoscaler.update()
|
||||
assert autoscaler.pending_launches.value == 0
|
||||
assert len(self.provider.non_terminated_nodes({})) == 1
|
||||
assert len(
|
||||
self.provider.non_terminated_nodes({
|
||||
TAG_RAY_NODE_KIND: NODE_KIND_WORKER
|
||||
})) == 1
|
||||
|
||||
# Scales up as nodes are reported as used
|
||||
local_ip = services.get_node_ip_address()
|
||||
lm.update(
|
||||
local_ip, {"CPU": 2}, {"CPU": 0}, {},
|
||||
waiting_bundles=2 * [{
|
||||
"CPU": 2
|
||||
}]) # head
|
||||
autoscaler.update()
|
||||
lm.update(
|
||||
"172.0.0.0", {"CPU": 2}, {"CPU": 0}, {},
|
||||
"172.0.0.1", {"CPU": 2}, {"CPU": 0}, {},
|
||||
waiting_bundles=2 * [{
|
||||
"CPU": 2
|
||||
}])
|
||||
autoscaler.update()
|
||||
self.waitForNodes(3)
|
||||
self.waitForNodes(3, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER})
|
||||
lm.update(
|
||||
"172.0.0.1", {"CPU": 2}, {"CPU": 0}, {},
|
||||
"172.0.0.2", {"CPU": 2}, {"CPU": 0}, {},
|
||||
waiting_bundles=3 * [{
|
||||
"CPU": 2
|
||||
}])
|
||||
autoscaler.update()
|
||||
self.waitForNodes(5)
|
||||
self.waitForNodes(5, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER})
|
||||
|
||||
# Holds steady when load is removed
|
||||
lm.update("172.0.0.0", {"CPU": 2}, {"CPU": 2}, {})
|
||||
lm.update("172.0.0.1", {"CPU": 2}, {"CPU": 2}, {})
|
||||
lm.update("172.0.0.2", {"CPU": 2}, {"CPU": 2}, {})
|
||||
autoscaler.update()
|
||||
assert autoscaler.pending_launches.value == 0
|
||||
assert len(self.provider.non_terminated_nodes({})) == 5
|
||||
assert len(
|
||||
self.provider.non_terminated_nodes({
|
||||
TAG_RAY_NODE_KIND: NODE_KIND_WORKER
|
||||
})) == 5
|
||||
|
||||
# Scales down as nodes become unused
|
||||
lm.last_used_time_by_ip["172.0.0.0"] = 0
|
||||
lm.last_used_time_by_ip["172.0.0.1"] = 0
|
||||
lm.last_used_time_by_ip["172.0.0.2"] = 0
|
||||
autoscaler.update()
|
||||
|
||||
assert autoscaler.pending_launches.value == 0
|
||||
@@ -1167,18 +1225,21 @@ class AutoscalingTest(unittest.TestCase):
|
||||
# are not connected and hence we rely more on connected nodes for
|
||||
# min_workers. When the "pending" nodes show up as connected,
|
||||
# then we can terminate the ones connected before.
|
||||
assert len(self.provider.non_terminated_nodes({})) == 4
|
||||
lm.last_used_time_by_ip["172.0.0.2"] = 0
|
||||
assert len(
|
||||
self.provider.non_terminated_nodes({
|
||||
TAG_RAY_NODE_KIND: NODE_KIND_WORKER
|
||||
})) == 4
|
||||
lm.last_used_time_by_ip["172.0.0.3"] = 0
|
||||
lm.last_used_time_by_ip["172.0.0.4"] = 0
|
||||
autoscaler.update()
|
||||
assert autoscaler.pending_launches.value == 0
|
||||
# 2 nodes and not 1 because 1 is needed for min_worker and the other 1
|
||||
# is still not connected.
|
||||
self.waitForNodes(2)
|
||||
self.waitForNodes(2, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER})
|
||||
# when we connect it, we will see 1 node.
|
||||
lm.last_used_time_by_ip["172.0.0.4"] = 0
|
||||
lm.last_used_time_by_ip["172.0.0.5"] = 0
|
||||
autoscaler.update()
|
||||
self.waitForNodes(1)
|
||||
self.waitForNodes(1, tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_WORKER})
|
||||
|
||||
def testTargetUtilizationFraction(self):
|
||||
config = SMALL_CLUSTER.copy()
|
||||
|
||||
@@ -8,14 +8,20 @@ import time
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import ray
|
||||
import ray.cluster_utils
|
||||
import ray.test_utils
|
||||
from ray.test_utils import (
|
||||
client_test_enabled,
|
||||
dicts_equal,
|
||||
wait_for_pid_to_exit,
|
||||
)
|
||||
|
||||
import ray
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# https://github.com/ray-project/ray/issues/6662
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="interferes with grpc")
|
||||
def test_ignore_http_proxy(shutdown_only):
|
||||
ray.init(num_cpus=1)
|
||||
os.environ["http_proxy"] = "http://example.com"
|
||||
@@ -29,6 +35,7 @@ def test_ignore_http_proxy(shutdown_only):
|
||||
|
||||
|
||||
# https://github.com/ray-project/ray/issues/7263
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="message size")
|
||||
def test_grpc_message_size(shutdown_only):
|
||||
ray.init(num_cpus=1)
|
||||
|
||||
@@ -178,7 +185,7 @@ def test_many_fractional_resources(shutdown_only):
|
||||
}
|
||||
if block:
|
||||
ray.get(g.remote())
|
||||
return ray.test_utils.dicts_equal(true_resources, accepted_resources)
|
||||
return dicts_equal(true_resources, accepted_resources)
|
||||
|
||||
# Check that the resource are assigned correctly.
|
||||
result_ids = []
|
||||
@@ -257,7 +264,7 @@ def test_background_tasks_with_max_calls(shutdown_only):
|
||||
pid, g_id = nested.pop(0)
|
||||
ray.get(g_id)
|
||||
del g_id
|
||||
ray.test_utils.wait_for_pid_to_exit(pid)
|
||||
wait_for_pid_to_exit(pid)
|
||||
|
||||
|
||||
def test_fair_queueing(shutdown_only):
|
||||
@@ -327,6 +334,7 @@ def test_wait_timing(shutdown_only):
|
||||
assert len(not_ready) == 1
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="internal _raylet")
|
||||
def test_function_descriptor():
|
||||
python_descriptor = ray._raylet.PythonFunctionDescriptor(
|
||||
"module_name", "function_name", "class_name", "function_hash")
|
||||
@@ -345,6 +353,8 @@ def test_function_descriptor():
|
||||
|
||||
|
||||
def test_ray_options(shutdown_only):
|
||||
ray.init(num_cpus=10, num_gpus=10, resources={"custom1": 2})
|
||||
|
||||
@ray.remote(
|
||||
num_cpus=2, num_gpus=3, memory=150 * 2**20, resources={"custom1": 1})
|
||||
def foo():
|
||||
@@ -353,8 +363,6 @@ def test_ray_options(shutdown_only):
|
||||
time.sleep(0.1)
|
||||
return ray.available_resources()
|
||||
|
||||
ray.init(num_cpus=10, num_gpus=10, resources={"custom1": 2})
|
||||
|
||||
without_options = ray.get(foo.remote())
|
||||
with_options = ray.get(
|
||||
foo.options(
|
||||
@@ -371,6 +379,43 @@ def test_ray_options(shutdown_only):
|
||||
assert without_options != with_options
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="message size")
|
||||
@pytest.mark.parametrize(
|
||||
"ray_start_cluster_head", [{
|
||||
"num_cpus": 0,
|
||||
"object_store_memory": 75 * 1024 * 1024,
|
||||
}],
|
||||
indirect=True)
|
||||
def test_fetch_local(ray_start_cluster_head):
|
||||
cluster = ray_start_cluster_head
|
||||
cluster.add_node(num_cpus=2, object_store_memory=75 * 1024 * 1024)
|
||||
|
||||
signal_actor = ray.test_utils.SignalActor.remote()
|
||||
|
||||
@ray.remote
|
||||
def put():
|
||||
ray.wait([signal_actor.wait.remote()])
|
||||
return np.random.rand(5 * 1024 * 1024) # 40 MB data
|
||||
|
||||
local_ref = ray.put(np.random.rand(5 * 1024 * 1024))
|
||||
remote_ref = put.remote()
|
||||
# Data is not ready in any node
|
||||
(ready_ref, remaining_ref) = ray.wait(
|
||||
[remote_ref], timeout=2, fetch_local=False)
|
||||
assert (0, 1) == (len(ready_ref), len(remaining_ref))
|
||||
ray.wait([signal_actor.send.remote()])
|
||||
|
||||
# Data is ready in some node, but not local node.
|
||||
(ready_ref, remaining_ref) = ray.wait([remote_ref], fetch_local=False)
|
||||
assert (1, 0) == (len(ready_ref), len(remaining_ref))
|
||||
(ready_ref, remaining_ref) = ray.wait(
|
||||
[remote_ref], timeout=2, fetch_local=True)
|
||||
assert (0, 1) == (len(ready_ref), len(remaining_ref))
|
||||
del local_ref
|
||||
(ready_ref, remaining_ref) = ray.wait([remote_ref], fetch_local=True)
|
||||
assert (1, 0) == (len(ready_ref), len(remaining_ref))
|
||||
|
||||
|
||||
def test_nested_functions(ray_start_shared_local_modes):
|
||||
# Make sure that remote functions can use other values that are defined
|
||||
# after the remote function but before the first function invocation.
|
||||
@@ -402,8 +447,11 @@ def test_nested_functions(ray_start_shared_local_modes):
|
||||
assert ray.get(factorial.remote(4)) == 24
|
||||
assert ray.get(factorial.remote(5)) == 120
|
||||
|
||||
# Test remote functions that recursively call each other.
|
||||
|
||||
@pytest.mark.skipif(
|
||||
client_test_enabled(), reason="mutual recursion is a known issue")
|
||||
def test_mutually_recursive_functions(ray_start_shared_local_modes):
|
||||
# Test remote functions that recursively call each other.
|
||||
@ray.remote
|
||||
def factorial_even(n):
|
||||
assert n % 2 == 0
|
||||
@@ -674,6 +722,7 @@ def test_args_stars_after(ray_start_shared_local_modes):
|
||||
ray.get(remote_test_function.remote(local_method, actor_method))
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
|
||||
def test_object_id_backward_compatibility(ray_start_shared_local_modes):
|
||||
# We've renamed Python's `ObjectID` to `ObjectRef`, and added a type
|
||||
# alias for backward compatibility.
|
||||
|
||||
@@ -9,10 +9,16 @@ import pytest
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import ray
|
||||
import ray.cluster_utils
|
||||
import ray.test_utils
|
||||
from ray.test_utils import client_test_enabled
|
||||
from ray.tests.client_test_utils import create_remote_signal_actor
|
||||
from ray.exceptions import GetTimeoutError
|
||||
from ray.exceptions import RayTaskError
|
||||
|
||||
if client_test_enabled():
|
||||
from ray.experimental.client import ray
|
||||
else:
|
||||
import ray
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -25,6 +31,8 @@ logger = logging.getLogger(__name__)
|
||||
}],
|
||||
indirect=True)
|
||||
def test_variable_number_of_args(shutdown_only):
|
||||
ray.init(num_cpus=1)
|
||||
|
||||
@ray.remote
|
||||
def varargs_fct1(*a):
|
||||
return " ".join(map(str, a))
|
||||
@@ -33,8 +41,6 @@ def test_variable_number_of_args(shutdown_only):
|
||||
def varargs_fct2(a, *b):
|
||||
return " ".join(map(str, b))
|
||||
|
||||
ray.init(num_cpus=1)
|
||||
|
||||
x = varargs_fct1.remote(0, 1, 2)
|
||||
assert ray.get(x) == "0 1 2"
|
||||
x = varargs_fct2.remote(0, 1, 2)
|
||||
@@ -160,7 +166,7 @@ def test_redefining_remote_functions(shutdown_only):
|
||||
def g():
|
||||
return nonexistent()
|
||||
|
||||
with pytest.raises(ray.exceptions.RayTaskError, match="nonexistent"):
|
||||
with pytest.raises(RayTaskError, match="nonexistent"):
|
||||
ray.get(g.remote())
|
||||
|
||||
def nonexistent():
|
||||
@@ -187,6 +193,7 @@ def test_redefining_remote_functions(shutdown_only):
|
||||
assert ray.get(ray.get(h.remote(i))) == i
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="message size")
|
||||
def test_call_matrix(shutdown_only):
|
||||
ray.init(object_store_memory=1000 * 1024 * 1024)
|
||||
|
||||
@@ -312,6 +319,7 @@ def test_actor_pass_by_ref_order_optimization(shutdown_only):
|
||||
assert delta < 10, "did not skip slow value"
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="message size")
|
||||
@pytest.mark.parametrize(
|
||||
"ray_start_cluster", [{
|
||||
"num_cpus": 1,
|
||||
@@ -332,6 +340,7 @@ def test_call_chain(ray_start_cluster):
|
||||
assert ray.get(x) == 100
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="message size")
|
||||
def test_system_config_when_connecting(ray_start_cluster):
|
||||
config = {"object_pinning_enabled": 0, "object_timeout_milliseconds": 200}
|
||||
cluster = ray.cluster_utils.Cluster()
|
||||
@@ -368,7 +377,8 @@ def test_get_multiple(ray_start_regular_shared):
|
||||
|
||||
|
||||
def test_get_with_timeout(ray_start_regular_shared):
|
||||
signal = ray.test_utils.SignalActor.remote()
|
||||
SignalActor = create_remote_signal_actor(ray)
|
||||
signal = SignalActor.remote()
|
||||
|
||||
# Check that get() returns early if object is ready.
|
||||
start = time.time()
|
||||
@@ -438,6 +448,7 @@ def test_inline_arg_memory_corruption(ray_start_regular_shared):
|
||||
ray.get(a.add.remote(f.remote()))
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
|
||||
def test_skip_plasma(ray_start_regular_shared):
|
||||
@ray.remote
|
||||
class Actor:
|
||||
@@ -454,6 +465,8 @@ def test_skip_plasma(ray_start_regular_shared):
|
||||
assert ray.get(obj_ref) == 2
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
client_test_enabled(), reason="internal api and message size")
|
||||
def test_actor_large_objects(ray_start_regular_shared):
|
||||
@ray.remote
|
||||
class Actor:
|
||||
@@ -626,6 +639,7 @@ def test_duplicate_args(ray_start_regular_shared):
|
||||
arg1, arg2, arg1, kwarg1=arg1, kwarg2=arg2, kwarg1_duplicate=arg1))
|
||||
|
||||
|
||||
@pytest.mark.skipif(client_test_enabled(), reason="internal api")
|
||||
def test_get_correct_node_ip():
|
||||
with patch("ray.worker") as worker_mock:
|
||||
node_mock = MagicMock()
|
||||
|
||||
@@ -1,21 +1,10 @@
|
||||
import pytest
|
||||
from contextlib import contextmanager
|
||||
import time
|
||||
import sys
|
||||
import logging
|
||||
|
||||
import ray.experimental.client.server.server as ray_client_server
|
||||
from ray.experimental.client import ray, reset_api
|
||||
from ray.experimental.client.common import ClientObjectRef
|
||||
|
||||
|
||||
@contextmanager
|
||||
def ray_start_client_server():
|
||||
server = ray_client_server.serve("localhost:50051", test_mode=True)
|
||||
ray.connect("localhost:50051")
|
||||
try:
|
||||
yield ray
|
||||
finally:
|
||||
ray.disconnect()
|
||||
server.stop(0)
|
||||
reset_api()
|
||||
from ray.experimental.client.ray_client_helpers import ray_start_client_server
|
||||
|
||||
|
||||
def test_real_ray_fallback(ray_start_regular_shared):
|
||||
@@ -81,11 +70,11 @@ def test_wait(ray_start_regular_shared):
|
||||
with pytest.raises(Exception):
|
||||
# Reference not in the object store.
|
||||
ray.wait([ClientObjectRef("blabla")])
|
||||
with pytest.raises(AssertionError):
|
||||
with pytest.raises(TypeError):
|
||||
ray.wait("blabla")
|
||||
with pytest.raises(AssertionError):
|
||||
with pytest.raises(TypeError):
|
||||
ray.wait(ClientObjectRef("blabla"))
|
||||
with pytest.raises(AssertionError):
|
||||
with pytest.raises(TypeError):
|
||||
ray.wait(["blabla"])
|
||||
|
||||
|
||||
@@ -142,7 +131,7 @@ def test_function_calling_function(ray_start_regular_shared):
|
||||
|
||||
@ray.remote
|
||||
def f():
|
||||
print(f, f._name, g._name, g)
|
||||
print(f, g)
|
||||
return ray.get(g.remote())
|
||||
|
||||
print(f, type(f))
|
||||
@@ -171,8 +160,7 @@ def test_basic_actor(ray_start_regular_shared):
|
||||
|
||||
|
||||
def test_pass_handles(ray_start_regular_shared):
|
||||
"""
|
||||
Test that passing client handles to actors and functions to remote actors
|
||||
"""Test that passing client handles to actors and functions to remote actors
|
||||
in functions (on the server or raylet side) works transparently to the
|
||||
caller.
|
||||
"""
|
||||
@@ -234,6 +222,98 @@ def test_pass_handles(ray_start_regular_shared):
|
||||
4)) == local_fact(4)
|
||||
|
||||
|
||||
def test_basic_log_stream(ray_start_regular_shared):
|
||||
with ray_start_client_server() as ray:
|
||||
log_msgs = []
|
||||
|
||||
def test_log(level, msg):
|
||||
log_msgs.append(msg)
|
||||
|
||||
ray.worker.log_client.log = test_log
|
||||
ray.worker.log_client.set_logstream_level(logging.DEBUG)
|
||||
# Allow some time to propogate
|
||||
time.sleep(1)
|
||||
x = ray.put("Foo")
|
||||
assert ray.get(x) == "Foo"
|
||||
time.sleep(1)
|
||||
logs_with_id = [msg for msg in log_msgs if msg.find(x.id.hex()) >= 0]
|
||||
assert len(logs_with_id) >= 2
|
||||
assert any((msg.find("get") >= 0 for msg in logs_with_id))
|
||||
assert any((msg.find("put") >= 0 for msg in logs_with_id))
|
||||
|
||||
|
||||
def test_stdout_log_stream(ray_start_regular_shared):
|
||||
with ray_start_client_server() as ray:
|
||||
log_msgs = []
|
||||
|
||||
def test_log(level, msg):
|
||||
log_msgs.append(msg)
|
||||
|
||||
ray.worker.log_client.stdstream = test_log
|
||||
|
||||
@ray.remote
|
||||
def print_on_stderr_and_stdout(s):
|
||||
print(s)
|
||||
print(s, file=sys.stderr)
|
||||
|
||||
time.sleep(1)
|
||||
print_on_stderr_and_stdout.remote("Hello world")
|
||||
time.sleep(1)
|
||||
assert len(log_msgs) == 2
|
||||
assert all((msg.find("Hello world") for msg in log_msgs))
|
||||
|
||||
|
||||
def test_create_remote_before_start(ray_start_regular_shared):
|
||||
"""Creates remote objects (as though in a library) before
|
||||
starting the client.
|
||||
"""
|
||||
from ray.experimental.client import ray
|
||||
|
||||
@ray.remote
|
||||
class Returner:
|
||||
def doit(self):
|
||||
return "foo"
|
||||
|
||||
@ray.remote
|
||||
def f(x):
|
||||
return x + 20
|
||||
|
||||
# Prints in verbose tests
|
||||
print("Created remote functions")
|
||||
|
||||
with ray_start_client_server() as ray:
|
||||
assert ray.get(f.remote(3)) == 23
|
||||
a = Returner.remote()
|
||||
assert ray.get(a.doit.remote()) == "foo"
|
||||
|
||||
|
||||
def test_basic_named_actor(ray_start_regular_shared):
|
||||
"""Test that ray.get_actor() can create and return a detached actor.
|
||||
"""
|
||||
with ray_start_client_server() as ray:
|
||||
|
||||
@ray.remote
|
||||
class Accumulator:
|
||||
def __init__(self):
|
||||
self.x = 0
|
||||
|
||||
def inc(self):
|
||||
self.x += 1
|
||||
|
||||
def get(self):
|
||||
return self.x
|
||||
|
||||
# Create the actor
|
||||
actor = Accumulator.options(name="test_acc").remote()
|
||||
|
||||
actor.inc.remote()
|
||||
actor.inc.remote()
|
||||
del actor
|
||||
|
||||
new_actor = ray.get_actor("test_acc")
|
||||
new_actor.inc.remote()
|
||||
assert ray.get(new_actor.get.remote()) == 3
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
from ray.tests.test_experimental_client import ray_start_client_server
|
||||
from ray.experimental.client.ray_client_helpers import ray_start_client_server
|
||||
|
||||
|
||||
def test_get_ray_metadata(ray_start_regular_shared):
|
||||
"""
|
||||
Test the ClusterInfo client data pathway and API surface
|
||||
"""Test the ClusterInfo client data pathway and API surface
|
||||
"""
|
||||
with ray_start_client_server() as ray:
|
||||
ip_address = ray_start_regular_shared["node_ip_address"]
|
||||
|
||||
@@ -0,0 +1,152 @@
|
||||
from ray.experimental.client.ray_client_helpers import ray_start_client_server
|
||||
from ray.test_utils import wait_for_condition
|
||||
import ray as real_ray
|
||||
from ray.core.generated.gcs_pb2 import ActorTableData
|
||||
from ray.experimental.client.server.server import _get_current_servicer
|
||||
|
||||
|
||||
def server_object_ref_count(n):
|
||||
server = _get_current_servicer()
|
||||
assert server is not None
|
||||
|
||||
def test_cond():
|
||||
if len(server.object_refs) == 0:
|
||||
# No open clients
|
||||
return n == 0
|
||||
client_id = list(server.object_refs.keys())[0]
|
||||
return len(server.object_refs[client_id]) == n
|
||||
|
||||
return test_cond
|
||||
|
||||
|
||||
def server_actor_ref_count(n):
|
||||
server = _get_current_servicer()
|
||||
assert server is not None
|
||||
|
||||
def test_cond():
|
||||
if len(server.actor_refs) == 0:
|
||||
# No running actors
|
||||
return n == 0
|
||||
return len(server.actor_refs) == n
|
||||
|
||||
return test_cond
|
||||
|
||||
|
||||
def test_delete_refs_on_disconnect(ray_start_regular):
|
||||
with ray_start_client_server() as ray:
|
||||
|
||||
@ray.remote
|
||||
def f(x):
|
||||
return x + 2
|
||||
|
||||
thing1 = f.remote(6) # noqa
|
||||
thing2 = ray.put("Hello World") # noqa
|
||||
|
||||
# One put, one function -- the function result thing1 is
|
||||
# in a different category, according to the raylet.
|
||||
assert len(real_ray.objects()) == 2
|
||||
# But we're maintaining the reference
|
||||
assert server_object_ref_count(3)()
|
||||
# And can get the data
|
||||
assert ray.get(thing1) == 8
|
||||
|
||||
# Close the client
|
||||
ray.close()
|
||||
|
||||
wait_for_condition(server_object_ref_count(0), timeout=5)
|
||||
|
||||
def test_cond():
|
||||
return len(real_ray.objects()) == 0
|
||||
|
||||
wait_for_condition(test_cond, timeout=5)
|
||||
|
||||
|
||||
def test_delete_ref_on_object_deletion(ray_start_regular):
|
||||
with ray_start_client_server() as ray:
|
||||
vals = {
|
||||
"ref": ray.put("Hello World"),
|
||||
"ref2": ray.put("This value stays"),
|
||||
}
|
||||
|
||||
del vals["ref"]
|
||||
|
||||
wait_for_condition(server_object_ref_count(1), timeout=5)
|
||||
|
||||
|
||||
def test_delete_actor_on_disconnect(ray_start_regular):
|
||||
with ray_start_client_server() as ray:
|
||||
|
||||
@ray.remote
|
||||
class Accumulator:
|
||||
def __init__(self):
|
||||
self.acc = 0
|
||||
|
||||
def inc(self):
|
||||
self.acc += 1
|
||||
|
||||
def get(self):
|
||||
return self.acc
|
||||
|
||||
actor = Accumulator.remote()
|
||||
actor.inc.remote()
|
||||
|
||||
assert server_actor_ref_count(1)()
|
||||
|
||||
assert ray.get(actor.get.remote()) == 1
|
||||
|
||||
ray.close()
|
||||
|
||||
wait_for_condition(server_actor_ref_count(0), timeout=5)
|
||||
|
||||
def test_cond():
|
||||
alive_actors = [
|
||||
v for v in real_ray.actors().values()
|
||||
if v["State"] != ActorTableData.DEAD
|
||||
]
|
||||
return len(alive_actors) == 0
|
||||
|
||||
wait_for_condition(test_cond, timeout=10)
|
||||
|
||||
|
||||
def test_delete_actor(ray_start_regular):
|
||||
with ray_start_client_server() as ray:
|
||||
|
||||
@ray.remote
|
||||
class Accumulator:
|
||||
def __init__(self):
|
||||
self.acc = 0
|
||||
|
||||
def inc(self):
|
||||
self.acc += 1
|
||||
|
||||
actor = Accumulator.remote()
|
||||
actor.inc.remote()
|
||||
actor2 = Accumulator.remote()
|
||||
actor2.inc.remote()
|
||||
|
||||
assert server_actor_ref_count(2)()
|
||||
|
||||
del actor
|
||||
|
||||
wait_for_condition(server_actor_ref_count(1), timeout=5)
|
||||
|
||||
|
||||
def test_simple_multiple_references(ray_start_regular):
|
||||
with ray_start_client_server() as ray:
|
||||
|
||||
@ray.remote
|
||||
class A:
|
||||
def __init__(self):
|
||||
self.x = ray.put("hi")
|
||||
|
||||
def get(self):
|
||||
return [self.x]
|
||||
|
||||
a = A.remote()
|
||||
ref1 = ray.get(a.get.remote())[0]
|
||||
ref2 = ray.get(a.get.remote())[0]
|
||||
del a
|
||||
assert ray.get(ref1) == "hi"
|
||||
del ref1
|
||||
assert ray.get(ref2) == "hi"
|
||||
del ref2
|
||||
@@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
import asyncio
|
||||
from ray.tests.test_experimental_client import ray_start_client_server
|
||||
from ray.experimental.client.ray_client_helpers import ray_start_client_server
|
||||
from ray.tests.client_test_utils import create_remote_signal_actor
|
||||
from ray.test_utils import wait_for_condition
|
||||
from ray.exceptions import TaskCancelledError
|
||||
from ray.exceptions import RayTaskError
|
||||
@@ -45,21 +45,7 @@ def test_kill_actor_immediately_after_creation(ray_start_regular):
|
||||
@pytest.mark.parametrize("use_force", [True, False])
|
||||
def test_cancel_chain(ray_start_regular, use_force):
|
||||
with ray_start_client_server() as ray:
|
||||
|
||||
@ray.remote
|
||||
class SignalActor:
|
||||
def __init__(self):
|
||||
self.ready_event = asyncio.Event()
|
||||
|
||||
def send(self, clear=False):
|
||||
self.ready_event.set()
|
||||
if clear:
|
||||
self.ready_event.clear()
|
||||
|
||||
async def wait(self, should_wait=True):
|
||||
if should_wait:
|
||||
await self.ready_event.wait()
|
||||
|
||||
SignalActor = create_remote_signal_actor(ray)
|
||||
signaler = SignalActor.remote()
|
||||
|
||||
@ray.remote
|
||||
|
||||
@@ -16,14 +16,8 @@ import ray.utils
|
||||
import ray.ray_constants as ray_constants
|
||||
from ray.exceptions import RayTaskError
|
||||
from ray.cluster_utils import Cluster
|
||||
from ray.test_utils import (
|
||||
wait_for_condition,
|
||||
SignalActor,
|
||||
init_error_pubsub,
|
||||
get_error_message,
|
||||
Semaphore,
|
||||
new_scheduler_enabled,
|
||||
)
|
||||
from ray.test_utils import (wait_for_condition, SignalActor, init_error_pubsub,
|
||||
get_error_message, Semaphore)
|
||||
|
||||
|
||||
def test_failed_task(ray_start_regular, error_pubsub):
|
||||
@@ -638,11 +632,10 @@ def test_export_large_objects(ray_start_regular, error_pubsub):
|
||||
assert errors[0].type == ray_constants.PICKLING_LARGE_OBJECT_PUSH_ERROR
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="TODO detect resource deadlock")
|
||||
def test_warning_for_resource_deadlock(error_pubsub, shutdown_only):
|
||||
p = error_pubsub
|
||||
# Check that we get warning messages for infeasible tasks.
|
||||
ray.init(num_cpus=1)
|
||||
def test_warning_all_tasks_blocked(shutdown_only):
|
||||
ray.init(
|
||||
num_cpus=1, _system_config={"debug_dump_period_milliseconds": 500})
|
||||
p = init_error_pubsub()
|
||||
|
||||
@ray.remote(num_cpus=1)
|
||||
class Foo:
|
||||
@@ -652,7 +645,7 @@ def test_warning_for_resource_deadlock(error_pubsub, shutdown_only):
|
||||
@ray.remote
|
||||
def f():
|
||||
# Creating both actors is not possible.
|
||||
actors = [Foo.remote() for _ in range(2)]
|
||||
actors = [Foo.remote() for _ in range(3)]
|
||||
for a in actors:
|
||||
ray.get(a.f.remote())
|
||||
|
||||
@@ -663,7 +656,46 @@ def test_warning_for_resource_deadlock(error_pubsub, shutdown_only):
|
||||
assert errors[0].type == ray_constants.RESOURCE_DEADLOCK_ERROR
|
||||
|
||||
|
||||
@pytest.mark.skipif(new_scheduler_enabled(), reason="broken")
|
||||
def test_warning_actor_waiting_on_actor(shutdown_only):
|
||||
ray.init(
|
||||
num_cpus=1, _system_config={"debug_dump_period_milliseconds": 500})
|
||||
p = init_error_pubsub()
|
||||
|
||||
@ray.remote(num_cpus=1)
|
||||
class Actor:
|
||||
pass
|
||||
|
||||
a = Actor.remote() # noqa
|
||||
b = Actor.remote() # noqa
|
||||
|
||||
errors = get_error_message(p, 1, ray_constants.RESOURCE_DEADLOCK_ERROR)
|
||||
assert len(errors) == 1
|
||||
assert errors[0].type == ray_constants.RESOURCE_DEADLOCK_ERROR
|
||||
|
||||
|
||||
def test_warning_task_waiting_on_actor(shutdown_only):
|
||||
ray.init(
|
||||
num_cpus=1, _system_config={"debug_dump_period_milliseconds": 500})
|
||||
p = init_error_pubsub()
|
||||
|
||||
@ray.remote(num_cpus=1)
|
||||
class Actor:
|
||||
pass
|
||||
|
||||
a = Actor.remote() # noqa
|
||||
|
||||
@ray.remote(num_cpus=1)
|
||||
def f():
|
||||
print("f running")
|
||||
time.sleep(999)
|
||||
|
||||
ids = [f.remote()] # noqa
|
||||
|
||||
errors = get_error_message(p, 1, ray_constants.RESOURCE_DEADLOCK_ERROR)
|
||||
assert len(errors) == 1
|
||||
assert errors[0].type == ray_constants.RESOURCE_DEADLOCK_ERROR
|
||||
|
||||
|
||||
def test_warning_for_infeasible_tasks(ray_start_regular, error_pubsub):
|
||||
p = error_pubsub
|
||||
# Check that we get warning messages for infeasible tasks.
|
||||
@@ -689,7 +721,6 @@ def test_warning_for_infeasible_tasks(ray_start_regular, error_pubsub):
|
||||
assert errors[0].type == ray_constants.INFEASIBLE_TASK_ERROR
|
||||
|
||||
|
||||
@pytest.mark.skipif(new_scheduler_enabled(), reason="broken")
|
||||
def test_warning_for_infeasible_zero_cpu_actor(shutdown_only):
|
||||
# Check that we cannot place an actor on a 0 CPU machine and that we get an
|
||||
# infeasibility warning (even though the actor creation task itself
|
||||
@@ -956,7 +987,6 @@ def test_raylet_crash_when_get(ray_start_regular):
|
||||
thread.join()
|
||||
|
||||
|
||||
@pytest.mark.skipif(new_scheduler_enabled(), reason="broken")
|
||||
def test_connect_with_disconnected_node(shutdown_only):
|
||||
config = {
|
||||
"num_heartbeats_timeout": 50,
|
||||
|
||||
@@ -6,7 +6,6 @@ from ray.test_utils import (
|
||||
generate_system_config_map,
|
||||
wait_for_condition,
|
||||
wait_for_pid_to_exit,
|
||||
new_scheduler_enabled,
|
||||
)
|
||||
|
||||
|
||||
@@ -21,7 +20,6 @@ def increase(x):
|
||||
return x + 1
|
||||
|
||||
|
||||
@pytest.mark.skipif(new_scheduler_enabled(), reason="notimpl")
|
||||
@pytest.mark.parametrize(
|
||||
"ray_start_regular", [
|
||||
generate_system_config_map(
|
||||
|
||||
@@ -9,7 +9,7 @@ import pytest
|
||||
|
||||
import ray
|
||||
import ray.cluster_utils
|
||||
from ray.test_utils import wait_for_condition, new_scheduler_enabled
|
||||
from ray.test_utils import wait_for_condition
|
||||
from ray.internal.internal_api import global_gc
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -166,9 +166,9 @@ def test_global_gc_when_full(shutdown_only):
|
||||
gc.enable()
|
||||
|
||||
|
||||
@pytest.mark.skipif(new_scheduler_enabled(), reason="hangs")
|
||||
def test_global_gc_actors(shutdown_only):
|
||||
ray.init(num_cpus=1)
|
||||
ray.init(
|
||||
num_cpus=1, _system_config={"debug_dump_period_milliseconds": 500})
|
||||
|
||||
try:
|
||||
gc.disable()
|
||||
@@ -179,8 +179,7 @@ def test_global_gc_actors(shutdown_only):
|
||||
return "Ok"
|
||||
|
||||
# Try creating 3 actors. Unless python GC is triggered to break
|
||||
# reference cycles, this won't be possible. Note this test takes 20s
|
||||
# to run due to the 10s delay before checking of infeasible tasks.
|
||||
# reference cycles, this won't be possible.
|
||||
for i in range(3):
|
||||
a = A.remote()
|
||||
cycle = [a]
|
||||
|
||||
@@ -8,7 +8,6 @@ import time
|
||||
import ray
|
||||
import ray.ray_constants
|
||||
import ray.test_utils
|
||||
from ray.test_utils import new_scheduler_enabled
|
||||
|
||||
from ray._raylet import GlobalStateAccessor
|
||||
|
||||
@@ -217,8 +216,6 @@ def test_load_report(shutdown_only, max_shapes):
|
||||
global_state_accessor.disconnect()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
new_scheduler_enabled(), reason="requires placement groups")
|
||||
def test_placement_group_load_report(ray_start_cluster):
|
||||
cluster = ray_start_cluster
|
||||
# Add a head node that doesn't have gpu resource.
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import joblib
|
||||
import sys
|
||||
import time
|
||||
import os
|
||||
|
||||
import pickle
|
||||
import numpy as np
|
||||
|
||||
from sklearn.datasets import load_digits, load_iris
|
||||
from sklearn.model_selection import RandomizedSearchCV
|
||||
from sklearn.datasets import fetch_openml
|
||||
from sklearn.ensemble import ExtraTreesClassifier
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from sklearn.kernel_approximation import Nystroem
|
||||
@@ -14,7 +15,6 @@ from sklearn.kernel_approximation import RBFSampler
|
||||
from sklearn.pipeline import make_pipeline
|
||||
from sklearn.svm import LinearSVC, SVC
|
||||
from sklearn.tree import DecisionTreeClassifier
|
||||
from sklearn.utils import check_array
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.neural_network import MLPClassifier
|
||||
from sklearn.model_selection import cross_val_score
|
||||
@@ -112,20 +112,14 @@ def test_sklearn_benchmarks(ray_start_cluster_2_nodes):
|
||||
}
|
||||
# Load dataset.
|
||||
print("Loading dataset...")
|
||||
data = fetch_openml("mnist_784")
|
||||
X = check_array(data["data"], dtype=np.float32, order="C")
|
||||
y = data["target"]
|
||||
|
||||
unnormalized_X_train, y_train = pickle.load(
|
||||
open(
|
||||
os.path.join(
|
||||
os.path.dirname(__file__), "mnist_784_100_samples.pkl"), "rb"))
|
||||
# Normalize features.
|
||||
X = X / 255
|
||||
X_train = unnormalized_X_train / 255
|
||||
|
||||
# Create train-test split.
|
||||
print("Creating train-test split...")
|
||||
n_train = 100
|
||||
X_train = X[:n_train]
|
||||
y_train = y[:n_train]
|
||||
register_ray()
|
||||
|
||||
train_time = {}
|
||||
random_seed = 0
|
||||
# Use two workers per classifier.
|
||||
|
||||
@@ -741,10 +741,10 @@ ray.get(main_wait.release.remote())
|
||||
driver1_out_split = driver1_out.split("\n")
|
||||
driver2_out_split = driver2_out.split("\n")
|
||||
|
||||
assert driver1_out_split[0][-1] == "1"
|
||||
assert driver1_out_split[1][-1] == "2"
|
||||
assert driver2_out_split[0][-1] == "3"
|
||||
assert driver2_out_split[1][-1] == "4"
|
||||
assert driver1_out_split[0][-1] == "1", driver1_out_split
|
||||
assert driver1_out_split[1][-1] == "2", driver1_out_split
|
||||
assert driver2_out_split[0][-1] == "3", driver2_out_split
|
||||
assert driver2_out_split[1][-1] == "4", driver2_out_split
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user